//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/SPIRVOps.h"

#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/bit.h"

using namespace mlir;

// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kInitializerAttrName[] = "initializer";
static constexpr const char kInterfaceAttrName[] = "interface";
static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
static constexpr const char kSemanticsAttrName[] = "semantics";
static constexpr const char kSpecConstAttrName[] = "spec_const";
static constexpr const char kSpecIdAttrName[] = "spec_id";
static constexpr const char kTypeAttrName[] = "type";
static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
static constexpr const char kVariableAttrName[] = "variable";

//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//

static LogicalResult extractValueFromConstOp(Operation *op,
                                             int32_t &indexValue) {
  auto constOp = dyn_cast<spirv::ConstantOp>(op);
  if (!constOp) {
    return failure();
  }
  auto valueAttr = constOp.value();
  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
  if (!integerValueAttr) {
    return failure();
  }
  indexValue = integerValueAttr.getInt();
  return success();
}

template <typename Ty>
static ArrayAttr
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
                           function_ref<StringRef(Ty)> stringifyFn) {
  if (enumValues.empty()) {
    return nullptr;
  }
  SmallVector<StringRef, 1> enumValStrs;
  enumValStrs.reserve(enumValues.size());
  for (auto val : enumValues) {
    enumValStrs.emplace_back(stringifyFn(val));
  }
  return builder.getStrArrayAttr(enumValStrs);
}

template <typename EnumClass>
static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
                   StringRef attrName = spirv::attributeName<EnumClass>()) {
  Attribute attrVal;
  SmallVector<NamedAttribute, 1> attr;
  auto loc = parser.getCurrentLocation();
  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
                            attrName, attr)) {
    return failure();
  }
  if (!attrVal.isa<StringAttr>()) {
    return parser.emitError(loc, "expected ")
           << attrName << " attribute specified as string";
  }
  auto attrOptional =
      spirv::symbolizeEnum<EnumClass>()(attrVal.cast<StringAttr>().getValue());
  if (!attrOptional) {
    return parser.emitError(loc, "invalid ")
           << attrName << " attribute specification: " << attrVal;
  }
  value = attrOptional.getValue();
  return success();
}

template <typename EnumClass>
static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser, OperationState &state,
                   StringRef attrName = spirv::attributeName<EnumClass>()) {
  if (parseEnumAttribute(value, parser)) {
    return failure();
  }
  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
                                   llvm::bit_cast<int32_t>(value)));
  return success();
}

static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                               OperationState &state) {
  // Parse an optional list of attributes staring with '['
  if (parser.parseOptionalLSquare()) {
    // Nothing to do
    return success();
  }

  spirv::MemoryAccess memoryAccessAttr;
  if (parseEnumAttribute(memoryAccessAttr, parser, state)) {
    return failure();
  }

  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
    // Parse integer attribute for alignment.
    Attribute alignmentAttr;
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseComma() ||
        parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
                              state.attributes)) {
      return failure();
    }
  }
  return parser.parseRSquare();
}

template <typename LoadStoreOpTy>
static void
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
                           SmallVectorImpl<StringRef> &elidedAttrs) {
  // Print optional memory access attribute.
  if (auto memAccess = loadStoreOp.memory_access()) {
    elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

    // Print integer alignment attribute.
    if (auto alignment = loadStoreOp.alignment()) {
      elidedAttrs.push_back(kAlignmentAttrName);
      printer << ", " << alignment;
    }
    printer << "]";
  }
  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}

static LogicalResult verifyCastOp(Operation *op,
                                  bool requireSameBitWidth = true) {
  Type operandType = op->getOperand(0).getType();
  Type resultType = op->getResult(0).getType();

  // ODS checks that result type and operand type have the same shape.
  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
    operandType = vectorType.getElementType();
    resultType = resultType.cast<VectorType>().getElementType();
  }

  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;

  if (requireSameBitWidth) {
    if (!isSameBitWidth) {
      return op->emitOpError(
                 "expected the same bit widths for operand type and result "
                 "type, but provided ")
             << operandType << " and " << resultType;
    }
    return success();
  }

  if (isSameBitWidth) {
    return op->emitOpError(
               "expected the different bit widths for operand type and result "
               "type, but provided ")
           << operandType << " and " << resultType;
  }
  return success();
}

template <typename LoadStoreOpTy>
static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
  // ODS checks for attributes values. Just need to verify that if the
  // memory-access attribute is Aligned, then the alignment attribute must be
  // present.
  auto *op = loadStoreOp.getOperation();
  auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
  if (!memAccessAttr) {
    // Alignment attribute shouldn't be present if memory access attribute is
    // not present.
    if (op->getAttr(kAlignmentAttrName)) {
      return loadStoreOp.emitOpError(
          "invalid alignment specification without aligned memory access "
          "specification");
    }
    return success();
  }

  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());

  if (!memAccess) {
    return loadStoreOp.emitOpError("invalid memory access specifier: ")
           << memAccessVal;
  }

  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
    if (!op->getAttr(kAlignmentAttrName)) {
      return loadStoreOp.emitOpError("missing alignment value");
    }
  } else {
    if (op->getAttr(kAlignmentAttrName)) {
      return loadStoreOp.emitOpError(
          "invalid alignment specification with non-aligned memory access "
          "specification");
    }
  }
  return success();
}

template <typename BarrierOp>
static LogicalResult verifyMemorySemantics(BarrierOp op) {
  // According to the SPIR-V specification:
  // "Despite being a mask and allowing multiple bits to be combined, it is
  // invalid for more than one of these four bits to be set: Acquire, Release,
  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
  // Release semantics is done by setting the AcquireRelease bit, not by setting
  // two bits."
  auto memorySemantics = op.memory_semantics();
  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
                        spirv::MemorySemantics::Release |
                        spirv::MemorySemantics::AcquireRelease |
                        spirv::MemorySemantics::SequentiallyConsistent;

  auto bitCount = llvm::countPopulation(
      static_cast<uint32_t>(memorySemantics & atMostOneInSet));
  if (bitCount > 1) {
    return op.emitError("expected at most one of these four memory constraints "
                        "to be set: `Acquire`, `Release`,"
                        "`AcquireRelease` or `SequentiallyConsistent`");
  }
  return success();
}

template <typename LoadStoreOpTy>
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
                                                   Value val) {
  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
  // type of the pointer and the type of the value are the same
  //
  // TODO(ravishankarm): Check that the value type satisfies restrictions of
  // SPIR-V OpLoad/OpStore operations
  if (val.getType() !=
      ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
    return op.emitOpError("mismatch in result type and pointer type");
  }
  return success();
}

static ParseResult parseVariableDecorations(OpAsmParser &parser,
                                            OperationState &state) {
  auto builtInName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
  if (succeeded(parser.parseOptionalKeyword("bind"))) {
    Attribute set, binding;
    // Parse optional descriptor binding
    auto descriptorSetName = convertToSnakeCase(
        stringifyDecoration(spirv::Decoration::DescriptorSet));
    auto bindingName =
        convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseLParen() ||
        parser.parseAttribute(set, i32Type, descriptorSetName,
                              state.attributes) ||
        parser.parseComma() ||
        parser.parseAttribute(binding, i32Type, bindingName,
                              state.attributes) ||
        parser.parseRParen()) {
      return failure();
    }
  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
    StringAttr builtIn;
    if (parser.parseLParen() ||
        parser.parseAttribute(builtIn, builtInName, state.attributes) ||
        parser.parseRParen()) {
      return failure();
    }
  }

  // Parse other attributes
  if (parser.parseOptionalAttrDict(state.attributes))
    return failure();

  return success();
}

static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
                                     SmallVectorImpl<StringRef> &elidedAttrs) {
  // Print optional descriptor binding
  auto descriptorSetName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
  auto bindingName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
  if (descriptorSet && binding) {
    elidedAttrs.push_back(descriptorSetName);
    elidedAttrs.push_back(bindingName);
    printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
            << ")";
  }

  // Print BuiltIn attribute if present
  auto builtInName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
    printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
    elidedAttrs.push_back(builtInName);
  }

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}

// Extracts an element from the given `composite` by following the given
// `indices`. Returns a null Attribute if error happens.
static Attribute extractCompositeElement(Attribute composite,
                                         ArrayRef<unsigned> indices) {
  // Check that given composite is a constant.
  if (!composite)
    return {};
  // Return composite itself if we reach the end of the index chain.
  if (indices.empty())
    return composite;

  if (auto vector = composite.dyn_cast<ElementsAttr>()) {
    assert(indices.size() == 1 && "must have exactly one index for a vector");
    return vector.getValue({indices[0]});
  }

  if (auto array = composite.dyn_cast<ArrayAttr>()) {
    assert(!indices.empty() && "must have at least one index for an array");
    return extractCompositeElement(array.getValue()[indices[0]],
                                   indices.drop_front());
  }

  return {};
}

// Get bit width of types.
static unsigned getBitWidth(Type type) {
  if (type.isa<spirv::PointerType>()) {
    // Just return 64 bits for pointer types for now.
    // TODO: Make sure not caller relies on the actual pointer width value.
    return 64;
  }
  if (type.isIntOrFloat()) {
    return type.getIntOrFloatBitWidth();
  }
  if (auto vectorType = type.dyn_cast<VectorType>()) {
    assert(vectorType.getElementType().isIntOrFloat());
    return vectorType.getNumElements() *
           vectorType.getElementType().getIntOrFloatBitWidth();
  }
  llvm_unreachable("unhandled bit width computation for type");
}

/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
static Type
getElementType(Type type, ArrayRef<int32_t> indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  if (indices.empty()) {
    emitErrorFn("expected at least one index for spv.CompositeExtract");
    return nullptr;
  }

  for (auto index : indices) {
    if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
        emitErrorFn("index ") << index << " out of bounds for " << type;
        return nullptr;
      }
      type = cType.getElementType(index);
    } else {
      emitErrorFn("cannot extract from non-composite type ")
          << type << " with index " << index;
      return nullptr;
    }
  }
  return type;
}

