@@ -96,6 +96,21 @@ class Estimator(object):
96
96
constructor enforces this). Subclasses should use `model_fn` to configure
97
97
the base class, and may add methods implementing specialized functionality.
98
98
99
+ See [estimators](https://tensorflow.org/guide/estimators) for more
100
+ information.
101
+
102
+ To warm-start an `Estimator`:
103
+
104
+ ```python
105
+ estimator = tf.estimator.DNNClassifier(
106
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
107
+ hidden_units=[1024, 512, 256],
108
+ warm_start_from="/path/to/checkpoint/dir")
109
+ ```
110
+
111
+ For more details on warm-start configuration, see
112
+ `tf.estimator.WarmStartSettings`.
113
+
99
114
@compatibility(eager)
100
115
Calling methods of `Estimator` will work while eager execution is enabled.
101
116
However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
@@ -114,42 +129,29 @@ def __init__(self,
114
129
warm_start_from = None ):
115
130
"""Constructs an `Estimator` instance.
116
131
117
- See [estimators](https://tensorflow.org/guide/estimators) for more
118
- information.
119
-
120
- To warm-start an `Estimator`:
121
-
122
- ```python
123
- estimator = tf.estimator.DNNClassifier(
124
- feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
125
- hidden_units=[1024, 512, 256],
126
- warm_start_from="/path/to/checkpoint/dir")
127
- ```
128
132
129
- For more details on warm-start configuration, see
130
- `tf.estimator.WarmStartSettings`.
131
133
132
134
Args:
133
135
model_fn: Model function. Follows the signature:
134
- `features` -- This is the first item returned from the `input_fn`
136
+ * `features` -- This is the first item returned from the `input_fn`
135
137
passed to `train`, `evaluate`, and `predict`. This should be a
136
138
single `tf.Tensor` or `dict` of same.
137
- `labels` -- This is the second item returned from the `input_fn`
139
+ * `labels` -- This is the second item returned from the `input_fn`
138
140
passed to `train`, `evaluate`, and `predict`. This should be a
139
141
single `tf.Tensor` or `dict` of same (for multi-head models). If
140
142
mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will be
141
143
passed. If the `model_fn`'s signature does not accept `mode`, the
142
144
`model_fn` must still be able to handle `labels=None`.
143
- `mode` -- Optional. Specifies if this is training, evaluation or
145
+ * `mode` -- Optional. Specifies if this is training, evaluation or
144
146
prediction. See `tf.estimator.ModeKeys`.
145
147
`params` -- Optional `dict` of hyperparameters. Will receive what is
146
148
passed to Estimator in `params` parameter. This allows to configure
147
149
Estimators from hyper parameter tuning.
148
- `config` -- Optional `estimator.RunConfig` object. Will receive what
150
+ * `config` -- Optional `estimator.RunConfig` object. Will receive what
149
151
is passed to Estimator as its `config` parameter, or a default
150
152
value. Allows setting up things in your `model_fn` based on
151
153
configuration such as `num_ps_replicas`, or `model_dir`.
152
- Returns -- `tf.estimator.EstimatorSpec`
154
+ * Returns -- `tf.estimator.EstimatorSpec`
153
155
model_dir: Directory to save model parameters, graph and etc. This can
154
156
also be used to load checkpoints from the directory into an estimator to
155
157
continue training a previously saved model. If `PathLike` object, the
@@ -560,12 +562,12 @@ def predict(self,
560
562
https://tensorflow.org/guide/premade_estimators#create_input_functions)
561
563
for more information. The function should construct and return one of
562
564
the following:
563
- `tf.data.Dataset` object -- Outputs of `Dataset` object must have
564
- same constraints as below.
565
- features -- A `tf.Tensor` or a dictionary of string feature name to
566
- `Tensor`. features are consumed by `model_fn`. They should satisfy
567
- the expectation of `model_fn` from inputs. * A tuple, in which case
568
- the first item is extracted as features.
565
+ * `tf.data.Dataset` object -- Outputs of `Dataset` object must have
566
+ same constraints as below.
567
+ * features -- A `tf.Tensor` or a dictionary of string feature name to
568
+ `Tensor`. features are consumed by `model_fn`. They should satisfy
569
+ the expectation of `model_fn` from inputs. * A tuple, in which case
570
+ the first item is extracted as features.
569
571
predict_keys: list of `str`, name of the keys to predict. It is used if
570
572
the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
571
573
`predict_keys` is used then rest of the predictions will be filtered
0 commit comments