Skip to content

Commit ece3d15

Browse files
authored
[Llama 3.1] Updates MLLOG tags (#790)
* further removes one token count compute, and update all MLLOG tags * updates the function names and train_samples * updates the decay schedule
1 parent 637c82f commit ece3d15

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

large_language_model_pretraining/nemo/callbacks.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def log_metrics(self, metrics, step):
124124
assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). "
125125

126126
def log_validation_loss(self, metrics, step):
127-
consumed_tokens = (step - self.init_global_step) * self.gbs * self.seq_len
127+
consumed_samples = step * self.gbs
128128

129129
loss = metrics[self.val_loss_key]
130130

131-
mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={'epoch_num': consumed_tokens})
131+
mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_samples})
132132

133133
if not self.is_target_reached and loss <= self.target:
134134
self.is_target_reached = True
@@ -146,7 +146,7 @@ def version(self):
146146
return 1
147147

148148
### MLPerf callbacks
149-
def compute_consumed_mllog_tokens(trainer, init_global_step, global_batch_size, seq_length):
149+
def compute_consumed_mllog_samples(trainer, init_global_step, global_batch_size, seq_length):
150150
consumed_samples = (
151151
trainer.global_step * global_batch_size
152152
)
@@ -174,40 +174,40 @@ def __init__(
174174
self.status = constants.ABORTED
175175
self.configs = configs
176176

177-
def consumed_tokens(self, trainer):
178-
return compute_consumed_mllog_tokens(trainer, self.init_global_step, self.gbs, self.seq_len)
177+
def consumed_samples(self, trainer):
178+
return compute_consumed_mllog_samples(trainer, self.init_global_step, self.gbs, self.seq_len)
179179

180180
def set_success_status(self):
181181
self.status = constants.SUCCESS
182182
self.is_target_reached = True
183183

184184
@rank_zero_only
185185
def on_train_epoch_start(self, trainer, pl_module):
186-
mllogger.start(key=constants.EPOCH_START, metadata={'epoch_num': self.consumed_tokens(trainer)})
187-
mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)})
186+
mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
187+
mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
188188

189189
return super().on_train_epoch_start(trainer, pl_module)
190190

191191
@rank_zero_only
192192
def on_train_epoch_end(self, trainer, pl_module):
193-
mllogger.end(key=constants.EPOCH_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
193+
mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
194194
return super().on_train_epoch_end(trainer, pl_module)
195195

196196
def on_train_end(self, trainer, pl_module):
197197
# for every occurrences, run on all ranks to allow sync
198198
barrier()
199199
mllogger.end(key=constants.RUN_STOP, metadata={"status": self.status})
200-
mllogger.event(key="trained_samples", value=self.consumed_tokens(trainer))
200+
mllogger.event(key="train_samples", value=self.consumed_samples(trainer))
201201
return super().on_train_end(trainer, pl_module)
202202

203203
@rank_zero_only
204204
def on_validation_start(self, trainer, pl_module):
205-
mllogger.end(key=constants.BLOCK_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
206-
mllogger.start(key=constants.EVAL_START, metadata={'epoch_num': self.consumed_tokens(trainer)})
205+
mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
206+
mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
207207
return super().on_validation_start(trainer, pl_module)
208208

209209
def on_validation_end(self, trainer, pl_module):
210-
mllogger.end(key=constants.EVAL_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)})
210+
mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
211211

212212
for logger in trainer.loggers:
213213
if isinstance(logger, MetricsLogger):
@@ -216,7 +216,7 @@ def on_validation_end(self, trainer, pl_module):
216216
self.set_success_status()
217217

218218
if not trainer.should_stop:
219-
mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)})
219+
mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)})
220220

221221
return super().on_validation_end(trainer, pl_module)
222222

@@ -234,4 +234,4 @@ def on_train_start(self, trainer, pl_module):
234234
mllogger.event(key=key, value=value)
235235

236236
mllogger.end(key=constants.INIT_STOP)
237-
mllogger.start(key=constants.RUN_START)
237+
mllogger.start(key=constants.RUN_START)

large_language_model_pretraining/nemo/pretrain_llama31.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_parser() -> argparse.ArgumentParser:
396396
constants.OPT_END_LR: pretrain.optim.lr_scheduler.min_lr,
397397
constants.OPT_LR_WARMUP_STEPS: pretrain.optim.lr_scheduler.warmup_steps,
398398
constants.OPT_LR_DECAY_STEPS: pretrain.trainer.max_steps - pretrain.optim.lr_scheduler.warmup_steps,
399-
constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmups",
399+
constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmup",
400400
}
401401

402402
# Override config for MLPerf

0 commit comments

Comments
 (0)