Skip to content

Commit d07a1ee

Browse files
authored
Add support for multi threaded interpreter (#512)
* Add support for multi threaded interpreter * Run end2end tests using multithreaded interpreter
1 parent ca5f8dd commit d07a1ee

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

larq_compute_engine/tests/end2end_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import os
23
import sys
34

45
import larq as lq
@@ -155,7 +156,7 @@ def preprocess(data):
155156

156157

157158
def assert_model_output(model_lce, inputs, outputs):
158-
interpreter = Interpreter(model_lce)
159+
interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4))
159160
actual_outputs = interpreter.predict(inputs)
160161
np.testing.assert_allclose(actual_outputs, outputs, rtol=0.001, atol=0.25)
161162

larq_compute_engine/tflite/python/interpreter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Interpreter:
4242
4343
# Arguments
4444
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
45+
num_threads: The number of threads used by the interpreter.
4546
4647
# Attributes
4748
input_types: Returns a list of input types.
@@ -50,8 +51,10 @@ class Interpreter:
5051
output_shapes: Returns a list of output shapes.
5152
"""
5253

53-
def __init__(self, flatbuffer_model: bytes):
54-
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(flatbuffer_model)
54+
def __init__(self, flatbuffer_model: bytes, num_threads: int = 1):
55+
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
56+
flatbuffer_model, num_threads
57+
)
5558

5659
@property
5760
def input_types(self) -> list:

larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
class LiteInterpreterWrapper
99
: public InterpreterWrapperBase<tflite::Interpreter> {
1010
public:
11-
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer);
11+
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer,
12+
const int num_threads);
1213
~LiteInterpreterWrapper(){};
1314

1415
private:
@@ -20,7 +21,7 @@ class LiteInterpreterWrapper
2021
};
2122

2223
LiteInterpreterWrapper::LiteInterpreterWrapper(
23-
const pybind11::bytes& flatbuffer) {
24+
const pybind11::bytes& flatbuffer, const int num_threads = 1) {
2425
// Make a copy of the flatbuffer because it can get deallocated after the
2526
// constructor is done
2627
flatbuffer_ = static_cast<std::string>(flatbuffer);
@@ -36,7 +37,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
3637
compute_engine::tflite::RegisterLCECustomOps(resolver_.get());
3738

3839
tflite::InterpreterBuilder builder(*model_, *resolver_);
39-
builder(&interpreter_);
40+
builder(&interpreter_, num_threads);
4041
MINIMAL_CHECK(interpreter_ != nullptr);
4142

4243
// Allocate tensor buffers.
@@ -45,7 +46,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
4546

4647
PYBIND11_MODULE(interpreter_wrapper_lite, m) {
4748
pybind11::class_<LiteInterpreterWrapper>(m, "LiteInterpreter")
48-
.def(pybind11::init<const pybind11::bytes&>())
49+
.def(pybind11::init<const pybind11::bytes&, const int>())
4950
.def_property("input_types", &LiteInterpreterWrapper::get_input_types,
5051
nullptr)
5152
.def_property("output_types", &LiteInterpreterWrapper::get_output_types,

larq_compute_engine/tflite/tests/interpreter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_interpreter_multi_input(use_iterator):
4646
expected_output_x = x_np.reshape(16, -1)
4747
expected_output_y = y_np.reshape(16, -1)
4848

49-
interpreter = Interpreter(converter.convert())
49+
interpreter = Interpreter(converter.convert(), num_threads=2)
5050
assert interpreter.input_types == [np.float32, np.float32]
5151
assert interpreter.output_types == [np.float32, np.float32]
5252
assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)]

0 commit comments

Comments
 (0)