//===- FxpMathConfig.cpp - Reference fixed point config -------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines a TargetConfiguration for reference fixed-point math // quantization scheme based on the FxpMathOps (plus a small category of // extension ops that can be added from other dialects). // //===----------------------------------------------------------------------===// #include "mlir/Quantizer/Configurations/FxpMathConfig.h" #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" #include "mlir/Quantizer/Support/Metadata.h" #include "mlir/Quantizer/Support/Statistics.h" #include "mlir/Quantizer/Support/UniformConstraints.h" using namespace mlir; using namespace mlir::quantizer; using namespace mlir::fxpmath; using namespace mlir::quant; using namespace std::placeholders; namespace { struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { FxpMathTargetConfigImpl(SolverContext &context) : FxpMathTargetConfig(context) { Builder b(&context.getMlirContext()); IntegerType i8Type = b.getIntegerType(8); IntegerType i16Type = b.getIntegerType(16); IntegerType i32Type = b.getIntegerType(32); q8 = addCandidateType( AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr, std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), CandidateQuantizedType::Scheme::UniformPerLayer); q16 = addCandidateType( AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr, std::numeric_limits<int16_t>::min(), std::numeric_limits<int16_t>::max()), CandidateQuantizedType::Scheme::UniformPerLayer); q32ExplicitFixedPoint = addCandidateType( AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr, std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max()), CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale); // Op handlers. addOpHandler<ConstantOp>( std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); addOpHandler<ReturnOp>( std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); addOpHandler<quant::StatisticsOp>( std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); // FxpMathOps. addOpHandler<RealAddEwOp>( std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2)); addOpHandler<RealMulEwOp>( std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2)); addOpHandler<RealMatMulOp>( std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2)); addOpHandler<RealMatMulBiasOp>( std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2)); // Require stats ops. addRequireStatsOp<RealAddEwOp>(); addRequireStatsOp<RealSubEwOp>(); addRequireStatsOp<RealDivEwOp>(); addRequireStatsOp<RealMulEwOp>(); addRequireStatsOp<RealMatMulOp>(); addRequireStatsOp<RealMatMulBiasOp>(); } bool isHandledType(Type t) const final { if (t.isa<FloatType>()) return true; return (t.isa<VectorType>() || t.isa<TensorType>()) && t.cast<ShapedType>().getElementType().isa<FloatType>(); } void finalizeAnchors(CAGSlice &cag) const override { cag.enumerateImpliedConnections( [&](CAGAnchorNode *from, CAGAnchorNode *to) { UniformConstraintsBuilder(cag).coupleAnchors(from, to); }); } void addValueIdentityOpByName(StringRef opName) override { addOpHandlerByName( opName, std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2)); } void handleValueIdentity(Operation *op, CAGSlice &cag) const { assert(op->getNumResults() == 1); if (!isHandledType(op->getResult(0).getType())) return; auto resultNode = cag.getResultAnchor(op, 0); resultNode->setTypeTransformRule( CAGAnchorNode::TypeTransformRule::DirectStorage); for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) { if (!isHandledType(op->getOperand(opIdx).getType())) continue; auto operandNode = cag.getOperandAnchor(op, opIdx); operandNode->setTypeTransformRule( CAGAnchorNode::TypeTransformRule::DirectStorage); UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode); } } void handleConstant(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto resultNode = cag.getResultAnchor(op, 0); resultNode->setTypeTransformRule( CAGAnchorNode::TypeTransformRule::ExpressedOnly); Attribute valueAttr; if (!matchPattern(op, m_Constant(&valueAttr))) { return; } AttributeTensorStatistics stats(valueAttr); TensorAxisStatistics layerStats; if (!stats.get(layerStats)) { op->emitOpError("could not compute statistics"); return; } UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); } void handleTerminal(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getOperand(0).getType())) return; auto operandNode = cag.getOperandAnchor(op, 0); operandNode->setTypeTransformRule( CAGAnchorNode::TypeTransformRule::ExpressedOnly); } void handleStats(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto argNode = cag.getOperandAnchor(op, 0); auto resultNode = cag.getResultAnchor(op, 0); UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode); TensorAxisStatistics layerStats; auto statsOp = cast<quant::StatisticsOp>(op); auto layerStatsAttr = statsOp.layerStats(); layerStats.minValue = layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble(); layerStats.maxValue = layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble(); UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); } void handleAdd(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto lhs = cag.getOperandAnchor(op, 0); auto rhs = cag.getOperandAnchor(op, 1); auto resultNode = cag.getResultAnchor(op, 0); // Add supports 8/16 bit math. llvm::SmallBitVector disableMask = getCandidateTypeDisabledExceptMask({q8, q16}); lhs->getUniformMetadata().disabledCandidateTypes = disableMask; rhs->getUniformMetadata().disabledCandidateTypes = disableMask; resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; // NOTE: We couple the add such that the scale/zeroPoint match between // both args and the result. This is overly constrained in that it is // possible to write efficient add kernels with a bit more freedom (i.e. // zeroPoints can vary, scales can differ by a power of two, etc). // However, fully coupled yields the simples solutions on the fast path. // Further efficiency can be had by constraining the zeroPoint to 0, but // there isn't a constraint for this yet (and there are tradeoffs). UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode); UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode); addRealMathOptionalConstraints(op, resultNode, cag); } void handleMul(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto lhs = cag.getOperandAnchor(op, 0); auto rhs = cag.getOperandAnchor(op, 1); auto resultNode = cag.getResultAnchor(op, 0); // Mul supports 8/16 bit math. llvm::SmallBitVector disableMask = getCandidateTypeDisabledExceptMask({q8, q16}); lhs->getUniformMetadata().disabledCandidateTypes = disableMask; rhs->getUniformMetadata().disabledCandidateTypes = disableMask; resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; addRealMathOptionalConstraints(op, resultNode, cag); } void handleMatMul(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto lhs = cag.getOperandAnchor(op, 0); auto rhs = cag.getOperandAnchor(op, 1); auto resultNode = cag.getResultAnchor(op, 0); // Mul supports 8/16 bit math. llvm::SmallBitVector disableMask = getCandidateTypeDisabledExceptMask({q8, q16}); lhs->getUniformMetadata().disabledCandidateTypes = disableMask; rhs->getUniformMetadata().disabledCandidateTypes = disableMask; resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; addRealMathOptionalConstraints(op, resultNode, cag); } void handleMatMulBias(Operation *op, CAGSlice &cag) const { if (!isHandledType(op->getResult(0).getType())) return; auto lhs = cag.getOperandAnchor(op, 0); auto rhs = cag.getOperandAnchor(op, 1); auto bias = cag.getOperandAnchor(op, 2); bias->getUniformMetadata().disabledCandidateTypes = getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint}); auto resultNode = cag.getResultAnchor(op, 0); UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias); // Mul supports 8/16 bit math. llvm::SmallBitVector disableMask = getCandidateTypeDisabledExceptMask({q8, q16}); lhs->getUniformMetadata().disabledCandidateTypes = disableMask; rhs->getUniformMetadata().disabledCandidateTypes = disableMask; resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; addRealMathOptionalConstraints(op, resultNode, cag); } void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor, CAGSlice &cag) const { // TODO: It would be nice if these all extended some base trait instead // of requiring name lookup. auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min"); auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max"); if (clampMinAttr || clampMaxAttr) { auto nan = APFloat::getQNaN(APFloat::IEEEdouble()); auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan; auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan; UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax); } } unsigned q8; unsigned q16; unsigned q32ExplicitFixedPoint; }; } // anonymous namespace std::unique_ptr<FxpMathTargetConfig> FxpMathTargetConfig::create(SolverContext &context) { return std::make_unique<FxpMathTargetConfigImpl>(context); }