//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/ADT/TypeSwitch.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "spirv-serialization" using namespace mlir; /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into /// the given `binary` vector. static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op, ArrayRef<uint32_t> operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); binary.append(operands.begin(), operands.end()); return success(); } /// A pre-order depth-first visitor function for processing basic blocks. /// /// Visits the basic blocks starting from the given `headerBlock` in pre-order /// depth-first manner and calls `blockHandler` on each block. Skips handling /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s /// successors. /// /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order /// of blocks in a function must satisfy the rule that blocks appear before /// all blocks they dominate." This can be achieved by a pre-order CFG /// traversal algorithm. To make the serialization output more logical and /// readable to human, we perform depth-first CFG traversal and delay the /// serialization of the merge block and the continue block, if exists, until /// after all other blocks have been processed. static LogicalResult visitInPrettyBlockOrder( Block *headerBlock, function_ref<LogicalResult(Block *)> blockHandler, bool skipHeader = false, ArrayRef<Block *> skipBlocks = {}) { llvm::df_iterator_default_set<Block *, 4> doneBlocks; doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { if (skipHeader && block == headerBlock) continue; if (failed(blockHandler(block))) return failure(); } return success(); } /// Returns the last structured control flow op's merge block if the given /// `block` contains any structured control flow op. Otherwise returns nullptr. static Block *getLastStructuredControlFlowOpMergeBlock(Block *block) { for (Operation &op : llvm::reverse(block->getOperations())) { if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) return selectionOp.getMergeBlock(); if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) return loopOp.getMergeBlock(); } return nullptr; } namespace { /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each /// instruction is composed of 32-bit words with the layout: /// /// | <word-count>|<opcode> | <operand> | <operand> | ... | /// | <------ word -------> | <-- word --> | <-- word --> | ... | /// /// For the first word, the 16 high-order bits are the word count of the /// instruction, the 16 low-order bits are the opcode enumerant. The /// instructions then belong to different sections, which must be laid out in /// the particular order as specified in "2.4 Logical Layout of a Module" of /// the SPIR-V spec. class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. explicit Serializer(spirv::ModuleOp module); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); /// Collects the final SPIR-V `binary`. void collect(SmallVectorImpl<uint32_t> &binary); /// (For debugging) prints each value and its corresponding result <id>. void printValueIDMap(raw_ostream &os); private: // Note that there are two main categories of methods in this class: // * process*() methods are meant to fully serialize a SPIR-V module entity // (header, type, op, etc.). They update internal vectors containing // different binary sections. They are not meant to be called except the // top-level serialization loop. // * prepare*() methods are meant to be helpers that prepare for serializing // certain entity. They may or may not update internal vectors containing // different binary sections. They are meant to be called among themselves // or by other process*() methods for subtasks. //===--------------------------------------------------------------------===// // <id> //===--------------------------------------------------------------------===// // Note that it is illegal to use id <0> in SPIR-V binary module. Various // methods in this class, if using SPIR-V word (uint32_t) as interface, // check or return id <0> to indicate error in processing. /// Consumes the next unused <id>. This method will never return 0. uint32_t getNextID() { return nextID++; } //===--------------------------------------------------------------------===// // Module structure //===--------------------------------------------------------------------===// uint32_t getSpecConstID(StringRef constName) const { return specConstIDMap.lookup(constName); } uint32_t getVariableID(StringRef varName) const { return globalVarIDMap.lookup(varName); } uint32_t getFunctionID(StringRef fnName) const { return funcIDMap.lookup(fnName); } /// Gets the <id> for the function with the given name. Assigns the next /// available <id> if the function haven't been deserialized. uint32_t getOrCreateFunctionID(StringRef fnName); void processCapability(); void processExtension(); void processMemoryModel(); LogicalResult processConstantOp(spirv::ConstantOp op); LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an /// OpUndef for each type needed at module level. LogicalResult processUndefOp(spirv::UndefOp op); /// Emit OpName for the given `resultID`. LogicalResult processName(uint32_t resultID, StringRef name); /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); LogicalResult processVariableOp(spirv::VariableOp op); /// Process a SPIR-V GlobalVariableOp LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); /// Process attributes that translate to decorations on the result <id> LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); template <typename DType> LogicalResult processTypeDecoration(Location loc, DType type, uint32_t resultId) { return emitError(loc, "unhandled decoration for type:") << type; } /// Process member decoration LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex, spirv::Decoration decorationType, ArrayRef<uint32_t> values = {}); //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } Type getVoidType() { return mlirBuilder.getNoneType(); } bool isVoidType(Type type) const { return type.isa<NoneType>(); } /// Returns true if the given type is a pointer type to a struct in Uniform or /// StorageBuffer storage class. bool isInterfaceStructPtrType(Type type) const; /// Main dispatch method for serializing a type. The result <id> of the /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands); LogicalResult prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands); //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// uint32_t getConstantID(Attribute value) const { return constIDMap.lookup(value); } /// Main dispatch method for processing a constant with the given `constType` /// and `valueAttr`. `constType` is needed here because we can interpret the /// `valueAttr` as a different type than the type of `valueAttr` itself; for /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType /// constants. uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); /// Prepares array attribute serialization. This method emits corresponding /// OpConstant* and returns the result <id> associated with it. Returns 0 if /// failed. uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); /// Prepares bool/int/float DenseElementsAttr serialization. This method /// iterates the DenseElementsAttr to construct the constant array, and /// returns the result <id> associated with it. Returns 0 if failed. Note /// that the size of `index` must match the rank. /// TODO(hanchung): Consider to enhance splat elements cases. For splat cases, /// we don't need to loop over all elements, especially when the splat value /// is zero. We can use OpConstantNull when the value is zero. uint32_t prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef<uint64_t> index); /// Prepares scalar attribute serialization. This method emits corresponding /// OpConstant* and returns the result <id> associated with it. Returns 0 if /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is /// true, then the constant will be serialized as a specialization constant. uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec = false); uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec = false); uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec = false); uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec = false); //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// /// Returns the result <id> for the given block. uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } /// Returns the result <id> for the given block. If no <id> has been assigned, /// assigns the next available <id> uint32_t getOrCreateBlockID(Block *block); /// Processes the given `block` and emits SPIR-V instructions for all ops /// inside. Does not emit OpLabel for this block if `omitLabel` is true. /// `actionBeforeTerminator` is a callback that will be invoked before /// handling the terminator op. It can be used to inject the Op*Merge /// instruction if this is a SPIR-V selection/loop header block. LogicalResult processBlock(Block *block, bool omitLabel = false, function_ref<void()> actionBeforeTerminator = nullptr); /// Emits OpPhi instructions for the given block if it has block arguments. LogicalResult emitPhiForBlockArguments(Block *block); LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); LogicalResult processLoopOp(spirv::LoopOp loopOp); LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); LogicalResult processBranchOp(spirv::BranchOp branchOp); //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// LogicalResult encodeExtensionInstruction(Operation *op, StringRef extensionSetName, uint32_t opcode, ArrayRef<uint32_t> operands); uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); /// Method to dispatch to the serialization function for an operation in /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. /// This is auto-generated from ODS. Dispatch is handled for all operations /// in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenSerialization(Operation *op); /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. template <typename OpTy> LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } //===--------------------------------------------------------------------===// // Utilities //===--------------------------------------------------------------------===// /// Emits an OpDecorate instruction to decorate the given `target` with the /// given `decoration`. LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef<uint32_t> params = {}); private: /// The SPIR-V module to be serialized. spirv::ModuleOp module; /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; /// The next available result <id>. uint32_t nextID = 1; // The following are for different SPIR-V instruction sections. They follow // the logical layout of a SPIR-V module. SmallVector<uint32_t, 4> capabilities; SmallVector<uint32_t, 0> extensions; SmallVector<uint32_t, 0> extendedSets; SmallVector<uint32_t, 3> memoryModel; SmallVector<uint32_t, 0> entryPoints; SmallVector<uint32_t, 4> executionModes; // TODO(antiagainst): debug instructions SmallVector<uint32_t, 0> names; SmallVector<uint32_t, 0> decorations; SmallVector<uint32_t, 0> typesGlobalValues; SmallVector<uint32_t, 0> functions; /// `functionHeader` contains all the instructions that must be in the first /// block in the function, and `functionBody` contains the rest. After /// processing FuncOp, the encoded instructions of a function are appended to /// `functions`. An example of instructions in `functionHeader` in order: /// OpFunction ... /// OpFunctionParameter ... /// OpFunctionParameter ... /// OpLabel ... /// OpVariable ... /// OpVariable ... SmallVector<uint32_t, 0> functionHeader; SmallVector<uint32_t, 0> functionBody; /// Map from type used in SPIR-V module to their <id>s. DenseMap<Type, uint32_t> typeIDMap; /// Map from constant values to their <id>s. DenseMap<Attribute, uint32_t> constIDMap; /// Map from specialization constant names to their <id>s. llvm::StringMap<uint32_t> specConstIDMap; /// Map from GlobalVariableOps name to <id>s. llvm::StringMap<uint32_t> globalVarIDMap; /// Map from FuncOps name to <id>s. llvm::StringMap<uint32_t> funcIDMap; /// Map from blocks to their <id>s. DenseMap<Block *, uint32_t> blockIDMap; /// Map from the Type to the <id> that represents undef value of that type. DenseMap<Type, uint32_t> undefValIDMap; /// Map from results of normal operations to their <id>s. DenseMap<Value, uint32_t> valueIDMap; /// Map from extended instruction set name to <id>s. llvm::StringMap<uint32_t> extendedInstSetIDMap; /// Map from values used in OpPhi instructions to their offset in the /// `functions` section. /// /// When processing a block with arguments, we need to emit OpPhi /// instructions to record the predecessor block <id>s and the values they /// send to the block in question. But it's not guaranteed all values are /// visited and thus assigned result <id>s. So we need this list to capture /// the offsets into `functions` where a value is used so that we can fix it /// up later after processing all the blocks in a function. /// /// More concretely, say if we are visiting the following blocks: /// /// ```mlir /// ^phi(%arg0: i32): /// ... /// ^parent1: /// ... /// spv.Branch ^phi(%val0: i32) /// ^parent2: /// ... /// spv.Branch ^phi(%val1: i32) /// ``` /// /// When we are serializing the `^phi` block, we need to emit at the beginning /// of the block OpPhi instructions which has the following parameters: /// /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 /// id-for-%val1 id-for-^parent2 /// /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit /// all the blocks twice and use the first visit to assign an <id> to each /// value. But it's paying the overheads just for OpPhi emission. Instead, /// we still visit the blocks once for emission. When we emit the OpPhi /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1. /// At the same time, we record their offsets in the emitted binary (which is /// placed inside `functions`) here. And then after emitting all blocks, we /// replace the dummy <id> 0 with the real result <id> by overwriting /// `functions[offset]`. DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues; }; } // namespace Serializer::Serializer(spirv::ModuleOp module) : module(module), mlirBuilder(module.getContext()) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); if (failed(module.verify())) return failure(); // TODO(antiagainst): handle the other sections processCapability(); processExtension(); processMemoryModel(); // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : module.getBlock()) { if (failed(processOperation(&op))) { return failure(); } } LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); return success(); } void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + typesGlobalValues.size() + functions.size(); binary.clear(); binary.reserve(moduleSize); spirv::appendModuleHeader(binary, nextID); binary.append(capabilities.begin(), capabilities.end()); binary.append(extensions.begin(), extensions.end()); binary.append(extendedSets.begin(), extendedSets.end()); binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } void Serializer::printValueIDMap(raw_ostream &os) { os << "\n= Value <id> Map =\n\n"; for (auto valueIDPair : valueIDMap) { Value val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; } else if (auto arg = val.dyn_cast<BlockArgument>()) { Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; } os << '\n'; } } //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { auto funcID = funcIDMap.lookup(fnName); if (!funcID) { funcID = getNextID(); funcIDMap[fnName] = funcID; } return funcID; } void Serializer::processCapability() { auto caps = module.getAttrOfType<ArrayAttr>("capabilities"); if (!caps) return; for (auto cap : caps.getValue()) { auto capStr = cap.cast<StringAttr>().getValue(); auto capVal = spirv::symbolizeCapability(capStr); encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, {static_cast<uint32_t>(*capVal)}); } } void Serializer::processExtension() { auto exts = module.getAttrOfType<ArrayAttr>("extensions"); if (!exts) return; SmallVector<uint32_t, 16> extName; for (auto ext : exts.getValue()) { auto extStr = ext.cast<StringAttr>().getValue(); extName.clear(); spirv::encodeStringLiteralInto(extName, extStr); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } } void Serializer::processMemoryModel() { uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt(); uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt(); encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); } LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { valueIDMap[op.getResult()] = resultID; return success(); } return failure(); } LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), /*isSpec=*/true)) { // Emit the OpDecorate instruction for SpecId. if (auto specID = op.getAttrOfType<IntegerAttr>("spec_id")) { auto val = static_cast<uint32_t>(specID.getInt()); emitDecoration(resultID, spirv::Decoration::SpecId, {val}); } specConstIDMap[op.sym_name()] = resultID; return processName(resultID, op.sym_name()); } return failure(); } LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; if (!id) { id = getNextID(); uint32_t typeID = 0; if (failed(processType(op.getLoc(), undefType, typeID)) || failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, {typeID, id}))) { return failure(); } } valueIDMap[op.getResult()] = id; return success(); } LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.first.strref(); auto decorationName = mlir::convertToCamelCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( loc, "non-argument attributes expected to have snake-case-ified " "decoration name, unhandled attribute with name : ") << attrName; } SmallVector<uint32_t, 1> args; switch (decoration.getValue()) { case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast<uint32_t>(enumVal.getValue())); break; } return emitError(loc, "invalid ") << attrName << " attribute " << strAttr.getValue(); } return emitError(loc, "expected string attribute for ") << attrName; default: return emitError(loc, "unhandled decoration ") << decorationName; } return emitDecoration(resultID, decoration.getValue(), args); } LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { assert(!name.empty() && "unexpected empty string for OpName"); SmallVector<uint32_t, 4> nameOperands; nameOperands.push_back(resultID); if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { return failure(); } return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); } namespace { template <> LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( Location loc, spirv::ArrayType type, uint32_t resultID) { if (type.hasLayout()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral return emitDecoration(resultID, spirv::Decoration::ArrayStride, {static_cast<uint32_t>(type.getArrayStride())}); } return success(); } LogicalResult Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, spirv::Decoration decorationType, ArrayRef<uint32_t> values) { SmallVector<uint32_t, 4> args( {structID, memberIndex, static_cast<uint32_t>(decorationType)}); if (!values.empty()) { args.append(values.begin(), values.end()); } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); } } // namespace LogicalResult Serializer::processFuncOp(FuncOp op) { LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); assert(functionHeader.empty() && functionBody.empty()); uint32_t fnTypeID = 0; // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); // Add the function definition. SmallVector<uint32_t, 4> operands; uint32_t resTypeID = 0; auto resultTypes = op.getType().getResults(); if (resultTypes.size() > 1) { return op.emitError("cannot serialize function with multiple return types"); } if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), resTypeID))) { return failure(); } operands.push_back(resTypeID); auto funcID = getOrCreateFunctionID(op.getName()); operands.push_back(funcID); // TODO : Support other function control options. operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None)); operands.push_back(fnTypeID); encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); // Add function name. if (failed(processName(funcID, op.getName()))) { return failure(); } // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { return failure(); } auto argValueID = getNextID(); valueIDMap[arg] = argValueID; encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, {argTypeID, argValueID}); } // Process the body. if (op.isExternal()) { return op.emitError("external function is unhandled"); } // Some instructions (e.g., OpVariable) in a function must be in the first // block in the function. These instructions will be put in functionHeader. // Thus, we put the label in functionHeader first, and omit it from the first // block. encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, {getOrCreateBlockID(&op.front())}); processBlock(&op.front(), /*omitLabel=*/true); if (failed(visitInPrettyBlockOrder( &op.front(), [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true))) { return failure(); } // There might be OpPhi instructions who have value references needing to fix. for (auto deferredValue : deferredPhiValues) { Value value = deferredValue.first; uint32_t id = getValueID(value); LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value << " to id = " << id << '\n'); assert(id && "OpPhi references undefined value!"); for (size_t offset : deferredValue.second) functionBody[offset] = id; } deferredPhiValues.clear(); LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() << "' --\n"); // Insert OpFunctionEnd. if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}))) { return failure(); } functions.append(functionHeader.begin(), functionHeader.end()); functions.append(functionBody.begin(), functionBody.end()); functionHeader.clear(); functionBody.clear(); return success(); } LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { SmallVector<uint32_t, 4> operands; SmallVector<StringRef, 2> elidedAttrs; uint32_t resultID = 0; uint32_t resultTypeID = 0; if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { return failure(); } operands.push_back(resultTypeID); resultID = getNextID(); valueIDMap[op.getResult()] = resultID; operands.push_back(resultID); auto attr = op.getAttr(spirv::attributeName<spirv::StorageClass>()); if (attr) { operands.push_back(static_cast<uint32_t>( attr.cast<IntegerAttr>().getValue().getZExtValue())); } elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); for (auto arg : op.getODSOperands(0)) { auto argID = getValueID(arg); if (!argID) { return emitError(op.getLoc(), "operand 0 has a use before def"); } operands.push_back(argID); } encodeInstructionInto(functionHeader, spirv::getOpcode<spirv::VariableOp>(), operands); for (auto attr : op.getAttrs()) { if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return attr.first.is(elided); })) { continue; } if (failed(processDecoration(op.getLoc(), resultID, attr))) { return failure(); } } return success(); } LogicalResult Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { // Get TypeID. uint32_t resultTypeID = 0; SmallVector<StringRef, 4> elidedAttrs; if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { return failure(); } if (isInterfaceStructPtrType(varOp.type())) { auto structType = varOp.type() .cast<spirv::PointerType>() .getPointeeType() .cast<spirv::StructType>(); if (failed( emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { return varOp.emitError("cannot decorate ") << structType << " with Block decoration"; } } elidedAttrs.push_back("type"); SmallVector<uint32_t, 4> operands; operands.push_back(resultTypeID); auto resultID = getNextID(); // Encode the name. auto varName = varOp.sym_name(); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); if (failed(processName(resultID, varName))) { return failure(); } globalVarIDMap[varName] = resultID; operands.push_back(resultID); // Encode StorageClass. operands.push_back(static_cast<uint32_t>(varOp.storageClass())); // Encode initialization. if (auto initializer = varOp.initializer()) { auto initializerID = getVariableID(initializer.getValue()); if (!initializerID) { return emitError(varOp.getLoc(), "invalid usage of undefined variable as initializer"); } operands.push_back(initializerID); elidedAttrs.push_back("initializer"); } if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands))) { elidedAttrs.push_back("initializer"); return failure(); } // Encode decorations. for (auto attr : varOp.getAttrs()) { if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return attr.first.is(elided); })) { continue; } if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { return failure(); } } return success(); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// bool Serializer::isInterfaceStructPtrType(Type type) const { if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { auto storageClass = ptrType.getStorageClass(); if (storageClass == spirv::StorageClass::Uniform || storageClass == spirv::StorageClass::StorageBuffer) { return ptrType.getPointeeType().isa<spirv::StructType>(); } } return false; } LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { typeID = getTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector<uint32_t, 4> operands; operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; if ((type.isa<FunctionType>() && succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) { typeIDMap[type] = typeID; return encodeInstructionInto(typesGlobalValues, typeEnum, operands); } return failure(); } LogicalResult Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands) { if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); } if (auto intType = type.dyn_cast<IntegerType>()) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); } typeEnum = spirv::Opcode::OpTypeInt; operands.push_back(intType.getWidth()); // TODO(antiagainst): support unsigned integers operands.push_back(1); return success(); } if (auto floatType = type.dyn_cast<FloatType>()) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } if (auto vectorType = type.dyn_cast<VectorType>()) { uint32_t elementTypeID = 0; if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; operands.push_back(elementTypeID); operands.push_back(vectorType.getNumElements()); return success(); } if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) { return failure(); } operands.push_back(elementTypeID); if (auto elementCountID = prepareConstantInt( loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { operands.push_back(elementCountID); } return processTypeDecoration(loc, arrayType, resultID); } if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { uint32_t pointeeTypeID = 0; if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) { return failure(); } typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); return success(); } if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { uint32_t elementTypeID = 0; if (failed(processType(loc, runtimeArrayType.getElementType(), elementTypeID))) { return failure(); } operands.push_back(elementTypeID); typeEnum = spirv::Opcode::OpTypeRuntimeArray; return success(); } if (auto structType = type.dyn_cast<spirv::StructType>()) { bool hasLayout = structType.hasLayout(); for (auto elementIndex : llvm::seq<uint32_t>(0, structType.getNumElements())) { uint32_t elementTypeID = 0; if (failed(processType(loc, structType.getElementType(elementIndex), elementTypeID))) { return failure(); } operands.push_back(elementTypeID); if (hasLayout) { // Decorate each struct member with an offset if (failed(processMemberDecoration( resultID, elementIndex, spirv::Decoration::Offset, static_cast<uint32_t>(structType.getOffset(elementIndex))))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType << " with its offset"; } } } SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; structType.getMemberDecorations(memberDecorations); for (auto &memberDecoration : memberDecorations) { if (failed(processMemberDecoration(resultID, memberDecoration.first, memberDecoration.second))) { return emitError(loc, "cannot decorate ") << memberDecoration.first << "-th member of " << structType << " with " << stringifyDecoration(memberDecoration.second); } } typeEnum = spirv::Opcode::OpTypeStruct; return success(); } // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } LogicalResult Serializer::prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && "serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), resultID))) { return failure(); } operands.push_back(resultID); for (auto &res : type.getInputs()) { uint32_t argTypeID = 0; if (failed(processType(loc, res, argTypeID))) { return failure(); } operands.push_back(argTypeID); } return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// uint32_t Serializer::prepareConstant(Location loc, Type constType, Attribute valueAttr) { if (auto id = prepareConstantScalar(loc, valueAttr)) { return id; } // This is a composite literal. We need to handle each component separately // and then emit an OpConstantComposite for the whole. if (auto id = getConstantID(valueAttr)) { return id; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = 0; if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { int rank = attr.getType().dyn_cast<ShapedType>().getRank(); SmallVector<uint64_t, 4> index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } if (resultID == 0) { emitError(loc, "cannot serialize attribute: ") << valueAttr; return 0; } constIDMap[valueAttr] = resultID; return resultID; } uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, ArrayAttr attr) { uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector<uint32_t, 4> operands = {typeID, resultID}; operands.reserve(attr.size() + 2); auto elementType = constType.cast<spirv::ArrayType>().getElementType(); for (Attribute elementAttr : attr) { if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } // TODO(hanchung): Turn the below function into iterative function, instead of // recursive function. uint32_t Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef<uint64_t> index) { auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { return attr.getType().getElementType().isInteger(1) ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); } if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); } return 0; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector<uint32_t, 4> operands = {typeID, resultID}; operands.reserve(shapedType.getDimSize(dim) + 2); auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); for (int i = 0; i < shapedType.getDimSize(dim); ++i) { index[dim] = i; if (auto elementID = prepareDenseElementsConstant( loc, elementType, valueAttr, dim + 1, index)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec) { if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { return prepareConstantFp(loc, floatAttr, isSpec); } if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { return prepareConstantInt(loc, intAttr, isSpec); } if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { return prepareConstantBool(loc, boolAttr, isSpec); } return 0; } uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(boolAttr)) { return id; } } // Process the type for this bool literal uint32_t typeID = 0; if (failed(processType(loc, boolAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); auto opcode = boolAttr.getValue() ? (isSpec ? spirv::Opcode::OpSpecConstantTrue : spirv::Opcode::OpConstantTrue) : (isSpec ? spirv::Opcode::OpSpecConstantFalse : spirv::Opcode::OpConstantFalse); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); if (!isSpec) { constIDMap[boolAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(intAttr)) { return id; } } // Process the type for this integer literal uint32_t typeID = 0; if (failed(processType(loc, intAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APInt value = intAttr.getValue(); unsigned bitwidth = value.getBitWidth(); bool isSigned = value.isSignedIntN(bitwidth); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; // According to SPIR-V spec, "When the type's bit width is less than 32-bits, // the literal's value appears in the low-order bits of the word, and the // high-order bits must be 0 for a floating-point type, or 0 for an integer // type with Signedness of 0, or sign extended when Signedness is 1." if (bitwidth == 32 || bitwidth == 16) { uint32_t word = 0; if (isSigned) { word = static_cast<int32_t>(value.getSExtValue()); } else { word = static_cast<uint32_t>(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } // According to SPIR-V spec: "When the type's bit width is larger than one // word, the literal’s low-order words appear first." else if (bitwidth == 64) { struct DoubleWord { uint32_t word1; uint32_t word2; } words; if (isSigned) { words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); } else { words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss, /*isSigned=*/false); emitError(loc, "cannot serialize ") << bitwidth << "-bit integer literal: " << rss.str(); return 0; } if (!isSpec) { constIDMap[intAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(floatAttr)) { return id; } } // Process the type for this float literal uint32_t typeID = 0; if (failed(processType(loc, floatAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APFloat value = floatAttr.getValue(); APInt intValue = value.bitcastToAPInt(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; if (&value.getSemantics() == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { uint32_t word = static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss); emitError(loc, "cannot serialize ") << floatAttr.getType() << "-typed float literal: " << rss.str(); return 0; } if (!isSpec) { constIDMap[floatAttr] = resultID; } return resultID; } //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateBlockID(Block *block) { if (uint32_t id = getBlockID(block)) return id; return blockIDMap[block] = getNextID(); } LogicalResult Serializer::processBlock(Block *block, bool omitLabel, function_ref<void()> actionBeforeTerminator) { LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); LLVM_DEBUG(block->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { uint32_t blockID = getOrCreateBlockID(block); LLVM_DEBUG(llvm::dbgs() << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); } // Emit OpPhi instructions for block arguments, if any. if (failed(emitPhiForBlockArguments(block))) return failure(); // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) return failure(); } // Process the terminator. if (actionBeforeTerminator) actionBeforeTerminator(); if (failed(processOperation(&block->back()))) return failure(); return success(); } LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Nothing to do if this block has no arguments or it's the entry block, which // always has the same arguments as the function signature. if (block->args_empty() || block->isEntryBlock()) return success(); // If the block has arguments, we need to create SPIR-V OpPhi instructions. // A SPIR-V OpPhi instruction is of the syntax: // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); // Check whether this predecessor block contains a structured control flow // op. If so, the structured control flow op will be serialized to multiple // SPIR-V blocks. The branch op jumping to the OpPhi's block then resides in // the last structured control flow op's merge block. if (auto *merge = getLastStructuredControlFlowOpMergeBlock(predecessor)) predecessor = merge; if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { predecessors.emplace_back(predecessor, branchOp.operand_begin()); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } } // Then create OpPhi instruction for each of the block argument. for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { BlockArgument arg = block->getArgument(argIndex); // Get the type <id> and result <id> for this OpPhi instruction. uint32_t phiTypeID = 0; if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) return failure(); uint32_t phiID = getNextID(); LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' << arg << " (id = " << phiID << ")\n"); SmallVector<uint32_t, 8> phiArgs; phiArgs.push_back(phiTypeID); phiArgs.push_back(phiID); for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { Value value = *(predecessors[predIndex].second + argIndex); uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); // Each pair is a value <id> ... uint32_t valueId = getValueID(value); if (valueId == 0) { // The op generating this value hasn't been visited yet so we don't have // an <id> assigned yet. Record this to fix up later. LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); deferredPhiValues[value].push_back(functionBody.size() + 1 + phiArgs.size()); } else { LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); } phiArgs.push_back(valueId); // ... and a parent block <id>. phiArgs.push_back(predBlockId); } encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); valueIDMap[arg] = phiID; } return success(); } LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // Assign <id>s to all blocks so that branches inside the SelectionOp can // resolve properly. auto &body = selectionOp.body(); for (Block &block : body) getOrCreateBlockID(&block); auto *headerBlock = selectionOp.getHeaderBlock(); auto *mergeBlock = selectionOp.getMergeBlock(); auto mergeID = getBlockID(mergeBlock); // Emit the selection header block, which dominates all other blocks, first. // We need to emit an OpSelectionMerge instruction before the selection header // block's terminator. auto emitSelectionMerge = [&]() { // TODO(antiagainst): properly support selection control here encodeInstructionInto( functionBody, spirv::Opcode::OpSelectionMerge, {mergeID, static_cast<uint32_t>(spirv::SelectionControl::None)}); }; // For structured selection, we cannot have blocks in the selection construct // branching to the selection header block. Entering the selection (and // reaching the selection header) must be from the block containing the // spv.selection op. If there are ops ahead of the spv.selection op in the // block, we can "merge" them into the selection header. So here we don't need // to emit a separate block; just continue with the existing block. if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) return failure(); // Process all blocks with a depth-first visitor starting from the header // block. The selection header block and merge block are skipped by this // visitor. if (failed(visitInPrettyBlockOrder( headerBlock, [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) return failure(); // There is nothing to do for the merge block in the selection, which just // contains a spv._merge op, itself. But we need to have an OpLabel // instruction to start a new SPIR-V block for ops following this SelectionOp. // The block should use the <id> for the merge block. return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // Assign <id>s to all blocks so that branches inside the LoopOp can resolve // properly. We don't need to assign for the entry block, which is just for // satisfying MLIR region's structural requirement. auto &body = loopOp.body(); for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end())) { getOrCreateBlockID(&block); } auto *headerBlock = loopOp.getHeaderBlock(); auto *continueBlock = loopOp.getContinueBlock(); auto *mergeBlock = loopOp.getMergeBlock(); auto headerID = getBlockID(headerBlock); auto continueID = getBlockID(continueBlock); auto mergeID = getBlockID(mergeBlock); // This LoopOp is in some MLIR block with preceding and following ops. In the // binary format, it should reside in separate SPIR-V blocks from its // preceding and following ops. So we need to emit unconditional branches to // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow // afterwards. encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); // We omit the LoopOp's entry block and start serialization from the loop // header block. The entry block should not contain any additional ops other // than a single spv.Branch that jumps to the loop header block. However, // the spv.Branch can contain additional block arguments. Those block // arguments must come from out of the loop using implicit capture. We will // need to query the <id> for the value sent and the <id> for the incoming // parent block. For the latter, we need to make sure this block is // registered. The value sent should come from the block this loop resides in. blockIDMap[loopOp.getEntryBlock()] = getBlockID(loopOp.getOperation()->getBlock()); // Emit the loop header block, which dominates all other blocks, first. We // need to emit an OpLoopMerge instruction before the loop header block's // terminator. auto emitLoopMerge = [&]() { // TODO(antiagainst): properly support loop control here encodeInstructionInto( functionBody, spirv::Opcode::OpLoopMerge, {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)}); }; if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) return failure(); // Process all blocks with a depth-first visitor starting from the header // block. The loop header block, loop continue block, and loop merge block are // skipped by this visitor and handled later in this function. if (failed(visitInPrettyBlockOrder( headerBlock, [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) return failure(); // We have handled all other blocks. Now get to the loop continue block. if (failed(processBlock(continueBlock))) return failure(); // There is nothing to do for the merge block in the loop, which just contains // a spv._merge op, itself. But we need to have an OpLabel instruction to // start a new SPIR-V block for ops following this LoopOp. The block should // use the <id> for the merge block. return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { auto conditionID = getValueID(condBranchOp.condition()); auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; if (auto weights = condBranchOp.branch_weights()) { for (auto val : weights->getValue()) arguments.push_back(val.cast<IntegerAttr>().getInt()); } return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, arguments); } LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {getOrCreateBlockID(branchOp.getTarget())}); } //===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// LogicalResult Serializer::encodeExtensionInstruction( Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, ArrayRef<uint32_t> operands) { // Check if the extension has been imported. auto &setID = extendedInstSetIDMap[extensionSetName]; if (!setID) { setID = getNextID(); SmallVector<uint32_t, 16> importOperands; importOperands.push_back(setID); if (failed( spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || failed(encodeInstructionInto( extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { return failure(); } } // The first two operands are the result type <id> and result <id>. The set // <id> and the opcode need to be insert after this. if (operands.size() < 2) { return op->emitError("extended instructions must have a result encoding"); } SmallVector<uint32_t, 8> extInstOperands; extInstOperands.reserve(operands.size() + 2); extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); extInstOperands.push_back(setID); extInstOperands.push_back(extensionOpcode); extInstOperands.append(std::next(operands.begin(), 2), operands.end()); return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, extInstOperands); } LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.variable(); auto variableID = getVariableID(varName); if (!variableID) { return addressOfOp.emitError("unknown result <id> for variable ") << varName; } valueIDMap[addressOfOp.pointer()] = variableID; return success(); } LogicalResult Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { auto constName = referenceOfOp.spec_const(); auto constID = getSpecConstID(constName); if (!constID) { return referenceOfOp.emitError( "unknown result <id> for specialization constant ") << constName; } valueIDMap[referenceOfOp.reference()] = constID; return success(); } LogicalResult Serializer::processOperation(Operation *opInst) { LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. return TypeSwitch<Operation *, LogicalResult>(opInst) .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) .Case([&](spirv::BranchConditionalOp op) { return processBranchConditionalOp(op); }) .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) .Case([&](FuncOp op) { return processFuncOp(op); }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) .Case([&](spirv::ModuleEndOp) { return success(); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) // Then handle all the ops that directly mirror SPIR-V instructions with // auto-generated methods. .Default( [&](Operation *op) { return dispatchToAutogenSerialization(op); }); } namespace { template <> LogicalResult Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { SmallVector<uint32_t, 4> operands; // Add the ExecutionModel. operands.push_back(static_cast<uint32_t>(op.execution_model())); // Add the function <id>. auto funcID = getFunctionID(op.fn()); if (!funcID) { return op.emitError("missing <id> for function ") << op.fn() << "; function needs to be defined before spv.EntryPoint is " "serialized"; } operands.push_back(funcID); // Add the name of the function. spirv::encodeStringLiteralInto(operands, op.fn()); // Add the interface values. if (auto interface = op.interface()) { for (auto var : interface.getValue()) { auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); if (!id) { return op.emitError("referencing undefined global variable." "spv.EntryPoint is at the end of spv.module. All " "referenced variables should already be defined"); } operands.push_back(id); } } return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); } template <> LogicalResult Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) { StringRef argNames[] = {"execution_scope", "memory_scope", "memory_semantics"}; SmallVector<uint32_t, 3> operands; for (auto argName : argNames) { auto argIntAttr = op.getAttrOfType<IntegerAttr>(argName); auto operand = prepareConstantInt(op.getLoc(), argIntAttr); if (!operand) { return failure(); } operands.push_back(operand); } return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, operands); } template <> LogicalResult Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { SmallVector<uint32_t, 4> operands; // Add the function <id>. auto funcID = getFunctionID(op.fn()); if (!funcID) { return op.emitError("missing <id> for function ") << op.fn() << "; function needs to be serialized before ExecutionModeOp is " "serialized"; } operands.push_back(funcID); // Add the ExecutionMode. operands.push_back(static_cast<uint32_t>(op.execution_mode())); // Serialize values if any. auto values = op.values(); if (values) { for (auto &intVal : values.getValue()) { operands.push_back(static_cast<uint32_t>( intVal.cast<IntegerAttr>().getValue().getZExtValue())); } } return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, operands); } template <> LogicalResult Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) { StringRef argNames[] = {"memory_scope", "memory_semantics"}; SmallVector<uint32_t, 2> operands; for (auto argName : argNames) { auto argIntAttr = op.getAttrOfType<IntegerAttr>(argName); auto operand = prepareConstantInt(op.getLoc(), argIntAttr); if (!operand) { return failure(); } operands.push_back(operand); } return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands); } template <> LogicalResult Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { auto funcName = op.callee(); uint32_t resTypeID = 0; SmallVector<Type, 1> resultTypes(op.getResultTypes()); if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), resTypeID))) { return failure(); } auto funcID = getOrCreateFunctionID(funcName); auto funcCallID = getNextID(); SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; for (auto value : op.arguments()) { auto valueID = getValueID(value); assert(valueID && "cannot find a value for spv.FunctionCall"); operands.push_back(valueID); } if (!resultTypes.empty()) { valueIDMap[op.getResult(0)] = funcCallID; } return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); } // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace LogicalResult Serializer::emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef<uint32_t> params) { uint32_t wordCount = 3 + params.size(); decorations.push_back( spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); decorations.push_back(target); decorations.push_back(static_cast<uint32_t>(decoration)); decorations.append(params.begin(), params.end()); return success(); } LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl<uint32_t> &binary) { Serializer serializer(module); if (failed(serializer.serialize())) return failure(); LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); serializer.collect(binary); return success(); }