package internal;

import term.*;

import java.util.ArrayList;
import java.util.List;

import bil.ExecutionType;
import ghidra.program.model.block.CodeBlock;
import ghidra.program.model.block.CodeBlockReferenceIterator;
import ghidra.program.model.pcode.PcodeOp;
import ghidra.util.exception.CancelledException;

public final class JumpProcessing {

    public static List<String> jumps = new ArrayList<String>() {{
        add("BRANCH");
        add("CBRANCH");
        add("BRANCHIND");
        add("CALL");
        add("CALLIND");
        add("CALLOTHER");
        add("RETURN");
    }};

    // private constructor for non-instantiable classes
    private JumpProcessing() {
        throw new UnsupportedOperationException();
    }

    /**
     *
     * @param mnemonic: pcode mnemonic
     * @param numberOfPcodeOps: number of pcode instruction in pcode block
     * @return: indicator whether a jump occured inside a pcode block
     * 
     * Processes jump pcode instruction by checking where it occurs.
     * Distinguishes between jumps inside a pcode block and jumps at the end of a pcode block
     */
    public static Boolean processJump(String mnemonic, int numberOfPcodeOps) {

        Term<Blk> currentBlock = PcodeBlockData.blocks.get(PcodeBlockData.blocks.size() - 1);

        if(PcodeBlockData.pcodeIndex < numberOfPcodeOps - 1) {
            return processJumpInPcodeBlock(mnemonic, numberOfPcodeOps, currentBlock);
        }

        processJumpAtEndOfPcodeBlocks(mnemonic, numberOfPcodeOps, currentBlock);
        return false;
    }


    /**
     * 
     * @param mnemonic: pcode mnemonic
     * @param numberOfPcodeOps: number of pcode instruction in pcode block
     * @param currentBlock: current block term
     * 
     * Process jumps at the end of pcode blocks
     * If it is a return block, the call return address is changed to the current block
     */
    private static void processJumpAtEndOfPcodeBlocks(String mnemonic, int numberOfPcodeOps, Term<Blk> currentBlock) {
        // Case 1: jump at the end of pcode group but not end of ghidra generated block. Create a block for the next assembly instruction.
        if(PcodeBlockData.instructionIndex < PcodeBlockData.numberOfInstructionsInBlock - 1 && PcodeBlockData.instruction.getDelaySlotDepth() == 0) {
            PcodeBlockData.blocks.add(TermCreator.createBlkTerm(PcodeBlockData.instruction.getFallThrough().toString(), null));
        }
        // Case 2: jmp at last pcode op at last instruction in ghidra generated block
        // If Case 1 is true, the 'currentBlk' will be the second to last block as the new block is for the next instruction
        if(PcodeBlockData.pcodeOp.getOpcode() == PcodeOp.RETURN && currentBlock.getTid().getId().endsWith("_r")) {
            redirectCallReturn(currentBlock);
            return;
        }
        currentBlock.getTerm().addMultipleDefs(PcodeBlockData.temporaryDefStorage);
        currentBlock.getTerm().addMultipleJumps(TermCreator.createJmpTerm(false));
        PcodeBlockData.temporaryDefStorage.clear();
    }


    /**
     * 
     * @param mnemonic: pcode mnemonic
     * @param numberOfPcodeOps: number of pcode instruction in pcode block
     * @param currentBlock: current block term
     * @return: indicator whether a jump occured inside a pcode block
     * 
     * Processes a jump inside a pcode block and distinguishes between intra jumps and call return pairs.
     */
    private static Boolean processJumpInPcodeBlock(String mnemonic, int numberOfPcodeOps, Term<Blk> currentBlock) {
        Boolean intraInstructionJumpOccured = false;
        if(!isCall()) {
            intraInstructionJumpOccured = true;
            handleIntraInstructionJump(currentBlock.getTerm());
        } else {
            handleCallReturnPair(currentBlock);
        }
        PcodeBlockData.temporaryDefStorage.clear();

        return intraInstructionJumpOccured;
    }


