Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 131f54a

Browse files
w-xinyitensorflower-gardener
authored andcommitted
fix eval hooks with distribute strategy
PiperOrigin-RevId: 314226004
1 parent 079a14e commit 131f54a

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

tensorflow_estimator/python/estimator/distribute_strategy_estimator_integration_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
from absl.testing import parameterized
2424
import numpy as np
2525
import tensorflow as tf
26+
from tensorflow.python.data.ops import dataset_ops
2627
from tensorflow.python.distribute import combinations
2728
from tensorflow.python.distribute import strategy_combinations
29+
from tensorflow.python.training import basic_session_run_hooks
30+
from tensorflow.python.training import training_util
31+
from tensorflow_estimator.python.estimator import estimator as estimator_lib
32+
from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
2833
from tensorflow_estimator.python.estimator import run_config
2934
from tensorflow_estimator.python.estimator import training
3035
from tensorflow_estimator.python.estimator.canned import dnn_linear_combined
@@ -50,6 +55,61 @@ def input_fn():
5055

5156
return input_fn
5257

58+
@combinations.generate(
59+
combinations.combine(
60+
mode=['graph'],
61+
distribution=[
62+
strategy_combinations.one_device_strategy,
63+
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
64+
strategy_combinations.mirrored_strategy_with_two_gpus
65+
],
66+
use_train_and_evaluate=[True, False]))
67+
def test_estimator_with_strategy_hooks(self, distribution,
68+
use_train_and_evaluate):
69+
config = run_config.RunConfig(eval_distribute=distribution)
70+
71+
def _input_map_fn(tensor):
72+
return {'feature': tensor}, tensor
73+
74+
def input_fn():
75+
return dataset_ops.Dataset.from_tensors(
76+
[1.]).repeat(10).batch(5).map(_input_map_fn)
77+
78+
def model_fn(features, labels, mode):
79+
del features, labels
80+
global_step = training_util.get_global_step()
81+
if mode == model_fn_lib.ModeKeys.TRAIN:
82+
train_hook1 = basic_session_run_hooks.StepCounterHook(
83+
every_n_steps=1, output_dir=self.get_temp_dir())
84+
train_hook2 = tf.compat.v1.test.mock.MagicMock(
85+
wraps=tf.compat.v1.train.SessionRunHook(),
86+
spec=tf.compat.v1.train.SessionRunHook)
87+
return model_fn_lib.EstimatorSpec(
88+
mode,
89+
loss=tf.constant(1.),
90+
train_op=global_step.assign_add(1),
91+
training_hooks=[train_hook1, train_hook2])
92+
if mode == model_fn_lib.ModeKeys.EVAL:
93+
eval_hook1 = basic_session_run_hooks.StepCounterHook(
94+
every_n_steps=1, output_dir=self.get_temp_dir())
95+
eval_hook2 = tf.compat.v1.test.mock.MagicMock(
96+
wraps=tf.compat.v1.train.SessionRunHook(),
97+
spec=tf.compat.v1.train.SessionRunHook)
98+
return model_fn_lib.EstimatorSpec(
99+
mode=mode,
100+
loss=tf.constant(1.),
101+
evaluation_hooks=[eval_hook1, eval_hook2])
102+
num_steps = 10
103+
estimator = estimator_lib.EstimatorV2(
104+
model_fn=model_fn, model_dir=self.get_temp_dir(), config=config)
105+
if use_train_and_evaluate:
106+
training.train_and_evaluate(
107+
estimator, training.TrainSpec(input_fn, max_steps=num_steps),
108+
training.EvalSpec(input_fn))
109+
else:
110+
estimator.train(input_fn, steps=num_steps)
111+
estimator.evaluate(input_fn, steps=num_steps)
112+
53113
@combinations.generate(
54114
combinations.combine(
55115
mode=['graph'],

tensorflow_estimator/python/estimator/estimator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,8 +1619,16 @@ def step_fn(ctx, inputs):
16191619

16201620
scaffold = _combine_distributed_scaffold(grouped_estimator_spec.scaffold,
16211621
self._eval_distribution)
1622-
evaluation_hooks = self._eval_distribution.experimental_local_results(
1623-
grouped_estimator_spec.evaluation_hooks)[0]
1622+
1623+
def get_hooks_from_the_first_device(per_device_hooks):
1624+
return [
1625+
self._eval_distribution.experimental_local_results(per_device_hook)[0]
1626+
for per_device_hook in per_device_hooks
1627+
]
1628+
1629+
evaluation_hooks = get_hooks_from_the_first_device(
1630+
grouped_estimator_spec.evaluation_hooks)
1631+
16241632
return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)
16251633

16261634
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,

0 commit comments

Comments
 (0)