From d173113030e741bd1e6325ed69416c5ce7ee8bde Mon Sep 17 00:00:00 2001 From: Yunzhou Liu Date: Thu, 3 Apr 2025 08:58:08 -0700 Subject: [PATCH 1/3] further removes one token count compute, and update all MLLOG tags --- .../nemo/callbacks.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/large_language_model_pretraining/nemo/callbacks.py b/large_language_model_pretraining/nemo/callbacks.py index ba813fe2a..efd84ba80 100644 --- a/large_language_model_pretraining/nemo/callbacks.py +++ b/large_language_model_pretraining/nemo/callbacks.py @@ -124,11 +124,11 @@ def log_metrics(self, metrics, step): assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). " def log_validation_loss(self, metrics, step): - consumed_tokens = (step - self.init_global_step) * self.gbs * self.seq_len + consumed_tokens = step * self.gbs loss = metrics[self.val_loss_key] - mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={'epoch_num': consumed_tokens}) + mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_tokens}) if not self.is_target_reached and loss <= self.target: self.is_target_reached = True @@ -183,14 +183,14 @@ def set_success_status(self): @rank_zero_only def on_train_epoch_start(self, trainer, pl_module): - mllogger.start(key=constants.EPOCH_START, metadata={'epoch_num': self.consumed_tokens(trainer)}) - mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)}) + mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) return super().on_train_epoch_start(trainer, pl_module) @rank_zero_only def on_train_epoch_end(self, trainer, pl_module): - mllogger.end(key=constants.EPOCH_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)}) + mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) return super().on_train_epoch_end(trainer, pl_module) def on_train_end(self, trainer, pl_module): @@ -202,12 +202,12 @@ def on_train_end(self, trainer, pl_module): @rank_zero_only def on_validation_start(self, trainer, pl_module): - mllogger.end(key=constants.BLOCK_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)}) - mllogger.start(key=constants.EVAL_START, metadata={'epoch_num': self.consumed_tokens(trainer)}) + mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) return super().on_validation_start(trainer, pl_module) def on_validation_end(self, trainer, pl_module): - mllogger.end(key=constants.EVAL_STOP, metadata={'epoch_num': self.consumed_tokens(trainer)}) + mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) for logger in trainer.loggers: if isinstance(logger, MetricsLogger): @@ -216,7 +216,7 @@ def on_validation_end(self, trainer, pl_module): self.set_success_status() if not trainer.should_stop: - mllogger.start(key=constants.BLOCK_START, metadata={"epoch_num": self.consumed_tokens(trainer)}) + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) return super().on_validation_end(trainer, pl_module) @@ -234,4 +234,4 @@ def on_train_start(self, trainer, pl_module): mllogger.event(key=key, value=value) mllogger.end(key=constants.INIT_STOP) - mllogger.start(key=constants.RUN_START) \ No newline at end of file + mllogger.start(key=constants.RUN_START) From 795a7617e1e26c0c0de300e434727f7998a5f9a4 Mon Sep 17 00:00:00 2001 From: Yunzhou Liu Date: Fri, 4 Apr 2025 12:44:44 -0700 Subject: [PATCH 2/3] updates the function names and train_samples --- .../nemo/callbacks.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/large_language_model_pretraining/nemo/callbacks.py b/large_language_model_pretraining/nemo/callbacks.py index efd84ba80..7d4ad2439 100644 --- a/large_language_model_pretraining/nemo/callbacks.py +++ b/large_language_model_pretraining/nemo/callbacks.py @@ -124,11 +124,11 @@ def log_metrics(self, metrics, step): assert step_time <= self.train_step_time_atol, f"Logged train step time ({step_time}) is slower than tolerable ({self.train_step_time_atol}). " def log_validation_loss(self, metrics, step): - consumed_tokens = step * self.gbs + consumed_samples = step * self.gbs loss = metrics[self.val_loss_key] - mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_tokens}) + mllogger.event(key=constants.EVAL_ACCURACY, value=loss, metadata={constants.SAMPLES_COUNT: consumed_samples}) if not self.is_target_reached and loss <= self.target: self.is_target_reached = True @@ -146,7 +146,7 @@ def version(self): return 1 ### MLPerf callbacks -def compute_consumed_mllog_tokens(trainer, init_global_step, global_batch_size, seq_length): +def compute_consumed_mllog_samples(trainer, init_global_step, global_batch_size, seq_length): consumed_samples = ( trainer.global_step * global_batch_size ) @@ -174,8 +174,8 @@ def __init__( self.status = constants.ABORTED self.configs = configs - def consumed_tokens(self, trainer): - return compute_consumed_mllog_tokens(trainer, self.init_global_step, self.gbs, self.seq_len) + def consumed_samples(self, trainer): + return compute_consumed_mllog_samples(trainer, self.init_global_step, self.gbs, self.seq_len) def set_success_status(self): self.status = constants.SUCCESS @@ -183,31 +183,31 @@ def set_success_status(self): @rank_zero_only def on_train_epoch_start(self, trainer, pl_module): - mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) - mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.start(key=constants.EPOCH_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) return super().on_train_epoch_start(trainer, pl_module) @rank_zero_only def on_train_epoch_end(self, trainer, pl_module): - mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.end(key=constants.EPOCH_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) return super().on_train_epoch_end(trainer, pl_module) def on_train_end(self, trainer, pl_module): # for every occurrences, run on all ranks to allow sync barrier() mllogger.end(key=constants.RUN_STOP, metadata={"status": self.status}) - mllogger.event(key="trained_samples", value=self.consumed_tokens(trainer)) + mllogger.event(key="train_samples", value=self.consumed_samples(trainer)) return super().on_train_end(trainer, pl_module) @rank_zero_only def on_validation_start(self, trainer, pl_module): - mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) - mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.end(key=constants.BLOCK_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) + mllogger.start(key=constants.EVAL_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) return super().on_validation_start(trainer, pl_module) def on_validation_end(self, trainer, pl_module): - mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.end(key=constants.EVAL_STOP, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) for logger in trainer.loggers: if isinstance(logger, MetricsLogger): @@ -216,7 +216,7 @@ def on_validation_end(self, trainer, pl_module): self.set_success_status() if not trainer.should_stop: - mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_tokens(trainer)}) + mllogger.start(key=constants.BLOCK_START, metadata={constants.SAMPLES_COUNT: self.consumed_samples(trainer)}) return super().on_validation_end(trainer, pl_module) From 29df3fc5395c01b90b9162d182c80be2fd31851e Mon Sep 17 00:00:00 2001 From: Yunzhou Liu Date: Fri, 4 Apr 2025 12:45:21 -0700 Subject: [PATCH 3/3] updates the decay schedule --- large_language_model_pretraining/nemo/pretrain_llama31.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/large_language_model_pretraining/nemo/pretrain_llama31.py b/large_language_model_pretraining/nemo/pretrain_llama31.py index 4defa6d5e..65df8a8a9 100644 --- a/large_language_model_pretraining/nemo/pretrain_llama31.py +++ b/large_language_model_pretraining/nemo/pretrain_llama31.py @@ -396,7 +396,7 @@ def get_parser() -> argparse.ArgumentParser: constants.OPT_END_LR: pretrain.optim.lr_scheduler.min_lr, constants.OPT_LR_WARMUP_STEPS: pretrain.optim.lr_scheduler.warmup_steps, constants.OPT_LR_DECAY_STEPS: pretrain.trainer.max_steps - pretrain.optim.lr_scheduler.warmup_steps, - constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmups", + constants.OPT_LR_DECAY_SCHEDULE: "cosine with linear warmup", } # Override config for MLPerf