    /**
     * 
     * @param currentBlock: current block term
     * 
     * Adds an artificial jump from the previous instructions to the current instruction if an intra jump occurs.
     * This is done to isolate the current defs and jumps in an exclusive block.
     * The Jump is not added if the instruction is the first of the subroutine or if no defs and jumps have been added yet.
     * This might be the case when the current instruction is only preceded by a nop instruction.
     * 
     * In case an artificial jump has to be added, a new block has to be created for the instruction with the intra jump so that
     * the previous jump has a valid target TID.
     * 
     * e.g. This example shows one basic block generated by Ghidra which split into 4 basic blocks for proper analysis. 
     * It also includes the case where a jump from the previous assembly instruction had to be added.
     * Each individual assembly instruction is denoted with the [In] keyword and blocks are separated by dashed lines.
     * Keep in mind that the target rows for the branches are only placeholders for the actual target TIDs.
     * 
     * 1. [In]  ...                                         ...
     * 2.       RDI = COPY RDX                              RDI = COPY RDX
     * 3.                                                   BRANCH [row 5.]
     * 4.                                                   ---------------
     * 5. [In]  $U2360:1 = INT_EQUAL RCX, 0:8               $U2360:1 = INT_EQUAL RCX, 0:8
     * 6.       CBRANCH *[ram]0x1006e1:8, $U2360            CBRANCH *[ram]0x1006e1:8, $U2360
     * 7.                                                   BRANCH [row 9.]
     * 8.                                                   ---------------
     * 9.       RCX = INT_SUB RCX, 1:8                      RCX = INT_SUB RCX, 1:8
     * 10.      ...                                 ---->   ...
     * 11.      $U2380:1 = BOOL_NEGATE ZF                   $U2380:1 = BOOL_NEGATE ZF
     * 11.      CBRANCH *[ram]0x1006df:8, $U2380            CBRANCH *[ram]0x1006df:8, $U2380
     * 12.                                                  BRANCH [row 14.]
     * 13.                                                  ---------------
     * 14. [In] RAX = COPY RCX                              RAX = COPY RCX
     * 15.      ...                                         ...
     */
    private static void handleIntraInstructionJump(Blk currentBlock) {
        if(PcodeBlockData.instructionIndex > 0 && !(currentBlock.getDefs().size() == 0 && currentBlock.getJmps().size() == 0)) {
            addBranchToCurrentBlock(currentBlock, PcodeBlockData.instruction.getFallFrom().toString(), PcodeBlockData.instruction.getAddress().toString());
            createNewBlockForIntraInstructionJump();
        } else {
            currentBlock.addMultipleDefs(PcodeBlockData.temporaryDefStorage);
            currentBlock.addMultipleJumps(TermCreator.createJmpTerm(true));
        }
        // Create block for the pcode instructions after the intra jump !Not for the next assembly instruction!
        // Check whether the number of jumps is equal to 2, i.e. a pair of CBRANCH, BRANCH was created. If so, increase the pcodeIndex by 1
        // so that the next intr block has the correct index.
        if(PcodeBlockData.blocks.get(PcodeBlockData.blocks.size() - 1).getTerm().getJmps().size() == 2) {
            PcodeBlockData.pcodeIndex +=1;
        }
        PcodeBlockData.blocks.add(TermCreator.createBlkTerm(PcodeBlockData.instruction.getAddress().toString(), String.valueOf(PcodeBlockData.pcodeIndex + 1)));
        
    }


    /**
     * 
     * @param currentBlock: current block term
     * 
     * Handles call return pairs by creating a return block and redirecting the call's return to the return block
     */
    private static void handleCallReturnPair(Term<Blk> currentBlock) {
        currentBlock.getTerm().addMultipleDefs(PcodeBlockData.temporaryDefStorage);
        Term<Jmp> jump = TermCreator.createJmpTerm(false).get(0);
        Term<Blk> returnBlock = TermCreator.createBlkTerm(PcodeBlockData.instruction.getAddress().toString(), "r");
        jump.getTerm().getCall().setReturn_(new Label(new Tid(returnBlock.getTid().getId(), returnBlock.getTid().getAddress())));
        currentBlock.getTerm().addJmp(jump);
        PcodeBlockData.blocks.add(returnBlock);
    }


    /**
     * 
     * @return: boolean whether current pcode instruction is a call
     * 
     * checks whether the current pcode instruction is a call
     */
    private static Boolean isCall(){
        switch(PcodeBlockData.pcodeOp.getOpcode()) {
            case PcodeOp.CALL:
            case PcodeOp.CALLIND:
            case PcodeOp.CALLOTHER:
                return true;
            default:
                return false;
        }
    }


    /**
     * 
     * @param intraInstructionJumpOccured: indicator if jump occured within pcode block
     * 
     * fixes the control flow by adding missing jumps between artificially generated blocks.
     */
    public static void fixControlFlowWhenIntraInstructionJumpOccured(Boolean intraInstructionJumpOccured) {
        // A block is split when a Pcode Jump Instruction occurs in the PcodeBlock. 
        // A jump is added to the end of the split block to the pcode block of the next assembly instruction
        if(intraInstructionJumpOccured) {
            Term<Blk> lastBlock = PcodeBlockData.blocks.get(PcodeBlockData.blocks.size() - 1);
            addMissingJumpAfterInstructionSplit(lastBlock);
        }
    }


    /**
     * 
     * @param lastBlock: last block before split
     * 
     * Adds a missing jump after a Ghidra generated block has been split to maintain the control flow
     * between the blocks
     */
    public static void addMissingJumpAfterInstructionSplit(Term<Blk> lastBlock) {
        lastBlock.getTerm().addMultipleDefs(PcodeBlockData.temporaryDefStorage);
        addBranchToCurrentBlock(lastBlock.getTerm(), PcodeBlockData.instruction.getAddress().toString(), PcodeBlockData.instruction.getFallThrough().toString());
        PcodeBlockData.blocks.add(TermCreator.createBlkTerm(PcodeBlockData.instruction.getFallThrough().toString(), null));
        PcodeBlockData.temporaryDefStorage.clear();
    }


