Skip to content

Commit 14cb6fe

Browse files
committed
update PyTorch Lightning usage
1 parent 90924df commit 14cb6fe

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

ch13/ch13_part3_lightning.ipynb

+29-18
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@
166166
" super().__init__()\n",
167167
" \n",
168168
" # new PL attributes:\n",
169-
" \n",
170169
" if parse_version(torchmetrics_version) > parse_version(\"0.8\"):\n",
171170
" self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
172171
" self.valid_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
@@ -201,23 +200,35 @@
201200
" self.log(\"train_loss\", loss, prog_bar=True)\n",
202201
" return loss\n",
203202
"\n",
204-
" def training_epoch_end(self, outs):\n",
205-
" self.log(\"train_acc\", self.train_acc.compute())\n",
206-
" self.train_acc.reset()\n",
207-
" \n",
208-
" def validation_step(self, batch, batch_idx):\n",
209-
" x, y = batch\n",
210-
" logits = self(x)\n",
211-
" loss = nn.functional.cross_entropy(logits, y)\n",
212-
" preds = torch.argmax(logits, dim=1)\n",
213-
" self.valid_acc.update(preds, y)\n",
214-
" self.log(\"valid_loss\", loss, prog_bar=True)\n",
215-
" return loss\n",
216-
" \n",
217-
" def validation_epoch_end(self, outs):\n",
218-
" self.log(\"valid_acc\", self.valid_acc.compute(), prog_bar=True)\n",
219-
" self.valid_acc.reset()\n",
203+
" # Conditionally define epoch end methods based on PyTorch Lightning version\n",
204+
" if parse_version(pl.__version__) >= parse_version(\"2.0\"):\n",
205+
" # For PyTorch Lightning 2.0 and above\n",
206+
" def on_training_epoch_end(self):\n",
207+
" self.log(\"train_acc\", self.train_acc.compute())\n",
208+
" self.train_acc.reset()\n",
209+
"\n",
210+
" def on_validation_epoch_end(self):\n",
211+
" self.log(\"valid_acc\", self.valid_acc.compute())\n",
212+
" self.valid_acc.reset()\n",
213+
"\n",
214+
" def on_test_epoch_end(self):\n",
215+
" self.log(\"test_acc\", self.test_acc.compute())\n",
216+
" self.test_acc.reset()\n",
220217
"\n",
218+
" else:\n",
219+
" # For PyTorch Lightning < 2.0\n",
220+
" def training_epoch_end(self, outs):\n",
221+
" self.log(\"train_acc\", self.train_acc.compute())\n",
222+
" self.train_acc.reset()\n",
223+
"\n",
224+
" def validation_epoch_end(self, outs):\n",
225+
" self.log(\"valid_acc\", self.valid_acc.compute())\n",
226+
" self.valid_acc.reset()\n",
227+
"\n",
228+
" def test_epoch_end(self, outs):\n",
229+
" self.log(\"test_acc\", self.test_acc.compute())\n",
230+
" self.test_acc.reset()\n",
231+
" \n",
221232
" def test_step(self, batch, batch_idx):\n",
222233
" x, y = batch\n",
223234
" logits = self(x)\n",
@@ -1101,7 +1112,7 @@
11011112
"name": "python",
11021113
"nbconvert_exporter": "python",
11031114
"pygments_lexer": "ipython3",
1104-
"version": "3.10.6"
1115+
"version": "3.10.14"
11051116
}
11061117
},
11071118
"nbformat": 4,

0 commit comments

Comments
 (0)