@@ -35,7 +35,7 @@ def compute(self) -> torch.Tensor: # noqa: E301
35
35
36
36
37
37
class CallbackProtocol (Protocol ):
38
- def __call__ (self , module : "TrainableModule" , epoch : int , logs : Logs ) -> None :
38
+ def __call__ (self , module : "TrainableModule" , epoch : int , logs : Logs ) -> Literal [ "stop_training" ] | None :
39
39
...
40
40
41
41
@@ -138,6 +138,8 @@ class TrainableModule(torch.nn.Module):
138
138
from time import time as _time
139
139
from tqdm import tqdm as _tqdm
140
140
141
+ STOP_TRAINING : Literal ["stop_training" ] = "stop_training"
142
+
141
143
def __init__ (self , module : torch .nn .Module | None = None , device : torch .device | str | None = None ):
142
144
"""Initialize the module, optionally with an existing PyTorch module.
143
145
@@ -295,17 +297,18 @@ def fit(
295
297
- `epochs` is the number of epochs to train;
296
298
- `dev` is an optional development dataset;
297
299
- `callbacks` is a list of callbacks to call after every epoch, each implementing
298
- the CallbackProtocol with arguments `self`, `epoch`, and `logs` (note that the
299
- module is set to evaluation mode before calling each callback);
300
+ the CallbackProtocol with arguments `self`, `epoch`, and `logs`, possibly returning
301
+ `TrainableModule.STOP_TRAINING` to stop the training (note that the module is set
302
+ to evaluation mode before calling each callback);
300
303
- `log_graph` controls whether to log the model graph to TensorBoard;
301
304
- `console` controls the console verbosity: 0 for silent, 1 for epoch logs, 2 for
302
305
additional only-when-writing-to-console progress bar, 3 for persistent progress bar.
303
306
The method returns a dictionary of logs from the training and optionally dev evaluation,
304
307
and sets the model to evaluation mode after training.
305
308
"""
306
309
assert self .loss_tracker is not None , "The TrainableModule has not been configured, run configure first."
307
- logs , epochs = {}, self .epoch + epochs
308
- while self .epoch < epochs :
310
+ logs , epochs , stop_training = {}, self .epoch + epochs , False
311
+ while self .epoch < epochs and not stop_training :
309
312
self .epoch += 1
310
313
self .train ()
311
314
self .loss_tracker .reset ()
@@ -328,7 +331,7 @@ def fit(
328
331
if dev is not None :
329
332
logs |= {f"dev_{ k } " : v for k , v in self .eval ().evaluate (dev , log_as = None ).items ()}
330
333
for callback in callbacks :
331
- callback (self .eval (), self .epoch , logs )
334
+ stop_training = callback (self .eval (), self .epoch , logs ) == self . STOP_TRAINING or stop_training
332
335
self .log_metrics (logs , epochs , self ._time () - start , console )
333
336
self .eval ()
334
337
return logs
0 commit comments