@@ -123,9 +123,7 @@ def on_train_begin(self, args, state, control, **kwargs):
123
123
)
124
124
self .mllogger .event (
125
125
key = constants .GLOBAL_BATCH_SIZE ,
126
- value = args .per_device_train_batch_size
127
- * args .gradient_accumulation_steps
128
- * os .getenv ("WORLD_SIZE" , 1 ),
126
+ value = self .gbs ,
129
127
)
130
128
self .mllogger .event (
131
129
key = constants .TRAIN_SAMPLES ,
@@ -168,25 +166,25 @@ def on_step_begin(
168
166
self .mllogger .event (
169
167
"train_loss" ,
170
168
value = state .log_history [- 1 ]["loss" ],
171
- metadata = {"step_num " : state .log_history [- 1 ]["step" ]},
169
+ metadata = {"samples_count " : state .log_history [- 1 ]["step" ]* self . gbs },
172
170
)
173
171
control .should_log = True
174
172
175
173
if state .global_step % (state .eval_steps ) == 0 and state .global_step > 0 :
176
174
self .mllogger .end (
177
175
constants .BLOCK_STOP ,
178
176
value = "" ,
179
- metadata = {"step_num " : state .log_history [- 1 ]["step" ]},
177
+ metadata = {"samples_count " : state .log_history [- 1 ]["step" ]* self . gbs },
180
178
)
181
179
self .mllogger .event (
182
180
constants .EVAL_ACCURACY ,
183
181
value = state .log_history [- 1 ]["eval_loss" ],
184
- metadata = {"samples_num " : state .log_history [- 1 ]["step" ]* self .gbs },
182
+ metadata = {"samples_count " : state .log_history [- 1 ]["step" ]* self .gbs },
185
183
)
186
184
self .mllogger .start (
187
185
constants .BLOCK_START ,
188
186
value = "" ,
189
- metadata = {"step_num " : state .log_history [- 1 ]["step" ]},
187
+ metadata = {"samples_count " : state .log_history [- 1 ]["step" ]},
190
188
)
191
189
control .should_log = True
192
190
eval_loss_list = [
@@ -198,7 +196,7 @@ def on_step_begin(
198
196
constants .RUN_STOP ,
199
197
value = eval_loss_list [- 1 ],
200
198
metadata = {
201
- "samples_num " : state .log_history [- 1 ]["step" ]* self .gbs ,
199
+ "samples_count " : state .log_history [- 1 ]["step" ]* self .gbs ,
202
200
"status" : "success" ,
203
201
},
204
202
)
@@ -207,7 +205,7 @@ def on_step_begin(
207
205
self .mllogger .end (
208
206
constants .RUN_STOP ,
209
207
value = eval_loss_list [- 1 ],
210
- metadata = {"step_num " : state .log_history [- 1 ]["step" ], "status" : "fail" },
208
+ metadata = {"samples_count " : state .log_history [- 1 ]["step" ]* self . gbs , "status" : "fail" },
211
209
)
212
210
213
211
return control
0 commit comments