//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc" using namespace mlir; //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for handling inlining with standard /// operations. struct StdInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { return true; } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { // Only "std.return" needs to be handled here. auto returnOp = dyn_cast<ReturnOp>(op); if (!returnOp) return; // Replace the return with a branch to the dest. OpBuilder builder(op); builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); op->erase(); } /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef<Value> valuesToRepl) const final { // Only "std.return" needs to be handled here. auto returnOp = cast<ReturnOp>(op); // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // StandardOpsDialect //===----------------------------------------------------------------------===// /// A custom unary operation printer that omits the "std." prefix from the /// operation names. static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 1 && "unary op should have one operand"); assert(op->getNumResults() == 1 && "unary op should have one result"); int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperand(0).getType(); } /// A custom binary operation printer that omits the "std." prefix from the /// operation names. static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. auto resultType = op->getResult(0).getType(); if (op->getOperand(0).getType() != resultType || op->getOperand(1).getType() != resultType) { p.printGenericOp(op); return; } int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0) << ", " << op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. p << " : " << op->getResult(0).getType(); } /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to " << op->getResult(0).getType(); } /// A custom cast operation verifier. template <typename T> static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand().getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) return op.emitError("operand type ") << opType << " and result type " << resType << " are cast incompatible"; return success(); } StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations<DmaStartOp, DmaWaitOp, #define GET_OP_LIST #include "mlir/Dialect/StandardOps/Ops.cpp.inc" >(); addInterfaces<StdInlinerInterface>(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create<ConstantOp>(loc, type, value); } void mlir::printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p) { Operation::operand_range operands(begin, end); p << '(' << operands.take_front(numDims) << ')'; if (operands.size() != numDims) p << '[' << operands.drop_front(numDims) << ']'; } // Parses dimension and symbol list, and sets 'numDims' to the number of // dimension operands parsed. // Returns 'false' on success and 'true' on error. ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) { SmallVector<OpAsmParser::OperandType, 8> opInfos; if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) return failure(); // Store number of dimensions for validation by caller. numDims = opInfos.size(); // Parse the optional symbol operands. auto indexTy = parser.getBuilder().getIndexType(); if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare) || parser.resolveOperands(opInfos, indexTy, operands)) return failure(); return success(); } /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { return detail::op_matcher<ConstantIndexOp>(); } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp()); if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { operand.set(cast.getOperand()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) { return constFoldBinaryOp<FloatAttr>( operands, [](APFloat a, APFloat b) { return a + b; }); } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) { /// addi(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a + b; }); } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, AllocOp op) { p << "alloc"; // Print dynamic dimension operands. MemRefType type = op.getType(); printDimAndSymbolList(op.operand_begin(), op.operand_end(), type.getNumDynamicDims(), p); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); p << " : " << type; } static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) { MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a // memref type. unsigned numDimOperands; if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return failure(); // Check numDynamicDims against number of question marks in memref type. // Note: this check remains here (instead of in verify()), because the // partition between dim operands and symbol operands is lost after parsing. // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. if (numDimOperands != type.getNumDynamicDims()) return parser.emitError(parser.getNameLoc()) << "dimension operand count does not equal memref dynamic dimension " "count"; result.types.push_back(type); return success(); } static LogicalResult verify(AllocOp op) { auto memRefType = op.getResult().getType().dyn_cast<MemRefType>(); if (!memRefType) return op.emitOpError("result must be a memref"); unsigned numSymbols = 0; if (!memRefType.getAffineMaps().empty()) { // Store number of symbols used in affine map (used in subsequent check). AffineMap affineMap = memRefType.getAffineMaps()[0]; numSymbols = affineMap.getNumSymbols(); } // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. unsigned numDynamicDims = memRefType.getNumDynamicDims(); if (op.getNumOperands() != numDynamicDims + numSymbols) return op.emitOpError( "operand count does not equal dimension plus symbol operand count"); // Verify that all operands are of type Index. for (auto operandType : op.getOperandTypes()) if (!operandType.isIndex()) return op.emitOpError("requires operands to be of type Index"); return success(); } namespace { /// Fold constant dimensions into an alloc operation. struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { using OpRewritePattern<AllocOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector<int64_t, 4> newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector<Value, 4> newOperands; SmallVector<Value, 4> droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); // Record to check for zero uses later below. droppedOperands.push_back(constantIndexOp); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); newOperands.push_back(alloc.getOperand(dynamicDimPos)); } dynamicDimPos++; } // Create new memref type (which will have fewer dynamic dimensions). auto newMemRefType = MemRefType::get( newShapeConstants, memrefType.getElementType(), memrefType.getAffineMaps(), memrefType.getMemorySpace()); assert(static_cast<int64_t>(newOperands.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands, IntegerAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, alloc.getType()); rewriter.replaceOp(alloc, {resultCast}, droppedOperands); return matchSuccess(); } }; /// Fold alloc operations with no uses. Alloc has side effects on the heap, /// but can still be deleted if it has zero uses. struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { using OpRewritePattern<AllocOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { if (alloc.use_empty()) { rewriter.eraseOp(alloc); return matchSuccess(); } return matchFailure(); } }; } // end anonymous namespace. void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context); } //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// namespace { /// Simplify a branch to a block that has a single predecessor. This effectively /// merges the two blocks. struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> { using OpRewritePattern<BranchOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(BranchOp op, PatternRewriter &rewriter) const override { // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); if (succ == opParent || !has_single_element(succ->getPredecessors())) return matchFailure(); // Merge the successor into the current block and erase the branch. rewriter.mergeBlocks(succ, opParent, op.getOperands()); rewriter.eraseOp(op); return matchSuccess(); } }; } // end anonymous namespace. static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { Block *dest; SmallVector<Value, 4> destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); return success(); } static void print(OpAsmPrinter &p, BranchOp op) { p << "br "; p.printSuccessorAndUseList(op.getOperation(), 0); } Block *BranchOp::getDest() { return getSuccessor(0); } void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } void BranchOp::eraseOperand(unsigned index) { getOperation()->eraseSuccessorOperand(0, index); } void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<SimplifyBrToBlockWithSinglePred>(context); } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { FlatSymbolRefAttr calleeAttr; FunctionType calleeType; SmallVector<OpAsmParser::OperandType, 4> operands; auto calleeLoc = parser.getNameLoc(); if (parser.parseAttribute(calleeAttr, "callee", result.attributes) || parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(calleeType) || parser.addTypesToList(calleeType.getResults(), result.types) || parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, result.operands)) return failure(); return success(); } static void print(OpAsmPrinter &p, CallOp op) { p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : " << op.getCalleeType(); } static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' symbol reference attribute"); auto fn = op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getType(); if (fnType.getNumInputs() != op.getNumOperands()) return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (op.getOperand(i).getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (op.getResult(i).getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); } FunctionType CallOp::getCalleeType() { SmallVector<Type, 4> resultTypes(getResultTypes()); SmallVector<Type, 8> argTypes(getOperandTypes()); return FunctionType::get(argTypes, resultTypes, getContext()); } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// namespace { /// Fold indirect calls that have a constant function as the callee operand. struct SimplifyIndirectCallWithKnownCallee : public OpRewritePattern<CallIndirectOp> { using OpRewritePattern<CallIndirectOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, PatternRewriter &rewriter) const override { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return matchFailure(); // Replace with a direct call. SmallVector<Type, 8> callResults(indirectCall.getResultTypes()); rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults, indirectCall.getArgOperands()); return matchSuccess(); } }; } // end anonymous namespace. static ParseResult parseCallIndirectOp(OpAsmParser &parser, OperationState &result) { FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector<OpAsmParser::OperandType, 4> operands; return failure( parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(calleeType) || parser.resolveOperand(callee, calleeType, result.operands) || parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, result.operands) || parser.addTypesToList(calleeType.getResults(), result.types)); } static void print(OpAsmPrinter &p, CallIndirectOp op) { p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : " << op.getCallee().getType(); } static LogicalResult verify(CallIndirectOp op) { // The callee must be a function. auto fnType = op.getCallee().getType().dyn_cast<FunctionType>(); if (!fnType) return op.emitOpError("callee must have function type"); // Verify that the operand and result types match the callee. if (fnType.getNumInputs() != op.getNumOperands() - 1) return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (op.getOperand(i + 1).getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (op.getResult(i).getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); } void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<SimplifyIndirectCallWithKnownCallee>(context); } //===----------------------------------------------------------------------===// // General helpers for comparison ops //===----------------------------------------------------------------------===// // Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getCheckedI1SameShape(Builder *build, Type type) { auto i1Type = build->getI1Type(); if (type.isIntOrIndexOrFloat()) return i1Type; if (auto tensorType = type.dyn_cast<RankedTensorType>()) return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa<UnrankedTensorType>()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast<VectorType>()) return VectorType::get(vectorType.getShape(), i1Type); return Type(); } static Type getI1SameShape(Builder *build, Type type) { Type res = getCheckedI1SameShape(build, type); assert(res && "expected type with valid i1 shape"); return res; } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// static void buildCmpIOp(Builder *build, OperationState &result, CmpIPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs.getType())); result.addAttribute( CmpIOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast<int64_t>(predicate))); } static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<NamedAttribute, 4> attrs; Attribute predicateNameAttr; Type type; if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return failure(); if (!predicateNameAttr.isa<StringAttr>()) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName); if (!predicate.hasValue()) return parser.emitError(parser.getNameLoc()) << "unknown comparison predicate \"" << predicateName << "\""; auto builder = parser.getBuilder(); Type i1Type = getCheckedI1SameShape(&builder, type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate)); result.attributes = attrs; result.addTypes({i1Type}); return success(); } static void print(OpAsmPrinter &p, CmpIOp op) { p << "cmpi "; Builder b(op.getContext()); auto predicateValue = op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue)) << '"' << ", " << op.lhs() << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); p << " : " << op.lhs().getType(); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { switch (predicate) { case CmpIPredicate::eq: return lhs.eq(rhs); case CmpIPredicate::ne: return lhs.ne(rhs); case CmpIPredicate::slt: return lhs.slt(rhs); case CmpIPredicate::sle: return lhs.sle(rhs); case CmpIPredicate::sgt: return lhs.sgt(rhs); case CmpIPredicate::sge: return lhs.sge(rhs); case CmpIPredicate::ult: return lhs.ult(rhs); case CmpIPredicate::ule: return lhs.ule(rhs); case CmpIPredicate::ugt: return lhs.ugt(rhs); case CmpIPredicate::uge: return lhs.uge(rhs); } llvm_unreachable("unknown comparison predicate"); } // Constant folding hook for comparisons. OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// // Returns an array of mnemonics for CmpFPredicates indexed by values thereof. static inline const char *const *getCmpFPredicateNames() { static const char *predicateNames[] = { /*AlwaysFalse*/ "false", /*OEQ*/ "oeq", /*OGT*/ "ogt", /*OGE*/ "oge", /*OLT*/ "olt", /*OLE*/ "ole", /*ONE*/ "one", /*ORD*/ "ord", /*UEQ*/ "ueq", /*UGT*/ "ugt", /*UGE*/ "uge", /*ULT*/ "ult", /*ULE*/ "ule", /*UNE*/ "une", /*UNO*/ "uno", /*AlwaysTrue*/ "true", }; static_assert(std::extent<decltype(predicateNames)>::value == (size_t)CmpFPredicate::NumPredicates, "wrong number of predicate names"); return predicateNames; } // Returns a value of the predicate corresponding to the given mnemonic. // Returns NumPredicates (one-past-end) if there is no such mnemonic. CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { return llvm::StringSwitch<CmpFPredicate>(name) .Case("false", CmpFPredicate::AlwaysFalse) .Case("oeq", CmpFPredicate::OEQ) .Case("ogt", CmpFPredicate::OGT) .Case("oge", CmpFPredicate::OGE) .Case("olt", CmpFPredicate::OLT) .Case("ole", CmpFPredicate::OLE) .Case("one", CmpFPredicate::ONE) .Case("ord", CmpFPredicate::ORD) .Case("ueq", CmpFPredicate::UEQ) .Case("ugt", CmpFPredicate::UGT) .Case("uge", CmpFPredicate::UGE) .Case("ult", CmpFPredicate::ULT) .Case("ule", CmpFPredicate::ULE) .Case("une", CmpFPredicate::UNE) .Case("uno", CmpFPredicate::UNO) .Case("true", CmpFPredicate::AlwaysTrue) .Default(CmpFPredicate::NumPredicates); } static void buildCmpFOp(Builder *build, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs.getType())); result.addAttribute( CmpFOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast<int64_t>(predicate))); } static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<NamedAttribute, 4> attrs; Attribute predicateNameAttr; Type type; if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return failure(); if (!predicateNameAttr.isa<StringAttr>()) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); auto predicate = CmpFOp::getPredicateByName(predicateName); if (predicate == CmpFPredicate::NumPredicates) return parser.emitError(parser.getNameLoc(), "unknown comparison predicate \"" + predicateName + "\""); auto builder = parser.getBuilder(); Type i1Type = getCheckedI1SameShape(&builder, type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); result.attributes = attrs; result.addTypes({i1Type}); return success(); } static void print(OpAsmPrinter &p, CmpFOp op) { p << "cmpf "; auto predicateValue = op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && "unknown predicate index"); p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs() << ", " << op.rhs(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); p << " : " << op.lhs().getType(); } static LogicalResult verify(CmpFOp op) { auto predicateAttr = op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); if (!predicateAttr) return op.emitOpError("requires an integer attribute named 'predicate'"); auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpFPredicate::FirstValidValue || predicate >= (int64_t)CmpFPredicate::NumPredicates) return op.emitOpError("'predicate' attribute value out of range"); return success(); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point // comparison predicates. static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs) { auto cmpResult = lhs.compare(rhs); switch (predicate) { case CmpFPredicate::AlwaysFalse: return false; case CmpFPredicate::OEQ: return cmpResult == APFloat::cmpEqual; case CmpFPredicate::OGT: return cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::OGE: return cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::OLT: return cmpResult == APFloat::cmpLessThan; case CmpFPredicate::OLE: return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ONE: return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; case CmpFPredicate::ORD: return cmpResult != APFloat::cmpUnordered; case CmpFPredicate::UEQ: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UGT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::UGE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ULT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan; case CmpFPredicate::ULE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UNE: return cmpResult != APFloat::cmpEqual; case CmpFPredicate::UNO: return cmpResult == APFloat::cmpUnordered; case CmpFPredicate::AlwaysTrue: return true; default: llvm_unreachable("unknown comparison predicate"); } } // Constant folding hook for comparisons. OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "cmpf takes two arguments"); auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); // TODO(gcmn) We could actually do some intelligent things if we know only one // of the operands, but it's inf or nan. if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// namespace { /// cond_br true, ^bb1, ^bb2 -> br ^bb1 /// cond_br false, ^bb1, ^bb2 -> br ^bb2 /// struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { using OpRewritePattern<CondBranchOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { if (matchPattern(condbr.getCondition(), m_NonZero())) { // True branch taken. rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), condbr.getTrueOperands()); return matchSuccess(); } else if (matchPattern(condbr.getCondition(), m_Zero())) { // False branch taken. rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), condbr.getFalseOperands()); return matchSuccess(); } return matchFailure(); } }; } // end anonymous namespace. static ParseResult parseCondBranchOp(OpAsmParser &parser, OperationState &result) { SmallVector<Value, 4> destOperands; Block *dest; OpAsmParser::OperandType condInfo; // Parse the condition. Type int1Ty = parser.getBuilder().getI1Type(); if (parser.parseOperand(condInfo) || parser.parseComma() || parser.resolveOperand(condInfo, int1Ty, result.operands)) { return parser.emitError(parser.getNameLoc(), "expected condition type was boolean (i1)"); } // Parse the true successor. if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); // Parse the false successor. destOperands.clear(); if (parser.parseComma() || parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); return success(); } static void print(OpAsmPrinter &p, CondBranchOp op) { p << "cond_br " << op.getCondition() << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); p << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); } void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert<SimplifyConstCondBranchPred>(context); } //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstantOp &op) { p << "constant "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); if (op.getAttrs().size() > 1) p << ' '; p << op.getValue(); // If the value is a symbol reference, print a trailing type. if (op.getValue().isa<SymbolRefAttr>()) p << " : " << op.getType(); } static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &result) { Attribute valueAttr; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); // If the attribute is a symbol reference, then we expect a trailing type. Type type; if (!valueAttr.isa<SymbolRefAttr>()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); // Add the attribute type to the list. return parser.addTypeToList(type, result.types); } /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. static LogicalResult verify(ConstantOp &op) { auto value = op.getValue(); if (!value) return op.emitOpError("requires a 'value' attribute"); auto type = op.getType(); if (!value.getType().isa<NoneType>() && type != value.getType()) return op.emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; if (type.isa<IndexType>() || value.isa<BoolAttr>()) return success(); if (auto intAttr = value.dyn_cast<IntegerAttr>()) { // If the type has a known bitwidth we verify that the value can be // represented with the given bitwidth. auto bitwidth = type.cast<IntegerType>().getWidth(); auto intVal = intAttr.getValue(); if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) return op.emitOpError("requires 'value' to be an integer within the " "range of the integer result type"); return success(); } if (type.isa<FloatType>()) { if (!value.isa<FloatAttr>()) return op.emitOpError("requires 'value' to be a floating point constant"); return success(); } if (type.isa<ShapedType>()) { if (!value.isa<ElementsAttr>()) return op.emitOpError("requires 'value' to be a shaped constant"); return success(); } if (type.isa<FunctionType>()) { auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>(); if (!fnAttr) return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. auto fn = op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); // Check that the referenced function has the correct type. if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); } if (type.isa<NoneType>() && value.isa<UnitAttr>()) return success(); return op.emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { assert(operands.empty() && "constant has no operands"); return getValue(); } void ConstantOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { Type type = getType(); if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { IntegerType intTy = type.dyn_cast<IntegerType>(); // Sugar i1 constants with 'true' and 'false'. if (intTy && intTy.getWidth() == 1) return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); // Otherwise, build a complex name with the value and type. SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << 'c' << intCst.getInt(); if (intTy) specialName << '_' << type; setNameFn(getResult(), specialName.str()); } else if (type.isa<FunctionType>()) { setNameFn(getResult(), "f"); } else { setNameFn(getResult(), "cst"); } } /// Returns true if a constant operation can be built with the given value and /// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { // SymbolRefAttr can only be used with a function type. if (value.isa<SymbolRefAttr>()) return type.isa<FunctionType>(); // Otherwise, the attribute must have the same type as 'type'. if (value.getType() != type) return false; // Finally, check that the attribute kind is handled. return value.isa<BoolAttr>() || value.isa<IntegerAttr>() || value.isa<FloatAttr>() || value.isa<ElementsAttr>() || value.isa<UnitAttr>(); } void ConstantFloatOp::build(Builder *builder, OperationState &result, const APFloat &value, FloatType type) { ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); } bool ConstantFloatOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isa<IntegerType>(); } void ConstantIntOp::build(Builder *builder, OperationState &result, int64_t value, unsigned width) { Type type = builder->getIntegerType(width); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(Builder *builder, OperationState &result, int64_t value, Type type) { assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState &result, int64_t value) { Type type = builder->getIndexType(); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// namespace { /// Fold Dealloc operations that are deallocating an AllocOp that is only used /// by other Dealloc operations. struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { using OpRewritePattern<DeallocOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value memref = dealloc.memref(); if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. for (auto *user : memref.getUsers()) if (!isa<DeallocOp>(user)) return matchFailure(); // Erase the dealloc operation. rewriter.eraseOp(dealloc); return matchSuccess(); } }; } // end anonymous namespace. static void print(OpAsmPrinter &p, DeallocOp op) { p << "dealloc " << op.memref() << " : " << op.memref().getType(); } static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; MemRefType type; return failure(parser.parseOperand(memrefInfo) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands)); } static LogicalResult verify(DeallocOp op) { if (!op.memref().getType().isa<MemRefType>()) return op.emitOpError("operand must be a memref"); return success(); } void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<SimplifyDeadDealloc>(context); } LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, SmallVectorImpl<OpFoldResult> &results) { /// dealloc(memrefcast) -> dealloc return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, DimOp op) { p << "dim " << op.getOperand() << ", " << op.getIndex(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); p << " : " << op.getOperand().getType(); } static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; Type type; Type indexType = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(operandInfo) || parser.parseComma() || parser.parseAttribute(indexAttr, indexType, "index", result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(operandInfo, type, result.operands) || parser.addTypeToList(indexType, result.types)); } static LogicalResult verify(DimOp op) { // Check that we have an integer index operand. auto indexAttr = op.getAttrOfType<IntegerAttr>("index"); if (!indexAttr) return op.emitOpError("requires an integer attribute named 'index'"); int64_t index = indexAttr.getValue().getSExtValue(); auto type = op.getOperand().getType(); if (auto tensorType = type.dyn_cast<RankedTensorType>()) { if (index >= tensorType.getRank()) return op.emitOpError("index is out of range"); } else if (auto memrefType = type.dyn_cast<MemRefType>()) { if (index >= memrefType.getRank()) return op.emitOpError("index is out of range"); } else if (type.isa<UnrankedTensorType>()) { // ok, assumed to be in-range. } else { return op.emitOpError("requires an operand with tensor or memref type"); } return success(); } OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { // Constant fold dim when the size along the index referred to is a constant. auto opType = memrefOrTensor().getType(); int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast<RankedTensorType>()) indexSize = tensorType.getShape()[getIndex()]; else if (auto memrefType = opType.dyn_cast<MemRefType>()) indexSize = memrefType.getShape()[getIndex()]; if (!ShapedType::isDynamic(indexSize)) return IntegerAttr::get(IndexType::get(getContext()), indexSize); // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp. auto memrefType = opType.dyn_cast<MemRefType>(); if (!memrefType) return {}; // The size at getIndex() is now a dynamic size of a memref. auto memref = memrefOrTensor().getDefiningOp(); if (auto alloc = dyn_cast_or_null<AllocOp>(memref)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(getIndex())); if (auto view = dyn_cast_or_null<ViewOp>(memref)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(getIndex())); // The subview op here is expected to have rank dynamic sizes now. if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) { auto sizes = subview.sizes(); if (!sizes.empty()) return *(sizes.begin() + getIndex()); } /// dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } //===----------------------------------------------------------------------===// // SignedDivIOp //===----------------------------------------------------------------------===// OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } return a.sdiv_ov(b, overflowOrDiv0); }); return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // UnsignedDivIOp //===----------------------------------------------------------------------===// OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { if (div0 || !b) { div0 = true; return a; } return a.udiv(b); }); return div0 ? Attribute() : result; } // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, ValueRange destIndices, Value numElements, Value tagMemRef, ValueRange tagIndices, Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addOperands(destIndices); result.addOperands({numElements, tagMemRef}); result.addOperands(tagIndices); if (stride) result.addOperands({stride, elementsPerStride}); } void DmaStartOp::print(OpAsmPrinter &p) { p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; if (isStrided()) p << ", " << getStride() << ", " << getNumElementsPerStride(); p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() << ", " << getTagMemRef().getType(); } // Parse DmaStartOp. // Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index], %stride, %num_elt_per_stride : // : memref<3076 x f32, 0>, // memref<1024 x f32, 2>, // memref<1 x i32> // ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; OpAsmParser::OperandType numElementsInfo; OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; SmallVector<OpAsmParser::OperandType, 2> strideInfo; SmallVector<Type, 3> types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). // *) destination memref followed by its indices (in square brackets). // *) dma size in KiB. if (parser.parseOperand(srcMemRefInfo) || parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseComma() || parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) return failure(); bool isStrided = strideInfo.size() == 2; if (!strideInfo.empty() && !isStrided) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcIndexInfos, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstIndexInfos, indexType, result.operands) || // size should be an index. parser.resolveOperand(numElementsInfo, indexType, result.operands) || parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || // tag indices should be index. parser.resolveOperands(tagIndexInfos, indexType, result.operands)) return failure(); auto memrefType0 = types[0].dyn_cast<MemRefType>(); if (!memrefType0) return parser.emitError(parser.getNameLoc(), "expected source to be of memref type"); auto memrefType1 = types[1].dyn_cast<MemRefType>(); if (!memrefType1) return parser.emitError(parser.getNameLoc(), "expected destination to be of memref type"); auto memrefType2 = types[2].dyn_cast<MemRefType>(); if (!memrefType2) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } // Check that source/destination index list size matches associated rank. if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() || static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank()) return parser.emitError(parser.getNameLoc(), "memref rank not equal to indices count"); if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank()) return parser.emitError(parser.getNameLoc(), "tag memref rank not equal to indices count"); return success(); } LogicalResult DmaStartOp::verify() { // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) return emitOpError("DMA should be between different memory spaces"); if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + getDstMemRefRank() + 3 + 1 && getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + getDstMemRefRank() + 3 + 1 + 2) { return emitOpError("incorrect number of operands"); } return success(); } LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, SmallVectorImpl<OpFoldResult> &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); } // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- void DmaWaitOp::build(Builder *builder, OperationState &result, Value tagMemRef, ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addOperands(tagIndices); result.addOperands(numElements); } void DmaWaitOp::print(OpAsmPrinter &p) { p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " << getNumElements(); p.printOptionalAttrDict(getAttrs()); p << " : " << getTagMemRef().getType(); } // Parse DmaWaitOp. // Eg: // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> // ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; Type type; auto indexType = parser.getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. if (parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseColonType(type) || parser.resolveOperand(tagMemrefInfo, type, result.operands) || parser.resolveOperands(tagIndexInfos, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); auto memrefType = type.dyn_cast<MemRefType>(); if (!memrefType) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank()) return parser.emitError(parser.getNameLoc(), "tag memref rank not equal to indices count"); return success(); } LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, SmallVectorImpl<OpFoldResult> &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExtractElementOp op) { p << "extract_element " << op.getAggregate() << '[' << op.getIndices(); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getAggregate().getType(); } static ParseResult parseExtractElementOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType aggregateInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; ShapedType type; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(aggregateInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(aggregateInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static LogicalResult verify(ExtractElementOp op) { auto aggregateType = op.getAggregate().getType().cast<ShapedType>(); // This should be possible with tablegen type constraints if (op.getType() != aggregateType.getElementType()) return op.emitOpError("result type must match element type of aggregate"); // Verify the # indices match if we have a ranked type. if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { assert(!operands.empty() && "extract_element takes at least one operand"); // The aggregate operand must be a known constant. Attribute aggregate = operands.front(); if (!aggregate) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) return splatAggregate.getSplatValue(); // Otherwise, collect the constant indices into the aggregate. SmallVector<uint64_t, 8> indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa<IntegerAttr>()) return {}; indices.push_back(indice.cast<IntegerAttr>().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = aggregate.dyn_cast<ElementsAttr>(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// // Index cast is applicable from index to integer and backwards. bool IndexCastOp::areCastCompatible(Type a, Type b) { return (a.isIndex() && b.isa<IntegerType>()) || (a.isa<IntegerType>() && b.isIndex()); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, LoadOp op) { p << "load " << op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; MemRefType type; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static LogicalResult verify(LoadOp op) { if (op.getType() != op.getMemRefType().getElementType()) return op.emitOpError("result type must match element type of memref"); if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("incorrect number of indices for load"); return success(); } OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// bool MemRefCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast<MemRefType>(); auto bT = b.dyn_cast<MemRefType>(); auto uaT = a.dyn_cast<UnrankedMemRefType>(); auto ubT = b.dyn_cast<UnrankedMemRefType>(); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) return false; if (aT.getAffineMaps() != bT.getAffineMaps()) { int64_t aOffset, bOffset; SmallVector<int64_t, 4> aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || failed(getStridesAndOffset(bT, bStrides, bOffset)) || aStrides.size() != bStrides.size()) return false; // Strides along a dimension/offset are compatible if the value in the // source memref is static and the value in the target memref is the // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { return (a == MemRefType::getDynamicStrideOrOffset() || b == MemRefType::getDynamicStrideOrOffset() || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; for (auto aStride : enumerate(aStrides)) if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) return false; } if (aT.getMemorySpace() != bT.getMemorySpace()) return false; // They must have the same rank, and any specified dimensions must match. if (aT.getRank() != bT.getRank()) return false; for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); if (aDim != -1 && bDim != -1 && aDim != bDim) return false; } return true; } else { if (!aT && !uaT) return false; if (!bT && !ubT) return false; // Unranked to unranked casting is unsupported if (uaT && ubT) return false; auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); if (aEltType != bEltType) return false; auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); if (aMemSpace != bMemSpace) return false; return true; } return false; } OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) { return constFoldBinaryOp<FloatAttr>( operands, [](APFloat a, APFloat b) { return a * b; }); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { /// muli(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// muli(x, 1) -> x if (matchPattern(rhs(), m_One())) return getOperand(0); // TODO: Handle the overflow case. return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a * b; }); } //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, PrefetchOp op) { p << PrefetchOp::getOperationName() << " " << op.memref() << '['; p.printOperands(op.indices()); p << ']' << ", " << (op.isWrite() ? "write" : "read"); p << ", locality<" << op.localityHint(); p << ">, " << (op.isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); p << " : " << op.getMemRefType(); } static ParseResult parsePrefetchOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; IntegerAttr localityHint; MemRefType type; StringRef readOrWrite, cacheType; auto indexTy = parser.getBuilder().getIndexType(); auto i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseKeyword(&readOrWrite) || parser.parseComma() || parser.parseKeyword("locality") || parser.parseLess() || parser.parseAttribute(localityHint, i32Type, "localityHint", result.attributes) || parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands)) return failure(); if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute( PrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); if (!cacheType.equals("data") && !cacheType.equals("instr")) return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute( PrefetchOp::getIsDataCacheAttrName(), parser.getBuilder().getBoolAttr(cacheType.equals("data"))); return success(); } static LogicalResult verify(PrefetchOp op) { if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("too few indices"); return success(); } LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, SmallVectorImpl<OpFoldResult> &results) { // prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RankOp op) { p << "rank " << op.getOperand() << " : " << op.getOperand().getType(); } static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType operandInfo; Type type; Type indexType = parser.getBuilder().getIndexType(); return failure(parser.parseOperand(operandInfo) || parser.parseColonType(type) || parser.resolveOperand(operandInfo, type, result.operands) || parser.addTypeToList(indexType, result.types)); } OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { // Constant fold rank when the rank of the tensor is known. auto type = getOperand().getType(); if (auto tensorType = type.dyn_cast<RankedTensorType>()) return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); return IntegerAttr(); } //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "remi_signed takes two operands"); auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } //===----------------------------------------------------------------------===// // UnsignedRemIOp //===----------------------------------------------------------------------===// OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "remi_unsigned takes two operands"); auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 2> opInfo; SmallVector<Type, 2> types; llvm::SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } static void print(OpAsmPrinter &p, ReturnOp op) { p << "return"; if (op.getNumOperands() != 0) p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); } static LogicalResult verify(ReturnOp op) { auto function = cast<FuncOp>(op.getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() << " operands, but enclosing function returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (op.getOperand(i).getType() != results[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")"; return success(); } //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { return a.isa<IntegerType>() && b.isa<FloatType>(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 3> ops; SmallVector<NamedAttribute, 4> attrs; Type type; if (parser.parseOperandList(ops, 3) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return failure(); auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); SmallVector<Type, 3> types = {i1Type, type, type}; return failure(parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands) || parser.addTypeToList(type, result.types)); } static void print(OpAsmPrinter &p, SelectOp op) { p << "select " << op.getOperands() << " : " << op.getTrueValue().getType(); p.printOptionalAttrDict(op.getAttrs()); } static LogicalResult verify(SelectOp op) { auto trueType = op.getTrueValue().getType(); auto falseType = op.getFalseValue().getType(); if (trueType != falseType) return op.emitOpError( "requires 'true' and 'false' arguments to be of the same type"); return success(); } OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { auto condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) return getTrueValue(); // select false, %0, %1 => %1 if (matchPattern(condition, m_Zero())) return getFalseValue(); return nullptr; } //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(SignExtendIOp op) { // Get the scalar type (which is either directly the type of the operand // or the vector's/tensor's element type. auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); // For now, index is forbidden for the source and the destination type. if (srcType.isa<IndexType>()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa<IndexType>()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast<IntegerType>().getWidth() >= dstType.cast<IntegerType>().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, SplatOp op) { p << "splat " << op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getType(); } static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType splatValueInfo; ShapedType shapedType; return failure(parser.parseOperand(splatValueInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(shapedType) || parser.resolveOperand(splatValueInfo, shapedType.getElementType(), result.operands) || parser.addTypeToList(shapedType, result.types)); } static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. if (op.getOperand().getType() != op.getType().cast<ShapedType>().getElementType()) return op.emitError("operand should be of elemental type of result type"); return success(); } // Constant folding hook for SplatOp. OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 1 && "splat takes one operand"); auto constOperand = operands.front(); if (!constOperand || (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>())) return {}; auto shapedType = getType().cast<ShapedType>(); assert(shapedType.getElementType() == constOperand.getType() && "incorrect input attribute type for folding"); // SplatElementsAttr::get treats single value for second arg as being a splat. return SplatElementsAttr::get(shapedType, {constOperand}); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, StoreOp op) { p << "store " << op.getValueToStore(); p << ", " << op.getMemRef() << '[' << op.getIndices() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; MemRefType memrefType; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.resolveOperand(storeValueInfo, memrefType.getElementType(), result.operands) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands)); } static LogicalResult verify(StoreOp op) { // First operand must have same type as memref element type. if (op.getValueToStore().getType() != op.getMemRefType().getElementType()) return op.emitOpError( "first operand must have same type memref element type"); if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) return op.emitOpError("store index operand count not equal to memref rank"); return success(); } LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, SmallVectorImpl<OpFoldResult> &results) { /// store(memrefcast) -> store return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) { return constFoldBinaryOp<FloatAttr>( operands, [](APFloat a, APFloat b) { return a - b; }); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a - b; }); } //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { /// and(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// and(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a & b; }); } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { /// or(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// or(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a | b; }); } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) { /// xor(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// xor(x,x) -> 0 if (lhs() == rhs()) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp<IntegerAttr>(operands, [](APInt a, APInt b) { return a ^ b; }); } //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// bool TensorCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast<TensorType>(); auto bT = b.dyn_cast<TensorType>(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; return succeeded(verifyCompatibleShape(aT, bT)); } OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op //===----------------------------------------------------------------------===// static Type getTensorTypeFromMemRefType(Builder &b, Type type) { if (auto memref = type.dyn_cast<MemRefType>()) return RankedTensorType::get(memref.getShape(), memref.getElementType()); return b.getNoneType(); } //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorLoadOp op) { p << "tensor_load " << op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand().getType(); } static ParseResult parseTensorLoadOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType op; Type type; return failure(parser.parseOperand(op) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(op, type, result.operands) || parser.addTypeToList( getTensorTypeFromMemRefType(parser.getBuilder(), type), result.types)); } //===----------------------------------------------------------------------===// // TensorStoreOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorStoreOp op) { p << "tensor_store " << op.tensor() << ", " << op.memref(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.memref().getType(); } static ParseResult parseTensorStoreOp(OpAsmParser &parser, OperationState &result) { SmallVector<OpAsmParser::OperandType, 2> ops; Type type; llvm::SMLoc loc = parser.getCurrentLocation(); return failure( parser.parseOperandList(ops, /*requiredOperandCount=*/2) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands( ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type}, loc, result.operands)); } //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// static LogicalResult verify(TruncateIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa<IndexType>()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa<IndexType>()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast<IntegerType>().getWidth() <= dstType.cast<IntegerType>().getWidth()) return op.emitError("operand type ") << srcType << " must be wider than result type " << dstType; return success(); } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; SmallVector<OpAsmParser::OperandType, 1> offsetInfo; SmallVector<OpAsmParser::OperandType, 4> sizesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; llvm::SMLoc offsetLoc; if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) return failure(); if (offsetInfo.size() > 1) return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; return failure( parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || parser.resolveOperands(offsetInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } static void print(OpAsmPrinter &p, ViewOp op) { p << op.getOperationName() << ' ' << op.getOperand(0) << '['; auto dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); p << "][" << op.getDynamicSizes() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } Value ViewOp::getDynamicOffset() { int64_t offset; SmallVector<int64_t, 4> strides; auto result = succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); assert(result); if (result && offset == MemRefType::getDynamicStrideOrOffset()) return getOperand(1); return nullptr; } static LogicalResult verifyDynamicStrides(MemRefType memrefType, ArrayRef<int64_t> strides) { ArrayRef<int64_t> shape = memrefType.getShape(); unsigned rank = memrefType.getRank(); assert(rank == strides.size()); bool dynamicStrides = false; for (int i = rank - 2; i >= 0; --i) { // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. if (ShapedType::isDynamic(shape[i + 1])) dynamicStrides = true; // If stride at dim 'i' is not dynamic, return error. if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) return failure(); } return success(); } static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0).getType().cast<MemRefType>(); auto viewType = op.getResult().getType().cast<MemRefType>(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || (baseType.getAffineMaps().size() == 1 && !baseType.getAffineMaps()[0].isIdentity())) return op.emitError("unsupported map for base memref type ") << baseType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; // Verify that the result memref type has a strided layout map. int64_t offset; SmallVector<int64_t, 4> strides; if (failed(getStridesAndOffset(viewType, strides, offset))) return op.emitError("result type ") << viewType << " is not strided"; // Verify that we have the correct number of operands for the result type. unsigned memrefOperandCount = 1; unsigned numDynamicDims = viewType.getNumDynamicDims(); unsigned dynamicOffsetCount = offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; if (op.getNumOperands() != memrefOperandCount + numDynamicDims + dynamicOffsetCount) return op.emitError("incorrect number of operands for type ") << viewType; // Verify dynamic strides symbols were added to correct dimensions based // on dynamic sizes. if (failed(verifyDynamicStrides(viewType, strides))) return op.emitError("incorrect dynamic strides in view memref type ") << viewType; return success(); } namespace { struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { using OpRewritePattern<ViewOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { // Return if none of the operands are constants. if (llvm::none_of(viewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); // Get result memref type. auto memrefType = viewOp.getType(); if (memrefType.getAffineMaps().size() > 1) return matchFailure(); auto map = memrefType.getAffineMaps().empty() ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(), rewriter.getContext()) : memrefType.getAffineMaps()[0]; // Get offset from old memref view type 'memRefType'. int64_t oldOffset; SmallVector<int64_t, 4> oldStrides; if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return matchFailure(); SmallVector<Value, 4> newOperands; SmallVector<Value, 4> droppedOperands; // Fold dynamic offset operand if it is produced by a constant. auto dynamicOffset = viewOp.getDynamicOffset(); int64_t newOffset = oldOffset; unsigned dynamicOffsetOperandCount = 0; if (dynamicOffset != nullptr) { auto *defOp = dynamicOffset.getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic offset will be folded into the map. newOffset = constantIndexOp.getValue(); droppedOperands.push_back(dynamicOffset); } else { // Unable to fold dynamic offset. Add it to 'newOperands' list. newOperands.push_back(dynamicOffset); dynamicOffsetOperandCount = 1; } } // Fold any dynamic dim operands which are produced by a constant. SmallVector<int64_t, 4> newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart(); unsigned rank = memrefType.getRank(); for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (!ShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); // Record to check for zero uses later below. droppedOperands.push_back(constantIndexOp); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(dimSize); newOperands.push_back(viewOp.getOperand(dynamicDimPos)); } dynamicDimPos++; } // Compute new strides based on 'newShapeConstants'. SmallVector<int64_t, 4> newStrides(rank); newStrides[rank - 1] = 1; bool dynamicStrides = false; for (int i = rank - 2; i >= 0; --i) { if (ShapedType::isDynamic(newShapeConstants[i + 1])) dynamicStrides = true; if (dynamicStrides) newStrides[i] = MemRefType::getDynamicStrideOrOffset(); else newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1]; } // Regenerate strided layout map with 'newStrides' and 'newOffset'. map = makeStridedLinearLayoutMap(newStrides, newOffset, rewriter.getContext()); // Create new memref type with constant folded dims and/or offset/strides. auto newMemRefType = MemRefType::get(newShapeConstants, memrefType.getElementType(), {map}, memrefType.getMemorySpace()); (void)dynamicOffsetOperandCount; // unused in opt mode assert(static_cast<int64_t>(newOperands.size()) == dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); // Create new ViewOp. auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp, newViewOp, viewOp.getType()); return matchSuccess(); } }; struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { using OpRewritePattern<ViewOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); MemRefCastOp memrefCastOp = dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp()); if (!memrefCastOp) return matchFailure(); Value allocOperand = memrefCastOp.getOperand(); AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp()); if (!allocOp) return matchFailure(); rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(), allocOperand, viewOp.operands()); return matchSuccess(); } }; } // end anonymous namespace void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); } //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// // Returns a MemRefType with dynamic sizes and offset and the same stride as the // `memRefType` passed as argument. // TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep // sizes and offset static. static Type inferSubViewResultType(MemRefType memRefType) { auto rank = memRefType.getRank(); int64_t offset; SmallVector<int64_t, 4> strides; Type elementType = memRefType.getElementType(); auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; // Assume sizes and offset are fully dynamic for now until canonicalization // occurs on the ranges. Typed strides don't change though. offset = MemRefType::getDynamicStrideOrOffset(); // Overwrite strides because verifier will not pass. // TODO(b/144419106): don't force degrade the strides to fully dynamic. for (auto &stride : strides) stride = MemRefType::getDynamicStrideOrOffset(); auto stridedLayout = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize); return MemRefType::get(sizes, elementType, stridedLayout, memRefType.getMemorySpace()); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, Type resultType, ArrayRef<NamedAttribute> attrs) { if (!resultType) resultType = inferSubViewResultType(source.getType().cast<MemRefType>()); auto segmentAttr = b->getI32VectorAttr( {1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()), static_cast<int32_t>(strides.size())}); build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); result.addAttributes(attrs); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, Value source) { build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, resultType); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; SmallVector<OpAsmParser::OperandType, 4> offsetsInfo; SmallVector<OpAsmParser::OperandType, 4> sizesInfo; SmallVector<OpAsmParser::OperandType, 4> stridesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; if (parser.parseOperand(srcInfo) || parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); } auto builder = parser.getBuilder(); result.addAttribute( SubViewOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()), static_cast<int32_t>(sizesInfo.size()), static_cast<int32_t>(stridesInfo.size())})); return failure( parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || parser.resolveOperands(offsetsInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.resolveOperands(stridesInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets() << "][" << op.sizes() << "][" << op.strides() << ']'; SmallVector<StringRef, 1> elidedAttrs = { SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } static LogicalResult verify(SubViewOp op) { auto baseType = op.getBaseMemRefType().cast<MemRefType>(); auto subViewType = op.getType(); // The rank of the base and result subview must match. if (baseType.getRank() != subViewType.getRank()) { return op.emitError( "expected rank of result type to match rank of base type "); } // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. int64_t baseOffset; SmallVector<int64_t, 4> baseStrides; if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) return op.emitError("base type ") << subViewType << " is not strided"; // Verify that the result memref type has a strided layout map. int64_t subViewOffset; SmallVector<int64_t, 4> subViewStrides; if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) return op.emitError("result type ") << subViewType << " is not strided"; // Num offsets should either be zero or rank of memref. if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { return op.emitError("expected number of dynamic offsets specified to match " "the rank of the result type ") << subViewType; } // Num sizes should either be zero or rank of memref. if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { return op.emitError("expected number of dynamic sizes specified to match " "the rank of the result type ") << subViewType; } // Num strides should either be zero or rank of memref. if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { return op.emitError("expected number of dynamic strides specified to match " "the rank of the result type ") << subViewType; } // Verify that if the shape of the subview type is static, then sizes are not // dynamic values, and vice versa. if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { return op.emitError("invalid to specify dynamic sizes when subview result " "type is statically shaped and viceversa"); } // Verify that if dynamic sizes are specified, then the result memref type // have full dynamic dimensions. if (op.getNumSizes() > 0) { if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { return dim != ShapedType::kDynamicSize; })) { // TODO: This is based on the assumption that number of size arguments are // either 0, or the rank of the result type. It is possible to have more // fine-grained verification where only particular dimensions are // dynamic. That probably needs further changes to the shape op // specification. return op.emitError("expected shape of result type to be fully dynamic " "when sizes are specified"); } } // Verify that if dynamic offsets are specified or base memref has dynamic // offset or base memref has dynamic strides, then the subview offset is // dynamic. if ((op.getNumOffsets() > 0 || baseOffset == MemRefType::getDynamicStrideOrOffset() || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset())) && subViewOffset != MemRefType::getDynamicStrideOrOffset()) { return op.emitError( "expected result memref layout map to have dynamic offset"); } // For now, verify that if dynamic strides are specified, then all the result // memref type have dynamic strides. if (op.getNumStrides() > 0) { if (llvm::any_of(subViewStrides, [](int64_t stride) { return stride != MemRefType::getDynamicStrideOrOffset(); })) { return op.emitError("expected result type to have dynamic strides"); } } // If any of the base memref has dynamic stride, then the corresponding // stride of the subview must also have dynamic stride. assert(baseStrides.size() == subViewStrides.size()); for (auto stride : enumerate(baseStrides)) { if (stride.value() == MemRefType::getDynamicStrideOrOffset() && subViewStrides[stride.index()] != MemRefType::getDynamicStrideOrOffset()) { return op.emitError( "expected result type to have dynamic stride along a dimension if " "the base memref type has dynamic stride along that dimension"); } } return success(); } raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { return os << "range " << range.offset << ":" << range.size << ":" << range.stride; } SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() { SmallVector<Range, 8> res; unsigned rank = getType().getRank(); res.reserve(rank); for (unsigned i = 0; i < rank; ++i) res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), *(strides().begin() + i)}); return res; } LogicalResult SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) { // If the strides are dynamic return failure. if (getNumStrides()) return failure(); // When static, the stride operands can be retrieved by taking the strides of // the result of the subview op, and dividing the strides of the base memref. int64_t resultOffset, baseOffset; SmallVector<int64_t, 2> resultStrides, baseStrides; if (failed( getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) return failure(); assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() && baseStrides.size() == resultStrides.size() && "base and result memrefs must have the same rank"); assert(!llvm::is_contained(resultStrides, MemRefType::getDynamicStrideOrOffset()) && "strides of subview op must be static, when there are no dynamic " "strides specified"); staticStrides.resize(getType().getRank()); for (auto resultStride : enumerate(resultStrides)) { auto baseStride = baseStrides[resultStride.index()]; // The result stride is expected to be a multiple of the base stride. Abort // if that is not the case. if (resultStride.value() < baseStride || resultStride.value() % baseStride != 0) return failure(); staticStrides[resultStride.index()] = resultStride.value() / baseStride; } return success(); } namespace { /// Pattern to rewrite a subview op with constant size arguments. class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> { public: using OpRewritePattern<SubViewOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { MemRefType subViewType = subViewOp.getType(); // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || llvm::any_of(subViewOp.sizes(), [](Value operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes()); for (auto size : llvm::enumerate(subViewOp.sizes())) { auto defOp = size.value().getDefiningOp(); assert(defOp); staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue(); } MemRefType newMemRefType = MemRefType::get( staticShape, subViewType.getElementType(), subViewType.getAffineMaps(), subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), ArrayRef<Value>(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType()); return matchSuccess(); } }; // Pattern to rewrite a subview op with constant stride arguments. class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> { public: using OpRewritePattern<SubViewOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { if (subViewOp.getNumStrides() == 0) { return matchFailure(); } // Follow all or nothing approach for strides for now. If all the operands // for strides are constants then fold it into the strides of the result // memref. int64_t baseOffset, resultOffset; SmallVector<int64_t, 4> baseStrides, resultStrides; MemRefType subViewType = subViewOp.getType(); if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, baseOffset)) || failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || llvm::any_of(subViewOp.strides(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides()); for (auto stride : llvm::enumerate(subViewOp.strides())) { auto defOp = stride.value().getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); staticStrides[stride.index()] = cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()]; } AffineMap layoutMap = makeStridedLinearLayoutMap( staticStrides, resultOffset, rewriter.getContext()); MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), subViewOp.sizes(), ArrayRef<Value>(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType()); return matchSuccess(); } }; // Pattern to rewrite a subview op with constant offset arguments. class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> { public: using OpRewritePattern<SubViewOp>::OpRewritePattern; PatternMatchResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { if (subViewOp.getNumOffsets() == 0) { return matchFailure(); } // Follow all or nothing approach for offsets for now. If all the operands // for offsets are constants then fold it into the offset of the result // memref. int64_t baseOffset, resultOffset; SmallVector<int64_t, 4> baseStrides, resultStrides; MemRefType subViewType = subViewOp.getType(); if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, baseOffset)) || failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || llvm::any_of(subViewOp.offsets(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } auto staticOffset = baseOffset; for (auto offset : llvm::enumerate(subViewOp.offsets())) { auto defOp = offset.value().getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); staticOffset += cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()]; } AffineMap layoutMap = makeStridedLinearLayoutMap( resultStrides, staticOffset, rewriter.getContext()); MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create<SubViewOp>( subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp<MemRefCastOp>( subViewOp.offsets(), subViewOp, newSubViewOp, subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder, SubViewOpOffsetFolder>(context); } //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(ZeroExtendIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa<IndexType>()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa<IndexType>()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast<IntegerType>().getWidth() >= dstType.cast<IntegerType>().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// bool FPExtOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast<FloatType>()) if (auto fb = b.dyn_cast<FloatType>()) return fa.getWidth() < fb.getWidth(); return false; } //===----------------------------------------------------------------------===// // FPTruncOp //===----------------------------------------------------------------------===// bool FPTruncOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast<FloatType>()) if (auto fb = b.dyn_cast<FloatType>()) return fa.getWidth() > fb.getWidth(); return false; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/Ops.cpp.inc"