static Type
getElementType(Type type, Attribute indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
  if (!indicesArrayAttr) {
    emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
    return nullptr;
  }
  if (!indicesArrayAttr.size()) {
    emitErrorFn("expected at least one index for spv.CompositeExtract");
    return nullptr;
  }

  SmallVector<int32_t, 2> indexVals;
  for (auto indexAttr : indicesArrayAttr) {
    auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
    if (!indexIntAttr) {
      emitErrorFn("expected an 32-bit integer for index, but found '")
          << indexAttr << "'";
      return nullptr;
    }
    indexVals.push_back(indexIntAttr.getInt());
  }
  return getElementType(type, indexVals, emitErrorFn);
}

static Type getElementType(Type type, Attribute indices, Location loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return ::mlir::emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
                           llvm::SMLoc loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return parser.emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

/// Returns true if the given `block` only contains one `spv._merge` op.
static inline bool isMergeBlock(Block &block) {
  return !block.empty() && std::next(block.begin()) == block.end() &&
         isa<spirv::MergeOp>(block.front());
}

//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//

namespace {
#include "SPIRVCanonicalization.inc"
}

//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//

static ParseResult parseBitFieldExtractOp(OpAsmParser &parser,
                                          OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
  Type baseType;
  Type offsetType;
  Type countType;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operandInfo, 3) || parser.parseColon() ||
      parser.parseType(baseType) || parser.parseComma() ||
      parser.parseType(offsetType) || parser.parseComma() ||
      parser.parseType(countType) ||
      parser.resolveOperands(operandInfo, {baseType, offsetType, countType},
                             loc, state.operands)) {
    return failure();
  }
  state.addTypes(baseType);
  return success();
}

static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
  printer << op->getName() << ' ' << op->getOperands() << " : "
          << op->getOperandTypes();
}

static LogicalResult verifyBitFieldExtractOp(Operation *op) {
  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
    return op->emitError("expected the same type for the first operand and "
                         "result, but provided ")
           << op->getOperand(0).getType() << " and "
           << op->getResult(0).getType();
  }
  return success();
}

// Parses an atomic update op. If the update op does not take a value (like
// AtomicIIncrement) `hasValue` must be false.
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
                                       OperationState &state, bool hasValue) {
  spirv::Scope scope;
  spirv::MemorySemantics memoryScope;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  OpAsmParser::OperandType ptrInfo, valueInfo;
  Type type;
  llvm::SMLoc loc;
  if (parseEnumAttribute(scope, parser, state, kMemoryScopeAttrName) ||
      parseEnumAttribute(memoryScope, parser, state, kSemanticsAttrName) ||
      parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
      parser.getCurrentLocation(&loc) || parser.parseColonType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected pointer type");

  SmallVector<Type, 2> operandTypes;
  operandTypes.push_back(ptrType);
  if (hasValue)
    operandTypes.push_back(ptrType.getPointeeType());
  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
                             state.operands))
    return failure();
  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}

// Prints an atomic update op.
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
  printer << op->getName() << " \"";
  auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
  printer << spirv::stringifyScope(
                 static_cast<spirv::Scope>(scopeAttr.getInt()))
          << "\" \"";
  auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
  printer << spirv::stringifyMemorySemantics(
                 static_cast<spirv::MemorySemantics>(
                     memorySemanticsAttr.getInt()))
          << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}

// Verifies an atomic update op.
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
  auto elementType = ptrType.getPointeeType();
  if (!elementType.isa<IntegerType>())
    return op->emitOpError(
               "pointer operand must point to an integer value, found ")
           << elementType;

  if (op->getNumOperands() > 1) {
    auto valueType = op->getOperand(1).getType();
    if (valueType != elementType)
      return op->emitOpError("expected value to have the same type as the "
                             "pointer operand's pointee type ")
             << elementType << ", but found " << valueType;
  }
  return success();
}

// Parses an op that has no inputs and no outputs.
static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
  if (parser.parseOptionalAttrDict(state.attributes))
    return failure();
  return success();
}

// Prints an op that has no inputs and no outputs.
static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
  printer << op->getName();
  printer.printOptionalAttrDict(op->getAttrs());
}

static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
  OpAsmParser::OperandType operandInfo;
  Type type;
  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
      parser.resolveOperands(operandInfo, type, state.operands)) {
    return failure();
  }
  state.addTypes(type);
  return success();
}

static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
  printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
          << unaryOp->getOperand(0).getType();
}

/// Result of a logical op must be a scalar or vector of boolean type.
static Type getUnaryOpResultType(Builder &builder, Type operandType) {
  Type resultType = builder.getIntegerType(1);
  if (auto vecType = operandType.dyn_cast<VectorType>()) {
    return VectorType::get(vecType.getNumElements(), resultType);
  }
  return resultType;
}

static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
                                       OperationState &state) {
  OpAsmParser::OperandType operandInfo;
  Type type;
  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
      parser.resolveOperand(operandInfo, type, state.operands)) {
    return failure();
  }
  state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
  return success();
}

static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
                                        OperationState &result) {
  SmallVector<OpAsmParser::OperandType, 2> ops;
  Type type;
  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
      parser.resolveOperands(ops, type, result.operands)) {
    return failure();
  }
  result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
  return success();
}

static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
  printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
          << logicalOp->getOperand(0).getType();
}

static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  Type baseType;
  Type shiftType;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
      parser.parseType(baseType) || parser.parseComma() ||
      parser.parseType(shiftType) ||
      parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
                             state.operands)) {
    return failure();
  }
  state.addTypes(baseType);
  return success();
}

static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
  Value base = op->getOperand(0);
  Value shift = op->getOperand(1);
  printer << op->getName() << ' ' << base << ", " << shift << " : "
          << base.getType() << ", " << shift.getType();
}

static LogicalResult verifyShiftOp(Operation *op) {
  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
    return op->emitError("expected the same type for the first operand and "
                         "result, but provided ")
           << op->getOperand(0).getType() << " and "
           << op->getResult(0).getType();
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//

static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
  if (indices.empty()) {
    emitError(baseLoc, "'spv.AccessChain' op expected at least "
                       "one index ");
    return nullptr;
  }

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType) {
    emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
                       "to composite type, but provided ")
        << type;
    return nullptr;
  }

  auto resultType = ptrType.getPointeeType();
  auto resultStorageClass = ptrType.getStorageClass();
  int32_t index = 0;

  for (auto indexSSA : indices) {
    auto cType = resultType.dyn_cast<spirv::CompositeType>();
    if (!cType) {
      emitError(baseLoc,
                "'spv.AccessChain' op cannot extract from non-composite type ")
          << resultType << " with index " << index;
      return nullptr;
    }
    index = 0;
    if (resultType.isa<spirv::StructType>()) {
      Operation *op = indexSSA.getDefiningOp();
      if (!op) {
        emitError(baseLoc, "'spv.AccessChain' op index must be an "
                           "integer spv.constant to access "
                           "element of spv.struct");
        return nullptr;
      }

      // TODO(denis0x0D): this should be relaxed to allow
      // integer literals of other bitwidths.
      if (failed(extractValueFromConstOp(op, index))) {
        emitError(baseLoc,
                  "'spv.AccessChain' index must be an integer spv.constant to "
                  "access element of spv.struct, but provided ")
            << op->getName();
        return nullptr;
      }
      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
        emitError(baseLoc, "'spv.AccessChain' op index ")
            << index << " out of bounds for " << resultType;
        return nullptr;
      }
    }
    resultType = cType.getElementType(index);
  }
  return spirv::PointerType::get(resultType, resultStorageClass);
}

void spirv::AccessChainOp::build(Builder *builder, OperationState &state,
                                 Value basePtr, ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, indices);
}

static ParseResult parseAccessChainOp(OpAsmParser &parser,
                                      OperationState &state) {
  OpAsmParser::OperandType ptrInfo;
  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
  Type type;
  // TODO(denis0x0D): regarding to the spec an index must be any integer type,
  // figure out how to use resolveOperand with a range of types and do not
  // fail on first attempt.
  Type indicesType = parser.getBuilder().getIntegerType(32);

  if (parser.parseOperand(ptrInfo) ||
      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
      parser.parseColonType(type) ||
      parser.resolveOperand(ptrInfo, type, state.operands) ||
      parser.resolveOperands(indicesInfo, indicesType, state.operands)) {
    return failure();
  }

  auto resultType = getElementPtrType(
      type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
  if (!resultType) {
    return failure();
  }

  state.addTypes(resultType);
  return success();
}

static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
  printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
          << '[' << op.indices() << "] : " << op.base_ptr().getType();
}

static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
  SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
                                accessChainOp.indices().end());
  auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
                                      indices, accessChainOp.getLoc());
  if (!resultType) {
    return failure();
  }

  auto providedResultType =
      accessChainOp.getType().dyn_cast<spirv::PointerType>();
  if (!providedResultType) {
    return accessChainOp.emitOpError(
               "result type must be a pointer, but provided")
           << providedResultType;
  }

  if (resultType != providedResultType) {
    return accessChainOp.emitOpError("invalid result type: expected ")
           << resultType << ", but provided " << providedResultType;
  }

  return success();
}

namespace {

/// Combines chained `spirv::AccessChainOp` operations into one
/// `spirv::AccessChainOp` operation.
struct CombineChainedAccessChain
    : public OpRewritePattern<spirv::AccessChainOp> {
  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;

  PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
                                     PatternRewriter &rewriter) const override {
    auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
        accessChainOp.base_ptr().getDefiningOp());

    if (!parentAccessChainOp) {
      return matchFailure();
    }

    // Combine indices.
    SmallVector<Value, 4> indices(parentAccessChainOp.indices());
    indices.append(accessChainOp.indices().begin(),
                   accessChainOp.indices().end());

    rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
        accessChainOp, parentAccessChainOp.base_ptr(), indices);

    return matchSuccess();
  }
};
} // end anonymous namespace

void spirv::AccessChainOp::getCanonicalizationPatterns(
    OwningRewritePatternList &results, MLIRContext *context) {
  results.insert<CombineChainedAccessChain>(context);
}

//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//

void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
                               spirv::GlobalVariableOp var) {
  build(builder, state, var.type(), builder->getSymbolRefAttr(var));
}

static ParseResult parseAddressOfOp(OpAsmParser &parser,
                                    OperationState &state) {
  FlatSymbolRefAttr varRefAttr;
  Type type;
  if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
                            state.attributes) ||
      parser.parseColonType(type)) {
    return failure();
  }
  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType) {
    return parser.emitError(parser.getCurrentLocation(),
                            "expected spv.ptr type");
  }
  state.addTypes(ptrType);
  return success();
}

static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs;
  printer << spirv::AddressOfOp::getOperationName();

  // Print symbol name.
  printer << ' ';
  printer.printSymbolName(addressOfOp.variable());

  // Print the type.
  printer << " : " << addressOfOp.pointer().getType();
}

static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
  auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
  auto varOp =
      moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
  if (!varOp) {
    return addressOfOp.emitOpError("expected spv.globalVariable symbol");
  }
  if (addressOfOp.pointer().getType() != varOp.type()) {
    return addressOfOp.emitOpError(
        "result type mismatch with the referenced global variable's type");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.AtomicCompareExchangeWeak
//===----------------------------------------------------------------------===//

static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
                                                    OperationState &state) {
  spirv::Scope memoryScope;
  spirv::MemorySemantics equalSemantics, unequalSemantics;
  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
  Type type;
  if (parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) ||
      parseEnumAttribute(equalSemantics, parser, state,
                         kEqualSemanticsAttrName) ||
      parseEnumAttribute(unequalSemantics, parser, state,
                         kUnequalSemanticsAttrName) ||
      parser.parseOperandList(operandInfo, 3))
    return failure();

  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected pointer type");

  if (parser.resolveOperands(
          operandInfo,
          {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
          parser.getNameLoc(), state.operands))
    return failure();

  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}

static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
                  OpAsmPrinter &printer) {
  printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
          << stringifyScope(atomOp.memory_scope()) << "\" \""
          << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
}

static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
  // According to the spec:
  // "The type of Value must be the same as Result Type. The type of the value
  // pointed to by Pointer must be the same as Result Type. This type must also
  // match the type of Comparator."
  if (atomOp.getType() != atomOp.value().getType())
    return atomOp.emitOpError("value operand must have the same type as the op "
                              "result, but found ")
           << atomOp.value().getType() << " vs " << atomOp.getType();

  if (atomOp.getType() != atomOp.comparator().getType())
    return atomOp.emitOpError(
               "comparator operand must have the same type as the op "
               "result, but found ")
           << atomOp.comparator().getType() << " vs " << atomOp.getType();

  Type pointeeType =
      atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
  if (atomOp.getType() != pointeeType)
    return atomOp.emitOpError(
               "pointer operand's pointee type must have the same "
               "as the op result type, but found ")
           << pointeeType << " vs " << atomOp.getType();

  // TODO(antiagainst): Unequal cannot be set to Release or Acquire and Release.
  // In addition, Unequal cannot be set to a stronger memory-order then Equal.

  return success();
}

//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::BitcastOp bitcastOp) {
  // TODO: The SPIR-V spec validation rules are different for different
  // versions.
  auto operandType = bitcastOp.operand().getType();
  auto resultType = bitcastOp.result().getType();
  if (operandType == resultType) {
    return bitcastOp.emitError(
        "result type must be different from operand type");
  }
  if (operandType.isa<spirv::PointerType>() &&
      !resultType.isa<spirv::PointerType>()) {
    return bitcastOp.emitError(
        "unhandled bit cast conversion from pointer type to non-pointer type");
  }
  if (!operandType.isa<spirv::PointerType>() &&
      resultType.isa<spirv::PointerType>()) {
    return bitcastOp.emitError(
        "unhandled bit cast conversion from non-pointer type to pointer type");
  }
  auto operandBitWidth = getBitWidth(operandType);
  auto resultBitWidth = getBitWidth(resultType);
  if (operandBitWidth != resultBitWidth) {
    return bitcastOp.emitOpError("mismatch in result type bitwidth ")
           << resultBitWidth << " and operand type bitwidth "
           << operandBitWidth;
  }
  return success();
}

void spirv::BitcastOp::getCanonicalizationPatterns(
    OwningRewritePatternList &results, MLIRContext *context) {
  results.insert<ConvertChainedBitcast>(context);
}

//===----------------------------------------------------------------------===//
// spv.BitFieldInsert
//===----------------------------------------------------------------------===//

static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
                                         OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 4> operandInfo;
  Type baseType;
  Type offsetType;
  Type countType;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operandInfo, 4) || parser.parseColon() ||
      parser.parseType(baseType) || parser.parseComma() ||
      parser.parseType(offsetType) || parser.parseComma() ||
      parser.parseType(countType) ||
      parser.resolveOperands(operandInfo,
                             {baseType, baseType, offsetType, countType}, loc,
                             state.operands)) {
    return failure();
  }
  state.addTypes(baseType);
  return success();
}

static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
                  OpAsmPrinter &printer) {
  printer << spirv::BitFieldInsertOp::getOperationName() << ' '
          << bitFieldInsertOp.getOperands() << " : "
          << bitFieldInsertOp.base().getType() << ", "
          << bitFieldInsertOp.offset().getType() << ", "
          << bitFieldInsertOp.count().getType();
}

static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) {
  auto baseType = bitFieldOp.base().getType();
  auto insertType = bitFieldOp.insert().getType();
  auto resultType = bitFieldOp.getResult().getType();

  if ((baseType != insertType) || (baseType != resultType)) {
    return bitFieldOp.emitError("expected the same type for the base operand, "
                                "insert operand, and "
                                "result, but provided ")
           << baseType << ", " << insertType << " and " << resultType;
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//

static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
  Block *dest;
  SmallVector<Value, 4> destOperands;
  if (parser.parseSuccessorAndUseList(dest, destOperands))
    return failure();
  state.addSuccessor(dest, destOperands);
  return success();
}

static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
  printer << spirv::BranchOp::getOperationName() << ' ';
  printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
}

static LogicalResult verify(spirv::BranchOp branchOp) {
  auto *op = branchOp.getOperation();
  if (op->getNumSuccessors() != 1)
    branchOp.emitOpError("must have exactly one successor");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//

static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
                                            OperationState &state) {
  auto &builder = parser.getBuilder();
  OpAsmParser::OperandType condInfo;
  Block *dest;
  SmallVector<Value, 4> destOperands;

  // Parse the condition.
  Type boolTy = builder.getI1Type();
  if (parser.parseOperand(condInfo) ||
      parser.resolveOperand(condInfo, boolTy, state.operands))
    return failure();

  // Parse the optional branch weights.
  if (succeeded(parser.parseOptionalLSquare())) {
    IntegerAttr trueWeight, falseWeight;
    SmallVector<NamedAttribute, 2> weights;

    auto i32Type = builder.getIntegerType(32);
    if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
        parser.parseComma() ||
        parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
        parser.parseRSquare())
      return failure();

    state.addAttribute(kBranchWeightAttrName,
                       builder.getArrayAttr({trueWeight, falseWeight}));
  }

  // Parse the true branch.
  if (parser.parseComma() ||
      parser.parseSuccessorAndUseList(dest, destOperands))
    return failure();
  state.addSuccessor(dest, destOperands);

  // Parse the false branch.
  destOperands.clear();
  if (parser.parseComma() ||
      parser.parseSuccessorAndUseList(dest, destOperands))
    return failure();
  state.addSuccessor(dest, destOperands);

  return success();
}

static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
  printer << spirv::BranchConditionalOp::getOperationName() << ' '
          << branchOp.condition();

  if (auto weights = branchOp.branch_weights()) {
    printer << " [";
    interleaveComma(weights->getValue(), printer, [&](Attribute a) {
      printer << a.cast<IntegerAttr>().getInt();
    });
    printer << "]";
  }

  printer << ", ";
  printer.printSuccessorAndUseList(branchOp.getOperation(),
                                   spirv::BranchConditionalOp::kTrueIndex);
  printer << ", ";
  printer.printSuccessorAndUseList(branchOp.getOperation(),
                                   spirv::BranchConditionalOp::kFalseIndex);
}

static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
  auto *op = branchOp.getOperation();
  if (op->getNumSuccessors() != 2)
    return branchOp.emitOpError("must have exactly two successors");

  if (auto weights = branchOp.branch_weights()) {
    if (weights->getValue().size() != 2) {
      return branchOp.emitOpError("must have exactly two branch weights");
    }
    if (llvm::all_of(*weights, [](Attribute attr) {
          return attr.cast<IntegerAttr>().getValue().isNullValue();
        }))
      return branchOp.emitOpError("branch weights cannot both be zero");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeConstruct
//===----------------------------------------------------------------------===//

static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
                                             OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 4> operands;
  Type type;
  auto loc = parser.getCurrentLocation();

  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
    return failure();
  }
  auto cType = type.dyn_cast<spirv::CompositeType>();
  if (!cType) {
    return parser.emitError(
               loc, "result type must be a composite type, but provided ")
           << type;
  }

  if (operands.size() != cType.getNumElements()) {
    return parser.emitError(loc, "has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided " << operands.size();
  }
  // TODO: Add support for constructing a vector type from the vector operands.
  // According to the spec: "for constructing a vector, the operands may
  // also be vectors with the same component type as the Result Type component
  // type".
  SmallVector<Type, 4> elementTypes;
  elementTypes.reserve(cType.getNumElements());
  for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) {
    elementTypes.push_back(cType.getElementType(index));
  }
  state.addTypes(type);
  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
}

