Skip to content

Commit 9007181

Browse files
authored
Add tests for MLIR passes (#248)
1 parent 8b083f9 commit 9007181

File tree

7 files changed

+325
-0
lines changed

7 files changed

+325
-0
lines changed

.github/workflows/unittests.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,29 @@ jobs:
5353
run: ./bazelisk run larq_compute_engine/tflite/tests:cc_tests_aarch64_qemu --config=aarch64
5454

5555
MLIR:
56+
runs-on: macos-latest
57+
if: "!contains(github.event.head_commit.message, 'ci-skip')"
58+
steps:
59+
- uses: actions/checkout@v2
60+
- name: Install Bazelisk
61+
run: |
62+
curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/bazelisk-darwin-amd64 > bazelisk
63+
chmod +x bazelisk
64+
- name: Install GNU find
65+
run: |
66+
HOMEBREW_NO_AUTO_UPDATE=1 brew install findutils
67+
ln -sf /usr/local/bin/gfind /usr/local/bin/find
68+
- uses: actions/setup-python@v1
69+
with:
70+
python-version: 3.7
71+
- name: Configure Bazel
72+
run: ./configure.sh
73+
- name: Install pip dependencies
74+
run: pip install numpy six --no-cache-dir
75+
- name: Run FileCheck tests
76+
run: ./bazelisk test larq_compute_engine/mlir/tests:all --test_output=all --distinct_host_configuration=false
77+
78+
MLIR_Python:
5679
runs-on: macos-latest
5780
if: "!contains(github.event.head_commit.message, 'ci-skip')"
5881
steps:

larq_compute_engine/mlir/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ cc_library(
136136
],
137137
)
138138

