|
135 | 135 | "\n",
|
136 | 136 | "def create_model():\n",
|
137 | 137 | " return tf.keras.models.Sequential([\n",
|
138 |
| - " tf.keras.layers.Flatten(input_shape=(28, 28), name='layers_flatten'),\n", |
| 138 | + " tf.keras.layers.Input(shape=(28, 28), name='layers_input'),\n", |
| 139 | + " tf.keras.layers.Flatten(name='layers_flatten'),\n", |
139 | 140 | " tf.keras.layers.Dense(512, activation='relu', name='layers_dense'),\n",
|
140 | 141 | " tf.keras.layers.Dropout(0.2, name='layers_dropout'),\n",
|
141 | 142 | " tf.keras.layers.Dense(10, activation='softmax', name='layers_dense_2')\n",
|
|
452 | 453 | " test_accuracy.result()*100))\n",
|
453 | 454 | "\n",
|
454 | 455 | " # Reset metrics every epoch\n",
|
455 |
| - " train_loss.reset_states()\n", |
456 |
| - " test_loss.reset_states()\n", |
457 |
| - " train_accuracy.reset_states()\n", |
458 |
| - " test_accuracy.reset_states()" |
| 456 | + " train_loss.reset_state()\n", |
| 457 | + " test_loss.reset_state()\n", |
| 458 | + " train_accuracy.reset_state()\n", |
| 459 | + " test_accuracy.reset_state()" |
459 | 460 | ]
|
460 | 461 | },
|
461 | 462 | {
|
|
0 commit comments