Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Configurable call-back method name for AllocOp and DeallocOp #55

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions include/mlir/StandardOps/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,28 @@ def AllocOp : Std_Op<"alloc"> {
let arguments = (ins Variadic<Index>:$value);
let results = (outs AnyMemRef);

let builders = [OpBuilder<
let builders = [
OpBuilder<
"Builder *builder, OperationState *result, MemRefType memrefType", [{
result->types.push_back(memrefType);
}]
>];
result->types.push_back(memrefType);
}]>,
OpBuilder<
"Builder *builder, OperationState *result, MemRefType memrefType, "
"StringRef callbackName", [{
result->types.push_back(memrefType);
assert(callbackName != "");
result->addAttribute("callbackName", builder->getSymbolRefAttr(callbackName));
}]>
];

let extraClassDeclaration = [{
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
StringRef getCallbackName() {
Attribute attr = getAttr("callbackName");
if (attr == Attribute(nullptr))
return "";
return attr.cast<mlir::StringAttr>().getValue();
}
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -524,6 +538,24 @@ def DeallocOp : Std_Op<"dealloc"> {
let arguments = (ins AnyMemRef:$memref);

let hasCanonicalizer = 1;

let builders = [OpBuilder<
"Builder *builder, OperationState *result, Value* memref, "
"StringRef callbackName", [{
result->operands.push_back(memref);
assert (callbackName != "");
result->addAttribute("callbackName", builder->getSymbolRefAttr(callbackName));
}]>
];

let extraClassDeclaration = [{
StringRef getCallbackName() {
Attribute attr = getAttr("callbackName");
if (attr == Attribute(nullptr))
return "";
return attr.cast<mlir::StringAttr>().getValue();
}
}];
}

def DimOp : Std_Op<"dim", [NoSideEffect]> {
Expand Down
23 changes: 17 additions & 6 deletions lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,19 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
cumulativeSize,
createIndexConstant(rewriter, op->getLoc(), elementSize)});

// Insert the `malloc` declaration if it is not already present.
// Insert the alloc callback declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
StringRef allocFuncName = allocOp.getCallbackName();

if (allocFuncName == "")
allocFuncName = "malloc";

FuncOp mallocFunc = module.lookupSymbol<FuncOp>(allocFuncName);
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc =
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
FuncOp::create(rewriter.getUnknownLoc(), allocFuncName, mallocType);
module.push_back(mallocFunc);
}

Expand Down Expand Up @@ -541,12 +546,18 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);

// Insert the `free` declaration if it is not already present.

auto deallocOp = cast<DeallocOp>(op);
StringRef callbackName = deallocOp.getCallbackName();
if (callbackName == "")
callbackName = "free";

// Insert the free call-back declaration if it is not already present.
FuncOp freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(callbackName);
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), callbackName, freeType);
op->getParentOfType<ModuleOp>().push_back(freeFunc);
}

Expand Down
1 change: 1 addition & 0 deletions lib/StandardOps/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) {
MemRefType type;

return failure(parser->parseOperand(memrefInfo) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands));
}
Expand Down
19 changes: 19 additions & 0 deletions test/LLVMIR/convert-memref-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ func @zero_d_dealloc(%arg0: memref<f32>) {
return
}

// CHECK-LABEL: func @alloc_with_callback() -> !llvm<"float*"> {
func @alloc_with_callback() -> memref<f32> {
// CHECK-NEXT: %0 = llvm.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %1 = llvm.constant(4 : index) : !llvm.i64
// CHECK-NEXT: %2 = llvm.mul %0, %1 : !llvm.i64
// CHECK-NEXT: %3 = llvm.call @myMalloc(%2) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %4 = llvm.bitcast %3 : !llvm<"i8*"> to !llvm<"float*">
%0 = alloc() {callbackName="myMalloc"} : memref<f32>
return %0 : memref<f32>
}

// CHECK-LABEL: func @dealloc_with_callback(%arg0: !llvm<"float*">)
func @dealloc_with_callback(%arg0: memref<f32>) {
// CHECK-NEXT: %0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*">
// CHECK-NEXT: llvm.call @myFree(%0) : (!llvm<"i8*">) -> ()
dealloc %arg0 {callbackName="myFree"} : memref<f32>
return
}

// CHECK-LABEL: func @mixed_alloc(%arg0: !llvm.i64, %arg1: !llvm.i64) -> !llvm<"{ float*, i64, i64 }"> {
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %0 = llvm.constant(42 : index) : !llvm.i64
Expand Down