Skip to content

Commit 58ec387

Browse files
authored
Fix default ranges in saved model converter (#671)
1 parent 95199a7 commit 58ec387

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(
150150
quant_specs.input_ranges.push_back({llvm::None, llvm::None});
151151
}
152152
if (!default_ranges.is_none()) {
153+
// When there are no Quantize nodes in the graph then in the PrepareQuantize
154+
// pass the variables `eager_quantize` and subsequently `infer_tensor_range`
155+
// are set to false:
156+
// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc#L360-L366
157+
// This means that the PrepareQuantize pass does *not* infer the int8
158+
// range of weight tensors. The DefaultQuantParamsPass will then set the
159+
// quantization stats of those weight tensors to this per-tensor default
160+
// range instead of proper per-channel ranges.
161+
// The tflite/tfmicro kernels can handle per-tensor weight quantization, but
162+
// for some private passes we desire per-channel quantization.
163+
// To make `infer_tensor_range` become true we simply set
164+
// `post_training_quantization` to true here.
165+
// Alternatively to this solution, we could set
166+
// `quant_specs.target_func = "serving_default";`
167+
// and set the `input_ranges` to some fixed values. In that case, the
168+
// PrepareQuantize pass would first insert Quantization ops at the input
169+
// here:
170+
// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc#L172
171+
// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc#L202
172+
quant_specs.post_training_quantization = true;
173+
153174
quant_specs.default_ranges =
154175
default_ranges.cast<std::pair<double, double>>();
155176
}

0 commit comments

Comments
 (0)