Skip to content

Commit ec3d582

Browse files
authored
[train_dreambooth_lora_flux_advanced] Add LANCZOS as the default interpolation mode for image resizing (#11472)
* [train_controlnet_sdxl] Add LANCZOS as the default interpolation mode for image resizing * [train_dreambooth_lora_flux_advanced] Add LANCZOS as the default interpolation mode for image resizing
1 parent ed6cf52 commit ec3d582

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,15 @@ def parse_args(input_args=None):
770770
),
771771
)
772772
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
773+
parser.add_argument(
774+
"--image_interpolation_mode",
775+
type=str,
776+
default="lanczos",
777+
choices=[
778+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
779+
],
780+
help="The image interpolation method to use for resizing images.",
781+
)
773782

774783
if input_args is not None:
775784
args = parser.parse_args(input_args)
@@ -1034,7 +1043,10 @@ def __init__(
10341043
self.instance_images.extend(itertools.repeat(img, repeats))
10351044

10361045
self.pixel_values = []
1037-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1046+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
1047+
if interpolation is None:
1048+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
1049+
train_resize = transforms.Resize(size, interpolation=interpolation)
10381050
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
10391051
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10401052
train_transforms = transforms.Compose(
@@ -1078,7 +1090,7 @@ def __init__(
10781090

10791091
self.image_transforms = transforms.Compose(
10801092
[
1081-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1093+
transforms.Resize(size, interpolation=interpolation),
10821094
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
10831095
transforms.ToTensor(),
10841096
transforms.Normalize([0.5], [0.5]),

0 commit comments

Comments
 (0)