static void print(spirv::CompositeConstructOp compositeConstructOp,
                  OpAsmPrinter &printer) {
  printer << spirv::CompositeConstructOp::getOperationName() << " "
          << compositeConstructOp.constituents() << " : "
          << compositeConstructOp.getResult().getType();
}

static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
  auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();

  SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
  if (constituents.size() != cType.getNumElements()) {
    return compositeConstructOp.emitError(
               "has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided "
           << constituents.size();
  }

  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
    if (constituents[index].getType() != cType.getElementType(index)) {
      return compositeConstructOp.emitError(
                 "operand type mismatch: expected operand type ")
             << cType.getElementType(index) << ", but provided "
             << constituents[index].getType();
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//

void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state,
                                      Value composite,
                                      ArrayRef<int32_t> indices) {
  auto indexAttr = builder->getI32ArrayAttr(indices);
  auto elementType =
      getElementType(composite.getType(), indexAttr, state.location);
  if (!elementType) {
    return;
  }
  build(builder, state, elementType, composite, indexAttr);
}

static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
                                           OperationState &state) {
  OpAsmParser::OperandType compositeInfo;
  Attribute indicesAttr;
  Type compositeType;
  llvm::SMLoc attrLocation;

  if (parser.parseOperand(compositeInfo) ||
      parser.getCurrentLocation(&attrLocation) ||
      parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
      parser.parseColonType(compositeType) ||
      parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
    return failure();
  }

  Type resultType =
      getElementType(compositeType, indicesAttr, parser, attrLocation);
  if (!resultType) {
    return failure();
  }
  state.addTypes(resultType);
  return success();
}

static void print(spirv::CompositeExtractOp compositeExtractOp,
                  OpAsmPrinter &printer) {
  printer << spirv::CompositeExtractOp::getOperationName() << ' '
          << compositeExtractOp.composite() << compositeExtractOp.indices()
          << " : " << compositeExtractOp.composite().getType();
}

static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
  auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
  auto resultType = getElementType(compExOp.composite().getType(),
                                   indicesArrayAttr, compExOp.getLoc());
  if (!resultType)
    return failure();

  if (resultType != compExOp.getType()) {
    return compExOp.emitOpError("invalid result type: expected ")
           << resultType << " but provided " << compExOp.getType();
  }

  return success();
}

OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
  assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
  auto indexVector = functional::map(
      [](Attribute attr) {
        return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
      },
      indices());
  return extractCompositeElement(operands[0], indexVector);
}

//===----------------------------------------------------------------------===//
// spv.CompositeInsert
//===----------------------------------------------------------------------===//

static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
                                          OperationState &state) {
  SmallVector<OpAsmParser::OperandType, 2> operands;
  Type objectType, compositeType;
  Attribute indicesAttr;
  auto loc = parser.getCurrentLocation();

  return failure(
      parser.parseOperandList(operands, 2) ||
      parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
      parser.parseColonType(objectType) ||
      parser.parseKeywordType("into", compositeType) ||
      parser.resolveOperands(operands, {objectType, compositeType}, loc,
                             state.operands) ||
      parser.addTypesToList(compositeType, state.types));
}

static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
  auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
  auto objectType =
      getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
                     compositeInsertOp.getLoc());
  if (!objectType)
    return failure();

  if (objectType != compositeInsertOp.object().getType()) {
    return compositeInsertOp.emitOpError("object operand type should be ")
           << objectType << ", but found "
           << compositeInsertOp.object().getType();
  }

  if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
    return compositeInsertOp.emitOpError("result type should be the same as "
                                         "the composite type, but found ")
           << compositeInsertOp.composite().getType() << " vs "
           << compositeInsertOp.getType();
  }

  return success();
}

static void print(spirv::CompositeInsertOp compositeInsertOp,
                  OpAsmPrinter &printer) {
  printer << spirv::CompositeInsertOp::getOperationName() << " "
          << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
          << compositeInsertOp.indices() << " : "
          << compositeInsertOp.object().getType() << " into "
          << compositeInsertOp.composite().getType();
}

//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//

static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
  Attribute value;
  if (parser.parseAttribute(value, kValueAttrName, state.attributes))
    return failure();

  Type type = value.getType();
  if (type.isa<NoneType>() || type.isa<TensorType>()) {
    if (parser.parseColonType(type))
      return failure();
  }

  return parser.addTypeToList(type, state.types);
}

static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
  printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
  if (constOp.getType().isa<spirv::ArrayType>())
    printer << " : " << constOp.getType();
}

static LogicalResult verify(spirv::ConstantOp constOp) {
  auto opType = constOp.getType();
  auto value = constOp.value();
  auto valueType = value.getType();

  // ODS already generates checks to make sure the result type is valid. We just
  // need to additionally check that the value's attribute type is consistent
  // with the result type.
  switch (value.getKind()) {
  case StandardAttributes::Bool:
  case StandardAttributes::Integer:
  case StandardAttributes::Float: {
    if (valueType != opType)
      return constOp.emitOpError("result type (")
             << opType << ") does not match value type (" << valueType << ")";
    return success();
  } break;
  case StandardAttributes::DenseElements:
  case StandardAttributes::SparseElements: {
    if (valueType == opType)
      break;
    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
    auto shapedType = valueType.dyn_cast<ShapedType>();
    if (!arrayType) {
      return constOp.emitOpError(
          "must have spv.array result type for array value");
    }

    int numElements = arrayType.getNumElements();
    auto opElemType = arrayType.getElementType();
    while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
      numElements *= t.getNumElements();
      opElemType = t.getElementType();
    }
    if (!opElemType.isIntOrFloat()) {
      return constOp.emitOpError("only support nested array result type");
    }

    auto valueElemType = shapedType.getElementType();
    if (valueElemType != opElemType) {
      return constOp.emitOpError("result element type (")
             << opElemType << ") does not match value element type ("
             << valueElemType << ")";
    }

    if (numElements != shapedType.getNumElements()) {
      return constOp.emitOpError("result number of elements (")
             << numElements << ") does not match value number of elements ("
             << shapedType.getNumElements() << ")";
    }
  } break;
  case StandardAttributes::Array: {
    auto arrayType = opType.dyn_cast<spirv::ArrayType>();
    if (!arrayType)
      return constOp.emitOpError(
          "must have spv.array result type for array value");
    auto elemType = arrayType.getElementType();
    for (auto element : value.cast<ArrayAttr>().getValue()) {
      if (element.getType() != elemType)
        return constOp.emitOpError("has array element whose type (")
               << element.getType()
               << ") does not match the result element type (" << elemType
               << ')';
    }
  } break;
  default:
    return constOp.emitOpError("cannot have value of type ") << valueType;
  }

  return success();
}

OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
  assert(operands.empty() && "spv.constant has no operands");
  return value();
}

bool spirv::ConstantOp::isBuildableWith(Type type) {
  // Must be valid SPIR-V type first.
  if (!SPIRVDialect::isValidType(type))
    return false;

  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
      type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
    // TODO(antiagainst): support constant struct
    return type.isa<spirv::ArrayType>();
  }

  return true;
}

spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
                                             OpBuilder *builder) {
  if (auto intType = type.dyn_cast<IntegerType>()) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder->create<spirv::ConstantOp>(loc, type,
                                                builder->getBoolAttr(false));
    return builder->create<spirv::ConstantOp>(
        loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
  }

  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
}

spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
                                            OpBuilder *builder) {
  if (auto intType = type.dyn_cast<IntegerType>()) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder->create<spirv::ConstantOp>(loc, type,
                                                builder->getBoolAttr(true));
    return builder->create<spirv::ConstantOp>(
        loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
  }

  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
}

//===----------------------------------------------------------------------===//
// spv.ControlBarrier
//===----------------------------------------------------------------------===//

static ParseResult parseControlBarrierOp(OpAsmParser &parser,
                                         OperationState &state) {
  spirv::Scope executionScope;
  spirv::Scope memoryScope;
  spirv::MemorySemantics memorySemantics;

  return failure(
      parseEnumAttribute(executionScope, parser, state,
                         kExecutionScopeAttrName) ||
      parser.parseComma() ||
      parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) ||
      parser.parseComma() ||
      parseEnumAttribute(memorySemantics, parser, state));
}

static void print(spirv::ControlBarrierOp op, OpAsmPrinter &printer) {
  printer << spirv::ControlBarrierOp::getOperationName() << " \""
          << stringifyScope(op.execution_scope()) << "\", \""
          << stringifyScope(op.memory_scope()) << "\", \""
          << stringifyMemorySemantics(op.memory_semantics()) << "\"";
}

//===----------------------------------------------------------------------===//
// spv.EntryPoint
//===----------------------------------------------------------------------===//

void spirv::EntryPointOp::build(Builder *builder, OperationState &state,
                                spirv::ExecutionModel executionModel,
                                FuncOp function,
                                ArrayRef<Attribute> interfaceVars) {
  build(builder, state,
        builder->getI32IntegerAttr(static_cast<int32_t>(executionModel)),
        builder->getSymbolRefAttr(function),
        builder->getArrayAttr(interfaceVars));
}

static ParseResult parseEntryPointOp(OpAsmParser &parser,
                                     OperationState &state) {
  spirv::ExecutionModel execModel;
  SmallVector<OpAsmParser::OperandType, 0> identifiers;
  SmallVector<Type, 0> idTypes;
  SmallVector<Attribute, 4> interfaceVars;

  FlatSymbolRefAttr fn;
  if (parseEnumAttribute(execModel, parser, state) ||
      parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
    return failure();
  }

  if (!parser.parseOptionalComma()) {
    // Parse the interface variables
    do {
      // The name of the interface variable attribute isnt important
      auto attrName = "var_symbol";
      FlatSymbolRefAttr var;
      SmallVector<NamedAttribute, 1> attrs;
      if (parser.parseAttribute(var, Type(), attrName, attrs)) {
        return failure();
      }
      interfaceVars.push_back(var);
    } while (!parser.parseOptionalComma());
  }
  state.addAttribute(kInterfaceAttrName,
                     parser.getBuilder().getArrayAttr(interfaceVars));
  return success();
}

