From cf00504ebf3be5eb8d489d502ac132ff9eef8c9b Mon Sep 17 00:00:00 2001 From: nmostafa Date: Tue, 30 Jul 2019 15:25:20 -0700 Subject: [PATCH] Enabling lowering to LLVM dialect based on attribute. Add mlir-opt tests --- include/mlir/StandardOps/Ops.td | 40 +++++++++++++++++-- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 23 ++++++++--- lib/StandardOps/Ops.cpp | 1 + test/LLVMIR/convert-memref-ops.mlir | 19 +++++++++ 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/include/mlir/StandardOps/Ops.td b/include/mlir/StandardOps/Ops.td index b6bf2cfb40b3..eb23192020d5 100644 --- a/include/mlir/StandardOps/Ops.td +++ b/include/mlir/StandardOps/Ops.td @@ -148,14 +148,28 @@ def AllocOp : Std_Op<"alloc"> { let arguments = (ins Variadic:$value); let results = (outs AnyMemRef); - let builders = [OpBuilder< + let builders = [ + OpBuilder< "Builder *builder, OperationState *result, MemRefType memrefType", [{ - result->types.push_back(memrefType); - }] - >]; + result->types.push_back(memrefType); + }]>, + OpBuilder< + "Builder *builder, OperationState *result, MemRefType memrefType, " + "StringRef callbackName", [{ + result->types.push_back(memrefType); + assert(callbackName != ""); + result->addAttribute("callbackName", builder->getSymbolRefAttr(callbackName)); + }]> + ]; let extraClassDeclaration = [{ MemRefType getType() { return getResult()->getType().cast(); } + StringRef getCallbackName() { + Attribute attr = getAttr("callbackName"); + if (attr == Attribute(nullptr)) + return ""; + return attr.cast().getValue(); + } }]; let hasCanonicalizer = 1; @@ -524,6 +538,24 @@ def DeallocOp : Std_Op<"dealloc"> { let arguments = (ins AnyMemRef:$memref); let hasCanonicalizer = 1; + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value* memref, " + "StringRef callbackName", [{ + result->operands.push_back(memref); + assert (callbackName != ""); + result->addAttribute("callbackName", builder->getSymbolRefAttr(callbackName)); + }]> + ]; + + let extraClassDeclaration = [{ + StringRef getCallbackName() { + Attribute attr = getAttr("callbackName"); + if (attr == Attribute(nullptr)) + return ""; + return attr.cast().getValue(); + } + }]; } def DimOp : Std_Op<"dim", [NoSideEffect]> { diff --git a/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index af8812c8cf4f..70213c8d4ec3 100644 --- a/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -478,14 +478,19 @@ struct AllocOpLowering : public LLVMLegalizationPattern { cumulativeSize, createIndexConstant(rewriter, op->getLoc(), elementSize)}); - // Insert the `malloc` declaration if it is not already present. + // Insert the alloc callback declaration if it is not already present. auto module = op->getParentOfType(); - FuncOp mallocFunc = module.lookupSymbol("malloc"); + StringRef allocFuncName = allocOp.getCallbackName(); + + if (allocFuncName == "") + allocFuncName = "malloc"; + + FuncOp mallocFunc = module.lookupSymbol(allocFuncName); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = - FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); + FuncOp::create(rewriter.getUnknownLoc(), allocFuncName, mallocType); module.push_back(mallocFunc); } @@ -541,12 +546,18 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); - // Insert the `free` declaration if it is not already present. + + auto deallocOp = cast(op); + StringRef callbackName = deallocOp.getCallbackName(); + if (callbackName == "") + callbackName = "free"; + + // Insert the free call-back declaration if it is not already present. FuncOp freeFunc = - op->getParentOfType().lookupSymbol("free"); + op->getParentOfType().lookupSymbol(callbackName); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); - freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); + freeFunc = FuncOp::create(rewriter.getUnknownLoc(), callbackName, freeType); op->getParentOfType().push_back(freeFunc); } diff --git a/lib/StandardOps/Ops.cpp b/lib/StandardOps/Ops.cpp index df99f00c1100..b8da03b19649 100644 --- a/lib/StandardOps/Ops.cpp +++ b/lib/StandardOps/Ops.cpp @@ -1218,6 +1218,7 @@ static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) { MemRefType type; return failure(parser->parseOperand(memrefInfo) || + parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type) || parser->resolveOperand(memrefInfo, type, result->operands)); } diff --git a/test/LLVMIR/convert-memref-ops.mlir b/test/LLVMIR/convert-memref-ops.mlir index 31658cf449c4..3c2813346f64 100644 --- a/test/LLVMIR/convert-memref-ops.mlir +++ b/test/LLVMIR/convert-memref-ops.mlir @@ -31,6 +31,25 @@ func @zero_d_dealloc(%arg0: memref) { return } +// CHECK-LABEL: func @alloc_with_callback() -> !llvm<"float*"> { +func @alloc_with_callback() -> memref { +// CHECK-NEXT: %0 = llvm.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %1 = llvm.constant(4 : index) : !llvm.i64 +// CHECK-NEXT: %2 = llvm.mul %0, %1 : !llvm.i64 +// CHECK-NEXT: %3 = llvm.call @myMalloc(%2) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %4 = llvm.bitcast %3 : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() {callbackName="myMalloc"} : memref + return %0 : memref +} + +// CHECK-LABEL: func @dealloc_with_callback(%arg0: !llvm<"float*">) +func @dealloc_with_callback(%arg0: memref) { + // CHECK-NEXT: %0 = llvm.bitcast %arg0 : !llvm<"float*"> to !llvm<"i8*"> + // CHECK-NEXT: llvm.call @myFree(%0) : (!llvm<"i8*">) -> () + dealloc %arg0 {callbackName="myFree"} : memref + return +} + // CHECK-LABEL: func @mixed_alloc(%arg0: !llvm.i64, %arg1: !llvm.i64) -> !llvm<"{ float*, i64, i64 }"> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref { // CHECK-NEXT: %0 = llvm.constant(42 : index) : !llvm.i64