//===--- MakeSmartPtrCheck.cpp - clang-tidy--------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "MakeSharedCheck.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/Preprocessor.h"

using namespace clang::ast_matchers;

namespace clang {
namespace tidy {
namespace modernize {

namespace {

constexpr char StdMemoryHeader[] = "memory";
constexpr char ConstructorCall[] = "constructorCall";
constexpr char ResetCall[] = "resetCall";
constexpr char NewExpression[] = "newExpression";

std::string GetNewExprName(const CXXNewExpr *NewExpr,
                           const SourceManager &SM,
                           const LangOptions &Lang) {
  StringRef WrittenName = Lexer::getSourceText(
      CharSourceRange::getTokenRange(
          NewExpr->getAllocatedTypeSourceInfo()->getTypeLoc().getSourceRange()),
      SM, Lang);
  if (NewExpr->isArray()) {
    return (WrittenName + "[]").str();
  }
  return WrittenName.str();
}

} // namespace

const char MakeSmartPtrCheck::PointerType[] = "pointerType";

MakeSmartPtrCheck::MakeSmartPtrCheck(StringRef Name,
                                     ClangTidyContext* Context,
                                     StringRef MakeSmartPtrFunctionName)
    : ClangTidyCheck(Name, Context),
      IncludeStyle(utils::IncludeSorter::parseIncludeStyle(
          Options.getLocalOrGlobal("IncludeStyle", "llvm"))),
      MakeSmartPtrFunctionHeader(
          Options.get("MakeSmartPtrFunctionHeader", StdMemoryHeader)),
      MakeSmartPtrFunctionName(
          Options.get("MakeSmartPtrFunction", MakeSmartPtrFunctionName)),
      IgnoreMacros(Options.getLocalOrGlobal("IgnoreMacros", true)) {}

void MakeSmartPtrCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
  Options.store(Opts, "IncludeStyle", IncludeStyle);
  Options.store(Opts, "MakeSmartPtrFunctionHeader", MakeSmartPtrFunctionHeader);
  Options.store(Opts, "MakeSmartPtrFunction", MakeSmartPtrFunctionName);
  Options.store(Opts, "IgnoreMacros", IgnoreMacros);
}

bool MakeSmartPtrCheck::isLanguageVersionSupported(
    const LangOptions &LangOpts) const {
  return LangOpts.CPlusPlus11;
}

void MakeSmartPtrCheck::registerPPCallbacks(const SourceManager &SM,
                                            Preprocessor *PP,
                                            Preprocessor *ModuleExpanderPP) {
  if (isLanguageVersionSupported(getLangOpts())) {
    Inserter = std::make_unique<utils::IncludeInserter>(SM, getLangOpts(),
                                                         IncludeStyle);
    PP->addPPCallbacks(Inserter->CreatePPCallbacks());
  }
}

void MakeSmartPtrCheck::registerMatchers(ast_matchers::MatchFinder *Finder) {
  if (!isLanguageVersionSupported(getLangOpts()))
    return;

  // Calling make_smart_ptr from within a member function of a type with a
  // private or protected constructor would be ill-formed.
  auto CanCallCtor = unless(has(ignoringImpCasts(
      cxxConstructExpr(hasDeclaration(decl(unless(isPublic())))))));

  Finder->addMatcher(
      cxxBindTemporaryExpr(has(ignoringParenImpCasts(
          cxxConstructExpr(
              hasType(getSmartPointerTypeMatcher()), argumentCountIs(1),
              hasArgument(0,
                          cxxNewExpr(hasType(pointsTo(qualType(hasCanonicalType(
                                         equalsBoundNode(PointerType))))),
                                     CanCallCtor)
                              .bind(NewExpression)),
              unless(isInTemplateInstantiation()))
              .bind(ConstructorCall)))),
      this);

  Finder->addMatcher(
      cxxMemberCallExpr(
          thisPointerType(getSmartPointerTypeMatcher()),
          callee(cxxMethodDecl(hasName("reset"))),
          hasArgument(0, cxxNewExpr(CanCallCtor).bind(NewExpression)),
          unless(isInTemplateInstantiation()))
          .bind(ResetCall),
      this);
}

void MakeSmartPtrCheck::check(const MatchFinder::MatchResult &Result) {
  // 'smart_ptr' refers to 'std::shared_ptr' or 'std::unique_ptr' or other
  // pointer, 'make_smart_ptr' refers to 'std::make_shared' or
  // 'std::make_unique' or other function that creates smart_ptr.

  SourceManager &SM = *Result.SourceManager;
  const auto *Construct =
      Result.Nodes.getNodeAs<CXXConstructExpr>(ConstructorCall);
  const auto *Reset = Result.Nodes.getNodeAs<CXXMemberCallExpr>(ResetCall);
  const auto *Type = Result.Nodes.getNodeAs<QualType>(PointerType);
  const auto *New = Result.Nodes.getNodeAs<CXXNewExpr>(NewExpression);

  if (New->getNumPlacementArgs() != 0)
    return;
  // Skip when this is a new-expression with `auto`, e.g. new auto(1)
  if (New->getType()->getPointeeType()->getContainedAutoType())
    return;

  // Be conservative for cases where we construct an array without any
  // initalization.
  // For example,
  //    P.reset(new int[5]) // check fix: P = std::make_unique<int []>(5)
  //
  // The fix of the check has side effect, it introduces default initialization
  // which maybe unexpected and cause performance regression.
  if (New->isArray() && !New->hasInitializer())
    return;
  if (Construct)
    checkConstruct(SM, Result.Context, Construct, Type, New);
  else if (Reset)
    checkReset(SM, Result.Context, Reset, New);
}

void MakeSmartPtrCheck::checkConstruct(SourceManager &SM, ASTContext *Ctx,
                                       const CXXConstructExpr *Construct,
                                       const QualType *Type,
                                       const CXXNewExpr *New) {
  SourceLocation ConstructCallStart = Construct->getExprLoc();
  bool InMacro = ConstructCallStart.isMacroID();

  if (InMacro && IgnoreMacros) {
    return;
  }

  bool Invalid = false;
  StringRef ExprStr = Lexer::getSourceText(
      CharSourceRange::getCharRange(
          ConstructCallStart, Construct->getParenOrBraceRange().getBegin()),
      SM, getLangOpts(), &Invalid);
  if (Invalid)
    return;

  auto Diag = diag(ConstructCallStart, "use %0 instead")
              << MakeSmartPtrFunctionName;

  // Disable the fix in macros.
  if (InMacro) {
    return;
  }

  if (!replaceNew(Diag, New, SM, Ctx)) {
    return;
  }

  // Find the location of the template's left angle.
  size_t LAngle = ExprStr.find("<");
  SourceLocation ConstructCallEnd;
  if (LAngle == StringRef::npos) {
    // If the template argument is missing (because it is part of the alias)
    // we have to add it back.
    ConstructCallEnd = ConstructCallStart.getLocWithOffset(ExprStr.size());
    Diag << FixItHint::CreateInsertion(
        ConstructCallEnd,
        "<" + GetNewExprName(New, SM, getLangOpts()) + ">");
  } else {
    ConstructCallEnd = ConstructCallStart.getLocWithOffset(LAngle);
  }

  Diag << FixItHint::CreateReplacement(
      CharSourceRange::getCharRange(ConstructCallStart, ConstructCallEnd),
      MakeSmartPtrFunctionName);

  // If the smart_ptr is built with brace enclosed direct initialization, use
  // parenthesis instead.
  if (Construct->isListInitialization()) {
    SourceRange BraceRange = Construct->getParenOrBraceRange();
    Diag << FixItHint::CreateReplacement(
        CharSourceRange::getCharRange(
            BraceRange.getBegin(), BraceRange.getBegin().getLocWithOffset(1)),
        "(");
    Diag << FixItHint::CreateReplacement(
        CharSourceRange::getCharRange(BraceRange.getEnd(),
                                      BraceRange.getEnd().getLocWithOffset(1)),
        ")");
  }

  insertHeader(Diag, SM.getFileID(ConstructCallStart));
}

void MakeSmartPtrCheck::checkReset(SourceManager &SM, ASTContext *Ctx,
                                   const CXXMemberCallExpr *Reset,
                                   const CXXNewExpr *New) {
  const auto *Expr = cast<MemberExpr>(Reset->getCallee());
  SourceLocation OperatorLoc = Expr->getOperatorLoc();
  SourceLocation ResetCallStart = Reset->getExprLoc();
  SourceLocation ExprStart = Expr->getBeginLoc();
  SourceLocation ExprEnd =
      Lexer::getLocForEndOfToken(Expr->getEndLoc(), 0, SM, getLangOpts());

  bool InMacro = ExprStart.isMacroID();

  if (InMacro && IgnoreMacros) {
    return;
  }

  // There are some cases where we don't have operator ("." or "->") of the
  // "reset" expression, e.g. call "reset()" method directly in the subclass of
  // "std::unique_ptr<>". We skip these cases.
  if (OperatorLoc.isInvalid()) {
    return;
  }

  auto Diag = diag(ResetCallStart, "use %0 instead")
              << MakeSmartPtrFunctionName;

  // Disable the fix in macros.
  if (InMacro) {
    return;
  }

  if (!replaceNew(Diag, New, SM, Ctx)) {
    return;
  }

  Diag << FixItHint::CreateReplacement(
      CharSourceRange::getCharRange(OperatorLoc, ExprEnd),
      (llvm::Twine(" = ") + MakeSmartPtrFunctionName + "<" +
       GetNewExprName(New, SM, getLangOpts()) + ">")
          .str());

  if (Expr->isArrow())
    Diag << FixItHint::CreateInsertion(ExprStart, "*");

  insertHeader(Diag, SM.getFileID(OperatorLoc));
}

bool MakeSmartPtrCheck::replaceNew(DiagnosticBuilder &Diag,
                                   const CXXNewExpr *New, SourceManager &SM,
                                   ASTContext *Ctx) {
  auto SkipParensParents = [&](const Expr *E) {
    for (const Expr *OldE = nullptr; E != OldE;) {
      OldE = E;
      for (const auto &Node : Ctx->getParents(*E)) {
        if (const Expr *Parent = Node.get<ParenExpr>()) {
          E = Parent;
          break;
        }
      }
    }
    return E;
  };

  SourceRange NewRange = SkipParensParents(New)->getSourceRange();
  SourceLocation NewStart = NewRange.getBegin();
  SourceLocation NewEnd = NewRange.getEnd();

  // Skip when the source location of the new expression is invalid.
  if (NewStart.isInvalid() || NewEnd.isInvalid())
    return false;

  std::string ArraySizeExpr;
  if (const auto* ArraySize = New->getArraySize().getValueOr(nullptr)) {
    ArraySizeExpr = Lexer::getSourceText(CharSourceRange::getTokenRange(
                                             ArraySize->getSourceRange()),
                                         SM, getLangOpts())
                        .str();
  }
  // Returns true if the given constructor expression has any braced-init-list
  // argument, e.g.
  //   Foo({1, 2}, 1) => true
  //   Foo(Bar{1, 2}) => true
  //   Foo(1) => false
  //   Foo{1} => false
  auto HasListIntializedArgument = [](const CXXConstructExpr *CE) {
    for (const auto *Arg : CE->arguments()) {
      Arg = Arg->IgnoreImplicit();

      if (isa<CXXStdInitializerListExpr>(Arg) || isa<InitListExpr>(Arg))
        return true;
      // Check whether we implicitly construct a class from a
      // std::initializer_list.
      if (const auto *CEArg = dyn_cast<CXXConstructExpr>(Arg)) {
        // Strip the elidable move constructor, it is present in the AST for
        // C++11/14, e.g. Foo(Bar{1, 2}), the move constructor is around the
        // init-list constructor.
        if (CEArg->isElidable()) {
          if (const auto *TempExp = CEArg->getArg(0)) {
            if (const auto *UnwrappedCE =
                    dyn_cast<CXXConstructExpr>(TempExp->IgnoreImplicit()))
              CEArg = UnwrappedCE;
          }
        }
        if (CEArg->isStdInitListInitialization())
          return true;
      }
    }
    return false;
  };
  switch (New->getInitializationStyle()) {
  case CXXNewExpr::NoInit: {
    if (ArraySizeExpr.empty()) {
      Diag << FixItHint::CreateRemoval(SourceRange(NewStart, NewEnd));
    } else {
      // New array expression without written initializer:
      //   smart_ptr<Foo[]>(new Foo[5]);
      Diag << FixItHint::CreateReplacement(SourceRange(NewStart, NewEnd),
                                           ArraySizeExpr);
    }
    break;
  }
  case CXXNewExpr::CallInit: {
    // FIXME: Add fixes for constructors with parameters that can be created
    // with a C++11 braced-init-list (e.g. std::vector, std::map).
    // Unlike ordinal cases, braced list can not be deduced in
    // std::make_smart_ptr, we need to specify the type explicitly in the fixes:
    //   struct S { S(std::initializer_list<int>, int); };
    //   struct S2 { S2(std::vector<int>); };
    //   struct S3 { S3(S2, int); };
    //   smart_ptr<S>(new S({1, 2, 3}, 1));  // C++98 call-style initialization
    //   smart_ptr<S>(new S({}, 1));
    //   smart_ptr<S2>(new S2({1})); // implicit conversion:
    //                               //   std::initializer_list => std::vector
    //   smart_ptr<S3>(new S3({1, 2}, 3));
    // The above samples have to be replaced with:
    //   std::make_smart_ptr<S>(std::initializer_list<int>({1, 2, 3}), 1);
    //   std::make_smart_ptr<S>(std::initializer_list<int>({}), 1);
    //   std::make_smart_ptr<S2>(std::vector<int>({1}));
    //   std::make_smart_ptr<S3>(S2{1, 2}, 3);
    if (const auto *CE = New->getConstructExpr()) {
      if (HasListIntializedArgument(CE))
        return false;
    }
    if (ArraySizeExpr.empty()) {
      SourceRange InitRange = New->getDirectInitRange();
      Diag << FixItHint::CreateRemoval(
          SourceRange(NewStart, InitRange.getBegin()));
      Diag << FixItHint::CreateRemoval(SourceRange(InitRange.getEnd(), NewEnd));
    }
    else {
      // New array expression with default/value initialization:
      //   smart_ptr<Foo[]>(new int[5]());
      //   smart_ptr<Foo[]>(new Foo[5]());
      Diag << FixItHint::CreateReplacement(SourceRange(NewStart, NewEnd),
                                           ArraySizeExpr);
    }
    break;
  }
  case CXXNewExpr::ListInit: {
    // Range of the substring that we do not want to remove.
    SourceRange InitRange;
    if (const auto *NewConstruct = New->getConstructExpr()) {
      if (NewConstruct->isStdInitListInitialization() ||
          HasListIntializedArgument(NewConstruct)) {
        // FIXME: Add fixes for direct initialization with the initializer-list
        // constructor. Similar to the above CallInit case, the type has to be
        // specified explicitly in the fixes.
        //   struct S { S(std::initializer_list<int>); };
        //   struct S2 { S2(S, int); };
        //   smart_ptr<S>(new S{1, 2, 3});  // C++11 direct list-initialization
        //   smart_ptr<S>(new S{});  // use initializer-list constructor
        //   smart_ptr<S2>()new S2{ {1,2}, 3 }; // have a list-initialized arg
        // The above cases have to be replaced with:
        //   std::make_smart_ptr<S>(std::initializer_list<int>({1, 2, 3}));
        //   std::make_smart_ptr<S>(std::initializer_list<int>({}));
        //   std::make_smart_ptr<S2>(S{1, 2}, 3);
        return false;
      } else {
        // Direct initialization with ordinary constructors.
        //   struct S { S(int x); S(); };
        //   smart_ptr<S>(new S{5});
        //   smart_ptr<S>(new S{}); // use default constructor
        // The arguments in the initialization list are going to be forwarded to
        // the constructor, so this has to be replaced with:
        //   std::make_smart_ptr<S>(5);
        //   std::make_smart_ptr<S>();
        InitRange = SourceRange(
            NewConstruct->getParenOrBraceRange().getBegin().getLocWithOffset(1),
            NewConstruct->getParenOrBraceRange().getEnd().getLocWithOffset(-1));
      }
    } else {
      // Aggregate initialization.
      //   smart_ptr<Pair>(new Pair{first, second});
      // Has to be replaced with:
      //   smart_ptr<Pair>(Pair{first, second});
      //
      // The fix (std::make_unique) needs to see copy/move constructor of
      // Pair. If we found any invisible or deleted copy/move constructor, we
      // stop generating fixes -- as the C++ rule is complicated and we are less
      // certain about the correct fixes.
      if (const CXXRecordDecl *RD = New->getType()->getPointeeCXXRecordDecl()) {
        if (llvm::find_if(RD->ctors(), [](const CXXConstructorDecl *Ctor) {
              return Ctor->isCopyOrMoveConstructor() &&
                     (Ctor->isDeleted() || Ctor->getAccess() == AS_private);
            }) != RD->ctor_end()) {
          return false;
        }
      }
      InitRange = SourceRange(
          New->getAllocatedTypeSourceInfo()->getTypeLoc().getBeginLoc(),
          New->getInitializer()->getSourceRange().getEnd());
    }
    Diag << FixItHint::CreateRemoval(
        CharSourceRange::getCharRange(NewStart, InitRange.getBegin()));
    Diag << FixItHint::CreateRemoval(
        SourceRange(InitRange.getEnd().getLocWithOffset(1), NewEnd));
    break;
  }
  }
  return true;
}

void MakeSmartPtrCheck::insertHeader(DiagnosticBuilder &Diag, FileID FD) {
  if (MakeSmartPtrFunctionHeader.empty()) {
    return;
  }
  if (auto IncludeFixit = Inserter->CreateIncludeInsertion(
          FD, MakeSmartPtrFunctionHeader,
          /*IsAngled=*/MakeSmartPtrFunctionHeader == StdMemoryHeader)) {
    Diag << *IncludeFixit;
  }
}

} // namespace modernize
} // namespace tidy
} // namespace clang