static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
  printer << spirv::EntryPointOp::getOperationName() << " \""
          << stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
  printer.printSymbolName(entryPointOp.fn());
  auto interfaceVars = entryPointOp.interface().getValue();
  if (!interfaceVars.empty()) {
    printer << ", ";
    interleaveComma(interfaceVars, printer);
  }
}

static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
  // verification.
  return success();
}

//===----------------------------------------------------------------------===//
// spv.ExecutionMode
//===----------------------------------------------------------------------===//

void spirv::ExecutionModeOp::build(Builder *builder, OperationState &state,
                                   FuncOp function,
                                   spirv::ExecutionMode executionMode,
                                   ArrayRef<int32_t> params) {
  build(builder, state, builder->getSymbolRefAttr(function),
        builder->getI32IntegerAttr(static_cast<int32_t>(executionMode)),
        builder->getI32ArrayAttr(params));
}

static ParseResult parseExecutionModeOp(OpAsmParser &parser,
                                        OperationState &state) {
  spirv::ExecutionMode execMode;
  Attribute fn;
  if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
      parseEnumAttribute(execMode, parser, state)) {
    return failure();
  }

  SmallVector<int32_t, 4> values;
  Type i32Type = parser.getBuilder().getIntegerType(32);
  while (!parser.parseOptionalComma()) {
    SmallVector<NamedAttribute, 1> attr;
    Attribute value;
    if (parser.parseAttribute(value, i32Type, "value", attr)) {
      return failure();
    }
    values.push_back(value.cast<IntegerAttr>().getInt());
  }
  state.addAttribute(kValuesAttrName,
                     parser.getBuilder().getI32ArrayAttr(values));
  return success();
}

static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
  printer << spirv::ExecutionModeOp::getOperationName() << " @"
          << execModeOp.fn() << " \""
          << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
  auto values = execModeOp.values();
  if (!values.size())
    return;
  printer << ", ";
  interleaveComma(values, printer, [&](Attribute a) {
    printer << a.cast<IntegerAttr>().getInt();
  });
}

//===----------------------------------------------------------------------===//
// spv.FunctionCall
//===----------------------------------------------------------------------===//

static ParseResult parseFunctionCallOp(OpAsmParser &parser,
                                       OperationState &state) {
  FlatSymbolRefAttr calleeAttr;
  FunctionType type;
  SmallVector<OpAsmParser::OperandType, 4> operands;
  auto loc = parser.getNameLoc();
  if (parser.parseAttribute(calleeAttr, kCallee, state.attributes) ||
      parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
      parser.parseColonType(type)) {
    return failure();
  }

  auto funcType = type.dyn_cast<FunctionType>();
  if (!funcType) {
    return parser.emitError(loc, "expected function type, but provided ")
           << type;
  }

  if (funcType.getNumResults() > 1) {
    return parser.emitError(loc, "expected callee function to have 0 or 1 "
                                 "result, but provided ")
           << funcType.getNumResults();
  }

  return failure(parser.addTypesToList(funcType.getResults(), state.types) ||
                 parser.resolveOperands(operands, funcType.getInputs(), loc,
                                        state.operands));
}

static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
  SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
  SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
  Type functionType =
      FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());

  printer << spirv::FunctionCallOp::getOperationName() << ' '
          << functionCallOp.getAttr(kCallee) << '('
          << functionCallOp.arguments() << ") : " << functionType;
}

static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
  auto fnName = functionCallOp.callee();

  auto moduleOp = functionCallOp.getParentOfType<spirv::ModuleOp>();
  if (!moduleOp) {
    return functionCallOp.emitOpError(
        "must appear in a function inside 'spv.module'");
  }

  auto funcOp = moduleOp.lookupSymbol<FuncOp>(fnName);
  if (!funcOp) {
    return functionCallOp.emitOpError("callee function '")
           << fnName << "' not found in 'spv.module'";
  }

  auto functionType = funcOp.getType();

  if (functionCallOp.getNumResults() > 1) {
    return functionCallOp.emitOpError(
               "expected callee function to have 0 or 1 result, but provided ")
           << functionCallOp.getNumResults();
  }

  if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
    return functionCallOp.emitOpError(
               "has incorrect number of operands for callee: expected ")
           << functionType.getNumInputs() << ", but provided "
           << functionCallOp.getNumOperands();
  }

  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
    if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
      return functionCallOp.emitOpError(
                 "operand type mismatch: expected operand type ")
             << functionType.getInput(i) << ", but provided "
             << functionCallOp.getOperand(i).getType() << " for operand number "
             << i;
    }
  }

  if (functionType.getNumResults() != functionCallOp.getNumResults()) {
    return functionCallOp.emitOpError(
               "has incorrect number of results has for callee: expected ")
           << functionType.getNumResults() << ", but provided "
           << functionCallOp.getNumResults();
  }

  if (functionCallOp.getNumResults() &&
      (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
    return functionCallOp.emitOpError("result type mismatch: expected ")
           << functionType.getResult(0) << ", but provided "
           << functionCallOp.getResult(0).getType();
  }

  return success();
}

CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
  return getAttrOfType<SymbolRefAttr>(kCallee);
}

Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
  return arguments();
}

//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//

void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
                                    Type type, StringRef name,
                                    unsigned descriptorSet, unsigned binding) {
  build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
        nullptr);
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
      builder->getI32IntegerAttr(descriptorSet));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
      builder->getI32IntegerAttr(binding));
}

void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
                                    Type type, StringRef name,
                                    spirv::BuiltIn builtin) {
  build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
        nullptr);
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
      builder->getStringAttr(spirv::stringifyBuiltIn(builtin)));
}

static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
                                         OperationState &state) {
  // Parse variable name.
  StringAttr nameAttr;
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes)) {
    return failure();
  }

  // Parse optional initializer
  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
    FlatSymbolRefAttr initSymbol;
    if (parser.parseLParen() ||
        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
                              state.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, state)) {
    return failure();
  }

  Type type;
  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type)) {
    return failure();
  }
  if (!type.isa<spirv::PointerType>()) {
    return parser.emitError(loc, "expected spv.ptr type");
  }
  state.addAttribute(kTypeAttrName, TypeAttr::get(type));

  return success();
}

static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
  auto *op = varOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};
  printer << spirv::GlobalVariableOp::getOperationName();

  // Print variable name.
  printer << ' ';
  printer.printSymbolName(varOp.sym_name());
  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());

  // Print optional initializer
  if (auto initializer = varOp.initializer()) {
    printer << " " << kInitializerAttrName << '(';
    printer.printSymbolName(initializer.getValue());
    printer << ')';
    elidedAttrs.push_back(kInitializerAttrName);
  }

  elidedAttrs.push_back(kTypeAttrName);
  printVariableDecorations(op, printer, elidedAttrs);
  printer << " : " << varOp.type();
}

static LogicalResult verify(spirv::GlobalVariableOp varOp) {
  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  if (varOp.storageClass() == spirv::StorageClass::Generic)
    return varOp.emitOpError("storage class cannot be 'Generic'");

  if (auto init =
          varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
    auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
    auto *initOp = moduleOp.lookupSymbol(init.getValue());
    // TODO: Currently only variable initialization with specialization
    // constants and other variables is supported. They could be normal
    // constants in the module scope as well.
    if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
                     isa<spirv::SpecConstantOp>(initOp))) {
      return varOp.emitOpError("initializer must be result of a "
                               "spv.specConstant or spv.globalVariable op");
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//

static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
                                                OperationState &state) {
  spirv::Scope executionScope;
  OpAsmParser::OperandType operandInfo;
  Type resultType;
  IntegerType i1Type = parser.getBuilder().getI1Type();
  if (parseEnumAttribute(executionScope, parser, state,
                         kExecutionScopeAttrName) ||
      parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
      parser.resolveOperand(operandInfo, i1Type, state.operands))
    return failure();

  return parser.addTypeToList(resultType, state.types);
}

static void print(spirv::GroupNonUniformBallotOp ballotOp,
                  OpAsmPrinter &printer) {
  printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
          << stringifyScope(ballotOp.execution_scope()) << "\" "
          << ballotOp.predicate() << " : " << ballotOp.getType();
}

static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
  // TODO(antiagainst): check the result integer type's signedness bit is 0.

  spirv::Scope scope = ballotOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return ballotOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.IAdd
//===----------------------------------------------------------------------===//

OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
  assert(operands.size() == 2 && "spv.IAdd expects two operands");
  // x + 0 = x
  if (matchPattern(operand2(), m_Zero()))
    return operand1();

  // According to the SPIR-V spec:
  //
  // The resulting value will equal the low-order N bits of the correct result
  // R, where N is the component width and R is computed with enough precision
  // to avoid overflow and underflow.
  return constFoldBinaryOp<IntegerAttr>(operands,
                                        [](APInt a, APInt b) { return a + b; });
}

//===----------------------------------------------------------------------===//
// spv.IMul
//===----------------------------------------------------------------------===//

OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
  assert(operands.size() == 2 && "spv.IMul expects two operands");
  // x * 0 == 0
  if (matchPattern(operand2(), m_Zero()))
    return operand2();
  // x * 1 = x
  if (matchPattern(operand2(), m_One()))
    return operand1();

  // According to the SPIR-V spec:
  //
  // The resulting value will equal the low-order N bits of the correct result
  // R, where N is the component width and R is computed with enough precision
  // to avoid overflow and underflow.
  return constFoldBinaryOp<IntegerAttr>(operands,
                                        [](APInt a, APInt b) { return a * b; });
}

