Skip to content

Commit 4b74994

Browse files
authored
Add bitpack order argument to packbits_tensor (#337)
1 parent b013bdc commit 4b74994

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

larq_compute_engine/core/packbits_utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int GetPackedTensorSize(const RuntimeShape& shape) {
2424

2525
// Convenience function for bitpacking a tensor along its last dimension
2626
// and updating the tensor shape
27-
template <class T, class TBitpacked>
27+
template <BitpackOrder bitpack_order, class T, class TBitpacked>
2828
inline void packbits_tensor(const RuntimeShape& in_shape, const T* in_data,
2929
const std::int32_t zero_point,
3030
RuntimeShape& out_shape, TBitpacked* out_data) {
@@ -35,8 +35,8 @@ inline void packbits_tensor(const RuntimeShape& in_shape, const T* in_data,
3535

3636
{
3737
gemmlowp::ScopedProfilingLabel label("Packbits");
38-
ce::core::packbits_matrix<ce::core::BitpackOrder::Optimized>(
39-
in_data, rows, cols, out_data, zero_point);
38+
ce::core::packbits_matrix<bitpack_order>(in_data, rows, cols, out_data,
39+
zero_point);
4040
}
4141

4242
out_shape.ReplaceWith(dims, in_shape.DimsData());

larq_compute_engine/tflite/kernels/bconv2d.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,9 @@ void EvalRef(TfLiteContext* context, TfLiteNode* node,
775775
} else {
776776
TfLiteTensor* packed_input =
777777
GetTemporary(context, node, params->packed_input_index);
778-
ce::core::packbits_tensor(input_shape, input_data, input->params.zero_point,
779-
packed_input_shape,
780-
GetTensorData<TBitpacked>(packed_input));
778+
ce::core::packbits_tensor<ce::core::BitpackOrder::Canonical>(
779+
input_shape, input_data, input->params.zero_point, packed_input_shape,
780+
GetTensorData<TBitpacked>(packed_input));
781781
packed_input_data = GetTensorData<TBitpacked>(packed_input);
782782
}
783783

larq_compute_engine/tflite/kernels/bconv2d_impl.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ inline void BConv2D(
145145
} else {
146146
// The input tensor has this shape which we bitpack along the channels
147147
// dimension [batch, input height, input width, channels].
148-
ce::core::packbits_tensor(input_shape, input_data, params.input_offset,
149-
packed_input_shape, packed_input_data);
148+
ce::core::packbits_tensor<ce::core::BitpackOrder::Optimized>(
149+
input_shape, input_data, params.input_offset, packed_input_shape,
150+
packed_input_data);
150151
im2col_input_data = packed_input_data;
151152
}
152153
im2col<TBitpacked>(params, packed_input_shape, im2col_input_data,
@@ -166,8 +167,9 @@ inline void BConv2D(
166167
// The RHS tensor has this shape which we bitpack along the last dimension
167168
// [batch, output_height, output_width, k * bitwidth]
168169
RuntimeShape packed_input_shape;
169-
ce::core::packbits_tensor(result_shape, result_data, params.input_offset,
170-
packed_input_shape, packed_input_data);
170+
ce::core::packbits_tensor<ce::core::BitpackOrder::Optimized>(
171+
result_shape, result_data, params.input_offset, packed_input_shape,
172+
packed_input_data);
171173
rhs_data = packed_input_data;
172174

173175
k = packed_input_shape.Dims(3);

larq_compute_engine/tflite/tests/bconv2d_test.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ void set_lce_op_input(const RuntimeShape& input_shape,
341341
std::vector<std::int32_t> input_data_bp(
342342
core::GetPackedTensorSize<std::int32_t>(input_shape));
343343
RuntimeShape output_shape;
344-
core::packbits_tensor(input_shape, input_data.data(), zero_point,
345-
output_shape, input_data_bp.data());
344+
core::packbits_tensor<ce::core::BitpackOrder::Canonical>(
345+
input_shape, input_data.data(), zero_point, output_shape,
346+
input_data_bp.data());
346347
m_lce.SetInput(input_data_bp);
347348
}
348349

@@ -358,8 +359,9 @@ void test_lce_op_output(const std::vector<std::int32_t>& lce_output_data,
358359
std::vector<std::int32_t> builtin_output_data_bp(
359360
core::GetPackedTensorSize<std::int32_t>(out_shape));
360361
RuntimeShape packed_shape;
361-
core::packbits_tensor(out_shape, builtin_output_data.data(), zero_point,
362-
packed_shape, builtin_output_data_bp.data());
362+
core::packbits_tensor<ce::core::BitpackOrder::Canonical>(
363+
out_shape, builtin_output_data.data(), zero_point, packed_shape,
364+
builtin_output_data_bp.data());
363365

364366
// We need the outputs here to be bit-exact, so don't allow for floating
365367
// point imprecision.

0 commit comments

Comments
 (0)