//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumsGen generates common utility functions for enums.
//
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"

using llvm::formatv;
using llvm::isDigit;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;

static std::string makeIdentifier(StringRef str) {
  if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
    std::string newStr = std::string("_") + str.str();
    return newStr;
  }
  return str.str();
}

static void emitEnumClass(const Record &enumDef, StringRef enumName,
                          StringRef underlyingType, StringRef description,
                          const std::vector<EnumAttrCase> &enumerants,
                          raw_ostream &os) {
  os << "// " << description << "\n";
  os << "enum class " << enumName;

  if (!underlyingType.empty())
    os << " : " << underlyingType;
  os << " {\n";

  for (const auto &enumerant : enumerants) {
    auto symbol = makeIdentifier(enumerant.getSymbol());
    auto value = enumerant.getValue();
    if (value >= 0) {
      os << formatv("  {0} = {1},\n", symbol, value);
    } else {
      os << formatv("  {0},\n", symbol);
    }
  }
  os << "};\n\n";
}

static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
                             StringRef cppNamespace, raw_ostream &os) {
  std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
  if (underlyingType.empty())
    underlyingType = formatv("std::underlying_type<{0}>::type", qualName);

  const char *const mapInfo = R"(
namespace llvm {
template<> struct DenseMapInfo<{0}> {{
  using StorageInfo = llvm::DenseMapInfo<{1}>;

  static inline {0} getEmptyKey() {{
    return static_cast<{0}>(StorageInfo::getEmptyKey());
  }

  static inline {0} getTombstoneKey() {{
    return static_cast<{0}>(StorageInfo::getTombstoneKey());
  }

  static unsigned getHashValue(const {0} &val) {{
    return StorageInfo::getHashValue(static_cast<{1}>(val));
  }

  static bool isEqual(const {0} &lhs, const {0} &rhs) {{
    return lhs == rhs;
  }
};
})";
  os << formatv(mapInfo, qualName, underlyingType);
  os << "\n\n";
}

static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
  auto enumerants = enumAttr.getAllCases();

  unsigned maxEnumVal = 0;
  for (const auto &enumerant : enumerants) {
    int64_t value = enumerant.getValue();
    // Avoid generating the max value function if there is an enumerant without
    // explicit value.
    if (value < 0)
      return;

    maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
  }

  // Emit the function to return the max enum value
  os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
  os << formatv("  return {0};\n", maxEnumVal);
  os << "}\n\n";
}

// Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
// otherwise.
static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
  for (auto attrCase : cases) {
    if (attrCase.getValue() == 0)
      return attrCase;
  }
  return llvm::None;
}

// Emits the following inline function for bit enums:
//
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  std::string underlyingType = enumAttr.getUnderlyingType();
  os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
     << formatv("  return static_cast<{0}>("
                "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
                enumName, underlyingType)
     << "}\n";
  os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
     << formatv("  return static_cast<{0}>("
                "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
                enumName, underlyingType)
     << "}\n";
  os << formatv(
            "inline bool bitEnumContains({0} bits, {0} bit) {{\n"
            "  return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
            enumName, underlyingType)
     << "}\n";
}

static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
  auto enumerants = enumAttr.getAllCases();

  os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
                symToStrFnRetType);
  os << "  switch (val) {\n";
  for (const auto &enumerant : enumerants) {
    auto symbol = enumerant.getSymbol();
    os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
                  makeIdentifier(symbol), symbol);
  }
  os << "  }\n";
  os << "  return \"\";\n";
  os << "}\n\n";
}

static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
  StringRef separator = enumDef.getValueAsString("separator");
  auto enumerants = enumAttr.getAllCases();
  auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);

  os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
                symToStrFnRetType);

  os << formatv("  auto val = static_cast<{0}>(symbol);\n",
                enumAttr.getUnderlyingType());
  if (allBitsUnsetCase) {
    os << "  // Special case for all bits unset.\n";
    os << formatv("  if (val == 0) return \"{0}\";\n\n",
                  allBitsUnsetCase->getSymbol());
  }
  os << "  llvm::SmallVector<llvm::StringRef, 2> strs;\n";
  for (const auto &enumerant : enumerants) {
    // Skip the special enumerant for None.
    if (auto val = enumerant.getValue())
      os << formatv("  if ({0}u & val) {{ strs.push_back(\"{1}\"); "
                    "val &= ~{0}u; }\n",
                    val, enumerant.getSymbol());
  }
  // If we have unknown bit set, return an empty string to signal errors.
  os << "\n  if (val) return \"\";\n";
  os << formatv("  return llvm::join(strs, \"{0}\");\n", separator);

  os << "}\n\n";
}

static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
  auto enumerants = enumAttr.getAllCases();

  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
                strToSymFnName);
  os << formatv("  return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
                enumName);
  for (const auto &enumerant : enumerants) {
    auto symbol = enumerant.getSymbol();
    os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
                  makeIdentifier(symbol));
  }
  os << "      .Default(llvm::None);\n";
  os << "}\n";
}

