use super::{AbstractDomain, HasBitSize, HasTop, RegisterDomain};
use crate::bil::*;
use crate::prelude::*;

/// The `BitvectorDomain` is a simple abstract domain describing a bitvector of known length.
///
/// As values it can only assume a known bitvector or *Top(bitsize)*.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
pub enum BitvectorDomain {
    Top(BitSize),
    Value(Bitvector),
}

impl AbstractDomain for BitvectorDomain {
    /// merge two values. Returns *Top* if the values are not equal.
    fn merge(&self, other: &Self) -> Self {
        if self == other {
            self.clone()
        } else {
            self.top()
        }
    }

    /// Check if the value is *Top*.
    fn is_top(&self) -> bool {
        matches!(self, Self::Top(_))
    }
}

impl HasTop for BitvectorDomain {
    /// Return a *Top* value with the same bitsize as `self`.
    fn top(&self) -> BitvectorDomain {
        BitvectorDomain::Top(self.bitsize())
    }
}

impl HasBitSize for BitvectorDomain {
    /// Return the bitsize of `self`.
    fn bitsize(&self) -> BitSize {
        use BitvectorDomain::*;
        match self {
            Top(bitsize) => *bitsize,
            Value(bitvec) => bitvec.width().to_usize() as u16,
        }
    }
}

impl RegisterDomain for BitvectorDomain {
    /// Get a *Top* element with the given bitsize.
    fn new_top(bitsize: BitSize) -> BitvectorDomain {
        BitvectorDomain::Top(bitsize)
    }

    /// Evaluate the given binary operation.
    ///
    /// For non-shift operations, this function will panic if the operands have different bitsizes.
    fn bin_op(&self, op: BinOpType, rhs: &Self) -> Self {
        use BinOpType::*;
        match op {
            LSHIFT | RSHIFT | ARSHIFT => (),
            _ => assert_eq!(self.bitsize(), rhs.bitsize()),
        }
        match (self, rhs) {
            (BitvectorDomain::Value(lhs_bitvec), BitvectorDomain::Value(rhs_bitvec)) => match op {
                PLUS => BitvectorDomain::Value(lhs_bitvec + rhs_bitvec),
                MINUS => BitvectorDomain::Value(lhs_bitvec - rhs_bitvec),
                TIMES => BitvectorDomain::Value(lhs_bitvec * rhs_bitvec),
                DIVIDE => BitvectorDomain::Value(
                    lhs_bitvec.clone().into_checked_udiv(rhs_bitvec).unwrap(),
                ),
                SDIVIDE => BitvectorDomain::Value(
                    lhs_bitvec.clone().into_checked_sdiv(rhs_bitvec).unwrap(),
                ),
                MOD => BitvectorDomain::Value(
                    lhs_bitvec.clone().into_checked_urem(rhs_bitvec).unwrap(),
                ),
                SMOD => BitvectorDomain::Value(
                    lhs_bitvec.clone().into_checked_srem(rhs_bitvec).unwrap(),
                ),
                LSHIFT => {
                    let shift_amount = rhs_bitvec.try_to_u64().unwrap() as usize;
                    if shift_amount < lhs_bitvec.width().to_usize() {
                        BitvectorDomain::Value(
                            lhs_bitvec.clone().into_checked_shl(shift_amount).unwrap(),
                        )
                    } else {
                        BitvectorDomain::Value(Bitvector::zero(lhs_bitvec.width()))
                    }
                }
                RSHIFT => {
                    let shift_amount = rhs_bitvec.try_to_u64().unwrap() as usize;
                    if shift_amount < lhs_bitvec.width().to_usize() {
                        BitvectorDomain::Value(
                            lhs_bitvec.clone().into_checked_lshr(shift_amount).unwrap(),
                        )
                    } else {
                        BitvectorDomain::Value(Bitvector::zero(lhs_bitvec.width()))
                    }
                }
                ARSHIFT => {
                    let shift_amount = rhs_bitvec.try_to_u64().unwrap() as usize;
                    if shift_amount < lhs_bitvec.width().to_usize() {
                        BitvectorDomain::Value(
                            lhs_bitvec.clone().into_checked_ashr(shift_amount).unwrap(),
                        )
                    } else {
                        let signed_bitvec = apint::Int::from(lhs_bitvec.clone());
                        if signed_bitvec.is_negative() {
                            let minus_one = Bitvector::zero(lhs_bitvec.width())
                                - &Bitvector::one(lhs_bitvec.width());
                            BitvectorDomain::Value(minus_one)
                        } else {
                            BitvectorDomain::Value(Bitvector::zero(lhs_bitvec.width()))
                        }
                    }
                }
                AND => BitvectorDomain::Value(lhs_bitvec & rhs_bitvec),
                OR => BitvectorDomain::Value(lhs_bitvec | rhs_bitvec),
                XOR => BitvectorDomain::Value(lhs_bitvec ^ rhs_bitvec),
                EQ => {
                    assert_eq!(lhs_bitvec.width(), rhs_bitvec.width());
                    BitvectorDomain::Value(Bitvector::from(lhs_bitvec == rhs_bitvec))
                }
                NEQ => {
                    assert_eq!(lhs_bitvec.width(), rhs_bitvec.width());
                    BitvectorDomain::Value(Bitvector::from(lhs_bitvec != rhs_bitvec))
                }
                LT => BitvectorDomain::Value(Bitvector::from(
                    lhs_bitvec.checked_ult(rhs_bitvec).unwrap(),
                )),
                LE => BitvectorDomain::Value(Bitvector::from(
                    lhs_bitvec.checked_ule(rhs_bitvec).unwrap(),
                )),
                SLT => BitvectorDomain::Value(Bitvector::from(
                    lhs_bitvec.checked_slt(rhs_bitvec).unwrap(),
                )),
                SLE => BitvectorDomain::Value(Bitvector::from(
                    lhs_bitvec.checked_sle(rhs_bitvec).unwrap(),
                )),
            },
            _ => match op {
                PLUS | MINUS | TIMES | DIVIDE | SDIVIDE | MOD | SMOD | LSHIFT | RSHIFT
                | ARSHIFT | AND | OR | XOR => BitvectorDomain::new_top(self.bitsize()),
                EQ | NEQ | LT | LE | SLT | SLE => BitvectorDomain::new_top(1),
            },
        }
    }

