Skip to content

Commit 878cb23

Browse files
authored
Fix weight bitpacking which could lead to non-deterministic behaviour (#377) (#378)
1 parent 53c9946 commit 878cb23

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

larq_compute_engine/mlir/tests/optimize.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16xf
150150
%0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2) {activation = "NONE", channels_in = 3 : i32, filter_format = "OHWI", padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
151151
return %0 : tensor<256x30x30x16xf32>
152152

153+
// CHECK: %cst = constant dense<0> : tensor<16x3x3x1xi32>
153154
// CHECK: %0 = "tf.LceBconv2d"(%arg0, %cst, %arg1, %arg2) {activation = "NONE", channels_in = 3 : i32, data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI_PACKED", pad_values = 0 : i32, padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
154155
// CHECK-NEXT: return %0
155156
}

larq_compute_engine/mlir/transforms/optimize.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ bool IsConstantValue(Attribute values, float expected_value) {
3636

3737
bool IsConv2DFilter(Attribute filter) {
3838
if (!filter.isa<DenseElementsAttr>()) return false;
39-
if (filter.getType().cast<ShapedType>().getShape().size() != 4) return false;
40-
return true;
39+
auto filter_type = filter.getType().cast<ShapedType>();
40+
return filter_type.getElementType().isF32() &&
41+
filter_type.getShape().size() == 4;
4142
}
4243

4344
DenseElementsAttr Bitpack(PatternRewriter& builder, Attribute x) {
@@ -53,11 +54,17 @@ DenseElementsAttr Bitpack(PatternRewriter& builder, Attribute x) {
5354
int packed_channels = (unpacked_channels + bitwidth - 1) / bitwidth;
5455

5556
std::vector<PackedType> new_values(num_rows * packed_channels);
57+
std::vector<float> old_values(num_rows * unpacked_channels);
58+
59+
int i = 0;
60+
for (float x : dense_elements_iter) {
61+
old_values[i++] = x;
62+
}
63+
assert(i == num_rows * unpacked_channels);
5664

57-
const float* in_ptr = &(*dense_elements_iter.begin());
5865
using namespace compute_engine::core;
59-
packbits_matrix<BitpackOrder::Canonical>(in_ptr, num_rows, unpacked_channels,
60-
new_values.data());
66+
packbits_matrix<BitpackOrder::Canonical>(
67+
old_values.data(), num_rows, unpacked_channels, new_values.data());
6168

6269
RankedTensorType out_tensor_type =
6370
RankedTensorType::get({shape[0], shape[1], shape[2], packed_channels},

0 commit comments

Comments
 (0)