Skip to content

Commit 3636990

Browse files
authored
[refactor] Fix FreeInit behaviour (#7410)
* fix freeinit impl * fix progress bar * fix progress bar and remove old code * fix num_inference_steps==1 case for freeinit by atleast running 1 step when fast sampling enabled
1 parent 9613576 commit 3636990

File tree

4 files changed

+42
-105
lines changed

4 files changed

+42
-105
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def __call__(
792792
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
793793

794794
# 8. Denoising loop
795-
with self.progress_bar(total=num_inference_steps) as progress_bar:
795+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
796796
for i, t in enumerate(timesteps):
797797
# expand the latents if we are doing classifier free guidance
798798
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def __call__(
944944
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
945945

946946
# 8. Denoising loop
947-
with self.progress_bar(total=num_inference_steps) as progress_bar:
947+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
948948
for i, t in enumerate(timesteps):
949949
# expand the latents if we are doing classifier free guidance
950950
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

src/diffusers/pipelines/free_init_utils.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -146,39 +146,40 @@ def _apply_free_init(
146146
):
147147
if free_init_iteration == 0:
148148
self._free_init_initial_noise = latents.detach().clone()
149-
return latents, self.scheduler.timesteps
150-
151-
latent_shape = latents.shape
152-
153-
free_init_filter_shape = (1, *latent_shape[1:])
154-
free_init_freq_filter = self._get_free_init_freq_filter(
155-
shape=free_init_filter_shape,
156-
device=device,
157-
filter_type=self._free_init_method,
158-
order=self._free_init_order,
159-
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
160-
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
161-
)
162-
163-
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
164-
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
165-
166-
z_t = self.scheduler.add_noise(
167-
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
168-
).to(dtype=torch.float32)
169-
170-
z_rand = randn_tensor(
171-
shape=latent_shape,
172-
generator=generator,
173-
device=device,
174-
dtype=torch.float32,
175-
)
176-
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
177-
latents = latents.to(dtype)
149+
else:
150+
latent_shape = latents.shape
151+
152+
free_init_filter_shape = (1, *latent_shape[1:])
153+
free_init_freq_filter = self._get_free_init_freq_filter(
154+
shape=free_init_filter_shape,
155+
device=device,
156+
filter_type=self._free_init_method,
157+
order=self._free_init_order,
158+
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
159+
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
160+
)
161+
162+
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
163+
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
164+
165+
z_t = self.scheduler.add_noise(
166+
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
167+
).to(dtype=torch.float32)
168+
169+
z_rand = randn_tensor(
170+
shape=latent_shape,
171+
generator=generator,
172+
device=device,
173+
dtype=torch.float32,
174+
)
175+
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
176+
latents = latents.to(dtype)
178177

179178
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
180179
if self._free_init_use_fast_sampling:
181-
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
180+
num_inference_steps = max(
181+
1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
182+
)
182183
self.scheduler.set_timesteps(num_inference_steps, device=device)
183184

184185
return latents, self.scheduler.timesteps

src/diffusers/pipelines/pia/pipeline_pia.py

+9-73
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import math
1716
from dataclasses import dataclass
18-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Union
1918

2019
import numpy as np
2120
import PIL
2221
import torch
23-
import torch.fft as fft
2422
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2523

2624
from ...image_processor import PipelineImageInput, VaeImageProcessor
@@ -130,81 +128,16 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
130128
return coef
131129

132130

133-
def _get_freeinit_freq_filter(
134-
shape: Tuple[int, ...],
135-
device: Union[str, torch.dtype],
136-
filter_type: str,
137-
order: float,
138-
spatial_stop_frequency: float,
139-
temporal_stop_frequency: float,
140-
) -> torch.Tensor:
141-
r"""Returns the FreeInit filter based on filter type and other input conditions."""
142-
143-
time, height, width = shape[-3], shape[-2], shape[-1]
144-
mask = torch.zeros(shape)
145-
146-
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
147-
return mask
148-
149-
if filter_type == "butterworth":
150-
151-
def retrieve_mask(x):
152-
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
153-
elif filter_type == "gaussian":
154-
155-
def retrieve_mask(x):
156-
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
157-
elif filter_type == "ideal":
158-
159-
def retrieve_mask(x):
160-
return 1 if x <= spatial_stop_frequency * 2 else 0
161-
else:
162-
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
163-
164-
for t in range(time):
165-
for h in range(height):
166-
for w in range(width):
167-
d_square = (
168-
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
169-
+ (2 * h / height - 1) ** 2
170-
+ (2 * w / width - 1) ** 2
171-
)
172-
mask[..., t, h, w] = retrieve_mask(d_square)
173-
174-
return mask.to(device)
175-
176-
177-
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
178-
r"""Noise reinitialization."""
179-
# FFT
180-
x_freq = fft.fftn(x, dim=(-3, -2, -1))
181-
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
182-
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
183-
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
184-
185-
# frequency mix
186-
HPF = 1 - LPF
187-
x_freq_low = x_freq * LPF
188-
noise_freq_high = noise_freq * HPF
189-
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
190-
191-
# IFFT
192-
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
193-
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
194-
195-
return x_mixed
196-
197-
198131
@dataclass
199132
class PIAPipelineOutput(BaseOutput):
200133
r"""
201134
Output class for PIAPipeline.
202135
203136
Args:
204137
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
205-
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
206-
NumPy array of shape `(batch_size, num_frames, channels, height, width,
207-
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
138+
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
139+
NumPy array of shape `(batch_size, num_frames, channels, height, width,
140+
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
208141
"""
209142

210143
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
@@ -788,7 +721,8 @@ def __call__(
788721
The input image to be used for video generation.
789722
prompt (`str` or `List[str]`, *optional*):
790723
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
791-
strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1.
724+
strength (`float`, *optional*, defaults to 1.0):
725+
Indicates extent to transform the reference `image`. Must be between 0 and 1.
792726
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
793727
The height in pixels of the generated video.
794728
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -979,8 +913,10 @@ def __call__(
979913
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
980914
)
981915

916+
self._num_timesteps = len(timesteps)
982917
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
983-
with self.progress_bar(total=num_inference_steps) as progress_bar:
918+
919+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
984920
for i, t in enumerate(timesteps):
985921
# expand the latents if we are doing classifier free guidance
986922
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

0 commit comments

Comments
 (0)