@@ -193,6 +193,7 @@ def train_one_epoch(
193
193
device : torch .device ,
194
194
logger : logging .Logger ,
195
195
local_rank : int ,
196
+ amp : bool = True ,
196
197
) -> torch .Tensor :
197
198
"""
198
199
Train the model for one epoch.
@@ -212,6 +213,7 @@ def train_one_epoch(
212
213
device (torch.device): Device to use for training.
213
214
logger (logging.Logger): Logger for logging information.
214
215
local_rank (int): Local rank for distributed training.
216
+ amp (bool): Use automatic mixed precision training.
215
217
216
218
Returns:
217
219
torch.Tensor: Training loss for the epoch.
@@ -237,7 +239,7 @@ def train_one_epoch(
237
239
238
240
optimizer .zero_grad (set_to_none = True )
239
241
240
- with autocast ("cuda" , enabled = True ):
242
+ with autocast ("cuda" , enabled = amp ):
241
243
noise = torch .randn (
242
244
(num_images_per_batch , 4 , images .size (- 3 ), images .size (- 2 ), images .size (- 1 )), device = device
243
245
)
@@ -256,9 +258,13 @@ def train_one_epoch(
256
258
257
259
loss = loss_pt (noise_pred .float (), noise .float ())
258
260
259
- scaler .scale (loss ).backward ()
260
- scaler .step (optimizer )
261
- scaler .update ()
261
+ if amp :
262
+ scaler .scale (loss ).backward ()
263
+ scaler .step (optimizer )
264
+ scaler .update ()
265
+ else :
266
+ loss .backward ()
267
+ optimizer .step ()
262
268
263
269
lr_scheduler .step ()
264
270
@@ -312,14 +318,16 @@ def save_checkpoint(
312
318
)
313
319
314
320
315
- def diff_model_train (env_config_path : str , model_config_path : str , model_def_path : str , num_gpus : int ) -> None :
321
+ def diff_model_train (env_config_path : str , model_config_path : str , model_def_path : str , num_gpus : int , amp : bool = True ) -> None :
316
322
"""
317
323
Main function to train a diffusion model.
318
324
319
325
Args:
320
326
env_config_path (str): Path to the environment configuration file.
321
327
model_config_path (str): Path to the model configuration file.
322
328
model_def_path (str): Path to the model definition file.
329
+ num_gpus (int): Number of GPUs to use for training.
330
+ amp (bool): Use automatic mixed precision training.
323
331
"""
324
332
args = load_config (env_config_path , model_config_path , model_def_path )
325
333
local_rank , world_size , device = initialize_distributed (num_gpus )
@@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
392
400
device ,
393
401
logger ,
394
402
local_rank ,
403
+ amp = amp ,
395
404
)
396
405
397
406
loss_torch = loss_torch .tolist ()
@@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
431
440
"--model_def" , type = str , default = "./configs/config_maisi.json" , help = "Path to model definition file"
432
441
)
433
442
parser .add_argument ("--num_gpus" , type = int , default = 1 , help = "Number of GPUs to use for training" )
443
+ parser .add_argument ("--no_amp" , dest = "amp" , action = "store_false" , help = "Disable automatic mixed precision training" )
434
444
435
445
args = parser .parse_args ()
436
- diff_model_train (args .env_config , args .model_config , args .model_def , args .num_gpus )
446
+ diff_model_train (args .env_config , args .model_config , args .model_def , args .num_gpus , args . amp )
0 commit comments