Skip to content

[Llama 3.1] Updates MLLOG tags #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions large_language_model_pretraining/nemo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def log_metrics(self, metrics, step):
assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). "

def log_validation_loss(self, metrics, step):
consumed_tokens = (step - self.init_global_step) * self.gbs * self.seq_len
consumed_samples = step * self.gbs

loss = metrics[self.val_loss_key]

mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={'epoch_num': consumed_tokens})
mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_samples})

if not self.is_target_reached and loss <= self.target:
self.is_target_reached = True
Expand All @@ -146,7 +146,7 @@ def version(self):
return 1

### MLPerf callbacks
def compute_consumed_mllog_tokens(trainer, init_global_step, global_batch_size, seq_length):
def compute_consumed_mllog_samples(trainer, init_global_step, global_batch_size, seq_length):
consumed_samples = (
trainer.global_step * global_batch_size
)
Expand Down Expand Up @@ -174,40 +174,40 @@ def __init__(
self.status = constants.ABORTED
self.configs = configs

def consumed_tokens(self, trainer):
return compute_consumed_mllog_tokens(trainer, self.init_global_step, self.gbs, self.seq_len)
def consumed_samples(self, trainer):
return compute_consumed_mllog_samples(trainer, self.init_global_step, self.gbs, self.seq_len)

def set_success_status(self):
self.status = constants.SUCCESS
self.is_target_reached = True

@rank_zero_only
def on_train_epoch_start(self, trainer, pl_module):
mllogger.start(key=constants.EPOCH_START, metadata={'epoch_num': self.consumed_tokens(trainer)})
mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)})
mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})

return super().on_train_epoch_start(trainer, pl_module)

@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
mllogger.end(key=constants.EPOCH_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
return super().on_train_epoch_end(trainer, pl_module)

def on_train_end(self, trainer, pl_module):
# for every occurrences, run on all ranks to allow sync
barrier()
mllogger.end(key=constants.RUN_STOP, metadata={"status": self.status})
mllogger.event(key="trained_samples", value=self.consumed_tokens(trainer))
mllogger.event(key="train_samples", value=self.consumed_samples(trainer))
return super().on_train_end(trainer, pl_module)

@rank_zero_only
def on_validation_start(self, trainer, pl_module):
mllogger.end(key=constants.BLOCK_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
mllogger.start(key=constants.EVAL_START, metadata={'epoch_num': self.consumed_tokens(trainer)})
mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
return super().on_validation_start(trainer, pl_module)

def on_validation_end(self, trainer, pl_module):
mllogger.end(key=constants.EVAL_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})

for logger in trainer.loggers:
if isinstance(logger, MetricsLogger):
Expand All @@ -216,7 +216,7 @@ def on_validation_end(self, trainer, pl_module):
self.set_success_status()

if not trainer.should_stop:
mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)})
mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})

return super().on_validation_end(trainer, pl_module)

Expand All @@ -234,4 +234,4 @@ def on_train_start(self, trainer, pl_module):
mllogger.event(key=key, value=value)

mllogger.end(key=constants.INIT_STOP)
mllogger.start(key=constants.RUN_START)
mllogger.start(key=constants.RUN_START)
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def get_parser() -> argparse.ArgumentParser:
constants.OPT_END_LR: pretrain.optim.lr_scheduler.min_lr,
constants.OPT_LR_WARMUP_STEPS: pretrain.optim.lr_scheduler.warmup_steps,
constants.OPT_LR_DECAY_STEPS: pretrain.trainer.max_steps - pretrain.optim.lr_scheduler.warmup_steps,
constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmups",
constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmup",
}

# Override config for MLPerf
Expand Down