Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 705a743

Browse files
antiagainsttensorflower-gardener
authored andcommitted
[spirv] Implement inliner interface
We just need to implement a few interface hooks to DialectInlinerInterface and CallOpInterface to gain the benefits of an inliner. :) Right now only supports some trivial cases: * Inlining single block with spv.Return/spv.ReturnValue * Inlining multi block with spv.Return * Inlining spv.selection/spv.loop without return ops More advanced cases will require block argument and Phi support. PiperOrigin-RevId: 275151132
1 parent 288ac1b commit 705a743

File tree

5 files changed

+264
-2
lines changed

5 files changed

+264
-2
lines changed

include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
include "mlir/SPIRV/SPIRVBase.td"
3030
#endif // SPIRV_BASE
3131

32+
#ifdef MLIR_CALLINTERFACES
33+
#else
34+
include "mlir/Analysis/CallInterfaces.td"
35+
#endif // MLIR_CALLINTERFACES
36+
3237
// -----
3338

3439
def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
@@ -151,7 +156,8 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
151156

152157
// -----
153158

154-
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
159+
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
160+
InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>]> {
155161
let summary = "Call a function.";
156162

157163
let description = [{

include/mlir/Dialect/SPIRV/SPIRVStructureOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
264264
}
265265

266266
def SPV_ModuleOp : SPV_Op<"module",
267-
[SingleBlockImplicitTerminator<"ModuleEndOp">,
267+
[IsolatedFromAbove,
268+
SingleBlockImplicitTerminator<"ModuleEndOp">,
268269
NativeOpTrait<"SymbolTable">]> {
269270
let summary = "The top-level op that defines a SPIR-V module";
270271

lib/Dialect/SPIRV/SPIRVDialect.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/StandardTypes.h"
1919
#include "mlir/Parser.h"
2020
#include "mlir/Support/StringExtras.h"
21+
#include "mlir/Transforms/InliningUtils.h"
2122
#include "llvm/ADT/DenseMap.h"
2223
#include "llvm/ADT/Sequence.h"
2324
#include "llvm/ADT/StringExtras.h"
@@ -34,6 +35,67 @@ namespace spirv {
3435
using namespace mlir;
3536
using namespace mlir::spirv;
3637

38+
//===----------------------------------------------------------------------===//
39+
// InlinerInterface
40+
//===----------------------------------------------------------------------===//
41+
42+
/// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
43+
static inline bool containsReturn(Region &region) {
44+
return llvm::any_of(region, [](Block &block) {
45+
Operation *terminator = block.getTerminator();
46+
return isa<spirv::ReturnOp>(terminator) ||
47+
isa<spirv::ReturnValueOp>(terminator);
48+
});
49+
}
50+
51+
namespace {
52+
/// This class defines the interface for inlining within the SPIR-V dialect.
53+
struct SPIRVInlinerInterface : public DialectInlinerInterface {
54+
using DialectInlinerInterface::DialectInlinerInterface;
55+
56+
/// Returns true if the given region 'src' can be inlined into the region
57+
/// 'dest' that is attached to an operation registered to the current dialect.
58+
bool isLegalToInline(Operation *op, Region *dest,
59+
BlockAndValueMapping &) const final {
60+
// TODO(antiagainst): Enable inlining structured control flows with return.
61+
if ((isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op)) &&
62+
containsReturn(op->getRegion(0)))
63+
return false;
64+
// TODO(antiagainst): we need to filter OpKill here to avoid inlining it to
65+
// a loop continue construct:
66+
// https://github.com/KhronosGroup/SPIRV-Headers/issues/86
67+
// However OpKill is fragment shader specific and we don't support it yet.
68+
return true;
69+
}
70+
71+
/// Handle the given inlined terminator by replacing it with a new operation
72+
/// as necessary.
73+
void handleTerminator(Operation *op, Block *newDest) const final {
74+
if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
75+
OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
76+
op->erase();
77+
} else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
78+
llvm_unreachable("unimplemented spv.ReturnValue in inliner");
79+
}
80+
}
81+
82+
/// Handle the given inlined terminator by replacing it with a new operation
83+
/// as necessary.
84+
void handleTerminator(Operation *op,
85+
ArrayRef<Value *> valuesToRepl) const final {
86+
// Only spv.ReturnValue needs to be handled here.
87+
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
88+
if (!retValOp)
89+
return;
90+
91+
// Replace the values directly with the return operands.
92+
assert(valuesToRepl.size() == 1 &&
93+
"spv.ReturnValue expected to only handle one result");
94+
valuesToRepl.front()->replaceAllUsesWith(retValOp.value());
95+
}
96+
};
97+
} // namespace
98+
3799
//===----------------------------------------------------------------------===//
38100
// SPIR-V Dialect
39101
//===----------------------------------------------------------------------===//
@@ -48,6 +110,8 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
48110
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
49111
>();
50112

113+
addInterfaces<SPIRVInlinerInterface>();
114+
51115
// Allow unknown operations because SPIR-V is extensible.
52116
allowUnknownOperations();
53117
}

lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
2323

24+
#include "mlir/Analysis/CallInterfaces.h"
2425
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
2526
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
2627
#include "mlir/IR/Builders.h"
@@ -1199,6 +1200,14 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
11991200
return success();
12001201
}
12011202

