Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Fix the wrong computation of dynamic strides for lowering AllocOp to LLVM #338

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,19 @@ class PassOptions : protected llvm::cl::SubCommand {

/// Utility methods for printing option values.
template <typename DataT>
static void printOptionValue(raw_ostream &os,
GenericOptionParser<DataT> &parser,
const DataT &value) {
static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
const DataT &value) {
if (Optional<StringRef> argStr = parser.findArgStrForValue(value))
os << argStr;
else
llvm_unreachable("unknown data value for option");
}
template <typename DataT, typename ParserT>
static void printOptionValue(raw_ostream &os, ParserT &parser,
const DataT &value) {
static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
os << value;
}
template <typename ParserT>
static void printOptionValue(raw_ostream &os, ParserT &parser,
const bool &value) {
static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
os << (value ? StringRef("true") : StringRef("false"));
}

Expand Down Expand Up @@ -129,7 +126,7 @@ class PassOptions : protected llvm::cl::SubCommand {
/// Print the name and value of this option to the given stream.
void print(raw_ostream &os) final {
os << this->ArgStr << '=';
printOptionValue(os, this->getParser(), this->getValue());
printValue(os, this->getParser(), this->getValue());
}

/// Copy the value from the given option into this one.
Expand Down Expand Up @@ -172,7 +169,7 @@ class PassOptions : protected llvm::cl::SubCommand {
void print(raw_ostream &os) final {
os << this->ArgStr << '=';
auto printElementFn = [&](const DataType &value) {
printOptionValue(os, this->getParser(), value);
printValue(os, this->getParser(), value);
};
interleave(*this, os, printElementFn, ",");
}
Expand Down
7 changes: 4 additions & 3 deletions lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,14 +1054,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
SmallVector<Value, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
for (auto indexedStride : llvm::enumerate(strides)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, enumerate is useless here, the body of the loop never accesses indexedStride.value(). Could you just rewrite the loop to iterate on index to remove the confusion that led to the bug in the first place?

int64_t index = nStrides - 1 - indexedStride.index();
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
// `runningStride *= sizes[index + 1]`
runningStride =
runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
? rewriter.create<LLVM::MulOp>(loc, runningStride,
sizes[index + 1])
: createIndexConstant(rewriter, loc, 1);
else
runningStride = createIndexConstant(rewriter, loc, strides[index]);
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/StandardToLLVM/convert-memref-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[c42]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
Expand Down Expand Up @@ -142,7 +142,7 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
Expand Down