|
166 | 166 | " super().__init__()\n",
|
167 | 167 | " \n",
|
168 | 168 | " # new PL attributes:\n",
|
169 |
| - " \n", |
170 | 169 | " if parse_version(torchmetrics_version) > parse_version(\"0.8\"):\n",
|
171 | 170 | " self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
|
172 | 171 | " self.valid_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
|
|
201 | 200 | " self.log(\"train_loss\", loss, prog_bar=True)\n",
|
202 | 201 | " return loss\n",
|
203 | 202 | "\n",
|
204 |
| - " def training_epoch_end(self, outs):\n", |
205 |
| - " self.log(\"train_acc\", self.train_acc.compute())\n", |
206 |
| - " self.train_acc.reset()\n", |
207 |
| - " \n", |
208 |
| - " def validation_step(self, batch, batch_idx):\n", |
209 |
| - " x, y = batch\n", |
210 |
| - " logits = self(x)\n", |
211 |
| - " loss = nn.functional.cross_entropy(logits, y)\n", |
212 |
| - " preds = torch.argmax(logits, dim=1)\n", |
213 |
| - " self.valid_acc.update(preds, y)\n", |
214 |
| - " self.log(\"valid_loss\", loss, prog_bar=True)\n", |
215 |
| - " return loss\n", |
216 |
| - " \n", |
217 |
| - " def validation_epoch_end(self, outs):\n", |
218 |
| - " self.log(\"valid_acc\", self.valid_acc.compute(), prog_bar=True)\n", |
219 |
| - " self.valid_acc.reset()\n", |
| 203 | + " # Conditionally define epoch end methods based on PyTorch Lightning version\n", |
| 204 | + " if parse_version(pl.__version__) >= parse_version(\"2.0\"):\n", |
| 205 | + " # For PyTorch Lightning 2.0 and above\n", |
| 206 | + " def on_training_epoch_end(self):\n", |
| 207 | + " self.log(\"train_acc\", self.train_acc.compute())\n", |
| 208 | + " self.train_acc.reset()\n", |
| 209 | + "\n", |
| 210 | + " def on_validation_epoch_end(self):\n", |
| 211 | + " self.log(\"valid_acc\", self.valid_acc.compute())\n", |
| 212 | + " self.valid_acc.reset()\n", |
| 213 | + "\n", |
| 214 | + " def on_test_epoch_end(self):\n", |
| 215 | + " self.log(\"test_acc\", self.test_acc.compute())\n", |
| 216 | + " self.test_acc.reset()\n", |
220 | 217 | "\n",
|
| 218 | + " else:\n", |
| 219 | + " # For PyTorch Lightning < 2.0\n", |
| 220 | + " def training_epoch_end(self, outs):\n", |
| 221 | + " self.log(\"train_acc\", self.train_acc.compute())\n", |
| 222 | + " self.train_acc.reset()\n", |
| 223 | + "\n", |
| 224 | + " def validation_epoch_end(self, outs):\n", |
| 225 | + " self.log(\"valid_acc\", self.valid_acc.compute())\n", |
| 226 | + " self.valid_acc.reset()\n", |
| 227 | + "\n", |
| 228 | + " def test_epoch_end(self, outs):\n", |
| 229 | + " self.log(\"test_acc\", self.test_acc.compute())\n", |
| 230 | + " self.test_acc.reset()\n", |
| 231 | + " \n", |
221 | 232 | " def test_step(self, batch, batch_idx):\n",
|
222 | 233 | " x, y = batch\n",
|
223 | 234 | " logits = self(x)\n",
|
|
1101 | 1112 | "name": "python",
|
1102 | 1113 | "nbconvert_exporter": "python",
|
1103 | 1114 | "pygments_lexer": "ipython3",
|
1104 |
| - "version": "3.10.6" |
| 1115 | + "version": "3.10.14" |
1105 | 1116 | }
|
1106 | 1117 | },
|
1107 | 1118 | "nbformat": 4,
|
|
0 commit comments