@@ -402,41 +402,18 @@ def forward(
402
402
403
403
# 1. Input
404
404
if self .is_input_continuous :
405
- batch , _ , height , width = hidden_states .shape
405
+ batch_size , _ , height , width = hidden_states .shape
406
406
residual = hidden_states
407
-
408
- hidden_states = self .norm (hidden_states )
409
- if not self .use_linear_projection :
410
- hidden_states = self .proj_in (hidden_states )
411
- inner_dim = hidden_states .shape [1 ]
412
- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
413
- else :
414
- inner_dim = hidden_states .shape [1 ]
415
- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
416
- hidden_states = self .proj_in (hidden_states )
417
-
407
+ hidden_states , inner_dim = self ._operate_on_continuous_inputs (hidden_states )
418
408
elif self .is_input_vectorized :
419
409
hidden_states = self .latent_image_embedding (hidden_states )
420
410
elif self .is_input_patches :
421
411
height , width = hidden_states .shape [- 2 ] // self .patch_size , hidden_states .shape [- 1 ] // self .patch_size
422
- hidden_states = self .pos_embed (hidden_states )
423
-
424
- if self .adaln_single is not None :
425
- if self .use_additional_conditions and added_cond_kwargs is None :
426
- raise ValueError (
427
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
428
- )
429
- batch_size = hidden_states .shape [0 ]
430
- timestep , embedded_timestep = self .adaln_single (
431
- timestep , added_cond_kwargs , batch_size = batch_size , hidden_dtype = hidden_states .dtype
432
- )
412
+ hidden_states , encoder_hidden_states , timestep , embedded_timestep = self ._operate_on_patched_inputs (
413
+ hidden_states , encoder_hidden_states , timestep , added_cond_kwargs
414
+ )
433
415
434
416
# 2. Blocks
435
- if self .is_input_patches and self .caption_projection is not None :
436
- batch_size = hidden_states .shape [0 ]
437
- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
438
- encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
439
-
440
417
for block in self .transformer_blocks :
441
418
if self .training and self .gradient_checkpointing :
442
419
@@ -474,51 +451,116 @@ def custom_forward(*inputs):
474
451
475
452
# 3. Output
476
453
if self .is_input_continuous :
477
- if not self .use_linear_projection :
478
- hidden_states = hidden_states . reshape ( batch , height , width , inner_dim ). permute ( 0 , 3 , 1 , 2 ). contiguous ()
479
- hidden_states = self . proj_out ( hidden_states )
480
- else :
481
- hidden_states = self . proj_out ( hidden_states )
482
- hidden_states = hidden_states . reshape ( batch , height , width , inner_dim ). permute ( 0 , 3 , 1 , 2 ). contiguous ()
483
-
484
- output = hidden_states + residual
454
+ output = self ._get_output_for_continuous_inputs (
455
+ hidden_states = hidden_states ,
456
+ residual = residual ,
457
+ batch_size = batch_size ,
458
+ height = height ,
459
+ width = width ,
460
+ inner_dim = inner_dim ,
461
+ )
485
462
elif self .is_input_vectorized :
486
- hidden_states = self .norm_out (hidden_states )
487
- logits = self .out (hidden_states )
488
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
489
- logits = logits .permute (0 , 2 , 1 )
463
+ output = self ._get_output_for_vectorized_inputs (hidden_states )
464
+ elif self .is_input_patches :
465
+ output = self ._get_output_for_patched_inputs (
466
+ hidden_states = hidden_states ,
467
+ timestep = timestep ,
468
+ class_labels = class_labels ,
469
+ embedded_timestep = embedded_timestep ,
470
+ height = height ,
471
+ width = width ,
472
+ )
473
+
474
+ if not return_dict :
475
+ return (output ,)
476
+
477
+ return Transformer2DModelOutput (sample = output )
478
+
479
+ def _operate_on_continuous_inputs (self , hidden_states ):
480
+ batch , _ , height , width = hidden_states .shape
481
+ hidden_states = self .norm (hidden_states )
482
+
483
+ if not self .use_linear_projection :
484
+ hidden_states = self .proj_in (hidden_states )
485
+ inner_dim = hidden_states .shape [1 ]
486
+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
487
+ else :
488
+ inner_dim = hidden_states .shape [1 ]
489
+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
490
+ hidden_states = self .proj_in (hidden_states )
491
+
492
+ return hidden_states , inner_dim
490
493
491
- # log(p(x_0))
492
- output = F .log_softmax (logits .double (), dim = 1 ).float ()
494
+ def _operate_on_patched_inputs (self , hidden_states , encoder_hidden_states , timestep , added_cond_kwargs ):
495
+ batch_size = hidden_states .shape [0 ]
496
+ hidden_states = self .pos_embed (hidden_states )
497
+ embedded_timestep = None
493
498
494
- if self .is_input_patches :
495
- if self .config . norm_type != "ada_norm_single" :
496
- conditioning = self . transformer_blocks [ 0 ]. norm1 . emb (
497
- timestep , class_labels , hidden_dtype = hidden_states . dtype
499
+ if self .adaln_single is not None :
500
+ if self .use_additional_conditions and added_cond_kwargs is None :
501
+ raise ValueError (
502
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
498
503
)
499
- shift , scale = self .proj_out_1 (F .silu (conditioning )).chunk (2 , dim = 1 )
500
- hidden_states = self .norm_out (hidden_states ) * (1 + scale [:, None ]) + shift [:, None ]
501
- hidden_states = self .proj_out_2 (hidden_states )
502
- elif self .config .norm_type == "ada_norm_single" :
503
- shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
504
- hidden_states = self .norm_out (hidden_states )
505
- # Modulation
506
- hidden_states = hidden_states * (1 + scale ) + shift
507
- hidden_states = self .proj_out (hidden_states )
508
- hidden_states = hidden_states .squeeze (1 )
509
-
510
- # unpatchify
511
- if self .adaln_single is None :
512
- height = width = int (hidden_states .shape [1 ] ** 0.5 )
513
- hidden_states = hidden_states .reshape (
514
- shape = (- 1 , height , width , self .patch_size , self .patch_size , self .out_channels )
504
+ timestep , embedded_timestep = self .adaln_single (
505
+ timestep , added_cond_kwargs , batch_size = batch_size , hidden_dtype = hidden_states .dtype
515
506
)
516
- hidden_states = torch .einsum ("nhwpqc->nchpwq" , hidden_states )
517
- output = hidden_states .reshape (
518
- shape = (- 1 , self .out_channels , height * self .patch_size , width * self .patch_size )
507
+
508
+ if self .caption_projection is not None :
509
+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
510
+ encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
511
+
512
+ return hidden_states , encoder_hidden_states , timestep , embedded_timestep
513
+
514
+ def _get_output_for_continuous_inputs (self , hidden_states , residual , batch_size , height , width , inner_dim ):
515
+ if not self .use_linear_projection :
516
+ hidden_states = (
517
+ hidden_states .reshape (batch_size , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
518
+ )
519
+ hidden_states = self .proj_out (hidden_states )
520
+ else :
521
+ hidden_states = self .proj_out (hidden_states )
522
+ hidden_states = (
523
+ hidden_states .reshape (batch_size , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
519
524
)
520
525
521
- if not return_dict :
522
- return ( output ,)
526
+ output = hidden_states + residual
527
+ return output
523
528
524
- return Transformer2DModelOutput (sample = output )
529
+ def _get_output_for_vectorized_inputs (self , hidden_states ):
530
+ hidden_states = self .norm_out (hidden_states )
531
+ logits = self .out (hidden_states )
532
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
533
+ logits = logits .permute (0 , 2 , 1 )
534
+ # log(p(x_0))
535
+ output = F .log_softmax (logits .double (), dim = 1 ).float ()
536
+ return output
537
+
538
+ def _get_output_for_patched_inputs (
539
+ self , hidden_states , timestep , class_labels , embedded_timestep , height = None , width = None
540
+ ):
541
+ if self .config .norm_type != "ada_norm_single" :
542
+ conditioning = self .transformer_blocks [0 ].norm1 .emb (
543
+ timestep , class_labels , hidden_dtype = hidden_states .dtype
544
+ )
545
+ shift , scale = self .proj_out_1 (F .silu (conditioning )).chunk (2 , dim = 1 )
546
+ hidden_states = self .norm_out (hidden_states ) * (1 + scale [:, None ]) + shift [:, None ]
547
+ hidden_states = self .proj_out_2 (hidden_states )
548
+ elif self .config .norm_type == "ada_norm_single" :
549
+ shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
550
+ hidden_states = self .norm_out (hidden_states )
551
+ # Modulation
552
+ hidden_states = hidden_states * (1 + scale ) + shift
553
+ hidden_states = self .proj_out (hidden_states )
554
+ hidden_states = hidden_states .squeeze (1 )
555
+
556
+ # unpatchify
557
+ if self .adaln_single is None :
558
+ height = width = int (hidden_states .shape [1 ] ** 0.5 )
559
+ hidden_states = hidden_states .reshape (
560
+ shape = (- 1 , height , width , self .patch_size , self .patch_size , self .out_channels )
561
+ )
562
+ hidden_states = torch .einsum ("nhwpqc->nchpwq" , hidden_states )
563
+ output = hidden_states .reshape (
564
+ shape = (- 1 , self .out_channels , height * self .patch_size , width * self .patch_size )
565
+ )
566
+ return output
0 commit comments