Skip to content

Commit 44f6b85

Browse files
sayakpaulDN6
andauthored
[Core] refactor transformer_2d forward logic into meaningful conditions. (#7489)
* refactor transformer_2d forward logic into meaningful conditions. * Empty-Commit * fix: _operate_on_patched_inputs * fix: _operate_on_patched_inputs * check * fix: patch output computation block. * fix: _operate_on_patched_inputs. * remove print. * move operations to blocks. * more readability neats. * empty commit * Apply suggestions from code review Co-authored-by: Dhruv Nair <[email protected]> * Revert "Apply suggestions from code review" This reverts commit 12178b1. --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent ac7ff7d commit 44f6b85

File tree

1 file changed

+110
-68
lines changed

1 file changed

+110
-68
lines changed

src/diffusers/models/transformers/transformer_2d.py

+110-68
Original file line numberDiff line numberDiff line change
@@ -402,41 +402,18 @@ def forward(
402402

403403
# 1. Input
404404
if self.is_input_continuous:
405-
batch, _, height, width = hidden_states.shape
405+
batch_size, _, height, width = hidden_states.shape
406406
residual = hidden_states
407-
408-
hidden_states = self.norm(hidden_states)
409-
if not self.use_linear_projection:
410-
hidden_states = self.proj_in(hidden_states)
411-
inner_dim = hidden_states.shape[1]
412-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
413-
else:
414-
inner_dim = hidden_states.shape[1]
415-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
416-
hidden_states = self.proj_in(hidden_states)
417-
407+
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
418408
elif self.is_input_vectorized:
419409
hidden_states = self.latent_image_embedding(hidden_states)
420410
elif self.is_input_patches:
421411
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
422-
hidden_states = self.pos_embed(hidden_states)
423-
424-
if self.adaln_single is not None:
425-
if self.use_additional_conditions and added_cond_kwargs is None:
426-
raise ValueError(
427-
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
428-
)
429-
batch_size = hidden_states.shape[0]
430-
timestep, embedded_timestep = self.adaln_single(
431-
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
432-
)
412+
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
413+
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
414+
)
433415

434416
# 2. Blocks
435-
if self.is_input_patches and self.caption_projection is not None:
436-
batch_size = hidden_states.shape[0]
437-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
438-
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
439-
440417
for block in self.transformer_blocks:
441418
if self.training and self.gradient_checkpointing:
442419

@@ -474,51 +451,116 @@ def custom_forward(*inputs):
474451

475452
# 3. Output
476453
if self.is_input_continuous:
477-
if not self.use_linear_projection:
478-
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
479-
hidden_states = self.proj_out(hidden_states)
480-
else:
481-
hidden_states = self.proj_out(hidden_states)
482-
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
483-
484-
output = hidden_states + residual
454+
output = self._get_output_for_continuous_inputs(
455+
hidden_states=hidden_states,
456+
residual=residual,
457+
batch_size=batch_size,
458+
height=height,
459+
width=width,
460+
inner_dim=inner_dim,
461+
)
485462
elif self.is_input_vectorized:
486-
hidden_states = self.norm_out(hidden_states)
487-
logits = self.out(hidden_states)
488-
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
489-
logits = logits.permute(0, 2, 1)
463+
output = self._get_output_for_vectorized_inputs(hidden_states)
464+
elif self.is_input_patches:
465+
output = self._get_output_for_patched_inputs(
466+
hidden_states=hidden_states,
467+
timestep=timestep,
468+
class_labels=class_labels,
469+
embedded_timestep=embedded_timestep,
470+
height=height,
471+
width=width,
472+
)
473+
474+
if not return_dict:
475+
return (output,)
476+
477+
return Transformer2DModelOutput(sample=output)
478+
479+
def _operate_on_continuous_inputs(self, hidden_states):
480+
batch, _, height, width = hidden_states.shape
481+
hidden_states = self.norm(hidden_states)
482+
483+
if not self.use_linear_projection:
484+
hidden_states = self.proj_in(hidden_states)
485+
inner_dim = hidden_states.shape[1]
486+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
487+
else:
488+
inner_dim = hidden_states.shape[1]
489+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
490+
hidden_states = self.proj_in(hidden_states)
491+
492+
return hidden_states, inner_dim
490493

491-
# log(p(x_0))
492-
output = F.log_softmax(logits.double(), dim=1).float()
494+
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
495+
batch_size = hidden_states.shape[0]
496+
hidden_states = self.pos_embed(hidden_states)
497+
embedded_timestep = None
493498

494-
if self.is_input_patches:
495-
if self.config.norm_type != "ada_norm_single":
496-
conditioning = self.transformer_blocks[0].norm1.emb(
497-
timestep, class_labels, hidden_dtype=hidden_states.dtype
499+
if self.adaln_single is not None:
500+
if self.use_additional_conditions and added_cond_kwargs is None:
501+
raise ValueError(
502+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
498503
)
499-
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
500-
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
501-
hidden_states = self.proj_out_2(hidden_states)
502-
elif self.config.norm_type == "ada_norm_single":
503-
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
504-
hidden_states = self.norm_out(hidden_states)
505-
# Modulation
506-
hidden_states = hidden_states * (1 + scale) + shift
507-
hidden_states = self.proj_out(hidden_states)
508-
hidden_states = hidden_states.squeeze(1)
509-
510-
# unpatchify
511-
if self.adaln_single is None:
512-
height = width = int(hidden_states.shape[1] ** 0.5)
513-
hidden_states = hidden_states.reshape(
514-
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
504+
timestep, embedded_timestep = self.adaln_single(
505+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
515506
)
516-
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
517-
output = hidden_states.reshape(
518-
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
507+
508+
if self.caption_projection is not None:
509+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
510+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
511+
512+
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
513+
514+
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
515+
if not self.use_linear_projection:
516+
hidden_states = (
517+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
518+
)
519+
hidden_states = self.proj_out(hidden_states)
520+
else:
521+
hidden_states = self.proj_out(hidden_states)
522+
hidden_states = (
523+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
519524
)
520525

521-
if not return_dict:
522-
return (output,)
526+
output = hidden_states + residual
527+
return output
523528

524-
return Transformer2DModelOutput(sample=output)
529+
def _get_output_for_vectorized_inputs(self, hidden_states):
530+
hidden_states = self.norm_out(hidden_states)
531+
logits = self.out(hidden_states)
532+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
533+
logits = logits.permute(0, 2, 1)
534+
# log(p(x_0))
535+
output = F.log_softmax(logits.double(), dim=1).float()
536+
return output
537+
538+
def _get_output_for_patched_inputs(
539+
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
540+
):
541+
if self.config.norm_type != "ada_norm_single":
542+
conditioning = self.transformer_blocks[0].norm1.emb(
543+
timestep, class_labels, hidden_dtype=hidden_states.dtype
544+
)
545+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
546+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
547+
hidden_states = self.proj_out_2(hidden_states)
548+
elif self.config.norm_type == "ada_norm_single":
549+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
550+
hidden_states = self.norm_out(hidden_states)
551+
# Modulation
552+
hidden_states = hidden_states * (1 + scale) + shift
553+
hidden_states = self.proj_out(hidden_states)
554+
hidden_states = hidden_states.squeeze(1)
555+
556+
# unpatchify
557+
if self.adaln_single is None:
558+
height = width = int(hidden_states.shape[1] ** 0.5)
559+
hidden_states = hidden_states.reshape(
560+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
561+
)
562+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
563+
output = hidden_states.reshape(
564+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
565+
)
566+
return output

0 commit comments

Comments
 (0)