Skip to content

Commit e4356d6

Browse files
yiyixuxualexisrollandstevhliu
authored
add a "Community Scripts" section (#7358)
* add * add tiling * fix * fix * fix * give community script its own readme * Update examples/community/README_community_scripts.md Co-authored-by: Steven Liu <[email protected]> * Update examples/community/README_community_scripts.md Co-authored-by: Steven Liu <[email protected]> * Update examples/community/README_community_scripts.md Co-authored-by: Steven Liu <[email protected]> * Update examples/community/README_community_scripts.md --------- Co-authored-by: Alexis Rolland <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent 8244146 commit e4356d6

File tree

2 files changed

+240
-6
lines changed

2 files changed

+240
-6
lines changed

examples/community/README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
# Community Examples
1+
# Community Pipeline Examples
22

33
> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
44
5-
**Community** examples consist of both inference and training examples that have been added by the community.
6-
Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
7-
If a community doesn't work as expected, please open an issue and ping the author on it.
5+
**Community pipeline** examples consist pipelines that have been added by the community.
6+
Please have a look at the following tables to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
7+
If a community pipeline doesn't work as expected, please open an issue and ping the author on it.
8+
9+
Please also check out our [Community Scripts](https://github.com/huggingface/diffusers/blob/main/examples/community/README_community_scripts.md) examples for tips and tricks that you can use with diffusers without having to run a community pipeline.
810

911
| Example | Description | Code Example | Colab | Author |
1012
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
@@ -1887,7 +1889,7 @@ In the above code, the `prompt2` is appended to the `prompt`, which is more than
18871889

18881890
For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
18891891

1890-
## Example Images Mixing (with CoCa)
1892+
### Example Images Mixing (with CoCa)
18911893
```python
18921894
import requests
18931895
from io import BytesIO
@@ -2934,7 +2936,7 @@ pipe(prompt =prompt, rp_args = rp_args)
29342936

29352937
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
29362938

2937-
## Diffusion Posterior Sampling Pipeline
2939+
### Diffusion Posterior Sampling Pipeline
29382940
* Reference paper
29392941
```
29402942
@article{chung2022diffusion,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Community Scripts
2+
3+
**Community scripts** consist of inference examples using Diffusers pipelines that have been added by the community.
4+
Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste code example that you can try out.
5+
If a community script doesn't work as expected, please open an issue and ping the author on it.
6+
7+
| Example | Description | Code Example | Colab | Author |
8+
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
9+
| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
10+
| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
11+
12+
13+
## Example usages
14+
15+
### IP Adapter Negative Noise
16+
17+
Diffusers pipelines are fully integrated with IP-Adapter, which allows you to prompt the diffusion model with an image. However, it does not support negative image prompts (there is no `negative_ip_adapter_image` argument) the same way it supports negative text prompts. When you pass an `ip_adapter_image,` it will create a zero-filled tensor as a negative image. This script shows you how to create a negative noise from `ip_adapter_image` and use it to significantly improve the generation quality while preserving the composition of images.
18+
19+
[cubiq](https://github.com/cubiq) initially developed this feature in his [repository](https://github.com/cubiq/ComfyUI_IPAdapter_plus). The community script was contributed by [asomoza](https://github.com/Somoza). You can find more details about this experimentation [this discussion](https://github.com/huggingface/diffusers/discussions/7167)
20+
21+
IP-Adapter without negative noise
22+
|source|result|
23+
|---|---|
24+
|![20240229150812](https://github.com/huggingface/diffusers/assets/5442875/901d8bd8-7a59-4fe7-bda1-a0e0d6c7dffd)|![20240229163923_normal](https://github.com/huggingface/diffusers/assets/5442875/3432e25a-ece6-45f4-a3f4-fca354f40b5b)|
25+
26+
IP-Adapter with negative noise
27+
|source|result|
28+
|---|---|
29+
|![20240229150812](https://github.com/huggingface/diffusers/assets/5442875/901d8bd8-7a59-4fe7-bda1-a0e0d6c7dffd)|![20240229163923](https://github.com/huggingface/diffusers/assets/5442875/736fd15a-36ba-40c0-a7d8-6ec1ac26f788)|
30+
31+
```python
32+
import torch
33+
34+
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, StableDiffusionXLPipeline
35+
from diffusers.models import ImageProjection
36+
from diffusers.utils import load_image
37+
38+
39+
def encode_image(
40+
image_encoder,
41+
feature_extractor,
42+
image,
43+
device,
44+
num_images_per_prompt,
45+
output_hidden_states=None,
46+
negative_image=None,
47+
):
48+
dtype = next(image_encoder.parameters()).dtype
49+
50+
if not isinstance(image, torch.Tensor):
51+
image = feature_extractor(image, return_tensors="pt").pixel_values
52+
53+
image = image.to(device=device, dtype=dtype)
54+
if output_hidden_states:
55+
image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
56+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
57+
58+
if negative_image is None:
59+
uncond_image_enc_hidden_states = image_encoder(
60+
torch.zeros_like(image), output_hidden_states=True
61+
).hidden_states[-2]
62+
else:
63+
if not isinstance(negative_image, torch.Tensor):
64+
negative_image = feature_extractor(negative_image, return_tensors="pt").pixel_values
65+
negative_image = negative_image.to(device=device, dtype=dtype)
66+
uncond_image_enc_hidden_states = image_encoder(negative_image, output_hidden_states=True).hidden_states[-2]
67+
68+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
69+
return image_enc_hidden_states, uncond_image_enc_hidden_states
70+
else:
71+
image_embeds = image_encoder(image).image_embeds
72+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
73+
uncond_image_embeds = torch.zeros_like(image_embeds)
74+
75+
return image_embeds, uncond_image_embeds
76+
77+
78+
@torch.no_grad()
79+
def prepare_ip_adapter_image_embeds(
80+
unet,
81+
image_encoder,
82+
feature_extractor,
83+
ip_adapter_image,
84+
do_classifier_free_guidance,
85+
device,
86+
num_images_per_prompt,
87+
ip_adapter_negative_image=None,
88+
):
89+
if not isinstance(ip_adapter_image, list):
90+
ip_adapter_image = [ip_adapter_image]
91+
92+
if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
93+
raise ValueError(
94+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
95+
)
96+
97+
image_embeds = []
98+
for single_ip_adapter_image, image_proj_layer in zip(
99+
ip_adapter_image, unet.encoder_hid_proj.image_projection_layers
100+
):
101+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
102+
single_image_embeds, single_negative_image_embeds = encode_image(
103+
image_encoder,
104+
feature_extractor,
105+
single_ip_adapter_image,
106+
device,
107+
1,
108+
output_hidden_state,
109+
negative_image=ip_adapter_negative_image,
110+
)
111+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
112+
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
113+
114+
if do_classifier_free_guidance:
115+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
116+
single_image_embeds = single_image_embeds.to(device)
117+
118+
image_embeds.append(single_image_embeds)
119+
120+
return image_embeds
121+
122+
123+
vae = AutoencoderKL.from_pretrained(
124+
"madebyollin/sdxl-vae-fp16-fix",
125+
torch_dtype=torch.float16,
126+
).to("cuda")
127+
128+
pipeline = StableDiffusionXLPipeline.from_pretrained(
129+
"RunDiffusion/Juggernaut-XL-v9",
130+
torch_dtype=torch.float16,
131+
vae=vae,
132+
variant="fp16",
133+
).to("cuda")
134+
135+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
136+
pipeline.scheduler.config.use_karras_sigmas = True
137+
138+
pipeline.load_ip_adapter(
139+
"h94/IP-Adapter",
140+
subfolder="sdxl_models",
141+
weight_name="ip-adapter-plus_sdxl_vit-h.safetensors",
142+
image_encoder_folder="models/image_encoder",
143+
)
144+
pipeline.set_ip_adapter_scale(0.7)
145+
146+
ip_image = load_image("source.png")
147+
negative_ip_image = load_image("noise.png")
148+
149+
image_embeds = prepare_ip_adapter_image_embeds(
150+
unet=pipeline.unet,
151+
image_encoder=pipeline.image_encoder,
152+
feature_extractor=pipeline.feature_extractor,
153+
ip_adapter_image=[[ip_image]],
154+
do_classifier_free_guidance=True,
155+
device="cuda",
156+
num_images_per_prompt=1,
157+
ip_adapter_negative_image=negative_ip_image,
158+
)
159+
160+
161+
prompt = "cinematic photo of a cyborg in the city, 4k, high quality, intricate, highly detailed"
162+
negative_prompt = "blurry, smooth, plastic"
163+
164+
image = pipeline(
165+
prompt=prompt,
166+
negative_prompt=negative_prompt,
167+
ip_adapter_image_embeds=image_embeds,
168+
guidance_scale=6.0,
169+
num_inference_steps=25,
170+
generator=torch.Generator(device="cpu").manual_seed(1556265306),
171+
).images[0]
172+
173+
image.save("result.png")
174+
```
175+
176+
### Asymmetric Tiling
177+
Stable Diffusion is not trained to generate seamless textures. However, you can use this simple script to add tiling to your generation. This script is contributed by [alexisrolland](https://github.com/alexisrolland). See more details in the [this issue](https://github.com/huggingface/diffusers/issues/556)
178+
179+
180+
|Generated|Tiled|
181+
|---|---|
182+
|![20240313003235_573631814](https://github.com/huggingface/diffusers/assets/5442875/eca174fb-06a4-464e-a3a7-00dbb024543e)|![wall](https://github.com/huggingface/diffusers/assets/5442875/b4aa774b-2a6a-4316-a8eb-8f30b5f4d024)|
183+
184+
185+
```py
186+
import torch
187+
from typing import Optional
188+
from diffusers import StableDiffusionPipeline
189+
from diffusers.models.lora import LoRACompatibleConv
190+
191+
def seamless_tiling(pipeline, x_axis, y_axis):
192+
def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
193+
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
194+
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
195+
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
196+
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
197+
return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)
198+
x_mode = 'circular' if x_axis else 'constant'
199+
y_mode = 'circular' if y_axis else 'constant'
200+
targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
201+
convolution_layers = []
202+
for target in targets:
203+
for module in target.modules():
204+
if isinstance(module, torch.nn.Conv2d):
205+
convolution_layers.append(module)
206+
for layer in convolution_layers:
207+
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
208+
layer.lora_layer = lambda * x: 0
209+
layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)
210+
return pipeline
211+
212+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
213+
pipeline.enable_model_cpu_offload()
214+
prompt = ["texture of a red brick wall"]
215+
seed = 123456
216+
generator = torch.Generator(device='cuda').manual_seed(seed)
217+
218+
pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)
219+
image = pipeline(
220+
prompt=prompt,
221+
width=512,
222+
height=512,
223+
num_inference_steps=20,
224+
guidance_scale=7,
225+
num_images_per_prompt=1,
226+
generator=generator
227+
).images[0]
228+
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)
229+
230+
torch.cuda.empty_cache()
231+
image.save('image.png')
232+
```

0 commit comments

Comments
 (0)