Skip to content

Commit 43e197b

Browse files
authored
Add quantized INT8 detection-postprocess op pass to the converter (#777)
1 parent 3fc2658 commit 43e197b

File tree

6 files changed

+245
-0
lines changed

6 files changed

+245
-0
lines changed

larq_compute_engine/mlir/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,22 @@ cc_library(
394394
alwayslink = 1,
395395
)
396396

397+
cc_library(
398+
name = "detection_postprocess_transform",
399+
srcs = [
400+
"transforms/detection_postprocess.cc",
401+
],
402+
hdrs = [
403+
"transforms/passes.h",
404+
],
405+
deps = [
406+
"//larq_compute_engine/mlir:larq_compute_engine",
407+
"@llvm-project//mlir:FuncDialect",
408+
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
409+
],
410+
alwayslink = 1,
411+
)
412+
397413
cc_library(
398414
name = "fuse_padding",
399415
srcs = [
@@ -433,6 +449,7 @@ cc_library(
433449
"tf_tfl_passes.h",
434450
],
435451
deps = [
452+
":detection_postprocess_transform",
436453
":fuse_padding",
437454
":larq_compute_engine_bitpack_weights",
438455
":larq_compute_engine_legalize_tflite",

larq_compute_engine/mlir/tests/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ lce_lit_test_suite(
1414
],
1515
)
1616

17+
test_suite(
18+
name = "all",
19+
tests = [
20+
":lit",
21+
],
22+
)
23+
1724
cc_test(
1825
name = "lce_ops_options_test",
1926
srcs = ["lce_ops_options_test.cc"],
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: lce-tf-opt %s -detection-postprocess-int -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: detection_postprocess_int
4+
func.func @detection_postprocess_int(%arg0: tensor<1x10x4x!quant.uniform<i8:f32, 2.343750e-02>>, %arg1: tensor<1x10x1x!quant.uniform<i8:f32, 2.343750e-02>>, %arg2: tensor<10x4x!quant.uniform<i8:f32, 2.343750e-02>>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32>) {
5+
%0 = "tfl.dequantize"(%arg0) : (tensor<1x10x4x!quant.uniform<i8:f32, 2.343750e-02>>) -> tensor<1x10x4xf32>
6+
%1 = "tfl.dequantize"(%arg1) : (tensor<1x10x1x!quant.uniform<i8:f32, 2.343750e-02>>) -> tensor<1x10x1xf32>
7+
%2 = "tfl.dequantize"(%arg2) : (tensor<10x4x!quant.uniform<i8:f32, 2.343750e-02>>) -> tensor<10x4xf32>
8+
%3:4 = "tfl.custom"(%0, %1, %2) {custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl<const_bytes : "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F01000000140000000000003F9A9959BF01000000010000000000803F0000803F0000803F0E06060E0E06060E0E0E322601">} : (tensor<1x10x4xf32>, tensor<1x10x1xf32>, tensor<10x4xf32>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32>)
9+
return %3#0, %3#1, %3#2, %3#3 : tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32> // boxes, classes, scores, num_detections
10+
11+
// CHECK: %3:4 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl<const_bytes : "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F01000000140000000000003F9A9959BF01000000010000000000803F0000803F0000803F0E06060E0E06060E0E0E322601">} : (tensor<1x10x4x!quant.uniform<i8:f32, 2.343750e-02>>, tensor<1x10x1x!quant.uniform<i8:f32, 2.343750e-02>>, tensor<10x4x!quant.uniform<i8:f32, 2.343750e-02>>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20x!quant.uniform<i8:f32, 2.343750e-02>>, tensor<1xi32>)
12+
// CHECK-NEXT: %4 = "tfl.dequantize"(%3#2) : (tensor<1x20x!quant.uniform<i8:f32, 2.343750e-02>>) -> tensor<1x20xf32>
13+
// CHECK-NEXT: return %3#0, %3#1, %4, %3#3 : tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32>
14+
}

larq_compute_engine/mlir/tf_tfl_passes.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ const char kTFLiteDataLayout[] = "NHWC";
2525
namespace {
2626
void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs,
2727
mlir::OpPassManager& pass_manager) {
28+
// PrepareQuantizePass adds Quantize->Dequantize pairs at *every* float tensor
29+
// even if it did not have a fake quantization node, or if that fake
30+
// quantization node was folded away (which would happen for weights).
2831
pass_manager.addNestedPass<mlir::func::FuncOp>(
2932
mlir::TFL::CreatePrepareQuantizePass(quant_specs));
33+
34+
// The LCEQuantizePass is similar to 'step 1' of the TFL QuantizePass below.
3035
pass_manager.addNestedPass<mlir::func::FuncOp>(
3136
mlir::TFL::CreateLCEQuantizePass());
37+
3238
if (quant_specs.default_ranges.first.hasValue() ||
3339
quant_specs.default_ranges.second.hasValue()) {
3440
pass_manager.addNestedPass<mlir::func::FuncOp>(
@@ -39,6 +45,18 @@ void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs,
3945
pass_manager.addNestedPass<mlir::func::FuncOp>(
4046
mlir::TFL::CreateLCEQuantizePass());
4147
}
48+
49+
// This absorbs Dequantize ops into the postprocessing op when possible,
50+
// similar to 'step 1' of the TFL QuantizePass below.
51+
pass_manager.addNestedPass<mlir::func::FuncOp>(
52+
mlir::TFL::QuantizeDetectionPostProcessPass());
53+
54+
// QuantizePass does two things:
55+
// 1. For TFLite ops with quantize traits, the Dequantize is absorbed
56+
// into the input of the op, and in certain cases per-channel quantization
57+
// is applied.
58+
// 2. Afterwards, any remaining Quantize->Dequantize pairs with constant input
59+
// are *removed*.
4260
pass_manager.addNestedPass<mlir::func::FuncOp>(
4361
mlir::TFL::CreateQuantizePass());
4462
bool emit_quant_adaptor_ops =
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#include "larq_compute_engine/mlir/ir/lce_ops.h"
2+
#include "mlir/Dialect/Func/IR/FuncOps.h"
3+
#include "mlir/IR/PatternMatch.h"
4+
#include "mlir/IR/TypeUtilities.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
7+
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
8+
9+
namespace mlir {
10+
namespace TFL {
11+
12+
struct DetectionPostProcess
13+
: public PassWrapper<DetectionPostProcess,
14+
OperationPass<mlir::func::FuncOp>> {
15+
llvm::StringRef getArgument() const final {
16+
return "detection-postprocess-int";
17+
}
18+
llvm::StringRef getDescription() const final {
19+
return "Make detection postprocessing op run with int8 input";
20+
}
21+
void runOnOperation() override;
22+
};
23+
24+
struct RemoveDequantizeBeforePostProcess : public OpRewritePattern<CustomOp> {
25+
using OpRewritePattern<CustomOp>::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(CustomOp detection_op,
28+
PatternRewriter& rewriter) const override {
29+
// ----------------- matching part -----------------
30+
31+
// Match the custom op code to 'TFLite_Detection_PostProcess'
32+
auto custom_code = detection_op.custom_code().str();
33+
if (custom_code != "TFLite_Detection_PostProcess") {
34+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
35+
diag << "op 'tfl.custom' attribute 'custom_code' failed to satisfy "
36+
"constraint: constant attribute TFLite_Detection_PostProcess";
37+
});
38+
}
39+
40+
// Check the number of inputs and outputs of the detection op
41+
auto original_detection_inputs = detection_op.input();
42+
if (original_detection_inputs.size() != 3) {
43+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
44+
diag << "expected 3 inputs for the detection op";
45+
});
46+
}
47+
auto original_detection_outputs = detection_op.output();
48+
if (original_detection_outputs.size() != 4) {
49+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
50+
diag << "expected 4 outputs for the original detection op";
51+
});
52+
}
53+
54+
// Match that dequantization happens just before the detection op
55+
auto boxes_input_op = original_detection_inputs[0].getDefiningOp();
56+
auto original_boxes_dequantize_op =
57+
llvm::dyn_cast_or_null<DequantizeOp>(boxes_input_op);
58+
if (!original_boxes_dequantize_op) {
59+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
60+
diag << "expected dequantization before the op for the 'boxes' input";
61+
});
62+
}
63+
auto scores_input_op = original_detection_inputs[1].getDefiningOp();
64+
auto original_scores_dequantize_op =
65+
llvm::dyn_cast_or_null<DequantizeOp>(scores_input_op);
66+
if (!original_scores_dequantize_op) {
67+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
68+
diag << "expected dequantization before the op for the 'scores' input";
69+
});
70+
}
71+
auto anchors_input_op = original_detection_inputs[2].getDefiningOp();
72+
auto original_anchors_dequantize_op =
73+
llvm::dyn_cast_or_null<DequantizeOp>(anchors_input_op);
74+
if (!original_anchors_dequantize_op) {
75+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
76+
diag << "expected dequantization before the op for the 'anchors' input";
77+
});
78+
}
79+
80+
// Verify the output types of the current detection op:
81+
// Output type #0: [int32] detection boxes (scaled by 2048)
82+
// Output type #1: [int32] detection class IDs
83+
// Output type #2: [float32] detection scores
84+
// Output type #3: [int32] number of detections
85+
auto output_data_types = SmallVector<Type, 4>{
86+
rewriter.getIntegerType(32),
87+
rewriter.getIntegerType(32),
88+
rewriter.getF32Type(),
89+
rewriter.getIntegerType(32),
90+
};
91+
for (auto i = 0; i < 4; ++i) {
92+
auto original_type =
93+
original_detection_outputs[i].getType().cast<ShapedType>();
94+
auto original_data_type = original_type.getElementType();
95+
if (output_data_types[i] != original_data_type) {
96+
return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) {
97+
diag << "unexpected output type of the op";
98+
});
99+
}
100+
}
101+
102+
// ----------------- re-write part -----------------
103+
104+
// Get the original inputs (before dequantization)
105+
auto boxes_input = original_boxes_dequantize_op.input();
106+
auto scores_input = original_scores_dequantize_op.input();
107+
auto anchors_input = original_anchors_dequantize_op.input();
108+
109+
// Set the new detection inputs
110+
auto new_detection_inputs =
111+
SmallVector<Value, 3>{boxes_input, scores_input, anchors_input};
112+
113+
// Set the 4 outputs types [scores, classes, boxes, num_detections]:
114+
// Output type #0: [int32] detection boxes (scaled by 2048)
115+
// Output type #1: [int32] detection class IDs
116+
// Output type #2: [int8 quantized] detection scores
117+
// Output type #3: [int32] number of detections
118+
// All as before, except for output #2 (float -> int8 quantized)
119+
auto scores_type = scores_input.getType()
120+
.cast<ShapedType>()
121+
.getElementType()
122+
.cast<quant::UniformQuantizedType>();
123+
const auto scores_zp = scores_type.getZeroPoint();
124+
const auto scores_scale = scores_type.getScale();
125+
output_data_types[2] = quant::UniformQuantizedType::get(
126+
true, rewriter.getIntegerType(8), rewriter.getF32Type(), scores_scale,
127+
scores_zp, -128, 127);
128+
129+
// Set for all the outputs: data-type (as set above) and shape (as before)
130+
auto new_op_output_types = SmallVector<Type, 4>{};
131+
for (auto i = 0; i < 4; ++i) {
132+
auto value = original_detection_outputs[i];
133+
auto shape = value.getType().cast<ShapedType>().getShape();
134+
auto new_output_type = RankedTensorType::get(shape, output_data_types[i]);
135+
new_op_output_types.push_back(new_output_type);
136+
}
137+
138+
// Add a new detection op (with int8 input and int8/int32 output)
139+
auto new_detection_op = rewriter.create<CustomOp>(
140+
detection_op->getLoc(), new_op_output_types, new_detection_inputs,
141+
std::string{"TFLite_Detection_PostProcess"},
142+
detection_op.custom_option());
143+
144+
// Add the 4 outputs: boxes, classes, scores, num_detections
145+
auto new_outputs = SmallVector<Value, 4>{};
146+
147+
// Output #0: [int32] detection boxes (scaled by 2048)
148+
new_outputs.push_back(new_detection_op.output()[0]);
149+
150+
// Output #1: [int32] detection class IDs
151+
new_outputs.push_back(new_detection_op.output()[1]);
152+
153+
// Output #2: [int8 quantized] detection scores
154+
auto new_dequantize_op = rewriter.create<DequantizeOp>(
155+
detection_op->getLoc(), original_detection_outputs[2].getType(),
156+
new_detection_op.output()[2]);
157+
new_outputs.push_back(new_dequantize_op.output());
158+
159+
// Output #3: [int32] number of detections
160+
new_outputs.push_back(new_detection_op.output()[3]);
161+
162+
// Final re-write of the detection op with detection + quantization
163+
rewriter.replaceOp(detection_op, new_outputs);
164+
return success();
165+
};
166+
};
167+
168+
void DetectionPostProcess::runOnOperation() {
169+
auto* ctx = &getContext();
170+
RewritePatternSet patterns(ctx);
171+
auto func = getOperation();
172+
173+
patterns.add<RemoveDequantizeBeforePostProcess>(ctx);
174+
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
175+
}
176+
177+
// Creates an instance of the TensorFlow dialect DetectionPostProcess pass.
178+
std::unique_ptr<OperationPass<func::FuncOp>>
179+
QuantizeDetectionPostProcessPass() {
180+
return std::make_unique<DetectionPostProcess>();
181+
}
182+
183+
static PassRegistration<DetectionPostProcess> pass;
184+
185+
} // namespace TFL
186+
} // namespace mlir

larq_compute_engine/mlir/transforms/passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateLCEQuantizePass();
2929
// Creates an instance of LegalizeLCE pass.
3030
std::unique_ptr<OperationPass<func::FuncOp>> CreateLegalizeLCEPass();
3131

32+
// Creates an instance of the TensorFlow dialect DetectionPostProcess pass.
33+
std::unique_ptr<OperationPass<func::FuncOp>> QuantizeDetectionPostProcessPass();
34+
3235
// Creates an instance of the FusePadding pass.
3336
std::unique_ptr<OperationPass<func::FuncOp>> CreateFusePaddingPass();
3437

0 commit comments

Comments
 (0)