|
| 1 | +#include "larq_compute_engine/mlir/ir/lce_ops.h" |
| 2 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 3 | +#include "mlir/IR/PatternMatch.h" |
| 4 | +#include "mlir/IR/TypeUtilities.h" |
| 5 | +#include "mlir/Pass/Pass.h" |
| 6 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 7 | +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| 8 | + |
| 9 | +namespace mlir { |
| 10 | +namespace TFL { |
| 11 | + |
| 12 | +struct DetectionPostProcess |
| 13 | + : public PassWrapper<DetectionPostProcess, |
| 14 | + OperationPass<mlir::func::FuncOp>> { |
| 15 | + llvm::StringRef getArgument() const final { |
| 16 | + return "detection-postprocess-int"; |
| 17 | + } |
| 18 | + llvm::StringRef getDescription() const final { |
| 19 | + return "Make detection postprocessing op run with int8 input"; |
| 20 | + } |
| 21 | + void runOnOperation() override; |
| 22 | +}; |
| 23 | + |
| 24 | +struct RemoveDequantizeBeforePostProcess : public OpRewritePattern<CustomOp> { |
| 25 | + using OpRewritePattern<CustomOp>::OpRewritePattern; |
| 26 | + |
| 27 | + LogicalResult matchAndRewrite(CustomOp detection_op, |
| 28 | + PatternRewriter& rewriter) const override { |
| 29 | + // ----------------- matching part ----------------- |
| 30 | + |
| 31 | + // Match the custom op code to 'TFLite_Detection_PostProcess' |
| 32 | + auto custom_code = detection_op.custom_code().str(); |
| 33 | + if (custom_code != "TFLite_Detection_PostProcess") { |
| 34 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 35 | + diag << "op 'tfl.custom' attribute 'custom_code' failed to satisfy " |
| 36 | + "constraint: constant attribute TFLite_Detection_PostProcess"; |
| 37 | + }); |
| 38 | + } |
| 39 | + |
| 40 | + // Check the number of inputs and outputs of the detection op |
| 41 | + auto original_detection_inputs = detection_op.input(); |
| 42 | + if (original_detection_inputs.size() != 3) { |
| 43 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 44 | + diag << "expected 3 inputs for the detection op"; |
| 45 | + }); |
| 46 | + } |
| 47 | + auto original_detection_outputs = detection_op.output(); |
| 48 | + if (original_detection_outputs.size() != 4) { |
| 49 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 50 | + diag << "expected 4 outputs for the original detection op"; |
| 51 | + }); |
| 52 | + } |
| 53 | + |
| 54 | + // Match that dequantization happens just before the detection op |
| 55 | + auto boxes_input_op = original_detection_inputs[0].getDefiningOp(); |
| 56 | + auto original_boxes_dequantize_op = |
| 57 | + llvm::dyn_cast_or_null<DequantizeOp>(boxes_input_op); |
| 58 | + if (!original_boxes_dequantize_op) { |
| 59 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 60 | + diag << "expected dequantization before the op for the 'boxes' input"; |
| 61 | + }); |
| 62 | + } |
| 63 | + auto scores_input_op = original_detection_inputs[1].getDefiningOp(); |
| 64 | + auto original_scores_dequantize_op = |
| 65 | + llvm::dyn_cast_or_null<DequantizeOp>(scores_input_op); |
| 66 | + if (!original_scores_dequantize_op) { |
| 67 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 68 | + diag << "expected dequantization before the op for the 'scores' input"; |
| 69 | + }); |
| 70 | + } |
| 71 | + auto anchors_input_op = original_detection_inputs[2].getDefiningOp(); |
| 72 | + auto original_anchors_dequantize_op = |
| 73 | + llvm::dyn_cast_or_null<DequantizeOp>(anchors_input_op); |
| 74 | + if (!original_anchors_dequantize_op) { |
| 75 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 76 | + diag << "expected dequantization before the op for the 'anchors' input"; |
| 77 | + }); |
| 78 | + } |
| 79 | + |
| 80 | + // Verify the output types of the current detection op: |
| 81 | + // Output type #0: [int32] detection boxes (scaled by 2048) |
| 82 | + // Output type #1: [int32] detection class IDs |
| 83 | + // Output type #2: [float32] detection scores |
| 84 | + // Output type #3: [int32] number of detections |
| 85 | + auto output_data_types = SmallVector<Type, 4>{ |
| 86 | + rewriter.getIntegerType(32), |
| 87 | + rewriter.getIntegerType(32), |
| 88 | + rewriter.getF32Type(), |
| 89 | + rewriter.getIntegerType(32), |
| 90 | + }; |
| 91 | + for (auto i = 0; i < 4; ++i) { |
| 92 | + auto original_type = |
| 93 | + original_detection_outputs[i].getType().cast<ShapedType>(); |
| 94 | + auto original_data_type = original_type.getElementType(); |
| 95 | + if (output_data_types[i] != original_data_type) { |
| 96 | + return rewriter.notifyMatchFailure(detection_op, [&](Diagnostic& diag) { |
| 97 | + diag << "unexpected output type of the op"; |
| 98 | + }); |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + // ----------------- re-write part ----------------- |
| 103 | + |
| 104 | + // Get the original inputs (before dequantization) |
| 105 | + auto boxes_input = original_boxes_dequantize_op.input(); |
| 106 | + auto scores_input = original_scores_dequantize_op.input(); |
| 107 | + auto anchors_input = original_anchors_dequantize_op.input(); |
| 108 | + |
| 109 | + // Set the new detection inputs |
| 110 | + auto new_detection_inputs = |
| 111 | + SmallVector<Value, 3>{boxes_input, scores_input, anchors_input}; |
| 112 | + |
| 113 | + // Set the 4 outputs types [scores, classes, boxes, num_detections]: |
| 114 | + // Output type #0: [int32] detection boxes (scaled by 2048) |
| 115 | + // Output type #1: [int32] detection class IDs |
| 116 | + // Output type #2: [int8 quantized] detection scores |
| 117 | + // Output type #3: [int32] number of detections |
| 118 | + // All as before, except for output #2 (float -> int8 quantized) |
| 119 | + auto scores_type = scores_input.getType() |
| 120 | + .cast<ShapedType>() |
| 121 | + .getElementType() |
| 122 | + .cast<quant::UniformQuantizedType>(); |
| 123 | + const auto scores_zp = scores_type.getZeroPoint(); |
| 124 | + const auto scores_scale = scores_type.getScale(); |
| 125 | + output_data_types[2] = quant::UniformQuantizedType::get( |
| 126 | + true, rewriter.getIntegerType(8), rewriter.getF32Type(), scores_scale, |
| 127 | + scores_zp, -128, 127); |
| 128 | + |
| 129 | + // Set for all the outputs: data-type (as set above) and shape (as before) |
| 130 | + auto new_op_output_types = SmallVector<Type, 4>{}; |
| 131 | + for (auto i = 0; i < 4; ++i) { |
| 132 | + auto value = original_detection_outputs[i]; |
| 133 | + auto shape = value.getType().cast<ShapedType>().getShape(); |
| 134 | + auto new_output_type = RankedTensorType::get(shape, output_data_types[i]); |
| 135 | + new_op_output_types.push_back(new_output_type); |
| 136 | + } |
| 137 | + |
| 138 | + // Add a new detection op (with int8 input and int8/int32 output) |
| 139 | + auto new_detection_op = rewriter.create<CustomOp>( |
| 140 | + detection_op->getLoc(), new_op_output_types, new_detection_inputs, |
| 141 | + std::string{"TFLite_Detection_PostProcess"}, |
| 142 | + detection_op.custom_option()); |
| 143 | + |
| 144 | + // Add the 4 outputs: boxes, classes, scores, num_detections |
| 145 | + auto new_outputs = SmallVector<Value, 4>{}; |
| 146 | + |
| 147 | + // Output #0: [int32] detection boxes (scaled by 2048) |
| 148 | + new_outputs.push_back(new_detection_op.output()[0]); |
| 149 | + |
| 150 | + // Output #1: [int32] detection class IDs |
| 151 | + new_outputs.push_back(new_detection_op.output()[1]); |
| 152 | + |
| 153 | + // Output #2: [int8 quantized] detection scores |
| 154 | + auto new_dequantize_op = rewriter.create<DequantizeOp>( |
| 155 | + detection_op->getLoc(), original_detection_outputs[2].getType(), |
| 156 | + new_detection_op.output()[2]); |
| 157 | + new_outputs.push_back(new_dequantize_op.output()); |
| 158 | + |
| 159 | + // Output #3: [int32] number of detections |
| 160 | + new_outputs.push_back(new_detection_op.output()[3]); |
| 161 | + |
| 162 | + // Final re-write of the detection op with detection + quantization |
| 163 | + rewriter.replaceOp(detection_op, new_outputs); |
| 164 | + return success(); |
| 165 | + }; |
| 166 | +}; |
| 167 | + |
| 168 | +void DetectionPostProcess::runOnOperation() { |
| 169 | + auto* ctx = &getContext(); |
| 170 | + RewritePatternSet patterns(ctx); |
| 171 | + auto func = getOperation(); |
| 172 | + |
| 173 | + patterns.add<RemoveDequantizeBeforePostProcess>(ctx); |
| 174 | + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
| 175 | +} |
| 176 | + |
| 177 | +// Creates an instance of the TensorFlow dialect DetectionPostProcess pass. |
| 178 | +std::unique_ptr<OperationPass<func::FuncOp>> |
| 179 | +QuantizeDetectionPostProcessPass() { |
| 180 | + return std::make_unique<DetectionPostProcess>(); |
| 181 | +} |
| 182 | + |
| 183 | +static PassRegistration<DetectionPostProcess> pass; |
| 184 | + |
| 185 | +} // namespace TFL |
| 186 | +} // namespace mlir |
0 commit comments