Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 8bb5582

Browse files
River707tensorflower-gardener
authored andcommitted
Add (parse|print)OptionalAttrDictWithKeyword hooks to simplify parsing attribute dictionaries with regions.
Many operations with regions add an additional 'attributes' prefix when printing the attribute dictionary to differentiate it from the region body. This leads to duplicated logic for detecting when to actually print the attribute dictionary. PiperOrigin-RevId: 278747681
1 parent fa12ec1 commit 8bb5582

File tree

6 files changed

+51
-45
lines changed

6 files changed

+51
-45
lines changed

include/mlir/IR/OpImplementation.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ class OpAsmPrinter {
7878
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
7979
ArrayRef<StringRef> elidedAttrs = {}) = 0;
8080

81+
/// If the specified operation has attributes, print out an attribute
82+
/// dictionary prefixed with 'attributes'.
83+
virtual void
84+
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
85+
ArrayRef<StringRef> elidedAttrs = {}) = 0;
86+
8187
/// Print the entire operation with the default generic assembly form.
8288
virtual void printGenericOp(Operation *op) = 0;
8389

@@ -342,6 +348,11 @@ class OpAsmParser {
342348
virtual ParseResult
343349
parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) = 0;
344350

351+
/// Parse a named dictionary into 'result' if the `attributes` keyword is
352+
/// present.
353+
virtual ParseResult
354+
parseOptionalAttrDictWithKeyword(SmallVectorImpl<NamedAttribute> &result) = 0;
355+
345356
//===--------------------------------------------------------------------===//
346357
// Identifier Parsing
347358
//===--------------------------------------------------------------------===//

lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,10 +1628,8 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
16281628
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
16291629
return failure();
16301630

1631-
if (succeeded(parser.parseOptionalKeyword("attributes"))) {
1632-
if (parser.parseOptionalAttrDict(state.attributes))
1633-
return failure();
1634-
}
1631+
if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1632+
return failure();
16351633

16361634
spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
16371635
return success();
@@ -1657,19 +1655,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
16571655

16581656
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
16591657
/*printBlockTerminators=*/false);
1660-
1661-
bool printAttrDict =
1662-
elidedAttrs.size() != 2 ||
1663-
llvm::any_of(op->getAttrs(), [&addressingModelAttrName,
1664-
&memoryModelAttrName](NamedAttribute attr) {
1665-
return attr.first != addressingModelAttrName &&
1666-
attr.first != memoryModelAttrName;
1667-
});
1668-
1669-
if (printAttrDict) {
1670-
printer << " attributes";
1671-
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1672-
}
1658+
printer.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
16731659
}
16741660

16751661
static LogicalResult verify(spirv::ModuleOp moduleOp) {

lib/IR/AsmPrinter.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ class ModulePrinter {
421421

422422
protected:
423423
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
424-
ArrayRef<StringRef> elidedAttrs = {});
424+
ArrayRef<StringRef> elidedAttrs = {},
425+
bool withKeyword = false);
425426
void printTrailingLocation(Location loc);
426427
void printLocationInternal(LocationAttr loc, bool pretty = false);
427428
void printDenseElementsAttr(DenseElementsAttr attr);
@@ -1327,27 +1328,26 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
13271328
//===----------------------------------------------------------------------===//
13281329

13291330
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1330-
ArrayRef<StringRef> elidedAttrs) {
1331+
ArrayRef<StringRef> elidedAttrs,
1332+
bool withKeyword) {
13311333
// If there are no attributes, then there is nothing to be done.
13321334
if (attrs.empty())
13331335
return;
13341336

13351337
// Filter out any attributes that shouldn't be included.
1336-
SmallVector<NamedAttribute, 8> filteredAttrs;
1337-
for (auto attr : attrs) {
1338-
// If the caller has requested that this attribute be ignored, then drop it.
1339-
if (llvm::any_of(elidedAttrs,
1340-
[&](StringRef elided) { return attr.first.is(elided); }))
1341-
continue;
1342-
1343-
// Otherwise add it to our filteredAttrs list.
1344-
filteredAttrs.push_back(attr);
1345-
}
1338+
SmallVector<NamedAttribute, 8> filteredAttrs(
1339+
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1340+
return !llvm::is_contained(elidedAttrs, attr.first.strref());
1341+
}));
13461342

