Skip to content

Commit d5b3037

Browse files
committed
propagate changes.
1 parent ded2fd6 commit d5b3037

File tree

1 file changed

+74
-17
lines changed

1 file changed

+74
-17
lines changed

src/diffusers/loaders/lora_pipeline.py

+74-17
Original file line numberDiff line numberDiff line change
@@ -1667,9 +1667,10 @@ def save_lora_weights(
16671667
weight_name: str = None,
16681668
save_function: Callable = None,
16691669
safe_serialization: bool = True,
1670+
transformer_lora_adapter_metadata: Optional[dict] = None,
16701671
):
16711672
r"""
1672-
Save the LoRA parameters corresponding to the UNet and text encoder.
1673+
Save the LoRA parameters corresponding to the transformer.
16731674
16741675
Arguments:
16751676
save_directory (`str` or `os.PathLike`):
@@ -1686,15 +1687,20 @@ def save_lora_weights(
16861687
`DIFFUSERS_SAVE_MODE`.
16871688
safe_serialization (`bool`, *optional*, defaults to `True`):
16881689
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1690+
transformer_lora_adapter_metadata: TODO
16891691
"""
16901692
state_dict = {}
1693+
lora_adapter_metadata = {}
16911694

16921695
if not transformer_lora_layers:
16931696
raise ValueError("You must pass `transformer_lora_layers`.")
16941697

16951698
if transformer_lora_layers:
16961699
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
16971700