    /// Evaluate the given unary operation.
    fn un_op(&self, op: UnOpType) -> Self {
        use UnOpType::*;
        if let BitvectorDomain::Value(bitvec) = self {
            match op {
                NEG => BitvectorDomain::Value(-bitvec),
                NOT => BitvectorDomain::Value(bitvec.clone().into_bitnot()),
            }
        } else {
            BitvectorDomain::new_top(self.bitsize())
        }
    }

    /// Extract a sub-bitvector out of a bitvector
    fn extract(&self, low_bit: BitSize, high_bit: BitSize) -> Self {
        if let BitvectorDomain::Value(bitvec) = self {
            BitvectorDomain::Value(
                bitvec
                    .clone()
                    .into_checked_lshr(low_bit as usize)
                    .unwrap()
                    .into_truncate((high_bit - low_bit + 1) as usize)
                    .unwrap(),
            )
        } else {
            BitvectorDomain::new_top(high_bit - low_bit + 1)
        }
    }

    /// Perform a size-changing cast on a bitvector.
    fn cast(&self, kind: CastType, width: BitSize) -> Self {
        if let BitvectorDomain::Value(bitvec) = self {
            use CastType::*;
            match kind {
                UNSIGNED => {
                    BitvectorDomain::Value(bitvec.clone().into_zero_extend(width as usize).unwrap())
                }
                SIGNED => {
                    BitvectorDomain::Value(bitvec.clone().into_sign_extend(width as usize).unwrap())
                }
                HIGH => BitvectorDomain::Value(
                    bitvec
                        .clone()
                        .into_checked_lshr((self.bitsize() - width) as usize)
                        .unwrap()
                        .into_truncate(width as usize)
                        .unwrap(),
                ),
                LOW => {
                    BitvectorDomain::Value(bitvec.clone().into_truncate(width as usize).unwrap())
                }
            }
        } else {
            BitvectorDomain::new_top(width)
        }
    }

    /// Concatenate two bitvectors.
    fn concat(&self, other: &Self) -> Self {
        match (self, other) {
            (BitvectorDomain::Value(left_bitvec), BitvectorDomain::Value(right_bitvec)) => {
                let new_bitwidth = (self.bitsize() + other.bitsize()) as usize;
                let upper_bits = left_bitvec
                    .clone()
                    .into_zero_extend(new_bitwidth)
                    .unwrap()
                    .into_checked_shl(other.bitsize() as usize)
                    .unwrap();
                let lower_bits = right_bitvec.clone().into_zero_extend(new_bitwidth).unwrap();
                BitvectorDomain::Value(upper_bits | &lower_bits)
            }
            _ => BitvectorDomain::new_top(self.bitsize() + other.bitsize()),
        }
    }
}

impl std::ops::Add for BitvectorDomain {
    type Output = BitvectorDomain;

    fn add(self, rhs: Self) -> Self {
        self.bin_op(crate::bil::BinOpType::PLUS, &rhs)
    }
}

impl std::ops::Sub for BitvectorDomain {
    type Output = BitvectorDomain;

    fn sub(self, rhs: Self) -> Self {
        self.bin_op(crate::bil::BinOpType::MINUS, &rhs)
    }
}

impl std::ops::Neg for BitvectorDomain {
    type Output = BitvectorDomain;

    fn neg(self) -> Self {
        self.un_op(crate::bil::UnOpType::NEG)
    }
}

