use crate::intermediate_representation::BinOpType as IrBinOpType;
use crate::intermediate_representation::ByteSize;
use crate::intermediate_representation::CastOpType as IrCastOpType;
use crate::intermediate_representation::Expression as IrExpression;
use crate::intermediate_representation::UnOpType as IrUnOpType;
use apint::Width;
use serde::{Deserialize, Serialize};

pub mod variable;
pub use variable::*;

pub type Bitvector = apint::ApInt;

pub type BitSize = u16;

impl From<BitSize> for ByteSize {
    /// Convert to `ByteSize`, while always rounding up to the nearest full byte.
    fn from(bitsize: BitSize) -> ByteSize {
        ((bitsize as u64 + 7) / 8).into()
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
pub enum Expression {
    Var(Variable),
    Const(Bitvector),
    Load {
        memory: Box<Expression>,
        address: Box<Expression>,
        endian: Endianness,
        size: BitSize,
    },
    Store {
        memory: Box<Expression>,
        address: Box<Expression>,
        value: Box<Expression>,
        endian: Endianness,
        size: BitSize,
    },
    BinOp {
        op: BinOpType,
        lhs: Box<Expression>,
        rhs: Box<Expression>,
    },
    UnOp {
        op: UnOpType,
        arg: Box<Expression>,
    },
    Cast {
        kind: CastType,
        width: BitSize,
        arg: Box<Expression>,
    },
    Let {
        var: Variable,
        bound_exp: Box<Expression>,
        body_exp: Box<Expression>,
    },
    Unknown {
        description: String,
        type_: Type,
    },
    IfThenElse {
        condition: Box<Expression>,
        true_exp: Box<Expression>,
        false_exp: Box<Expression>,
    },
    Extract {
        low_bit: BitSize,
        high_bit: BitSize,
        arg: Box<Expression>,
    },
    Concat {
        left: Box<Expression>,
        right: Box<Expression>,
    },
}

impl Expression {
    /// Resolve all let-bindings inside an expression to create an equivalent expression without usage of let-bindings.
    pub fn replace_let_bindings(&mut self) {
        use Expression::*;
        match self {
            Var(_) | Const(_) | Unknown { .. } => (),
            Load {
                memory, address, ..
            } => {
                memory.replace_let_bindings();
                address.replace_let_bindings();
            }
            Store {
                memory,
                address,
                value,
                ..
            } => {
                memory.replace_let_bindings();
                address.replace_let_bindings();
                value.replace_let_bindings();
            }
            BinOp { op: _, lhs, rhs } => {
                lhs.replace_let_bindings();
                rhs.replace_let_bindings();
            }
            UnOp { op: _, arg } => arg.replace_let_bindings(),
            Cast {
                kind: _,
                width: _,
                arg,
            } => arg.replace_let_bindings(),
            Let {
                var,
                bound_exp,
                body_exp,
            } => {
                let to_replace = Expression::Var(var.clone());
                body_exp.replace_let_bindings();
                body_exp.substitute(&to_replace, bound_exp);
                *self = *body_exp.clone();
            }
            IfThenElse {
                condition,
                true_exp,
                false_exp,
            } => {
                condition.replace_let_bindings();
                true_exp.replace_let_bindings();
                false_exp.replace_let_bindings();
            }
            Extract {
                low_bit: _,
                high_bit: _,
                arg,
            } => arg.replace_let_bindings(),
            Concat { left, right } => {
                left.replace_let_bindings();
                right.replace_let_bindings();
            }
        }
    }

    /// Substitutes all subexpressions equal to `to_replace` with the expression `replace_with`.
    fn substitute(&mut self, to_replace: &Expression, replace_with: &Expression) {
        use Expression::*;
        if self == to_replace {
            *self = replace_with.clone();
        } else {
            match self {
                Var(_) | Const(_) | Unknown { .. } => (),
                Load {
                    memory, address, ..
                } => {
                    memory.substitute(to_replace, replace_with);
                    address.substitute(to_replace, replace_with);
                }
                Store {
                    memory,
                    address,
                    value,
                    ..
                } => {
                    memory.substitute(to_replace, replace_with);
                    address.substitute(to_replace, replace_with);
                    value.substitute(to_replace, replace_with);
                }
                BinOp { op: _, lhs, rhs } => {
                    lhs.substitute(to_replace, replace_with);
                    rhs.substitute(to_replace, replace_with);
                }
                UnOp { op: _, arg } => arg.substitute(to_replace, replace_with),
                Cast {
                    kind: _,
                    width: _,
                    arg,
                } => arg.substitute(to_replace, replace_with),
                Let {
                    var: _,
                    bound_exp,
                    body_exp,
                } => {
                    bound_exp.substitute(to_replace, replace_with);
                    body_exp.substitute(to_replace, replace_with);
                }
                IfThenElse {
                    condition,
                    true_exp,
                    false_exp,
                } => {
                    condition.substitute(to_replace, replace_with);
                    true_exp.substitute(to_replace, replace_with);
                    false_exp.substitute(to_replace, replace_with);
                }
                Extract {
                    low_bit: _,
                    high_bit: _,
                    arg,
                } => arg.substitute(to_replace, replace_with),
                Concat { left, right } => {
                    left.substitute(to_replace, replace_with);
                    right.substitute(to_replace, replace_with);
                }
            }
        }
    }

    /// Compute the bitsize of the value that the expression computes.
    /// Return zero for `Store` expressions.
    pub fn bitsize(&self) -> BitSize {
        use Expression::*;
        match self {
            Var(var) => var.bitsize().unwrap(),
            Const(bitvector) => bitvector.width().to_usize() as u16,
            Load { size, .. } => *size,
            Store { .. } => 0,
            BinOp { op, lhs, rhs: _ } => {
                use BinOpType::*;
                match op {
                    EQ | NEQ | LT | LE | SLT | SLE => 1,
                    _ => lhs.bitsize(),
                }
            }
            UnOp { arg, .. } => arg.bitsize(),
            Cast { width, .. } => *width,
            Let { .. } => panic!(),
            Unknown {
                description: _,
                type_,
            } => type_.bitsize().unwrap(),
            IfThenElse { true_exp, .. } => true_exp.bitsize(),
            Extract {
                low_bit, high_bit, ..
            } => high_bit - low_bit,
            Concat { left, right } => left.bitsize() + right.bitsize(),
        }
    }
}

impl From<Expression> for IrExpression {
    /// Convert a BAP IR expression to an internal IR expression.
    /// Panics on expressions that are not expressions in the internal IR.
    /// Replaces `IfThenElse` expressions with `Unknown` expressions (thus losing some information).
    fn from(expr: Expression) -> IrExpression {
        use Expression::*;
        match expr {
            Var(var) => IrExpression::Var(var.into()),
            Const(bitvector) => {
                // The internal IR expects everything to be byte-sized, so we have to extend the bitvector if necessary.
                let size: ByteSize = bitvector.width().into();
                IrExpression::Const(
                    bitvector
                        .into_zero_extend(apint::BitWidth::from(size))
                        .unwrap(),
                )
            }
            Load { .. } | Store { .. } | Let { .. } => panic!(),
            IfThenElse { true_exp, .. } => IrExpression::Unknown {
                description: "BAP-IfThenElse-expression".into(),
                size: true_exp.bitsize().into(),
            },
            Unknown { description, type_ } => IrExpression::Unknown {
                description,
                size: type_.bitsize().unwrap().into(),
            },
            BinOp { op, lhs, rhs } => IrExpression::BinOp {
                op: op.into(),
                lhs: Box::new(IrExpression::from(*lhs)),
                rhs: Box::new(IrExpression::from(*rhs)),
            },
            UnOp { op, arg } => IrExpression::UnOp {
                op: op.into(),
                arg: Box::new(IrExpression::from(*arg)),
            },
            Cast { kind, width, arg } => {
                use CastType::*;
                match kind {
                    UNSIGNED => {
                        if width % 8 == 0 {
                            IrExpression::Cast {
                                arg: Box::new(IrExpression::from(*arg)),
                                op: IrCastOpType::IntZExt,
                                size: width.into(),
                            }
                        } else {
                            IrExpression::Unknown {
                                description: serde_json::to_string(&Cast { kind, width, arg })
                                    .unwrap(),
                                size: width.into(),
                            }
                        }
                    }
                    SIGNED => {
                        if width % 8 == 0 {
                            IrExpression::Cast {
                                arg: Box::new(IrExpression::from(*arg)),
                                op: IrCastOpType::IntSExt,
                                size: width.into(),
                            }
                        } else {
                            IrExpression::Unknown {
                                description: serde_json::to_string(&Cast { kind, width, arg })
                                    .unwrap(),
                                size: width.into(),
                            }
                        }
                    }
                    HIGH => {
                        if width == 1 {
                            IrExpression::BinOp {
                                op: IrBinOpType::IntSLess,
                                lhs: Box::new(IrExpression::Const(Bitvector::zero(
                                    (arg.bitsize() as usize).into(),
                                ))),
                                rhs: Box::new(IrExpression::from(*arg)),
                            }
                        } else if width % 8 == 0 {
                            let low_byte = (arg.bitsize() - width).into();
                            IrExpression::Subpiece {
                                arg: Box::new(IrExpression::from(*arg)),
                                low_byte,
                                size: width.into(),
                            }
                        } else {
                            IrExpression::Unknown {
                                description: serde_json::to_string(&Cast { kind, width, arg })
                                    .unwrap(),
                                size: width.into(),
                            }
                        }
                    }
                    LOW => {
                        if width == 1 {
                            IrExpression::Subpiece {
                                low_byte: ByteSize::new(0),
                                size: ByteSize::new(1),
                                arg: Box::new(IrExpression::BinOp {
                                    op: IrBinOpType::IntAnd,
                                    lhs: Box::new(IrExpression::Const(Bitvector::one(
                                        (arg.bitsize() as usize).into(),
                                    ))),
                                    rhs: Box::new(IrExpression::from(*arg)),
                                }),
                            }
                        } else if width % 8 == 0 {
                            IrExpression::Subpiece {
                                arg: Box::new(IrExpression::from(*arg)),
                                low_byte: (0 as u64).into(),
                                size: width.into(),
                            }
                        } else {
                            IrExpression::Unknown {
                                description: serde_json::to_string(&Cast { kind, width, arg })
                                    .unwrap(),
                                size: width.into(),
                            }
                        }
                    }
                }
            }
            Extract {
                low_bit,
                high_bit,
                arg,
            } => {
                if low_bit % 8 == 0 && (high_bit + 1) % 8 == 0 {
                    IrExpression::Subpiece {
                        size: (high_bit - low_bit + 1).into(),
                        low_byte: low_bit.into(),
                        arg: Box::new(IrExpression::from(*arg)),
                    }
                } else {
                    IrExpression::Unknown {
                        description: serde_json::to_string(&Extract {
                            low_bit,
                            high_bit,
                            arg,
                        })
                        .unwrap(),
                        size: (high_bit - low_bit + 1).into(),
                    }
                }
            }
            Concat { left, right } => IrExpression::BinOp {
                op: IrBinOpType::Piece,
                lhs: Box::new(IrExpression::from(*left)),
                rhs: Box::new(IrExpression::from(*right)),
            },
        }
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum CastType {
    UNSIGNED,
    SIGNED,
    HIGH,
    LOW,
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum BinOpType {
    PLUS,
    MINUS,
    TIMES,
    DIVIDE,
    SDIVIDE,
    MOD,
    SMOD,
    LSHIFT,
    RSHIFT,
    ARSHIFT,
    AND,
    OR,
    XOR,
    EQ,
    NEQ,
    LT,
    LE,
    SLT,
    SLE,
}

impl From<BinOpType> for IrBinOpType {
    /// Translate binary operation types.
    fn from(op: BinOpType) -> IrBinOpType {
        use BinOpType::*;
        use IrBinOpType::*;
        match op {
            PLUS => IntAdd,
            MINUS => IntSub,
            TIMES => IntMult,
            DIVIDE => IntDiv,
            SDIVIDE => IntSDiv,
            MOD => IntRem,
            SMOD => IntSRem,
            LSHIFT => IntLeft,
            RSHIFT => IntRight,
            ARSHIFT => IntSRight,
            AND => IntAnd,
            OR => IntOr,
            XOR => IntXOr,
            EQ => IntEqual,
            NEQ => IntNotEqual,
            LT => IntLess,
            LE => IntLessEqual,
            SLT => IntSLess,
            SLE => IntSLessEqual,
        }
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum UnOpType {
    NEG,
    NOT,
}

impl From<UnOpType> for IrUnOpType {
    /// Translate unary operation types.
    fn from(op: UnOpType) -> IrUnOpType {
        use UnOpType::*;
        match op {
            NEG => IrUnOpType::Int2Comp,
            NOT => IrUnOpType::IntNegate,
        }
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum Endianness {
    LittleEndian,
    BigEndian,
}

#[cfg(test)]
mod tests {
    use super::*;

    fn register(name: &str) -> Variable {
        Variable {
            name: name.into(),
            type_: Type::Immediate(64),
            is_temp: false,
        }
    }

    #[test]
    fn variant_deserialization() {
        let string = "\"UNSIGNED\"";
        assert_eq!(CastType::UNSIGNED, serde_json::from_str(string).unwrap());
        let string = "\"NEG\"";
        assert_eq!(UnOpType::NEG, serde_json::from_str(string).unwrap());
    }

    #[test]
    fn bitvector_deserialization() {
        let bitv = Bitvector::from_u64(234);
        let string = serde_json::to_string(&bitv).unwrap();
        println!("{}", string);
        println!("{:?}", bitv);
        let string = "{\"digits\":[234],\"width\":[64]}";
        assert_eq!(bitv, serde_json::from_str(string).unwrap());
    }

    #[test]
    fn expression_deserialization() {
        let string = "{\"BinOp\":{\"lhs\":{\"Const\":{\"digits\":[234],\"width\":[8]}},\"op\":\"PLUS\",\"rhs\":{\"Const\":{\"digits\":[234],\"width\":[8]}}}}";
        let bitv = Bitvector::from_u8(234);
        let exp = Expression::BinOp {
            op: BinOpType::PLUS,
            lhs: Box::new(Expression::Const(bitv.clone())),
            rhs: Box::new(Expression::Const(bitv)),
        };
        println!("{}", serde_json::to_string(&exp).unwrap());
        assert_eq!(exp, serde_json::from_str(string).unwrap())
    }

    #[test]
    fn replace_let_bindings() {
        let mut source_exp = Expression::Let {
            var: register("x"),
            bound_exp: Box::new(Expression::Const(Bitvector::from_u64(12))),
            body_exp: Box::new(Expression::BinOp {
                op: BinOpType::PLUS,
                lhs: Box::new(Expression::Var(register("x"))),
                rhs: Box::new(Expression::Const(Bitvector::from_u64(42))),
            }),
        };
        let target_exp = Expression::BinOp {
            op: BinOpType::PLUS,
            lhs: Box::new(Expression::Const(Bitvector::from_u64(12))),
            rhs: Box::new(Expression::Const(Bitvector::from_u64(42))),
        };

        source_exp.replace_let_bindings();
        assert_eq!(source_exp, target_exp);
    }
}