//===----------------------------------------------------------------------===//
// spv.ISub
//===----------------------------------------------------------------------===//

OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
  // x - x = 0
  if (operand1() == operand2())
    return Builder(getContext()).getIntegerAttr(getType(), 0);

  // According to the SPIR-V spec:
  //
  // The resulting value will equal the low-order N bits of the correct result
  // R, where N is the component width and R is computed with enough precision
  // to avoid overflow and underflow.
  return constFoldBinaryOp<IntegerAttr>(operands,
                                        [](APInt a, APInt b) { return a - b; });
}

//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//

void spirv::LoadOp::build(Builder *builder, OperationState &state,
                          Value basePtr, IntegerAttr memory_access,
                          IntegerAttr alignment) {
  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
  build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
        alignment);
}

static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  OpAsmParser::OperandType ptrInfo;
  Type elementType;
  if (parseEnumAttribute(storageClass, parser) ||
      parser.parseOperand(ptrInfo) ||
      parseMemoryAccessAttributes(parser, state) ||
      parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
    return failure();
  }

  state.addTypes(elementType);
  return success();
}

static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
  auto *op = loadOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
  printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
          << loadOp.ptr();

  printMemoryAccessAttribute(loadOp, printer, elidedAttrs);

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
  printer << " : " << loadOp.getType();
}

static LogicalResult verify(spirv::LoadOp loadOp) {
  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
  // type with fixed size; i.e., it cannot be, nor include, any
  // OpTypeRuntimeArray types."
  if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
                                           loadOp.value()))) {
    return failure();
  }
  return verifyMemoryAccessAttribute(loadOp);
}

//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//

void spirv::LogicalNotOp::getCanonicalizationPatterns(
    OwningRewritePatternList &results, MLIRContext *context) {
  results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
                 ConvertLogicalNotOfLogicalEqual,
                 ConvertLogicalNotOfLogicalNotEqual>(context);
}

//===----------------------------------------------------------------------===//
// spv.loop
//===----------------------------------------------------------------------===//

void spirv::LoopOp::build(Builder *builder, OperationState &state) {
  state.addAttribute("loop_control",
                     builder->getI32IntegerAttr(
                         static_cast<uint32_t>(spirv::LoopControl::None)));
  state.addRegion();
}

static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
  // TODO(antiagainst): support loop control properly
  Builder builder = parser.getBuilder();
  state.addAttribute("loop_control",
                     builder.getI32IntegerAttr(
                         static_cast<uint32_t>(spirv::LoopControl::None)));

  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                            /*argTypes=*/{});
}

static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
  auto *op = loopOp.getOperation();

  printer << spirv::LoopOp::getOperationName();
  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/true);
}

/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
/// given `dstBlock`.
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
  // Check that there is only one op in the `srcBlock`.
  if (!has_single_element(srcBlock))
    return false;

  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
  return branchOp && branchOp.getSuccessor(0) == &dstBlock;
}

static LogicalResult verify(spirv::LoopOp loopOp) {
  auto *op = loopOp.getOperation();

  // We need to verify that the blocks follow the following layout:
  //
  //                     +-------------+
  //                     | entry block |
  //                     +-------------+
  //                            |
  //                            v
  //                     +-------------+
  //                     | loop header | <-----+
  //                     +-------------+       |
  //                                           |
  //                           ...             |
  //                          \ | /            |
  //                            v              |
  //                    +---------------+      |
  //                    | loop continue | -----+
  //                    +---------------+
  //
  //                           ...
  //                          \ | /
  //                            v
  //                     +-------------+
  //                     | merge block |
  //                     +-------------+

  auto &region = op->getRegion(0);
  // Allow empty region as a degenerated case, which can come from
  // optimizations.
  if (region.empty())
    return success();

  // The last block is the merge block.
  Block &merge = region.back();
  if (!isMergeBlock(merge))
    return loopOp.emitOpError(
        "last block must be the merge block with only one 'spv._merge' op");

  if (std::next(region.begin()) == region.end())
    return loopOp.emitOpError(
        "must have an entry block branching to the loop header block");
  // The first block is the entry block.
  Block &entry = region.front();

  if (std::next(region.begin(), 2) == region.end())
    return loopOp.emitOpError(
        "must have a loop header block branched from the entry block");
  // The second block is the loop header block.
  Block &header = *std::next(region.begin(), 1);

  if (!hasOneBranchOpTo(entry, header))
    return loopOp.emitOpError(
        "entry block must only have one 'spv.Branch' op to the second block");

  if (std::next(region.begin(), 3) == region.end())
    return loopOp.emitOpError(
        "requires a loop continue block branching to the loop header block");
  // The second to last block is the loop continue block.
  Block &cont = *std::prev(region.end(), 2);

  // Make sure that we have a branch from the loop continue block to the loop
  // header block.
  if (llvm::none_of(
          llvm::seq<unsigned>(0, cont.getNumSuccessors()),
          [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
    return loopOp.emitOpError("second to last block must be the loop continue "
                              "block that branches to the loop header block");

  // Make sure that no other blocks (except the entry and loop continue block)
  // branches to the loop header block.
  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
                                      std::prev(region.end(), 2))) {
    for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
      if (block.getSuccessor(i) == &header) {
        return loopOp.emitOpError("can only have the entry and loop continue "
                                  "block branching to the loop header block");
      }
    }
  }

  return success();
}

Block *spirv::LoopOp::getEntryBlock() {
  assert(!body().empty() && "op region should not be empty!");
  return &body().front();
}

Block *spirv::LoopOp::getHeaderBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The second block is the loop header block.
  return &*std::next(body().begin());
}

Block *spirv::LoopOp::getContinueBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The second to last block is the loop continue block.
  return &*std::prev(body().end(), 2);
}

Block *spirv::LoopOp::getMergeBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The last block is the loop merge block.
  return &body().back();
}

void spirv::LoopOp::addEntryAndMergeBlock() {
  assert(body().empty() && "entry and merge block already exist");
  body().push_back(new Block());
  auto *mergeBlock = new Block();
  body().push_back(mergeBlock);
  OpBuilder builder(mergeBlock);

  // Add a spv._merge op into the merge block.
  builder.create<spirv::MergeOp>(getLoc());
}

//===----------------------------------------------------------------------===//
// spv._merge
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::MergeOp mergeOp) {
  auto *parentOp = mergeOp.getParentOp();
  if (!parentOp ||
      (!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
    return mergeOp.emitOpError(
        "expected parent op to be 'spv.selection' or 'spv.loop'");

  Block &parentLastBlock = mergeOp.getParentRegion()->back();
  if (mergeOp.getOperation() != parentLastBlock.getTerminator())
    return mergeOp.emitOpError(
        "can only be used in the last block of 'spv.selection' or 'spv.loop'");
  return success();
}

//===----------------------------------------------------------------------===//
// spv.MemoryBarrier
//===----------------------------------------------------------------------===//

static ParseResult parseMemoryBarrierOp(OpAsmParser &parser,
                                        OperationState &state) {
  spirv::Scope memoryScope;
  spirv::MemorySemantics memorySemantics;

  return failure(
      parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) ||
      parser.parseComma() ||
      parseEnumAttribute(memorySemantics, parser, state));
}

static void print(spirv::MemoryBarrierOp op, OpAsmPrinter &printer) {
  printer << spirv::MemoryBarrierOp::getOperationName() << " \""
          << stringifyScope(op.memory_scope()) << "\", \""
          << stringifyMemorySemantics(op.memory_semantics()) << "\"";
}

//===----------------------------------------------------------------------===//
// spv.module
//===----------------------------------------------------------------------===//

void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
  ensureTerminator(*state.addRegion(), *builder, state.location);
}

// TODO(ravishankarm): This is only here for resolving some dependency outside
// of mlir. Remove once it is done.
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
                            IntegerAttr addressing_model,
                            IntegerAttr memory_model) {
  state.addAttribute("addressing_model", addressing_model);
  state.addAttribute("memory_model", memory_model);
  build(builder, state);
}

void spirv::ModuleOp::build(Builder *builder, OperationState &state,
                            spirv::AddressingModel addressing_model,
                            spirv::MemoryModel memory_model,
                            ArrayRef<spirv::Capability> capabilities,
                            ArrayRef<spirv::Extension> extensions,
                            ArrayAttr extended_instruction_sets) {
  state.addAttribute(
      "addressing_model",
      builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
  state.addAttribute("memory_model", builder->getI32IntegerAttr(
                                         static_cast<int32_t>(memory_model)));
  if (!capabilities.empty())
    state.addAttribute("capabilities",
                       getStrArrayAttrForEnumList<spirv::Capability>(
                           *builder, capabilities, spirv::stringifyCapability));
  if (!extensions.empty())
    state.addAttribute("extensions",
                       getStrArrayAttrForEnumList<spirv::Extension>(
                           *builder, extensions, spirv::stringifyExtension));
  if (extended_instruction_sets)
    state.addAttribute("extended_instruction_sets", extended_instruction_sets);
  build(builder, state);
}

static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
  Region *body = state.addRegion();

  // Parse attributes
  spirv::AddressingModel addrModel;
  spirv::MemoryModel memoryModel;
  if (parseEnumAttribute(addrModel, parser, state) ||
      parseEnumAttribute(memoryModel, parser, state)) {
    return failure();
  }

  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
    return failure();

  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
    return failure();

  spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
  return success();
}

static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
  printer << spirv::ModuleOp::getOperationName();

  // Only print out addressing model and memory model in a nicer way if both
  // presents. Otherwise, print them in the general form. This helps
  // debugging ill-formed ModuleOp.
  SmallVector<StringRef, 2> elidedAttrs;
  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
  if (moduleOp.getAttr(addressingModelAttrName) &&
      moduleOp.getAttr(memoryModelAttrName)) {
    printer << " \""
            << spirv::stringifyAddressingModel(moduleOp.addressing_model())
            << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
            << '"';
    elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
  }

  printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/false);
  printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
}

