Skip to content

Commit 34746d7

Browse files
committed
Allow the callbacks to stop the training.
The training can be stopped by returning TrainableModule.STOP_TRAINING. If there are multiple callbacks, all callbacks are always called every epoch, even if one of them returned the value; the training is stopped only once the complete epoch processing has finished.
1 parent 433aec2 commit 34746d7

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

labs/npfl138/trainable_module.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def compute(self) -> torch.Tensor: # noqa: E301
3535

3636

3737
class CallbackProtocol(Protocol):
38-
def __call__(self, module: "TrainableModule", epoch: int, logs: Logs) -> None:
38+
def __call__(self, module: "TrainableModule", epoch: int, logs: Logs) -> Literal["stop_training"] | None:
3939
...
4040

4141

@@ -138,6 +138,8 @@ class TrainableModule(torch.nn.Module):
138138
from time import time as _time
139139
from tqdm import tqdm as _tqdm
140140

141+
STOP_TRAINING: Literal["stop_training"] = "stop_training"
142+
141143
def __init__(self, module: torch.nn.Module | None = None, device: torch.device | str | None = None):
142144
"""Initialize the module, optionally with an existing PyTorch module.
143145
@@ -295,17 +297,18 @@ def fit(
295297
- `epochs` is the number of epochs to train;
296298
- `dev` is an optional development dataset;
297299
- `callbacks` is a list of callbacks to call after every epoch, each implementing
298-
the CallbackProtocol with arguments `self`, `epoch`, and `logs` (note that the
299-
module is set to evaluation mode before calling each callback);
300+
the CallbackProtocol with arguments `self`, `epoch`, and `logs`, possibly returning
301+
`TrainableModule.STOP_TRAINING` to stop the training (note that the module is set
302+
to evaluation mode before calling each callback);
300303
- `log_graph` controls whether to log the model graph to TensorBoard;
301304
- `console` controls the console verbosity: 0 for silent, 1 for epoch logs, 2 for
302305
additional only-when-writing-to-console progress bar, 3 for persistent progress bar.
303306
The method returns a dictionary of logs from the training and optionally dev evaluation,
304307
and sets the model to evaluation mode after training.
305308
"""
306309
assert self.loss_tracker is not None, "The TrainableModule has not been configured, run configure first."
307-
logs, epochs = {}, self.epoch + epochs
308-
while self.epoch < epochs:
310+
logs, epochs, stop_training = {}, self.epoch + epochs, False
311+
while self.epoch < epochs and not stop_training:
309312
self.epoch += 1
310313
self.train()
311314
self.loss_tracker.reset()
@@ -328,7 +331,7 @@ def fit(
328331
if dev is not None:
329332
logs |= {f"dev_{k}": v for k, v in self.eval().evaluate(dev, log_as=None).items()}
330333
for callback in callbacks:
331-
callback(self.eval(), self.epoch, logs)
334+
stop_training = callback(self.eval(), self.epoch, logs) == self.STOP_TRAINING or stop_training
332335
self.log_metrics(logs, epochs, self._time() - start, console)
333336
self.eval()
334337
return logs

0 commit comments

Comments
 (0)