1701+
if transformer_lora_adapter_metadata:
1702+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
1703+
16981704
# Save the model
16991705
cls.write_lora_layers(
17001706
state_dict=state_dict,
@@ -1703,6 +1709,7 @@ def save_lora_weights(
17031709
weight_name=weight_name,
17041710
save_function=save_function,
17051711
safe_serialization=safe_serialization,
1712+
lora_adapter_metadata=lora_adapter_metadata,
17061713
)
17071714

17081715
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -2985,9 +2992,10 @@ def save_lora_weights(
29852992
weight_name: str = None,
29862993
save_function: Callable = None,
29872994
safe_serialization: bool = True,
2995+
transformer_lora_adapter_metadata: Optional[dict] = None,
29882996
):
29892997
r"""
2990-
Save the LoRA parameters corresponding to the UNet and text encoder.
2998+
Save the LoRA parameters corresponding to the transformer.
29912999
29923000
Arguments:
29933001
save_directory (`str` or `os.PathLike`):
@@ -3004,15 +3012,20 @@ def save_lora_weights(
30043012
`DIFFUSERS_SAVE_MODE`.
30053013
safe_serialization (`bool`, *optional*, defaults to `True`):
30063014
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3015+
transformer_lora_adapter_metadata: TODO
30073016
"""
30083017
state_dict = {}
3018+
lora_adapter_metadata = {}
30093019

30103020
if not transformer_lora_layers:
30113021
raise ValueError("You must pass `transformer_lora_layers`.")
30123022

30133023
if transformer_lora_layers:
30143024
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
30153025

3026+
if transformer_lora_adapter_metadata:
3027+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3028+
30163029
# Save the model
30173030
cls.write_lora_layers(
30183031
state_dict=state_dict,
@@ -3021,6 +3034,7 @@ def save_lora_weights(
30213034
weight_name=weight_name,
30223035
save_function=save_function,
30233036
safe_serialization=safe_serialization,
3037+
lora_adapter_metadata=lora_adapter_metadata,
30243038
)
30253039

30263040
def fuse_lora(
@@ -3302,9 +3316,10 @@ def save_lora_weights(
33023316
weight_name: str = None,
33033317
save_function: Callable = None,
33043318
safe_serialization: bool = True,
3319+
transformer_lora_adapter_metadata: Optional[dict] = None,
33053320
):
33063321
r"""
3307-
Save the LoRA parameters corresponding to the UNet and text encoder.
3322+
Save the LoRA parameters corresponding to the transformer.
33083323
33093324
Arguments:
33103325
save_directory (`str` or `os.PathLike`):
@@ -3321,15 +3336,20 @@ def save_lora_weights(
33213336
`DIFFUSERS_SAVE_MODE`.
33223337
safe_serialization (`bool`, *optional*, defaults to `True`):
33233338
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3339+
transformer_lora_adapter_metadata: TODO
33243340
"""
33253341
state_dict = {}
3342+
lora_adapter_metadata = {}
33263343

33273344
if not transformer_lora_layers:
33283345
raise ValueError("You must pass `transformer_lora_layers`.")
33293346

33303347
if transformer_lora_layers:
33313348
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
33323349

3350+
if transformer_lora_adapter_metadata:
3351+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3352+
33333353
# Save the model
33343354
cls.write_lora_layers(
33353355
state_dict=state_dict,
@@ -3338,6 +3358,7 @@ def save_lora_weights(
33383358
weight_name=weight_name,
33393359
save_function=save_function,
33403360
safe_serialization=safe_serialization,
3361+
lora_adapter_metadata=lora_adapter_metadata,
33413362
)
33423363

33433364
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3621,9 +3642,10 @@ def save_lora_weights(
36213642
weight_name: str = None,
36223643
save_function: Callable = None,
36233644
safe_serialization: bool = True,
3645+
transformer_lora_adapter_metadata: Optional[dict] = None,
36243646
):
36253647
r"""
3626-
Save the LoRA parameters corresponding to the UNet and text encoder.
3648+
Save the LoRA parameters corresponding to the transformer.
36273649
36283650
Arguments:
36293651
save_directory (`str` or `os.PathLike`):
@@ -3640,15 +3662,20 @@ def save_lora_weights(
36403662
`DIFFUSERS_SAVE_MODE`.
36413663
safe_serialization (`bool`, *optional*, defaults to `True`):
36423664
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3665+
transformer_lora_adapter_metadata: TODO
36433666
"""
36443667
state_dict = {}
3668+
lora_adapter_metadata = {}
36453669

36463670
if not transformer_lora_layers:
36473671
raise ValueError("You must pass `transformer_lora_layers`.")
36483672

36493673
if transformer_lora_layers:
36503674
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
36513675

3676+
if transformer_lora_adapter_metadata:
3677+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3678+
36523679
# Save the model
36533680
cls.write_lora_layers(
36543681
state_dict=state_dict,
@@ -3657,6 +3684,7 @@ def save_lora_weights(
36573684
weight_name=weight_name,
36583685
save_function=save_function,
36593686
safe_serialization=safe_serialization,
3687+
lora_adapter_metadata=lora_adapter_metadata,
36603688
)
36613689

36623690
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3940,9 +3968,10 @@ def save_lora_weights(
39403968
weight_name: str = None,
39413969
save_function: Callable = None,
39423970
safe_serialization: bool = True,
3971+
transformer_lora_adapter_metadata: Optional[dict] = None,
39433972
):
39443973
r"""
3945-
Save the LoRA parameters corresponding to the UNet and text encoder.
3974+
Save the LoRA parameters corresponding to the transformer.
39463975
39473976
Arguments:
39483977
save_directory (`str` or `os.PathLike`):
@@ -3959,15 +3988,20 @@ def save_lora_weights(
39593988
`DIFFUSERS_SAVE_MODE`.
39603989
safe_serialization (`bool`, *optional*, defaults to `True`):
39613990
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
3991+
transformer_lora_adapter_metadata: TODO
39623992
"""
39633993
state_dict = {}
3994+
lora_adapter_metadata = {}
39643995

39653996
if not transformer_lora_layers:
39663997
raise ValueError("You must pass `transformer_lora_layers`.")
39673998

39683999
if transformer_lora_layers:
39694000
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
39704001

4002+
if transformer_lora_adapter_metadata:
4003+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4004+
39714005
# Save the model
39724006
cls.write_lora_layers(
39734007
state_dict=state_dict,
@@ -3976,6 +4010,7 @@ def save_lora_weights(
39764010
weight_name=weight_name,
39774011
save_function=save_function,
39784012
safe_serialization=safe_serialization,
4013+
lora_adapter_metadata=lora_adapter_metadata,
39794014
)
39804015

39814016
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4262,9 +4297,10 @@ def save_lora_weights(
42624297
weight_name: str = None,
42634298
save_function: Callable = None,
42644299
safe_serialization: bool = True,
4300+
transformer_lora_adapter_metadata: Optional[dict] = None,
42654301
):
42664302
r"""
4267-
Save the LoRA parameters corresponding to the UNet and text encoder.
4303+
Save the LoRA parameters corresponding to the transformer.
42684304
42694305
Arguments:
42704306
save_directory (`str` or `os.PathLike`):
@@ -4281,15 +4317,20 @@ def save_lora_weights(
42814317
`DIFFUSERS_SAVE_MODE`.
42824318
safe_serialization (`bool`, *optional*, defaults to `True`):
42834319
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4320+
transformer_lora_adapter_metadata: TODO
42844321
"""
42854322
state_dict = {}
4323+
lora_adapter_metadata = {}
42864324

42874325
if not transformer_lora_layers:
42884326
raise ValueError("You must pass `transformer_lora_layers`.")
42894327

42904328
if transformer_lora_layers:
42914329
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
42924330

4331+
if transformer_lora_adapter_metadata:
4332+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4333+
42934334
# Save the model
42944335
cls.write_lora_layers(
42954336
state_dict=state_dict,
@@ -4298,6 +4339,7 @@ def save_lora_weights(
42984339
weight_name=weight_name,
42994340
save_function=save_function,
43004341
safe_serialization=safe_serialization,
4342+
lora_adapter_metadata=lora_adapter_metadata,
43014343
)
43024344

43034345
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4585,9 +4627,10 @@ def save_lora_weights(
45854627
weight_name: str = None,
45864628
save_function: Callable = None,
45874629
safe_serialization: bool = True,
4630+
transformer_lora_adapter_metadata: Optional[dict] = None,
45884631
):
45894632
r"""
4590-
Save the LoRA parameters corresponding to the UNet and text encoder.
4633+
Save the LoRA parameters corresponding to the transformer.
45914634
45924635
Arguments:
45934636
save_directory (`str` or `os.PathLike`):
@@ -4604,15 +4647,20 @@ def save_lora_weights(
46044647
`DIFFUSERS_SAVE_MODE`.
46054648
safe_serialization (`bool`, *optional*, defaults to `True`):
46064649
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4650+
transformer_lora_adapter_metadata: TODO
46074651
"""
46084652
state_dict = {}
4653+
lora_adapter_metadata = {}
46094654

46104655
if not transformer_lora_layers:
46114656
raise ValueError("You must pass `transformer_lora_layers`.")
46124657

46134658
if transformer_lora_layers:
46144659
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
46154660

4661+
if transformer_lora_adapter_metadata:
4662+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4663+
46164664
# Save the model
46174665
cls.write_lora_layers(
46184666
state_dict=state_dict,
@@ -4621,6 +4669,7 @@ def save_lora_weights(
46214669
weight_name=weight_name,
46224670
save_function=save_function,
46234671
safe_serialization=safe_serialization,
4672+
lora_adapter_metadata=lora_adapter_metadata,
46244673
)
46254674

46264675
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -4890,13 +4939,7 @@ def load_lora_weights(
48904939
@classmethod
48914940
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
48924941
def load_lora_into_transformer(
4893-
cls,
4894-
state_dict,
4895-
transformer,
4896-
adapter_name=None,
4897-
_pipeline=None,
4898-
low_cpu_mem_usage=False,
4899-
hotswap: bool = False,
4942+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
49004943
):
49014944
"""
49024945
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -4946,7 +4989,7 @@ def save_lora_weights(
49464989
transformer_lora_adapter_metadata: Optional[dict] = None,
49474990
):
49484991
r"""
4949-
Save the LoRA parameters corresponding to the UNet and text encoder.
4992+
Save the LoRA parameters corresponding to the transformer.
49504993
49514994
Arguments:
49524995
save_directory (`str` or `os.PathLike`):
@@ -5269,9 +5312,10 @@ def save_lora_weights(
52695312
weight_name: str = None,
52705313
save_function: Callable = None,
52715314
safe_serialization: bool = True,
5315+
transformer_lora_adapter_metadata: Optional[dict] = None,
52725316
):
52735317
r"""
5274-
Save the LoRA parameters corresponding to the UNet and text encoder.
5318+
Save the LoRA parameters corresponding to the transformer.
52755319
52765320
Arguments:
52775321
save_directory (`str` or `os.PathLike`):
@@ -5288,15 +5332,20 @@ def save_lora_weights(
52885332
`DIFFUSERS_SAVE_MODE`.
52895333
safe_serialization (`bool`, *optional*, defaults to `True`):
52905334
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5335+
transformer_lora_adapter_metadata: TODO
52915336
"""
52925337
state_dict = {}
5338+
lora_adapter_metadata = {}
52935339

52945340
if not transformer_lora_layers:
52955341
raise ValueError("You must pass `transformer_lora_layers`.")
52965342

52975343
if transformer_lora_layers:
52985344
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
52995345

5346+
if transformer_lora_adapter_metadata:
5347+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
5348+
53005349
# Save the model
53015350
cls.write_lora_layers(
53025351
state_dict=state_dict,
@@ -5305,6 +5354,7 @@ def save_lora_weights(
53055354
weight_name=weight_name,
53065355
save_function=save_function,
53075356
safe_serialization=safe_serialization,
5357+
lora_adapter_metadata=lora_adapter_metadata,
53085358
)
53095359

53105360
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5588,9 +5638,10 @@ def save_lora_weights(
55885638
weight_name: str = None,
55895639
save_function: Callable = None,
55905640
safe_serialization: bool = True,
5641+
transformer_lora_adapter_metadata: Optional[dict] = None,
55915642
):
55925643
r"""
5593-
Save the LoRA parameters corresponding to the UNet and text encoder.
5644+
Save the LoRA parameters corresponding to the transformer.
55945645
55955646
Arguments:
55965647
save_directory (`str` or `os.PathLike`):
@@ -5607,15 +5658,20 @@ def save_lora_weights(
56075658
`DIFFUSERS_SAVE_MODE`.
56085659
safe_serialization (`bool`, *optional*, defaults to `True`):
56095660
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
5661+
transformer_lora_adapter_metadata: TODO
56105662
"""
56115663
state_dict = {}
5664+
lora_adapter_metadata = {}
56125665

56135666
if not transformer_lora_layers:
56145667
raise ValueError("You must pass `transformer_lora_layers`.")
56155668

56165669
if transformer_lora_layers:
56175670
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
56185671

5672+
if transformer_lora_adapter_metadata:
5673+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
5674+
56195675
# Save the model
56205676
cls.write_lora_layers(
56215677
state_dict=state_dict,
@@ -5624,6 +5680,7 @@ def save_lora_weights(
56245680
weight_name=weight_name,
56255681
save_function=save_function,
56265682
safe_serialization=safe_serialization,
5683+
lora_adapter_metadata=lora_adapter_metadata,
56275684
)
56285685

56295686
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora

0 commit comments

Comments
 (0)