static LogicalResult verify(spirv::ModuleOp moduleOp) {
  auto &op = *moduleOp.getOperation();
  auto *dialect = op.getDialect();
  auto &body = op.getRegion(0).front();
  DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
      entryPoints;
  SymbolTable table(moduleOp);

  for (auto &op : body) {
    if (op.getDialect() == dialect) {
      // For EntryPoint op, check that the function and execution model is not
      // duplicated in EntryPointOps. Also verify that the interface specified
      // comes from globalVariables here to make this check cheaper.
      if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
        auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
        if (!funcOp) {
          return entryPointOp.emitError("function '")
                 << entryPointOp.fn() << "' not found in 'spv.module'";
        }
        if (auto interface = entryPointOp.interface()) {
          for (Attribute varRef : interface) {
            auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
            if (!varSymRef) {
              return entryPointOp.emitError(
                         "expected symbol reference for interface "
                         "specification instead of '")
                     << varRef;
            }
            auto variableOp =
                table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
            if (!variableOp) {
              return entryPointOp.emitError("expected spv.globalVariable "
                                            "symbol reference instead of'")
                     << varSymRef << "'";
            }
          }
        }

        auto key = std::pair<FuncOp, spirv::ExecutionModel>(
            funcOp, entryPointOp.execution_model());
        auto entryPtIt = entryPoints.find(key);
        if (entryPtIt != entryPoints.end()) {
          return entryPointOp.emitError("duplicate of a previous EntryPointOp");
        }
        entryPoints[key] = entryPointOp;
      }
      continue;
    }

    auto funcOp = dyn_cast<FuncOp>(op);
    if (!funcOp)
      return op.emitError("'spv.module' can only contain func and spv.* ops");

    if (funcOp.isExternal())
      return op.emitError("'spv.module' cannot contain external functions");

    for (auto &block : funcOp)
      for (auto &op : block) {
        if (op.getDialect() == dialect)
          continue;

        if (isa<FuncOp>(op))
          return op.emitError("'spv.module' cannot contain nested functions");

        return op.emitError(
            "functions in 'spv.module' can only contain spv.* ops");
      }
  }

  // Verify capabilities. ODS already guarantees that we have an array of
  // string attributes.
  if (auto caps = moduleOp.getAttrOfType<ArrayAttr>("capabilities")) {
    for (auto cap : caps.getValue()) {
      auto capStr = cap.cast<StringAttr>().getValue();
      if (!spirv::symbolizeCapability(capStr))
        return moduleOp.emitOpError("uses unknown capability: ") << capStr;
    }
  }

  // Verify extensions. ODS already guarantees that we have an array of
  // string attributes.
  if (auto exts = moduleOp.getAttrOfType<ArrayAttr>("extensions")) {
    for (auto ext : exts.getValue()) {
      auto extStr = ext.cast<StringAttr>().getValue();
      if (!spirv::symbolizeExtension(extStr))
        return moduleOp.emitOpError("uses unknown extension: ") << extStr;
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv._reference_of
//===----------------------------------------------------------------------===//

static ParseResult parseReferenceOfOp(OpAsmParser &parser,
                                      OperationState &state) {
  FlatSymbolRefAttr constRefAttr;
  Type type;
  if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
                            state.attributes) ||
      parser.parseColonType(type)) {
    return failure();
  }
  return parser.addTypeToList(type, state.types);
}

static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
  printer << spirv::ReferenceOfOp::getOperationName() << ' ';
  printer.printSymbolName(referenceOfOp.spec_const());
  printer << " : " << referenceOfOp.reference().getType();
}

static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
  auto moduleOp = referenceOfOp.getParentOfType<spirv::ModuleOp>();
  auto specConstOp =
      moduleOp.lookupSymbol<spirv::SpecConstantOp>(referenceOfOp.spec_const());
  if (!specConstOp) {
    return referenceOfOp.emitOpError("expected spv.specConstant symbol");
  }
  if (referenceOfOp.reference().getType() !=
      specConstOp.default_value().getType()) {
    return referenceOfOp.emitOpError("result type mismatch with the referenced "
                                     "specialization constant's type");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::ReturnOp returnOp) {
  auto funcOp = returnOp.getParentOfType<FuncOp>();
  auto numOutputs = funcOp.getType().getNumResults();
  if (numOutputs != 0)
    return returnOp.emitOpError("cannot be used in functions returning value")
           << (numOutputs > 1 ? "s" : "");

  return success();
}

//===----------------------------------------------------------------------===//
// spv.ReturnValue
//===----------------------------------------------------------------------===//

static ParseResult parseReturnValueOp(OpAsmParser &parser,
                                      OperationState &state) {
  OpAsmParser::OperandType retValInfo;
  Type retValType;
  return failure(parser.parseOperand(retValInfo) ||
                 parser.parseColonType(retValType) ||
                 parser.resolveOperand(retValInfo, retValType, state.operands));
}

static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
  printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
          << " : " << retValOp.value().getType();
}

static LogicalResult verify(spirv::ReturnValueOp retValOp) {
  auto funcOp = retValOp.getParentOfType<FuncOp>();
  auto numFnResults = funcOp.getType().getNumResults();
  if (numFnResults != 1)
    return retValOp.emitOpError(
               "returns 1 value but enclosing function requires ")
           << numFnResults << " results";

  auto operandType = retValOp.value().getType();
  auto fnResultType = funcOp.getType().getResult(0);
  if (operandType != fnResultType)
    return retValOp.emitOpError(" return value's type (")
           << operandType << ") mismatch with function's result type ("
           << fnResultType << ")";

  return success();
}

//===----------------------------------------------------------------------===//
// spv.Select
//===----------------------------------------------------------------------===//

void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
                            Value trueValue, Value falseValue) {
  build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
}

static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
  OpAsmParser::OperandType condition;
  SmallVector<OpAsmParser::OperandType, 2> operands;
  SmallVector<Type, 2> types;
  auto loc = parser.getCurrentLocation();
  if (parser.parseOperand(condition) || parser.parseComma() ||
      parser.parseOperandList(operands, 2) ||
      parser.parseColonTypeList(types)) {
    return failure();
  }
  if (types.size() != 2) {
    return parser.emitError(
        loc, "need exactly two trailing types for select condition and object");
  }
  if (parser.resolveOperand(condition, types[0], state.operands) ||
      parser.resolveOperands(operands, types[1], state.operands)) {
    return failure();
  }
  return parser.addTypesToList(types[1], state.types);
}

static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
  printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
          << " : " << op.condition().getType() << ", " << op.result().getType();
}

static LogicalResult verify(spirv::SelectOp op) {
  auto resultTy = op.result().getType();
  if (op.true_value().getType() != resultTy) {
    return op.emitOpError("result type and true value type must be the same");
  }
  if (op.false_value().getType() != resultTy) {
    return op.emitOpError("result type and false value type must be the same");
  }
  if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
    auto resultVectorTy = resultTy.dyn_cast<VectorType>();
    if (!resultVectorTy) {
      return op.emitOpError("result expected to be of vector type when "
                            "condition is of vector type");
    }
    if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
      return op.emitOpError("result should have the same number of elements as "
                            "the condition when condition is of vector type");
    }
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//

static ParseResult parseSelectionOp(OpAsmParser &parser,
                                    OperationState &state) {
  // TODO(antiagainst): support selection control properly
  Builder builder = parser.getBuilder();
  state.addAttribute("selection_control",
                     builder.getI32IntegerAttr(
                         static_cast<uint32_t>(spirv::SelectionControl::None)));

  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                            /*argTypes=*/{});
}

static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
  auto *op = selectionOp.getOperation();

  printer << spirv::SelectionOp::getOperationName();
  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/true);
}

static LogicalResult verify(spirv::SelectionOp selectionOp) {
  auto *op = selectionOp.getOperation();

  // We need to verify that the blocks follow the following layout:
  //
  //                     +--------------+
  //                     | header block |
  //                     +--------------+
  //                          / | \
  //                           ...
  //
  //
  //         +---------+   +---------+   +---------+
  //         | case #0 |   | case #1 |   | case #2 |  ...
  //         +---------+   +---------+   +---------+
  //
  //
  //                           ...
  //                          \ | /
  //                            v
  //                     +-------------+
  //                     | merge block |
  //                     +-------------+

  auto &region = op->getRegion(0);
  // Allow empty region as a degenerated case, which can come from
  // optimizations.
  if (region.empty())
    return success();

  // The last block is the merge block.
  if (!isMergeBlock(region.back()))
    return selectionOp.emitOpError(
        "last block must be the merge block with only one 'spv._merge' op");

  if (std::next(region.begin()) == region.end())
    return selectionOp.emitOpError("must have a selection header block");

  return success();
}

Block *spirv::SelectionOp::getHeaderBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The first block is the loop header block.
  return &body().front();
}

Block *spirv::SelectionOp::getMergeBlock() {
  assert(!body().empty() && "op region should not be empty!");
  // The last block is the loop merge block.
  return &body().back();
}

void spirv::SelectionOp::addMergeBlock() {
  assert(body().empty() && "entry and merge block already exist");
  auto *mergeBlock = new Block();
  body().push_back(mergeBlock);
  OpBuilder builder(mergeBlock);

  // Add a spv._merge op into the merge block.
  builder.create<spirv::MergeOp>(getLoc());
}

