|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import inspect
|
16 |
| -import math |
17 | 16 | from dataclasses import dataclass
|
18 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| 17 | +from typing import Any, Callable, Dict, List, Optional, Union |
19 | 18 |
|
20 | 19 | import numpy as np
|
21 | 20 | import PIL
|
22 | 21 | import torch
|
23 |
| -import torch.fft as fft |
24 | 22 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
25 | 23 |
|
26 | 24 | from ...image_processor import PipelineImageInput, VaeImageProcessor
|
@@ -130,81 +128,16 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
|
130 | 128 | return coef
|
131 | 129 |
|
132 | 130 |
|
133 |
| -def _get_freeinit_freq_filter( |
134 |
| - shape: Tuple[int, ...], |
135 |
| - device: Union[str, torch.dtype], |
136 |
| - filter_type: str, |
137 |
| - order: float, |
138 |
| - spatial_stop_frequency: float, |
139 |
| - temporal_stop_frequency: float, |
140 |
| -) -> torch.Tensor: |
141 |
| - r"""Returns the FreeInit filter based on filter type and other input conditions.""" |
142 |
| - |
143 |
| - time, height, width = shape[-3], shape[-2], shape[-1] |
144 |
| - mask = torch.zeros(shape) |
145 |
| - |
146 |
| - if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: |
147 |
| - return mask |
148 |
| - |
149 |
| - if filter_type == "butterworth": |
150 |
| - |
151 |
| - def retrieve_mask(x): |
152 |
| - return 1 / (1 + (x / spatial_stop_frequency**2) ** order) |
153 |
| - elif filter_type == "gaussian": |
154 |
| - |
155 |
| - def retrieve_mask(x): |
156 |
| - return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) |
157 |
| - elif filter_type == "ideal": |
158 |
| - |
159 |
| - def retrieve_mask(x): |
160 |
| - return 1 if x <= spatial_stop_frequency * 2 else 0 |
161 |
| - else: |
162 |
| - raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") |
163 |
| - |
164 |
| - for t in range(time): |
165 |
| - for h in range(height): |
166 |
| - for w in range(width): |
167 |
| - d_square = ( |
168 |
| - ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 |
169 |
| - + (2 * h / height - 1) ** 2 |
170 |
| - + (2 * w / width - 1) ** 2 |
171 |
| - ) |
172 |
| - mask[..., t, h, w] = retrieve_mask(d_square) |
173 |
| - |
174 |
| - return mask.to(device) |
175 |
| - |
176 |
| - |
177 |
| -def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: |
178 |
| - r"""Noise reinitialization.""" |
179 |
| - # FFT |
180 |
| - x_freq = fft.fftn(x, dim=(-3, -2, -1)) |
181 |
| - x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) |
182 |
| - noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) |
183 |
| - noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) |
184 |
| - |
185 |
| - # frequency mix |
186 |
| - HPF = 1 - LPF |
187 |
| - x_freq_low = x_freq * LPF |
188 |
| - noise_freq_high = noise_freq * HPF |
189 |
| - x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain |
190 |
| - |
191 |
| - # IFFT |
192 |
| - x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) |
193 |
| - x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real |
194 |
| - |
195 |
| - return x_mixed |
196 |
| - |
197 |
| - |
198 | 131 | @dataclass
|
199 | 132 | class PIAPipelineOutput(BaseOutput):
|
200 | 133 | r"""
|
201 | 134 | Output class for PIAPipeline.
|
202 | 135 |
|
203 | 136 | Args:
|
204 | 137 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
205 |
| - Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, |
206 |
| - NumPy array of shape `(batch_size, num_frames, channels, height, width, |
207 |
| - Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. |
| 138 | + Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, |
| 139 | + NumPy array of shape `(batch_size, num_frames, channels, height, width, |
| 140 | + Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. |
208 | 141 | """
|
209 | 142 |
|
210 | 143 | frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
@@ -788,7 +721,8 @@ def __call__(
|
788 | 721 | The input image to be used for video generation.
|
789 | 722 | prompt (`str` or `List[str]`, *optional*):
|
790 | 723 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
791 |
| - strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. |
| 724 | + strength (`float`, *optional*, defaults to 1.0): |
| 725 | + Indicates extent to transform the reference `image`. Must be between 0 and 1. |
792 | 726 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
793 | 727 | The height in pixels of the generated video.
|
794 | 728 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
@@ -979,8 +913,10 @@ def __call__(
|
979 | 913 | latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
980 | 914 | )
|
981 | 915 |
|
| 916 | + self._num_timesteps = len(timesteps) |
982 | 917 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
983 |
| - with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 918 | + |
| 919 | + with self.progress_bar(total=self._num_timesteps) as progress_bar: |
984 | 920 | for i, t in enumerate(timesteps):
|
985 | 921 | # expand the latents if we are doing classifier free guidance
|
986 | 922 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
|
0 commit comments