13471343
// If there are no attributes left to print after filtering, then we're done.
13481344
if (filteredAttrs.empty())
13491345
return;
13501346

1347+
// Print the 'attributes' keyword if necessary.
1348+
if (withKeyword)
1349+
os << " attributes ";
1350+
13511351
// Otherwise, print them all out in braces.
13521352
os << " {";
13531353
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
@@ -1389,8 +1389,14 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
13891389

13901390
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
13911391
ArrayRef<StringRef> elidedAttrs = {}) override {
1392-
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
1393-
};
1392+
ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
1393+
}
1394+
void printOptionalAttrDictWithKeyword(
1395+
ArrayRef<NamedAttribute> attrs,
1396+
ArrayRef<StringRef> elidedAttrs = {}) override {
1397+
ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
1398+
/*withKeyword=*/true);
1399+
}
13941400

13951401
enum { nameSentinel = ~0U };
13961402

lib/IR/FunctionSupport.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,8 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
183183
<< (errorMessage.empty() ? "" : ": ") << errorMessage;
184184

185185
// If function attributes are present, parse them.
186-
if (succeeded(parser.parseOptionalKeyword("attributes")))
187-
if (parser.parseOptionalAttrDict(result.attributes))
188-
return failure();
186+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
187+
return failure();
189188

190189
// Add the attributes to the function arguments.
191190
SmallString<8> attrNameBuf;

lib/IR/Module.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
4848
result.attributes);
4949

5050
// If module attributes are present, parse them.
51-
if (succeeded(parser.parseOptionalKeyword("attributes")))
52-
if (parser.parseOptionalAttrDict(result.attributes))
53-
return failure();
51+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
52+
return failure();
5453

5554
// Parse the module body.
5655
auto *body = result.addRegion();
@@ -65,18 +64,14 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
6564
void ModuleOp::print(OpAsmPrinter &p) {
6665
p << "module";
6766

68-
Optional<StringRef> name = getName();
69-
if (name) {
67+
if (Optional<StringRef> name = getName()) {
7068
p << ' ';
7169
p.printSymbolName(*name);
7270
}
7371

7472
// Print the module attributes.
75-
auto attrs = getAttrs();
76-
if (!attrs.empty() && !(attrs.size() == 1 && name)) {
77-
p << " attributes";
78-
p.printOptionalAttrDict(attrs, {mlir::SymbolTable::getSymbolAttrName()});
79-
}
73+
p.printOptionalAttrDictWithKeyword(getAttrs(),
74+
{mlir::SymbolTable::getSymbolAttrName()});
8075

8176
// Print the region.
8277
p.printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,

lib/Parser/Parser.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ Attribute Parser::parseAttribute(Type type) {
15331533
///
15341534
ParseResult
15351535
Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
1536-
if (!consumeIf(Token::l_brace))
1536+
if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
15371537
return failure();
15381538

15391539
auto parseElt = [&]() -> ParseResult {
@@ -3874,6 +3874,15 @@ class CustomOpAsmParser : public OpAsmParser {
38743874
return parser.parseAttributeDict(result);
38753875
}
38763876

3877+
/// Parse a named dictionary into 'result' if the `attributes` keyword is
3878+
/// present.
3879+
ParseResult parseOptionalAttrDictWithKeyword(
3880+
SmallVectorImpl<NamedAttribute> &result) override {
3881+
if (failed(parseOptionalKeyword("attributes")))
3882+
return success();
3883+
return parser.parseAttributeDict(result);
3884+
}
3885+
38773886
//===--------------------------------------------------------------------===//
38783887
// Identifier Parsing
38793888
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)