1203+
CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
1204+
return getAttrOfType<SymbolRefAttr>(kCallee);
1205+
}
1206+
1207+
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
1208+
return arguments();
1209+
}
1210+
12021211
//===----------------------------------------------------------------------===//
12031212
// spv.globalVariable
12041213
//===----------------------------------------------------------------------===//
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline)' -mlir-disable-inline-simplify | FileCheck %s
2+
3+
spv.module "Logical" "GLSL450" {
4+
func @callee() {
5+
spv.Return
6+
}
7+
8+
// CHECK-LABEL: func @calling_single_block_ret_func
9+
func @calling_single_block_ret_func() {
10+
// CHECK-NEXT: spv.Return
11+
spv.FunctionCall @callee() : () -> ()
12+
spv.Return
13+
}
14+
}
15+
16+
// -----
17+
18+
spv.module "Logical" "GLSL450" {
19+
func @callee() -> i32 {
20+
%0 = spv.constant 42 : i32
21+
spv.ReturnValue %0 : i32
22+
}
23+
24+
// CHECK-LABEL: func @calling_single_block_retval_func
25+
func @calling_single_block_retval_func() -> i32 {
26+
// CHECK-NEXT: %[[CST:.*]] = spv.constant 42
27+
%0 = spv.FunctionCall @callee() : () -> (i32)
28+
// CHECK-NEXT: spv.ReturnValue %[[CST]]
29+
spv.ReturnValue %0 : i32
30+
}
31+
}
32+
33+
// -----
34+
35+
spv.module "Logical" "GLSL450" {
36+
spv.globalVariable @data bind(0, 0) : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
37+
func @callee() {
38+
%0 = spv._address_of @data : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
39+
%1 = spv.constant 0: i32
40+
%2 = spv.AccessChain %0[%1, %1] : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
41+
spv.Branch ^next
42+
43+
^next:
44+
%3 = spv.constant 42: i32
45+
spv.Store "StorageBuffer" %2, %3 : i32
46+
spv.Return
47+
}
48+
49+
// CHECK-LABEL: func @calling_multi_block_ret_func
50+
func @calling_multi_block_ret_func() {
51+
// CHECK-NEXT: spv._address_of
52+
// CHECK-NEXT: spv.constant 0
53+
// CHECK-NEXT: spv.AccessChain
54+
// CHECK-NEXT: spv.Branch ^bb1
55+
// CHECK-NEXT: ^bb1:
56+
// CHECK-NEXT: spv.constant
57+
// CHECK-NEXT: spv.Store
58+
// CHECK-NEXT: spv.Branch ^bb2
59+
spv.FunctionCall @callee() : () -> ()
60+
// CHECK-NEXT: ^bb2:
61+
// CHECK-NEXT: spv.Return
62+
spv.Return
63+
}
64+
}
65+
66+
// TODO: calling_multi_block_retval_func
67+
68+
// -----
69+
70+
spv.module "Logical" "GLSL450" {
71+
func @callee(%cond : i1) -> () {
72+
spv.selection {
73+
spv.BranchConditional %cond, ^then, ^merge
74+
^then:
75+
spv.Return
76+
^merge:
77+
spv._merge
78+
}
79+
spv.Return
80+
}
81+
82+
// CHECK-LABEL: calling_selection_ret_func
83+
func @calling_selection_ret_func() {
84+
%0 = spv.constant true
85+
// CHECK: spv.FunctionCall
86+
spv.FunctionCall @callee(%0) : (i1) -> ()
87+
spv.Return
88+
}
89+
}
90+
91+
// -----
92+
93+
spv.module "Logical" "GLSL450" {
94+
func @callee(%cond : i1) -> () {
95+
spv.selection {
96+
spv.BranchConditional %cond, ^then, ^merge
97+
^then:
98+
spv.Branch ^merge
99+
^merge:
100+
spv._merge
101+
}
102+
spv.Return
103+
}
104+
105+
// CHECK-LABEL: calling_selection_no_ret_func
106+
func @calling_selection_no_ret_func() {
107+
// CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
108+
%0 = spv.constant true
109+
// CHECK-NEXT: spv.selection
110+
// CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb1, ^bb2
111+
// CHECK-NEXT: ^bb1:
112+
// CHECK-NEXT: spv.Branch ^bb2
113+
// CHECK-NEXT: ^bb2:
114+
// CHECK-NEXT: spv._merge
115+
spv.FunctionCall @callee(%0) : (i1) -> ()
116+
spv.Return
117+
}
118+
}
119+
120+
// -----
121+
122+
spv.module "Logical" "GLSL450" {
123+
func @callee(%cond : i1) -> () {
124+
spv.loop {
125+
spv.Branch ^header
126+
^header:
127+
spv.BranchConditional %cond, ^body, ^merge
128+
^body:
129+
spv.Return
130+
^continue:
131+
spv.Branch ^header
132+
^merge:
133+
spv._merge
134+
}
135+
spv.Return
136+
}
137+
138+
// CHECK-LABEL: calling_loop_ret_func
139+
func @calling_loop_ret_func() {
140+
%0 = spv.constant true
141+
// CHECK: spv.FunctionCall
142+
spv.FunctionCall @callee(%0) : (i1) -> ()
143+
spv.Return
144+
}
145+
}
146+
147+
// -----
148+
149+
spv.module "Logical" "GLSL450" {
150+
func @callee(%cond : i1) -> () {
151+
spv.loop {
152+
spv.Branch ^header
153+
^header:
154+
spv.BranchConditional %cond, ^body, ^merge
155+
^body:
156+
spv.Branch ^continue
157+
^continue:
158+
spv.Branch ^header
159+
^merge:
160+
spv._merge
161+
}
162+
spv.Return
163+
}
164+
165+
// CHECK-LABEL: calling_loop_no_ret_func
166+
func @calling_loop_no_ret_func() {
167+
// CHECK-NEXT: %[[TRUE:.*]] = spv.constant true
168+
%0 = spv.constant true
169+
// CHECK-NEXT: spv.loop
170+
// CHECK-NEXT: spv.Branch ^bb1
171+
// CHECK-NEXT: ^bb1:
172+
// CHECK-NEXT: spv.BranchConditional %[[TRUE]], ^bb2, ^bb4
173+
// CHECK-NEXT: ^bb2:
174+
// CHECK-NEXT: spv.Branch ^bb3
175+
// CHECK-NEXT: ^bb3:
176+
// CHECK-NEXT: spv.Branch ^bb1
177+
// CHECK-NEXT: ^bb4:
178+
// CHECK-NEXT: spv._merge
179+
spv.FunctionCall @callee(%0) : (i1) -> ()
180+
spv.Return
181+
}
182+
}

0 commit comments

Comments
 (0)