//! Creating and computing backward interprocedural fixpoint problems.
//!
//! # General notes
//!
//! This module supports computation of fixpoint problems on the control flow graphs generated by the `graph` module.
//!
//!
//! To compute a generalized fixpoint problem,
//! first construct a context object implementing the `Context`trait.
//! Use it to construct a `Computation` object.
//! The `Computation` object provides the necessary methods for the actual fixpoint computation.

use super::fixpoint::Context as GeneralFPContext;
use super::graph::*;
use super::interprocedural_fixpoint_generic::*;
use crate::intermediate_representation::*;
use petgraph::graph::EdgeIndex;
use std::marker::PhantomData;

/// The context for an backward interprocedural fixpoint computation.
///
/// Basically, a `Context` object needs to contain a reference to the actual graph,
/// a method for merging node values,
/// and methods for computing the edge transitions for each different edge type.
///
/// All trait methods have access to the FixpointProblem structure, so that context informations are accessible through it.
///
/// All edge transition functions can return `None` to indicate that no information flows through the edge.
/// For example, this can be used to indicate edges that can never been taken.
pub trait Context<'a> {
    type Value: PartialEq + Eq + Clone;

    /// Get a reference to the graph that the fixpoint is computed on.
    /// The return value is expected to be the reversed CFG.
    fn get_graph(&self) -> &Graph<'a>;

    /// Merge two node values.
    fn merge(&self, value1: &Self::Value, value2: &Self::Value) -> Self::Value;

    /// Transition function for `Def` terms.
    /// The transition function for a basic block is computed
    /// by iteratively applying this function to the starting value for each `Def` term in the basic block.
    /// The iteration short-circuits and returns `None` if `update_def` returns `None` at any point.
    fn update_def(&self, value: &Self::Value, def: &Term<Def>) -> Option<Self::Value>;

    /// Transition function for (conditional and unconditional) `Jmp` terms.
    fn update_jumpsite(
        &self,
        value_after_jump: &Self::Value,
        jump: &Term<Jmp>,
        untaken_conditional: Option<&Term<Jmp>>,
        jumpsite: &Term<Blk>,
    ) -> Option<Self::Value>;

    /// Transition function for in-program calls.
    /// The target value is coming in via the call edge from the BlkStart node of the called subroutine and
    /// the return_value is coming in via the call stub edge from the returned-to node of the caller
    fn update_callsite(
        &self,
        target_value: Option<&Self::Value>,
        return_value: Option<&Self::Value>,
        caller_sub: &Term<Sub>,
        call: &Term<Jmp>,
        return_: &Term<Jmp>,
    ) -> Option<Self::Value>;

    /// Transition function for call stub split.
    /// Has access to the value at the ReturnCombine node and
    /// decides which data is transferred along the Call Stub Edge.
    fn split_call_stub(&self, combined_value: &Self::Value) -> Option<Self::Value>;

    /// Transition function for return stub split.
    /// Has access to the value at the ReturnCombine node and
    /// decides which data is transferred along the Return Stub Edge.
    fn split_return_stub(
        &self,
        combined_value: &Self::Value,
        returned_from_sub: &Term<Sub>,
    ) -> Option<Self::Value>;

    /// Transition function for calls to functions not contained in the binary.
    /// The corresponding edge goes from the callsite to the returned-to block.
    fn update_call_stub(
        &self,
        value_after_call: &Self::Value,
        call: &Term<Jmp>,
    ) -> Option<Self::Value>;

    /// This function is used to refine the value using the information on which branch was taken on a conditional jump.
    fn specialize_conditional(
        &self,
        value_after_jump: &Self::Value,
        condition: &Expression,
        is_true: bool,
    ) -> Option<Self::Value>;
}

impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
    type EdgeLabel = Edge<'a>;
    type NodeLabel = Node<'a>;
    type NodeValue = NodeValue<T::Value>;

    /// Get a reference to the underlying graph.
    fn get_graph(&self) -> &Graph<'a> {
        self.context.get_graph()
    }

    /// Merge two values using the merge function from the interprocedural context object.
    fn merge(&self, val1: &Self::NodeValue, val2: &Self::NodeValue) -> Self::NodeValue {
        use NodeValue::*;
        match (val1, val2) {
            (Value(value1), Value(value2)) => Value(self.context.merge(value1, value2)),
            (
                CallFlowCombinator {
                    call_stub: call1,
                    interprocedural_flow: target1,
                },
                CallFlowCombinator {
                    call_stub: call2,
                    interprocedural_flow: target2,
                },
            ) => CallFlowCombinator {
                call_stub: merge_option(call1, call2, |v1, v2| self.context.merge(v1, v2)),
                interprocedural_flow: merge_option(target1, target2, |v1, v2| {
                    self.context.merge(v1, v2)
                }),
            },
            _ => panic!("Malformed CFG in fixpoint computation"),
        }
    }

    /// Backward edge transition function.
    /// Applies the transition functions from the interprocedural context object
    /// corresponding to the type of the provided edge.
    fn update_edge(
        &self,
        node_value: &Self::NodeValue,
        edge: EdgeIndex,
    ) -> Option<Self::NodeValue> {
        let graph = self.context.get_graph();
        let (start_node, end_node) = graph.edge_endpoints(edge).unwrap();

        match graph.edge_weight(edge).unwrap() {
            // Added rev() function to iterator to iterate backwards over the definitions
            Edge::Block => {
                let block_term = graph.node_weight(start_node).unwrap().get_block();
                let value = node_value.unwrap_value();
                let defs = &block_term.term.defs;
                let end_val = defs.iter().rev().try_fold(value.clone(), |accum, def| {
                    self.context.update_def(&accum, def)
                });
                end_val.map(NodeValue::Value)
            }
            Edge::ReturnCombine(_) => {
                Some(Self::NodeValue::Value(node_value.unwrap_value().clone()))
            }
            // The Call Edge value is added to the CallSourceCombinator.
            // The end node will be the callsite node and the node_value parameter is the value at the
            // called subroutine's BlkStart node
            Edge::Call(_) => Some(NodeValue::CallFlowCombinator {
                call_stub: None,
                interprocedural_flow: Some(node_value.unwrap_value().clone()),
            }),
            // The CallStub Edge value is added to the CallSourceCombinator
            // The user has the ability to split the node value at the BlkStart return to node
            // to only send specific data along the CallStub Edge to the callsite
            Edge::CRCallStub => Some(NodeValue::CallFlowCombinator {
                call_stub: self.context.split_call_stub(node_value.unwrap_value()),
                interprocedural_flow: None,
            }),
            // The user has the ability to split the node value at the BlkStart return node
            // to only send specific data along the ReturnStub Edge to the last BlkEnd node called subroutine
            Edge::CRReturnStub => {
                // The subroutine term from which the program returns
                let returned_from_sub = match graph.node_weight(end_node) {
                    Some(Node::BlkEnd { 0: _, 1: sub_term }) => sub_term,
                    _ => panic!("Malformed Control flow graph"),
                };
                self.context
                    .split_return_stub(node_value.unwrap_value(), returned_from_sub)
                    .map(NodeValue::Value)
            }

            // The CallCombine Edge merges the values coming in from the CallStub Edge and Call Edge
            // It also gives the user access to the call and return term.
            Edge::CallCombine(return_term) => match node_value {
                NodeValue::Value(_) => panic!("Unexpected interprocedural fixpoint graph state"),
                NodeValue::CallFlowCombinator {
                    call_stub,
                    interprocedural_flow,
                } => {
                    let (call_block, caller_sub) = match graph.node_weight(start_node) {
                        Some(Node::CallSource {
                            source: (call_block, call_sub),
                            target: _,
                        }) => (call_block, call_sub),
                        _ => panic!("Malformed Control flow graph"),
                    };
                    let call_term = &call_block.term.jmps[0];
                    match self.context.update_callsite(
                        interprocedural_flow.as_ref(),
                        call_stub.as_ref(),
                        caller_sub,
                        call_term,
                        return_term,
                    ) {
                        Some(val) => Some(NodeValue::Value(val)),
                        None => None,
                    }
                }
            },
            Edge::ExternCallStub(call) => self
                .context
                .update_call_stub(node_value.unwrap_value(), call)
                .map(NodeValue::Value),
            Edge::Jump(jump, untaken_conditional) => self
                .context
                .update_jumpsite(
                    node_value.unwrap_value(),
                    jump,
                    *untaken_conditional,
                    graph[end_node].get_block(),
                )
                .map(NodeValue::Value),
        }
    }
}

/// This struct is a wrapper to create a general fixpoint context out of an interprocedural fixpoint context.
pub struct GeneralizedContext<'a, T: Context<'a>> {
    context: T,
    _phantom_graph_reference: PhantomData<Graph<'a>>,
}

impl<'a, T: Context<'a>> GeneralizedContext<'a, T> {
    /// Create a new generalized context out of an interprocedural context object.
    pub fn new(context: T) -> Self {
        GeneralizedContext {
            context,
            _phantom_graph_reference: PhantomData,
        }
    }

    pub fn get_context(&self) -> &T {
        &self.context
    }
}

/// Generate a new computation from the corresponding context and an optional default value for nodes.
pub fn create_computation<'a, T: Context<'a>>(
    problem: T,
    default_value: Option<T::Value>,
) -> super::fixpoint::Computation<GeneralizedContext<'a, T>> {
    let generalized_problem = GeneralizedContext::new(problem);
    super::fixpoint::Computation::new(generalized_problem, default_value.map(NodeValue::Value))
}

#[cfg(test)]
pub mod tests;

#[cfg(test)]
pub mod mock_context;