Skip to content

Commit 61b513b

Browse files
committed
Signed-off-by: YunLiu <[email protected]>
1 parent b2d1c0a commit 61b513b

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

generation/maisi/scripts/diff_model_train.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def train_one_epoch(
193193
device: torch.device,
194194
logger: logging.Logger,
195195
local_rank: int,
196+
amp: bool = True,
196197
) -> torch.Tensor:
197198
"""
198199
Train the model for one epoch.
@@ -212,6 +213,7 @@ def train_one_epoch(
212213
device (torch.device): Device to use for training.
213214
logger (logging.Logger): Logger for logging information.
214215
local_rank (int): Local rank for distributed training.
216+
amp (bool): Use automatic mixed precision training.
215217
216218
Returns:
217219
torch.Tensor: Training loss for the epoch.
@@ -237,7 +239,7 @@ def train_one_epoch(
237239

238240
optimizer.zero_grad(set_to_none=True)
239241

240-
with autocast("cuda", enabled=True):
242+
with autocast("cuda", enabled=amp):
241243
noise = torch.randn(
242244
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
243245
)
@@ -256,9 +258,13 @@ def train_one_epoch(
256258

257259
loss = loss_pt(noise_pred.float(), noise.float())
258260

259-
scaler.scale(loss).backward()
260-
scaler.step(optimizer)
261-
scaler.update()
261+
if amp:
262+
scaler.scale(loss).backward()
263+
scaler.step(optimizer)
264+
scaler.update()
265+
else:
266+
loss.backward()
267+
optimizer.step()
262268

263269
lr_scheduler.step()
264270

@@ -312,14 +318,16 @@ def save_checkpoint(
312318
)
313319

314320

315-
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
321+
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True) -> None:
316322
"""
317323
Main function to train a diffusion model.
318324
319325
Args:
320326
env_config_path (str): Path to the environment configuration file.
321327
model_config_path (str): Path to the model configuration file.
322328
model_def_path (str): Path to the model definition file.
329+
num_gpus (int): Number of GPUs to use for training.
330+
amp (bool): Use automatic mixed precision training.
323331
"""
324332
args = load_config(env_config_path, model_config_path, model_def_path)
325333
local_rank, world_size, device = initialize_distributed(num_gpus)
@@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
392400
device,
393401
logger,
394402
local_rank,
403+
amp=amp,
395404
)
396405

397406
loss_torch = loss_torch.tolist()
@@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
431440
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
432441
)
433442
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
443+
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
434444

435445
args = parser.parse_args()
436-
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
446+
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)

0 commit comments

Comments
 (0)