@@ -124,11 +124,11 @@ def log_metrics(self, metrics, step):
124
124
assert step_time <= self .train_step_time_atol , f"Logged train step time ({ step_time } ) is slower than tolerable ({ self .train_step_time_atol } ). "
125
125
126
126
def log_validation_loss (self , metrics , step ):
127
- consumed_tokens = ( step - self . init_global_step ) * self .gbs * self . seq_len
127
+ consumed_samples = step * self .gbs
128
128
129
129
loss = metrics [self .val_loss_key ]
130
130
131
- mllogger .event (key = constants .EVAL_ACCURACY , value = loss , metadata = {'epoch_num' : consumed_tokens })
131
+ mllogger .event (key = constants .EVAL_ACCURACY , value = loss , metadata = {constants . SAMPLES_COUNT : consumed_samples })
132
132
133
133
if not self .is_target_reached and loss <= self .target :
134
134
self .is_target_reached = True
@@ -146,7 +146,7 @@ def version(self):
146
146
return 1
147
147
148
148
### MLPerf callbacks
149
- def compute_consumed_mllog_tokens (trainer , init_global_step , global_batch_size , seq_length ):
149
+ def compute_consumed_mllog_samples (trainer , init_global_step , global_batch_size , seq_length ):
150
150
consumed_samples = (
151
151
trainer .global_step * global_batch_size
152
152
)
@@ -174,40 +174,40 @@ def __init__(
174
174
self .status = constants .ABORTED
175
175
self .configs = configs
176
176
177
- def consumed_tokens (self , trainer ):
178
- return compute_consumed_mllog_tokens (trainer , self .init_global_step , self .gbs , self .seq_len )
177
+ def consumed_samples (self , trainer ):
178
+ return compute_consumed_mllog_samples (trainer , self .init_global_step , self .gbs , self .seq_len )
179
179
180
180
def set_success_status (self ):
181
181
self .status = constants .SUCCESS
182
182
self .is_target_reached = True
183
183
184
184
@rank_zero_only
185
185
def on_train_epoch_start (self , trainer , pl_module ):
186
- mllogger .start (key = constants .EPOCH_START , metadata = {'epoch_num' : self .consumed_tokens (trainer )})
187
- mllogger .start (key = constants .BLOCK_START , metadata = {"epoch_num" : self .consumed_tokens (trainer )})
186
+ mllogger .start (key = constants .EPOCH_START , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
187
+ mllogger .start (key = constants .BLOCK_START , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
188
188
189
189
return super ().on_train_epoch_start (trainer , pl_module )
190
190
191
191
@rank_zero_only
192
192
def on_train_epoch_end (self , trainer , pl_module ):
193
- mllogger .end (key = constants .EPOCH_STOP , metadata = {'epoch_num' : self .consumed_tokens (trainer )})
193
+ mllogger .end (key = constants .EPOCH_STOP , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
194
194
return super ().on_train_epoch_end (trainer , pl_module )
195
195
196
196
def on_train_end (self , trainer , pl_module ):
197
197
# for every occurrences, run on all ranks to allow sync
198
198
barrier ()
199
199
mllogger .end (key = constants .RUN_STOP , metadata = {"status" : self .status })
200
- mllogger .event (key = "trained_samples " , value = self .consumed_tokens (trainer ))
200
+ mllogger .event (key = "train_samples " , value = self .consumed_samples (trainer ))
201
201
return super ().on_train_end (trainer , pl_module )
202
202
203
203
@rank_zero_only
204
204
def on_validation_start (self , trainer , pl_module ):
205
- mllogger .end (key = constants .BLOCK_STOP , metadata = {'epoch_num' : self .consumed_tokens (trainer )})
206
- mllogger .start (key = constants .EVAL_START , metadata = {'epoch_num' : self .consumed_tokens (trainer )})
205
+ mllogger .end (key = constants .BLOCK_STOP , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
206
+ mllogger .start (key = constants .EVAL_START , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
207
207
return super ().on_validation_start (trainer , pl_module )
208
208
209
209
def on_validation_end (self , trainer , pl_module ):
210
- mllogger .end (key = constants .EVAL_STOP , metadata = {'epoch_num' : self .consumed_tokens (trainer )})
210
+ mllogger .end (key = constants .EVAL_STOP , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
211
211
212
212
for logger in trainer .loggers :
213
213
if isinstance (logger , MetricsLogger ):
@@ -216,7 +216,7 @@ def on_validation_end(self, trainer, pl_module):
216
216
self .set_success_status ()
217
217
218
218
if not trainer .should_stop :
219
- mllogger .start (key = constants .BLOCK_START , metadata = {"epoch_num" : self .consumed_tokens (trainer )})
219
+ mllogger .start (key = constants .BLOCK_START , metadata = {constants . SAMPLES_COUNT : self .consumed_samples (trainer )})
220
220
221
221
return super ().on_validation_end (trainer , pl_module )
222
222
@@ -234,4 +234,4 @@ def on_train_start(self, trainer, pl_module):
234
234
mllogger .event (key = key , value = value )
235
235
236
236
mllogger .end (key = constants .INIT_STOP )
237
- mllogger .start (key = constants .RUN_START )
237
+ mllogger .start (key = constants .RUN_START )
0 commit comments