139+
cc_binary(
140+
name = "lce-tf-opt",
141+
deps = [
142+
":lce_tfl_passes",
143+
"@local_config_mlir//:MlirOptMain",
144+
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
145+
],
146+
)
147+
139148
pybind_extension(
140149
name = "_graphdef_tfl_flatbuffer",
141150
srcs = ["python/graphdef_tfl_flatbuffer.cc"],

larq_compute_engine/mlir/tests/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("//larq_compute_engine/mlir/tests:lit_test.bzl", "lce_lit_test_suite")
2+
3+
package(
4+
default_visibility = ["//visibility:public"],
5+
licenses = ["notice"], # Apache 2.0
6+
)
7+
8+
lce_lit_test_suite(
9+
name = "lit",
10+
srcs = glob(["*.mlir"]),
11+
data = [
12+
"//larq_compute_engine/mlir:lce-tf-opt",
13+
"@llvm//:FileCheck",
14+
],
15+
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2019 Google LLC
2+
# Modifications copyright (C) 2020 Larq Contributors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Copied from https://github.com/google/iree/blob/master/iree/lit_test.bzl
17+
18+
"""Bazel macros for running lit tests."""
19+
20+
def lce_lit_test(
21+
name,
22+
test_file,
23+
data,
24+
size = "small",
25+
driver = "//larq_compute_engine/mlir/tests:run_lit.sh",
26+
**kwargs):
27+
"""Creates a lit test from the specified source file.
28+
29+
Args:
30+
name: name of the generated test suite.
31+
test_file: the test file with the lit test
32+
data: binaries used in the lit tests.
33+
size: size of the tests.
34+
driver: the shell runner for the lit tests.
35+
**kwargs: Any additional arguments that will be passed to the underlying sh_test.
36+
"""
37+
native.sh_test(
38+
name = name,
39+
srcs = [driver],
40+
size = size,
41+
data = data + [test_file],
42+
args = ["$(location %s)" % (test_file,)],
43+
**kwargs
44+
)
45+
46+
def lce_lit_test_suite(
47+
name,
48+
data,
49+
srcs,
50+
size = "small",
51+
driver = "//larq_compute_engine/mlir/tests:run_lit.sh",
52+
tags = [],
53+
**kwargs):
54+
"""Creates one lit test per source file and a test suite that bundles them.
55+
56+
Args:
57+
name: name of the generated test suite.
58+
data: binaries used in the lit tests.
59+
srcs: test file sources.
60+
size: size of the tests.
61+
driver: the shell runner for the lit tests.
62+
tags: tags to apply to the test. Note that as in standard test suites, manual
63+
is treated specially and will also apply to the test suite itself.
64+
**kwargs: Any additional arguments that will be passed to the underlying tests and test_suite.
65+
"""
66+
tests = []
67+
for test_file in srcs:
68+
# It's generally good practice to prefix any generated names with the
69+
# macro name, but we're trying to match the style of the names that are
70+
# used for LLVM internally.
71+
test_name = "%s.test" % (test_file)
72+
lce_lit_test(
73+
name = test_name,
74+
test_file = test_file,
75+
size = size,
76+
data = data,
77+
driver = driver,
78+
tags = tags,
79+
**kwargs
80+
)
81+
tests.append(test_name)
82+
83+
native.test_suite(
84+
name = name,
85+
tests = tests,
86+
# Note that only the manual tag really has any effect here. Others are
87+
# used for test suite filtering, but all tests are passed the same tags.
88+
tags = tags,
89+
# If there are kwargs that need to be passed here which only apply to
90+
# the generated tests and not to test_suite, they should be extracted
91+
# into separate named arguments.
92+
**kwargs
93+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: lce-tf-opt %s -tfl-optimize-lce | FileCheck %s
2+
3+
// CHECK-LABEL: fuseAddIntoBConv2d
4+
func @fuseAddIntoBConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
5+
%cst = constant dense<1.5> : tensor<16xf32>
6+
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
7+
%0 = "tf.LqceBconv2d64"(%arg0, %arg1, %arg2, %cst_0) {data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI", padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
8+
%1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
9+
return %1 : tensor<256x30x30x16xf32>
10+
11+
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
12+
// CHECK: %0 = "tf.LqceBconv2d64"(%arg0, %arg1, %arg2, %cst)
13+
}
14+
15+
16+
// CHECK-LABEL: fuseSubIntoBConv2d
17+
func @fuseSubIntoBConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
18+
%cst = constant dense<0.5> : tensor<16xf32>
19+
%cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
20+
%0 = "tf.LqceBconv2d64"(%arg0, %arg1, %arg2, %cst_0) {data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI", padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
21+
%1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
22+
return %1 : tensor<256x30x30x16xf32>
23+
24+
// CHECK: %cst = constant dense<[5.000000e-01, 1.500000e+00, 2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01]> : tensor<16xf32>
25+
// CHECK: %0 = "tf.LqceBconv2d64"(%arg0, %arg1, %arg2, %cst)
26+
}
27+
28+
// CHECK-LABEL: @fuseDivIntoBConv2d
29+
func @fuseDivIntoBConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
30+
%cst = constant dense<0.5> : tensor<16xf32>
31+
%cst_0 = constant dense<1.5> : tensor<16xf32>
32+
%cst_1 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
33+
%0 = "tf.LqceBconv2d64"(%arg0, %arg1, %cst_1, %cst_0) {data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI", padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
34+
%1 = "tfl.div"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
35+
return %1 : tensor<256x30x30x16xf32>
36+
37+
// CHECK: %cst = constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> : tensor<16xf32>
38+
// CHECK: %cst_0 = constant dense<3.000000e+00> : tensor<16xf32>
39+
// CHECK: %0 = "tf.LqceBconv2d64"(%arg0, %arg1, %cst, %cst_0)
40+
}
41+
42+
// CHECK-LABEL: @fuseMulIntoBConv2d
43+
func @fuseMulIntoBConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x30x30x16xf32> {
44+
%cst = constant dense<2.0> : tensor<16xf32>
45+
%cst_0 = constant dense<1.5> : tensor<16xf32>
46+
%cst_1 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
47+
%0 = "tf.LqceBconv2d64"(%arg0, %arg1, %cst_1, %cst_0) {data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI", padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
48+
%1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
49+
return %1 : tensor<256x30x30x16xf32>
50+
51+
// CHECK: %cst = constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> : tensor<16xf32>
52+
// CHECK: %cst_0 = constant dense<3.000000e+00> : tensor<16xf32>
53+
// CHECK: %0 = "tf.LqceBconv2d64"(%arg0, %arg1, %cst, %cst_0)
54+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: lce-tf-opt %s -tfl-prepare-lce | FileCheck %s --dump-input-on-failure
2+
3+
func @bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
4+
%0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
5+
%cst = constant dense<0.1> : tensor<f32>
6+
%2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
7+
%3 = "tf.Sign"(%2) : (tensor<8x16xf32>) -> tensor<8x16xf32>
8+
return %3 : tensor<8x16xf32>
9+
// CHECK-LABEL: bsign
10+
// CHECK: %0 = "tf.LqceBsign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
11+
// CHECK: return %0
12+
}
13+
14+
func @bconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
15+
%cst = constant dense<[[[[1.0, -1.0], [1.0, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]> : tensor<1x2x2x2xf32>
16+
%0 = "tf.LqceBsign"(%arg0) : (tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
17+
%1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x112x112x2xf32>
18+
return %1 : tensor<1x112x112x2xf32>
19+
// CHECK-LABEL: bconv2d
20+
// CHECK: %[[CST1:.*]] = constant dense<-2.000000e+00> : tensor<2xf32>
21+
// CHECK: %[[CST2:.*]] = constant dense<4.000000e+00> : tensor<2xf32>
22+
// CHECK: %[[TRP:.*]] = "tf.Transpose"
23+
// CHECK: %[[CONV:.*]] = "tf.LqceBconv2d64"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], filter_format = "OHWI", padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
24+
// CHECK: return %[[CONV]]
25+
}
26+
27+
func @notbconv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
28+
%cst = constant dense<0.5> : tensor<1x2x2x2xf32>
29+
%0 = "tf.LqceBsign"(%arg0) : (tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
30+
%1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x112x112x2xf32>
31+
return %1 : tensor<1x112x112x2xf32>
32+
// CHECK-LABEL: notbconv2d
33+
// CHECK: %0 = "tf.LqceBsign"(%arg0) : (tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
34+
// CHECK: %1 = "tf.Conv2D"
35+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env bash
2+
3+
# Copyright 2019 Google LLC
4+
# Modifications copyright (C) 2020 Larq Contributors.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# https://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
# Copied from https://github.com/google/iree/blob/master/iree/tools/run_lit.sh
19+
20+
set -e
21+
set -o pipefail
22+
23+
if [ -z "${RUNFILES_DIR}" ]; then
24+
# Some versions of bazel do not set RUNFILES_DIR. Instead they just cd
25+
# into the directory.
26+
RUNFILES_DIR="$PWD"
27+
fi
28+
29+
# Detect whether cygwin/msys2 paths need to be translated.
30+
set +e # Ignore errors if not found.
31+
cygpath="$(which cygpath 2>/dev/null)"
32+
set -e
33+
34+
function find_executables() {
35+
set -e
36+
local p="$1"
37+
if [ -z "$cygpath" ]; then
38+
# For non-windows, use the perm based executable check, which has been
39+
# supported by find for a very long time.
40+
find "${p}" -xtype f -perm /u=x,g=x,o=x -print
41+
else
42+
# For windows, always use the newer -executable find predicate (which is
43+
# not supported by ancient versions of find).
44+
find "${p}" -xtype f -executable -print
45+
fi
46+
}
47+
48+
# Bazel helpfully puts all data deps in the ${RUNFILES_DIR}, but
49+
# it unhelpfully preserves the nesting with no way to reason about
50+
# it generically. run_lit expects that anything passed in the runfiles
51+
# can be found on the path for execution. So we just iterate over the
52+
# entries in the MANIFEST and extend the PATH.
53+
SUBPATH=""
54+
for runfile_path in $(find_executables "${RUNFILES_DIR}"); do
55+
# Prepend so that local things override.
56+
EXEDIR="$(dirname ${runfile_path})"
57+
if ! [ -z "$cygpath" ]; then
58+
EXEDIR="$($cygpath -u "$EXEDIR")"
59+
fi
60+
SUBPATH="${EXEDIR}:$SUBPATH"
61+
done
62+
63+
echo "run_lit.sh: $1"
64+
echo "PWD=$(pwd)"
65+
66+
# For each "// RUN:" line, run the command.
67+
runline_matches="$(egrep "^// RUN: " "$1")"
68+
if [ -z "$runline_matches" ]; then
69+
echo "!!! No RUN lines found in test"
70+
exit 1
71+
fi
72+
73+
echo "$runline_matches" | while read -r runline
74+
do
75+
echo "RUNLINE: $runline"
76+
match="${runline%%// RUN: *}"
77+
command="${runline##// RUN: }"
78+
if [ -z "${command}" ]; then
79+
echo "ERROR: Could not extract command from runline"
80+
exit 1
81+
fi
82+
83+
# Substitute any embedded '%s' with the file name.
84+
full_command="${command//\%s/$1}"
85+
86+
# Run it.
87+
export PATH="$SUBPATH"
88+
echo "RUNNING TEST: $full_command"
89+
echo "----------------"
90+
if eval "$full_command"; then
91+
echo "--- COMPLETE ---"
92+
else
93+
echo "!!! ERROR EVALUATING: $full_command"
94+
exit 1
95+
fi
96+
done

0 commit comments

Comments
 (0)