@@ -131,26 +131,25 @@ def __init__(self,
131
131
132
132
Args:
133
133
model_fn: Model function. Follows the signature:
134
- * Args:
135
- * `features`: This is the first item returned from the `input_fn`
136
- passed to `train`, `evaluate`, and `predict`. This should be a
137
- single `tf.Tensor` or `dict` of same.
138
- * `labels`: This is the second item returned from the `input_fn`
139
- passed to `train`, `evaluate`, and `predict`. This should be a
140
- single `tf.Tensor` or `dict` of same (for multi-head models). If
141
- mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will be
142
- passed. If the `model_fn`'s signature does not accept `mode`, the
143
- `model_fn` must still be able to handle `labels=None`.
144
- * `mode`: Optional. Specifies if this is training, evaluation or
145
- prediction. See `tf.estimator.ModeKeys`.
146
- * `params`: Optional `dict` of hyperparameters. Will receive what is
147
- passed to Estimator in `params` parameter. This allows to configure
148
- Estimators from hyper parameter tuning.
149
- * `config`: Optional `estimator.RunConfig` object. Will receive what
150
- is passed to Estimator as its `config` parameter, or a default
151
- value. Allows setting up things in your `model_fn` based on
152
- configuration such as `num_ps_replicas`, or `model_dir`.
153
- * Returns: `tf.estimator.EstimatorSpec`
134
+ `features` -- This is the first item returned from the `input_fn`
135
+ passed to `train`, `evaluate`, and `predict`. This should be a
136
+ single `tf.Tensor` or `dict` of same.
137
+ `labels` -- This is the second item returned from the `input_fn`
138
+ passed to `train`, `evaluate`, and `predict`. This should be a
139
+ single `tf.Tensor` or `dict` of same (for multi-head models). If
140
+ mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will be
141
+ passed. If the `model_fn`'s signature does not accept `mode`, the
142
+ `model_fn` must still be able to handle `labels=None`.
143
+ `mode` -- Optional. Specifies if this is training, evaluation or
144
+ prediction. See `tf.estimator.ModeKeys`.
145
+ `params` -- Optional `dict` of hyperparameters. Will receive what is
146
+ passed to Estimator in `params` parameter. This allows to configure
147
+ Estimators from hyper parameter tuning.
148
+ `config` -- Optional `estimator.RunConfig` object. Will receive what
149
+ is passed to Estimator as its `config` parameter, or a default
150
+ value. Allows setting up things in your `model_fn` based on
151
+ configuration such as `num_ps_replicas`, or `model_dir`.
152
+ Returns -- `tf.estimator.EstimatorSpec`
154
153
model_dir: Directory to save model parameters, graph and etc. This can
155
154
also be used to load checkpoints from the directory into an estimator to
156
155
continue training a previously saved model. If `PathLike` object, the
@@ -559,11 +558,11 @@ def predict(self,
559
558
(`tf.errors.OutOfRangeError` or `StopIteration`). See [Premade
560
559
Estimators](
561
560
https://tensorflow.org/guide/premade_estimators#create_input_functions)
562
- for more information. The function should construct and return one of
561
+ for more information. The function should construct and return one of
563
562
the following:
564
- * A `tf.data.Dataset` object: Outputs of `Dataset` object must have
563
+ `tf.data.Dataset` object -- Outputs of `Dataset` object must have
565
564
same constraints as below.
566
- * features: A `tf.Tensor` or a dictionary of string feature name to
565
+ features -- A `tf.Tensor` or a dictionary of string feature name to
567
566
`Tensor`. features are consumed by `model_fn`. They should satisfy
568
567
the expectation of `model_fn` from inputs. * A tuple, in which case
569
568
the first item is extracted as features.
0 commit comments