    /**
     * 
     * Creates a new block for the pcode instructions of the current assembly instruction and the intra jump
     */
    private static void createNewBlockForIntraInstructionJump(){
        Term<Blk> newBlock;
        // If an assembly instruction's pcode block is split into multiple blocks, the blocks' TIDs have to be distinguished by pcode index as they share the same instruction address
        if(PcodeBlockData.temporaryDefStorage.size() > 0) {
            int nextBlockStartIndex = PcodeBlockData.temporaryDefStorage.get(0).getTerm().getPcodeIndex();
            if(nextBlockStartIndex == 0) {
                newBlock = TermCreator.createBlkTerm(PcodeBlockData.instruction.getAddress().toString(), null);
            } else {
                newBlock = TermCreator.createBlkTerm(PcodeBlockData.instruction.getAddress().toString(), String.valueOf(nextBlockStartIndex));
            }
        } else {
            newBlock = TermCreator.createBlkTerm(PcodeBlockData.instruction.getAddress().toString(), null);
        }
        newBlock.getTerm().addMultipleDefs(PcodeBlockData.temporaryDefStorage);
        newBlock.getTerm().addMultipleJumps(TermCreator.createJmpTerm(true));
        PcodeBlockData.blocks.add(newBlock);
    }


    /**
     * 
     * @param currentBlock: current block term
     * @param jmpAddress: address of jump instruction
     * @param gotoAddress: address of where to jump
     * 
     * Adds a branch to the current block.
     * The jump index for the instruction will be the pcode index +1
     */
    public static void addBranchToCurrentBlock(Blk currentBlock, String jumpSiteAddress, String gotoAddress) {
        int artificialJmpIndex = 1;
        if(currentBlock.getDefs().size() > 0) {
            artificialJmpIndex = currentBlock.getDefs().get(currentBlock.getDefs().size() - 1).getTerm().getPcodeIndex() + 1;
        }
        Tid jmpTid = new Tid(String.format("instr_%s_%s", jumpSiteAddress, artificialJmpIndex), jumpSiteAddress);
        Tid gotoTid = new Tid(String.format("blk_%s", gotoAddress), gotoAddress);
        currentBlock.addJmp(new Term<Jmp>(jmpTid, new Jmp(ExecutionType.JmpType.GOTO, "BRANCH", new Label((Tid) gotoTid), artificialJmpIndex)));
    }


    /**
     * 
     * @param currentBlock: current block term
     * 
     * Redirects the call's return address to the artificially created return block
     */
    private static void redirectCallReturn(Term<Blk> currentBlock) {
        Tid jmpTid = new Tid(String.format("instr_%s_%s_r", PcodeBlockData.instruction.getAddress().toString(), 0), PcodeBlockData.instruction.getAddress().toString());
        Term<Jmp> ret = new Term<Jmp>(jmpTid, new Jmp(ExecutionType.JmpType.RETURN, PcodeBlockData.pcodeOp.getMnemonic(), TermCreator.createLabel(null), 0));
        currentBlock.getTerm().addJmp(ret);
    } 


    /**
     * 
     * @param lastBlockTerm: latest generated block term
     * @param currentBlock: current code block from which the block term was generated
     * 
     * Checks whether the latest generated block term ends on a definition and gets the first
     * destination address of the current code block, if available, to create an artificial jump
     */
    public static void handlePossibleDefinitionAtEndOfBlock(Term<Blk> lastBlockTerm, CodeBlock currentBlock) {
        if(HelperFunctions.lastInstructionIsDef(lastBlockTerm)) {
            String destinationAddress = getGotoAddressForDestination(currentBlock);
            if(destinationAddress != null) {
                String instrAddress = lastBlockTerm.getTerm().getDefs().get(lastBlockTerm.getTerm().getDefs().size()-1).getTid().getAddress();
                addBranchToCurrentBlock(lastBlockTerm.getTerm(), instrAddress, destinationAddress);
            }
        }
    }

    /**
     * 
     * @param currentBlock
     * @return: goto address for jump
     * 
     * Checks whether a destination address exists
     */
    private static String getGotoAddressForDestination(CodeBlock currentBlock) {
        try {
            CodeBlockReferenceIterator destinations = currentBlock.getDestinations(HelperFunctions.monitor);
            if(destinations.hasNext()) {
                return destinations.next().getDestinationAddress().toString();
            }
        } catch (CancelledException e) {
            System.out.printf("Could not retrieve destinations for codeBlock at: %s\n", currentBlock.getFirstStartAddress());
        }

        return null;
    }
}