namespace {
// Blocks from the given `spv.selection` operation must satisfy the following
// layout:
//
//       +-----------------------------------------------+
//       | header block                                  |
//       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
//       +-----------------------------------------------+
//                            /   \
//                             ...
//
//
//   +------------------------+    +------------------------+
//   | case #0                |    | case #1                |
//   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
//   | spv.Branch ^merge      |    | spv.Branch ^merge      |
//   +------------------------+    +------------------------+
//
//
//                             ...
//                            \   /
//                              v
//                       +-------------+
//                       | merge block |
//                       +-------------+
//
struct ConvertSelectionOpToSelect
    : public OpRewritePattern<spirv::SelectionOp> {
  using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;

  PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
                                     PatternRewriter &rewriter) const override {
    auto *op = selectionOp.getOperation();
    auto &body = op->getRegion(0);
    // Verifier allows an empty region for `spv.selection`.
    if (body.empty()) {
      return matchFailure();
    }

    // Check that region consists of 4 blocks:
    // header block, `true` block, `false` block and merge block.
    if (std::distance(body.begin(), body.end()) != 4) {
      return matchFailure();
    }

    auto *headerBlock = selectionOp.getHeaderBlock();
    if (!onlyContainsBranchConditionalOp(headerBlock)) {
      return matchFailure();
    }

    auto brConditionalOp =
        cast<spirv::BranchConditionalOp>(headerBlock->front());

    auto *trueBlock = brConditionalOp.getSuccessor(0);
    auto *falseBlock = brConditionalOp.getSuccessor(1);
    auto *mergeBlock = selectionOp.getMergeBlock();

    if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
      return matchFailure();
    }

    auto trueValue = getSrcValue(trueBlock);
    auto falseValue = getSrcValue(falseBlock);
    auto ptrValue = getDstPtr(trueBlock);
    auto storeOpAttributes =
        cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();

    auto selectOp = rewriter.create<spirv::SelectOp>(
        selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
        trueValue, falseValue);
    rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
                                    selectOp.getResult(), storeOpAttributes);

    // `spv.selection` is not needed anymore.
    rewriter.eraseOp(op);
    return matchSuccess();
  }

private:
  // Checks that given blocks follow the following rules:
  // 1. Each conditional block consists of two operations, the first operation
  //    is a `spv.Store` and the last operation is a `spv.Branch`.
  // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
  // 3. A control flow goes into the given merge block from the given
  //    conditional blocks.
  PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
                                              Block *falseBlock,
                                              Block *mergeBlock) const;

  bool onlyContainsBranchConditionalOp(Block *block) const {
    return std::next(block->begin()) == block->end() &&
           isa<spirv::BranchConditionalOp>(block->front());
  }

  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
    return lhs.getOperation()->getAttrList().getDictionary() ==
           rhs.getOperation()->getAttrList().getDictionary();
  }

  // Checks that given type is valid for `spv.SelectOp`.
  // According to SPIR-V spec:
  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
  // Starting with version 1.4, Result Type can additionally be a composite type
  // other than a vector."
  bool isValidType(Type type) const {
    return spirv::SPIRVDialect::isValidScalarType(type) ||
           type.isa<VectorType>();
  }

  // Returns a soruce value for the given block.
  Value getSrcValue(Block *block) const {
    auto storeOp = cast<spirv::StoreOp>(block->front());
    return storeOp.value();
  }

  // Returns a destination value for the given block.
  Value getDstPtr(Block *block) const {
    auto storeOp = cast<spirv::StoreOp>(block->front());
    return storeOp.ptr();
  }
};

PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
    Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
  // Each block must consists of 2 operations.
  if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
      (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
    return matchFailure();
  }

  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
  auto trueBrBranchOp =
      dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
  auto falseBrBranchOp =
      dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));

  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
      !falseBrBranchOp) {
    return matchFailure();
  }

  // Check that each `spv.Store` uses the same pointer, memory access
  // attributes and a valid type of the value.
  if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
      !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
      !isValidType(trueBrStoreOp.value().getType())) {
    return matchFailure();
  }

  if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
      (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
    return matchFailure();
  }

  return matchSuccess();
}
} // end anonymous namespace

void spirv::SelectionOp::getCanonicalizationPatterns(
    OwningRewritePatternList &results, MLIRContext *context) {
  results.insert<ConvertSelectionOpToSelect>(context);
}

//===----------------------------------------------------------------------===//
// spv.specConstant
//===----------------------------------------------------------------------===//

static ParseResult parseSpecConstantOp(OpAsmParser &parser,
                                       OperationState &state) {
  StringAttr nameAttr;
  Attribute valueAttr;

  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes))
    return failure();

  // Parse optional spec_id.
  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
    IntegerAttr specIdAttr;
    if (parser.parseLParen() ||
        parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parser.parseEqual() ||
      parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
    return failure();

  return success();
}

static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
  printer << spirv::SpecConstantOp::getOperationName() << ' ';
  printer.printSymbolName(constOp.sym_name());
  if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
  printer << " = " << constOp.default_value();
}

static LogicalResult verify(spirv::SpecConstantOp constOp) {
  if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    if (specID.getValue().isNegative())
      return constOp.emitOpError("SpecId cannot be negative");

  auto value = constOp.default_value();

  switch (value.getKind()) {
  case StandardAttributes::Bool:
  case StandardAttributes::Integer:
  case StandardAttributes::Float: {
    // Make sure bitwidth is allowed.
    if (!spirv::SPIRVDialect::isValidType(value.getType()))
      return constOp.emitOpError("default value bitwidth disallowed");
    return success();
  }
  default:
    return constOp.emitOpError(
        "default value can only be a bool, integer, or float scalar");
  }
}

//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//

static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  auto loc = parser.getCurrentLocation();
  Type elementType;
  if (parseEnumAttribute(storageClass, parser) ||
      parser.parseOperandList(operandInfo, 2) ||
      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
                             state.operands)) {
    return failure();
  }
  return success();
}

static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
  auto *op = storeOp.getOperation();
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
  printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
          << storeOp.ptr() << ", " << storeOp.value();

  printMemoryAccessAttribute(storeOp, printer, elidedAttrs);

  printer << " : " << storeOp.value().getType();
  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}

static LogicalResult verify(spirv::StoreOp storeOp) {
  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
  // OpTypePointer whose Type operand is the same as the type of Object."
  if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
                                           storeOp.value()))) {
    return failure();
  }
  return verifyMemoryAccessAttribute(storeOp);
}

//===----------------------------------------------------------------------===//
// spv.SubgroupBallotKHROp
//===----------------------------------------------------------------------===//

static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
                                            OperationState &state) {
  OpAsmParser::OperandType operandInfo;
  Type resultType;
  IntegerType i1Type = parser.getBuilder().getI1Type();
  if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
      parser.resolveOperand(operandInfo, i1Type, state.operands))
    return failure();

  return parser.addTypeToList(resultType, state.types);
}

static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '
          << ballotOp.predicate() << " : " << ballotOp.getType();
}

//===----------------------------------------------------------------------===//
// spv.Undef
//===----------------------------------------------------------------------===//

static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &state) {
  Type type;
  if (parser.parseColonType(type)) {
    return failure();
  }
  state.addTypes(type);
  return success();
}

static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) {
  printer << spirv::UndefOp::getOperationName() << " : " << undefOp.getType();
}

//===----------------------------------------------------------------------===//
// spv.Unreachable
//===----------------------------------------------------------------------===//

static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
  auto *op = unreachableOp.getOperation();
  auto *block = op->getBlock();
  // Fast track: if this is in entry block, its invalid. Otherwise, if no
  // predecessors, it's valid.
  if (block->isEntryBlock())
    return unreachableOp.emitOpError("cannot be used in reachable block");
  if (block->hasNoPredecessors())
    return success();

  // TODO(antiagainst): further verification needs to analyze reachablility from
  // the entry block.

  return success();
}

//===----------------------------------------------------------------------===//
// spv.Variable
//===----------------------------------------------------------------------===//

static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
  // Parse optional initializer
  Optional<OpAsmParser::OperandType> initInfo;
  if (succeeded(parser.parseOptionalKeyword("init"))) {
    initInfo = OpAsmParser::OperandType();
    if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, state)) {
    return failure();
  }

  // Parse result pointer type
  Type type;
  if (parser.parseColon())
    return failure();
  auto loc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return failure();

  auto ptrType = type.dyn_cast<spirv::PointerType>();
  if (!ptrType)
    return parser.emitError(loc, "expected spv.ptr type");
  state.addTypes(ptrType);

  // Resolve the initializer operand
  if (initInfo) {
    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
                              state.operands))
      return failure();
  }

  auto attr = parser.getBuilder().getI32IntegerAttr(
      llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
  state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);

  return success();
}

static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};
  printer << spirv::VariableOp::getOperationName();

  // Print optional initializer
  if (varOp.getNumOperands() != 0)
    printer << " init(" << varOp.initializer() << ")";

  printVariableDecorations(varOp, printer, elidedAttrs);
  printer << " : " << varOp.getType();
}

static LogicalResult verify(spirv::VariableOp varOp) {
  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  if (varOp.storage_class() != spirv::StorageClass::Function) {
    return varOp.emitOpError(
        "can only be used to model function-level variables. Use "
        "spv.globalVariable for module-level variables.");
  }

  auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
  if (varOp.storage_class() != pointerType.getStorageClass())
    return varOp.emitOpError(
        "storage class must match result pointer's storage class");

  if (varOp.getNumOperands() != 0) {
    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
    // a global (module scope) OpVariable instruction".
    auto *initOp = varOp.getOperand(0).getDefiningOp();
    if (!initOp || !(isa<spirv::ConstantOp>(initOp) ||    // for normal constant
                     isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
                     isa<spirv::AddressOfOp>(initOp)))
      return varOp.emitOpError("initializer must be the result of a "
                               "constant or spv.globalVariable op");
  }

  // TODO(antiagainst): generate these strings using ODS.
  auto *op = varOp.getOperation();
  auto descriptorSetName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
  auto bindingName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
  auto builtInName =
      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));

  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
    if (op->getAttr(attr))
      return varOp.emitOpError("cannot have '")
             << attr << "' attribute (only allowed in spv.globalVariable)";
  }

  return success();
}

namespace mlir {
namespace spirv {

// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"

// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"

// TableGen'erated operation availability interface implementations.
#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"

} // namespace spirv
} // namespace mlir