impl std::convert::From<Bitvector> for BitvectorDomain {
    fn from(bitvector: Bitvector) -> BitvectorDomain {
        BitvectorDomain::Value(bitvector)
    }
}

impl std::convert::TryFrom<&BitvectorDomain> for Bitvector {
    type Error = ();
    fn try_from(bitvec_domain: &BitvectorDomain) -> Result<Bitvector, ()> {
        match bitvec_domain {
            BitvectorDomain::Value(bitvec) => Ok(bitvec.clone()),
            BitvectorDomain::Top(_) => Err(()),
        }
    }
}

impl std::fmt::Display for BitvectorDomain {
    fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            Self::Top(bitsize) => write!(formatter, "Top:i{}", bitsize),
            Self::Value(bitvector) => write!(
                formatter,
                "0x{:016x}:i{:?}",
                bitvector,
                bitvector.width().to_usize()
            ),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn bv(value: i64) -> BitvectorDomain {
        BitvectorDomain::Value(Bitvector::from_i64(value))
    }

    #[test]
    fn bitvector_domain_as_value_domain() {
        use crate::bil::BinOpType::*;
        use crate::bil::CastType::*;
        use crate::bil::UnOpType::*;
        let eight = bv(8);
        let sixteen = bv(16);

        assert_eq!(sixteen.bin_op(PLUS, &eight), bv(24));
        assert_eq!(sixteen.bin_op(MINUS, &eight), bv(8));
        assert_eq!(sixteen.bin_op(TIMES, &eight), bv(16 * 8));
        assert_eq!(sixteen.bin_op(DIVIDE, &eight), bv(2));
        assert_eq!(sixteen.bin_op(SDIVIDE, &eight), bv(2));
        assert_eq!(sixteen.bin_op(MOD, &eight), bv(0));
        assert_eq!(sixteen.bin_op(SMOD, &eight), bv(0));
        assert_eq!(sixteen.bin_op(LSHIFT, &bv(2)), bv(64));
        assert_eq!(sixteen.bin_op(RSHIFT, &bv(2)), bv(4));
        assert_eq!(sixteen.bin_op(ARSHIFT, &bv(2)), bv(4));
        assert_eq!(sixteen.bin_op(AND, &eight), bv(0));
        assert_eq!(sixteen.bin_op(OR, &eight), bv(24));
        assert_eq!(sixteen.bin_op(XOR, &eight), bv(24));

        assert_eq!(
            sixteen.bin_op(EQ, &bv(16)),
            BitvectorDomain::Value(Bitvector::from_bit(true))
        );
        assert_eq!(
            sixteen.bin_op(NEQ, &bv(16)),
            BitvectorDomain::Value(Bitvector::from_bit(false))
        );

        assert_eq!(sixteen.un_op(NEG), bv(-16));
        assert_eq!(bv(0).un_op(NOT), bv(-1));

        assert_eq!(
            sixteen.extract(0, 31),
            BitvectorDomain::Value(Bitvector::from_i32(16))
        );
        assert_eq!(
            sixteen.extract(32, 63),
            BitvectorDomain::Value(Bitvector::from_i32(0))
        );

        assert_eq!(
            BitvectorDomain::Value(Bitvector::from_i32(2)),
            BitvectorDomain::Value(Bitvector::from_i64(2 << 32)).cast(HIGH, 32)
        );

        assert_eq!(
            BitvectorDomain::Value(Bitvector::from_i32(-1))
                .concat(&BitvectorDomain::Value(Bitvector::from_i32(-1))),
            bv(-1)
        );
    }

    #[test]
    fn bitvector_domain_as_abstract_domain() {
        assert_eq!(bv(17).merge(&bv(17)), bv(17));
        assert_eq!(bv(17).merge(&bv(16)), BitvectorDomain::new_top(64));
        assert!(!bv(17).is_top());
        assert!(BitvectorDomain::new_top(64).is_top());
    }

    #[test]
    fn arshift() {
        use crate::bil::BinOpType::ARSHIFT;
        let positive_x = BitvectorDomain::Value(Bitvector::from_i64(31));
        let negative_x = BitvectorDomain::Value(Bitvector::from_i64(-31));
        let shift_3 = BitvectorDomain::Value(Bitvector::from_u8(3));
        let shift_70 = BitvectorDomain::Value(Bitvector::from_u8(70));
        assert_eq!(
            positive_x.bin_op(ARSHIFT, &shift_3),
            BitvectorDomain::Value(Bitvector::from_i64(3))
        );
        assert_eq!(
            positive_x.bin_op(ARSHIFT, &shift_70),
            BitvectorDomain::Value(Bitvector::from_i64(0))
        );
        assert_eq!(
            negative_x.bin_op(ARSHIFT, &shift_3),
            BitvectorDomain::Value(Bitvector::from_i64(-4))
        );
        assert_eq!(
            negative_x.bin_op(ARSHIFT, &shift_70),
            BitvectorDomain::Value(Bitvector::from_i64(-1))
        );
    }
}