23
23
from absl .testing import parameterized
24
24
import numpy as np
25
25
import tensorflow as tf
26
+ from tensorflow .python .data .ops import dataset_ops
26
27
from tensorflow .python .distribute import combinations
27
28
from tensorflow .python .distribute import strategy_combinations
29
+ from tensorflow .python .training import basic_session_run_hooks
30
+ from tensorflow .python .training import training_util
31
+ from tensorflow_estimator .python .estimator import estimator as estimator_lib
32
+ from tensorflow_estimator .python .estimator import model_fn as model_fn_lib
28
33
from tensorflow_estimator .python .estimator import run_config
29
34
from tensorflow_estimator .python .estimator import training
30
35
from tensorflow_estimator .python .estimator .canned import dnn_linear_combined
@@ -50,6 +55,61 @@ def input_fn():
50
55
51
56
return input_fn
52
57
58
+ @combinations .generate (
59
+ combinations .combine (
60
+ mode = ['graph' ],
61
+ distribution = [
62
+ strategy_combinations .one_device_strategy ,
63
+ strategy_combinations .mirrored_strategy_with_gpu_and_cpu ,
64
+ strategy_combinations .mirrored_strategy_with_two_gpus
65
+ ],
66
+ use_train_and_evaluate = [True , False ]))
67
+ def test_estimator_with_strategy_hooks (self , distribution ,
68
+ use_train_and_evaluate ):
69
+ config = run_config .RunConfig (eval_distribute = distribution )
70
+
71
+ def _input_map_fn (tensor ):
72
+ return {'feature' : tensor }, tensor
73
+
74
+ def input_fn ():
75
+ return dataset_ops .Dataset .from_tensors (
76
+ [1. ]).repeat (10 ).batch (5 ).map (_input_map_fn )
77
+
78
+ def model_fn (features , labels , mode ):
79
+ del features , labels
80
+ global_step = training_util .get_global_step ()
81
+ if mode == model_fn_lib .ModeKeys .TRAIN :
82
+ train_hook1 = basic_session_run_hooks .StepCounterHook (
83
+ every_n_steps = 1 , output_dir = self .get_temp_dir ())
84
+ train_hook2 = tf .compat .v1 .test .mock .MagicMock (
85
+ wraps = tf .compat .v1 .train .SessionRunHook (),
86
+ spec = tf .compat .v1 .train .SessionRunHook )
87
+ return model_fn_lib .EstimatorSpec (
88
+ mode ,
89
+ loss = tf .constant (1. ),
90
+ train_op = global_step .assign_add (1 ),
91
+ training_hooks = [train_hook1 , train_hook2 ])
92
+ if mode == model_fn_lib .ModeKeys .EVAL :
93
+ eval_hook1 = basic_session_run_hooks .StepCounterHook (
94
+ every_n_steps = 1 , output_dir = self .get_temp_dir ())
95
+ eval_hook2 = tf .compat .v1 .test .mock .MagicMock (
96
+ wraps = tf .compat .v1 .train .SessionRunHook (),
97
+ spec = tf .compat .v1 .train .SessionRunHook )
98
+ return model_fn_lib .EstimatorSpec (
99
+ mode = mode ,
100
+ loss = tf .constant (1. ),
101
+ evaluation_hooks = [eval_hook1 , eval_hook2 ])
102
+ num_steps = 10
103
+ estimator = estimator_lib .EstimatorV2 (
104
+ model_fn = model_fn , model_dir = self .get_temp_dir (), config = config )
105
+ if use_train_and_evaluate :
106
+ training .train_and_evaluate (
107
+ estimator , training .TrainSpec (input_fn , max_steps = num_steps ),
108
+ training .EvalSpec (input_fn ))
109
+ else :
110
+ estimator .train (input_fn , steps = num_steps )
111
+ estimator .evaluate (input_fn , steps = num_steps )
112
+
53
113
@combinations .generate (
54
114
combinations .combine (
55
115
mode = ['graph' ],
0 commit comments