Skip to content

Commit c65ee7f

Browse files
committed
Support load models with SparseTensor in signature
1 parent dff11fe commit c65ee7f

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

simple_tensorflow_serving/tensorflow_inference_service.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def inference(self, json_data):
217217
input_op_name = input_item[0]
218218
# Example: "Placeholder_0"
219219
input_tensor_name = input_item[1].name
220+
220221
# Example: {"Placeholder_0": [[1.0], [2.0]], "Placeholder_1:0": [[10, 10, 10, 8, 6, 1, 8, 9, 1], [6, 2, 1, 1, 1, 1, 7, 1, 1]]}
221222
feed_dict_map[input_tensor_name] = input_data[input_op_name]
222223

@@ -225,12 +226,26 @@ def inference(self, json_data):
225226
output_tensor_names = []
226227
output_op_names = []
227228
for output_item in self.model_graph_signature.outputs.items():
228-
# Example: "keys"
229-
output_op_name = output_item[0]
230-
output_op_names.append(output_op_name)
231-
# Example: "Identity:0"
232-
output_tensor_name = output_item[1].name
233-
output_tensor_names.append(output_tensor_name)
229+
#import ipdb;ipdb.set_trace()
230+
231+
if output_item[1].name != "":
232+
# Example: "keys"
233+
output_op_name = output_item[0]
234+
output_op_names.append(output_op_name)
235+
# Example: "Identity:0"
236+
output_tensor_name = output_item[1].name
237+
output_tensor_names.append(output_tensor_name)
238+
elif output_item[1].coo_sparse != None:
239+
# For SparseTensor op, Example: values_tensor_name: "CTCBeamSearchDecoder_1:1", indices_tensor_name: "CTCBeamSearchDecoder_1:0", dense_shape_tensor_name: "CTCBeamSearchDecoder_1:2"
240+
values_tensor_name = output_item[1].coo_sparse.values_tensor_name
241+
indices_tensor_name = output_item[1].coo_sparse.indices_tensor_name
242+
dense_shape_tensor_name = output_item[1].coo_sparse.dense_shape_tensor_name
243+
output_op_names.append("{}_{}".format(output_item[0], "values"))
244+
output_op_names.append("{}_{}".format(output_item[0], "indices"))
245+
output_op_names.append("{}_{}".format(output_item[0], "shape"))
246+
output_tensor_names.append(values_tensor_name)
247+
output_tensor_names.append(indices_tensor_name)
248+
output_tensor_names.append(dense_shape_tensor_name)
234249

235250
# 3. Inference with Session run
236251
if self.verbose:

0 commit comments

Comments
 (0)