18
18
from unittest import mock
19
19
20
20
from google .protobuf import text_format
21
+ import numpy as np
21
22
import tensorflow as tf
22
23
23
24
from tensorboard .plugins .hparams import _keras
24
25
from tensorboard .plugins .hparams import metadata
25
26
from tensorboard .plugins .hparams import plugin_data_pb2
26
27
from tensorboard .plugins .hparams import summary_v2 as hp
27
28
28
- # Stay on Keras 2 for now: https://github.com/keras-team/keras/issues/18467.
29
- version_fn = getattr (tf .keras , "version" , None )
30
- if version_fn and version_fn ().startswith ("3." ):
31
- import tf_keras as keras # Keras 2
32
- else :
33
- keras = tf .keras # Keras 2
34
-
35
- tf .compat .v1 .enable_eager_execution ()
36
-
37
29
38
30
class CallbackTest (tf .test .TestCase ):
39
31
def setUp (self ):
@@ -46,12 +38,12 @@ def _initialize_model(self, writer):
46
38
"optimizer" : "adam" ,
47
39
HP_DENSE_NEURONS : 8 ,
48
40
}
49
- self .model = keras .models .Sequential (
41
+ self .model = tf . keras .models .Sequential (
50
42
[
51
- keras .layers .Dense (
43
+ tf . keras .layers .Dense (
52
44
self .hparams [HP_DENSE_NEURONS ], input_shape = (1 ,)
53
45
),
54
- keras .layers .Dense (1 , activation = "sigmoid" ),
46
+ tf . keras .layers .Dense (1 , activation = "sigmoid" ),
55
47
]
56
48
)
57
49
self .model .compile (loss = "mse" , optimizer = self .hparams ["optimizer" ])
@@ -69,7 +61,11 @@ def mock_time():
69
61
initial_time = mock_time .time
70
62
with mock .patch ("time.time" , mock_time ):
71
63
self ._initialize_model (writer = self .logdir )
72
- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
64
+ self .model .fit (
65
+ x = tf .constant ([(1 ,)]),
66
+ y = tf .constant ([(2 ,)]),
67
+ callbacks = [self .callback ],
68
+ )
73
69
final_time = mock_time .time
74
70
75
71
files = os .listdir (self .logdir )
@@ -142,7 +138,11 @@ def test_explicit_writer(self):
142
138
filename_suffix = ".magic" ,
143
139
)
144
140
self ._initialize_model (writer = writer )
145
- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
141
+ self .model .fit (
142
+ x = tf .constant ([(1 ,)]),
143
+ y = tf .constant ([(2 ,)]),
144
+ callbacks = [self .callback ],
145
+ )
146
146
147
147
files = os .listdir (self .logdir )
148
148
self .assertEqual (len (files ), 1 , files )
@@ -158,15 +158,27 @@ def test_non_eager_failure(self):
158
158
with self .assertRaisesRegex (
159
159
RuntimeError , "only supported in TensorFlow eager mode"
160
160
):
161
- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
161
+ self .model .fit (
162
+ x = np .ones ((10 , 10 )),
163
+ y = np .ones ((10 , 10 )),
164
+ callbacks = [self .callback ],
165
+ )
162
166
163
167
def test_reuse_failure (self ):
164
168
self ._initialize_model (writer = self .logdir )
165
- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
169
+ self .model .fit (
170
+ x = tf .constant ([(1 ,)]),
171
+ y = tf .constant ([(2 ,)]),
172
+ callbacks = [self .callback ],
173
+ )
166
174
with self .assertRaisesRegex (
167
175
RuntimeError , "cannot be reused across training sessions"
168
176
):
169
- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
177
+ self .model .fit (
178
+ x = tf .constant ([(1 ,)]),
179
+ y = tf .constant ([(2 ,)]),
180
+ callbacks = [self .callback ],
181
+ )
170
182
171
183
def test_invalid_writer (self ):
172
184
with self .assertRaisesRegex (
0 commit comments