static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  std::string underlyingType = enumAttr.getUnderlyingType();
  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
  StringRef separator = enumDef.getValueAsString("separator");
  auto enumerants = enumAttr.getAllCases();
  auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);

  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
                strToSymFnName);

  if (allBitsUnsetCase) {
    os << "  // Special case for all bits unset.\n";
    StringRef caseSymbol = allBitsUnsetCase->getSymbol();
    os << formatv("  if (str == \"{1}\") return {0}::{2};\n\n", enumName,
                  caseSymbol, makeIdentifier(caseSymbol));
  }

  // Split the string to get symbols for all the bits.
  os << "  llvm::SmallVector<llvm::StringRef, 2> symbols;\n";
  os << formatv("  str.split(symbols, \"{0}\");\n\n", separator);

  os << formatv("  {0} val = 0;\n", underlyingType);
  os << "  for (auto symbol : symbols) {\n";

  // Convert each symbol to the bit ordinal and set the corresponding bit.
  os << formatv(
      "    auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
      underlyingType);
  for (const auto &enumerant : enumerants) {
    // Skip the special enumerant for None.
    if (auto val = enumerant.getValue())
      os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
                              val);
  }
  os.indent(6) << ".Default(llvm::None);\n";

  os << "    if (bit) { val |= *bit; } else { return llvm::None; }\n";
  os << "  }\n";

  os << formatv("  return static_cast<{0}>(val);\n", enumName);
  os << "}\n\n";
}

static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
                                            raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  std::string underlyingType = enumAttr.getUnderlyingType();
  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
  auto enumerants = enumAttr.getAllCases();

  // Avoid generating the underlying value to symbol conversion function if
  // there is an enumerant without explicit value.
  if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
        return enumerant.getValue() < 0;
      }))
    return;

  os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
                underlyingToSymFnName,
                underlyingType.empty() ? std::string("unsigned")
                                       : underlyingType)
     << "  switch (value) {\n";
  for (const auto &enumerant : enumerants) {
    auto symbol = enumerant.getSymbol();
    auto value = enumerant.getValue();
    os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
                  makeIdentifier(symbol));
  }
  os << "  default: return llvm::None;\n"
     << "  }\n"
     << "}\n\n";
}

static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
                                            raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  std::string underlyingType = enumAttr.getUnderlyingType();
  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
  auto enumerants = enumAttr.getAllCases();
  auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);

  os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
                underlyingToSymFnName, underlyingType);
  if (allBitsUnsetCase) {
    os << "  // Special case for all bits unset.\n";
    os << formatv("  if (value == 0) return {0}::{1};\n\n", enumName,
                  makeIdentifier(allBitsUnsetCase->getSymbol()));
  }
  llvm::SmallVector<std::string, 8> values;
  for (const auto &enumerant : enumerants) {
    if (auto val = enumerant.getValue())
      values.push_back(formatv("{0}u", val));
  }
  os << formatv("  if (value & ~({0})) return llvm::None;\n",
                llvm::join(values, " | "));
  os << formatv("  return static_cast<{0}>(value);\n", enumName);
  os << "}\n";
}

static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef enumName = enumAttr.getEnumClassName();
  StringRef cppNamespace = enumAttr.getCppNamespace();
  std::string underlyingType = enumAttr.getUnderlyingType();
  StringRef description = enumAttr.getDescription();
  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
  auto enumerants = enumAttr.getAllCases();

  llvm::SmallVector<StringRef, 2> namespaces;
  llvm::SplitString(cppNamespace, namespaces, "::");

  for (auto ns : namespaces)
    os << "namespace " << ns << " {\n";

  // Emit the enum class definition
  emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);

  // Emit conversion function declarations
  if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
        return enumerant.getValue() >= 0;
      })) {
    os << formatv(
        "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
        underlyingType.empty() ? std::string("unsigned") : underlyingType);
  }
  os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
                strToSymFnName);

  if (enumAttr.isBitEnum()) {
    emitOperators(enumDef, os);
  } else {
    emitMaxValueFn(enumDef, os);
  }

  for (auto ns : llvm::reverse(namespaces))
    os << "} // namespace " << ns << "\n";

  // Emit DenseMapInfo for this enum class
  emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
}

static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
  llvm::emitSourceFileHeader("Enum Utility Declarations", os);

  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
  for (const auto *def : defs)
    emitEnumDecl(*def, os);

  return false;
}

static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
  EnumAttr enumAttr(enumDef);
  StringRef cppNamespace = enumAttr.getCppNamespace();

  llvm::SmallVector<StringRef, 2> namespaces;
  llvm::SplitString(cppNamespace, namespaces, "::");

  for (auto ns : namespaces)
    os << "namespace " << ns << " {\n";

  if (enumAttr.isBitEnum()) {
    emitSymToStrFnForBitEnum(enumDef, os);
    emitStrToSymFnForBitEnum(enumDef, os);
    emitUnderlyingToSymFnForBitEnum(enumDef, os);
  } else {
    emitSymToStrFnForIntEnum(enumDef, os);
    emitStrToSymFnForIntEnum(enumDef, os);
    emitUnderlyingToSymFnForIntEnum(enumDef, os);
  }

  for (auto ns : llvm::reverse(namespaces))
    os << "} // namespace " << ns << "\n";
  os << "\n";
}

static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
  llvm::emitSourceFileHeader("Enum Utility Definitions", os);

  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
  for (const auto *def : defs)
    emitEnumDef(*def, os);

  return false;
}

// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
    genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
                 [](const RecordKeeper &records, raw_ostream &os) {
                   return emitEnumDecls(records, os);
                 });

// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
    genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
                [](const RecordKeeper &records, raw_ostream &os) {
                  return emitEnumDefs(records, os);
                });