Unverified Commit a6ef477d by Enkelmann Committed by GitHub

Use conditional branches to restrict value sets of variables (#159)

parent 83593fac
use super::{ use super::{
AbstractDomain, AbstractIdentifier, HasTop, Interval, PointerDomain, RegisterDomain, AbstractDomain, AbstractIdentifier, HasTop, Interval, PointerDomain, RegisterDomain,
SizedDomain, TryToBitvec, TryToInterval, SizedDomain, SpecializeByConditional, TryToBitvec, TryToInterval,
}; };
use crate::intermediate_representation::*; use crate::intermediate_representation::*;
use crate::prelude::*; use crate::prelude::*;
...@@ -68,6 +68,48 @@ impl<T: RegisterDomain> DataDomain<T> { ...@@ -68,6 +68,48 @@ impl<T: RegisterDomain> DataDomain<T> {
} }
} }
impl<T: SpecializeByConditional + RegisterDomain> SpecializeByConditional for DataDomain<T> {
fn add_signed_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_signed_less_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_unsigned_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_unsigned_less_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_signed_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_signed_greater_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_unsigned_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_unsigned_greater_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_not_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_not_equal_bound(bound)?))
} else {
Ok(self)
}
}
}
impl<T: RegisterDomain> SizedDomain for DataDomain<T> { impl<T: RegisterDomain> SizedDomain for DataDomain<T> {
// Return the bitsize of `self`. // Return the bitsize of `self`.
fn bytesize(&self) -> ByteSize { fn bytesize(&self) -> ByteSize {
...@@ -244,6 +286,22 @@ impl<T: RegisterDomain + TryToInterval> TryToInterval for DataDomain<T> { ...@@ -244,6 +286,22 @@ impl<T: RegisterDomain + TryToInterval> TryToInterval for DataDomain<T> {
} }
} }
impl<T: RegisterDomain> std::ops::Add for DataDomain<T> {
type Output = DataDomain<T>;
fn add(self, rhs: Self) -> Self {
self.bin_op(BinOpType::IntAdd, &rhs)
}
}
impl<T: RegisterDomain> std::ops::Sub for DataDomain<T> {
type Output = DataDomain<T>;
fn sub(self, rhs: Self) -> Self {
self.bin_op(BinOpType::IntSub, &rhs)
}
}
impl<T: RegisterDomain + Display> DataDomain<T> { impl<T: RegisterDomain + Display> DataDomain<T> {
/// Get a more compact json-representation of the data domain. /// Get a more compact json-representation of the data domain.
/// Intended for pretty printing, not useable for serialization/deserialization. /// Intended for pretty printing, not useable for serialization/deserialization.
......
...@@ -3,7 +3,7 @@ use std::fmt::Display; ...@@ -3,7 +3,7 @@ use std::fmt::Display;
use crate::intermediate_representation::*; use crate::intermediate_representation::*;
use crate::prelude::*; use crate::prelude::*;
use super::{AbstractDomain, HasTop, RegisterDomain, SizedDomain}; use super::{AbstractDomain, HasTop, RegisterDomain, SizedDomain, SpecializeByConditional};
use super::{TryToBitvec, TryToInterval}; use super::{TryToBitvec, TryToInterval};
mod simple_interval; mod simple_interval;
...@@ -217,6 +217,113 @@ impl IntervalDomain { ...@@ -217,6 +217,113 @@ impl IntervalDomain {
.map(|bitvec| bitvec.into_sign_extend(width).unwrap()), .map(|bitvec| bitvec.into_sign_extend(width).unwrap()),
} }
} }
/// Compute the intersection of two intervals.
/// Return an error if the intersection is empty.
pub fn intersect(&self, other: &Self) -> Result<Self, Error> {
let mut intersected_domain: IntervalDomain =
self.interval.signed_intersect(&other.interval)?.into();
intersected_domain.update_widening_lower_bound(&self.widening_lower_bound);
intersected_domain.update_widening_lower_bound(&other.widening_lower_bound);
intersected_domain.update_widening_upper_bound(&self.widening_upper_bound);
intersected_domain.update_widening_upper_bound(&other.widening_upper_bound);
Ok(intersected_domain)
}
}
impl SpecializeByConditional for IntervalDomain {
fn add_signed_less_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if let Some(old_upper_bound) = &self.widening_upper_bound {
if old_upper_bound.checked_sle(bound).unwrap() {
return Ok(self);
} else if self.interval.end.checked_slt(bound).unwrap() {
self.widening_upper_bound = Some(bound.clone());
return Ok(self);
} else {
self.widening_upper_bound = None;
}
} else if self.interval.end.checked_slt(bound).unwrap() {
self.widening_upper_bound = Some(bound.clone());
return Ok(self);
}
// we already know that the bound is less equal to `self.interval.end`
if self.interval.start.checked_sle(bound).unwrap() {
self.interval.end = bound.clone();
Ok(self)
} else {
Err(anyhow!("Empty interval"))
}
}
fn add_signed_greater_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if let Some(old_lower_bound) = &self.widening_lower_bound {
if old_lower_bound.checked_sge(bound).unwrap() {
return Ok(self);
} else if self.interval.start.checked_sgt(bound).unwrap() {
self.widening_lower_bound = Some(bound.clone());
return Ok(self);
} else {
self.widening_lower_bound = None;
}
} else if self.interval.start.checked_sgt(bound).unwrap() {
self.widening_lower_bound = Some(bound.clone());
return Ok(self);
}
// we already know that the bound is greater equal to `self.interval.start`
if self.interval.end.checked_sge(bound).unwrap() {
self.interval.start = bound.clone();
Ok(self)
} else {
Err(anyhow!("Empty interval"))
}
}
fn add_unsigned_less_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if bound.sign_bit().to_bool() {
if self.interval.end.sign_bit().to_bool() {
self.add_signed_less_equal_bound(bound)
} else if self.interval.start.sign_bit().to_bool() {
Ok(self)
} else {
self.add_signed_greater_equal_bound(&Bitvector::zero(bound.width()))
}
} else {
self = self.add_signed_greater_equal_bound(&Bitvector::zero(bound.width()))?;
self.add_signed_less_equal_bound(bound)
}
}
fn add_unsigned_greater_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if bound.sign_bit().to_bool() {
self = self.add_signed_less_equal_bound(&(-Bitvector::one(bound.width())))?;
self.add_signed_greater_equal_bound(bound)
} else if self.interval.end.checked_slt(bound).unwrap() {
self.add_signed_less_equal_bound(&(-Bitvector::one(bound.width())))
} else if self.interval.start.sign_bit().to_bool() {
Ok(self)
} else {
self.add_signed_greater_equal_bound(bound)
}
}
fn add_not_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if self.interval.start == *bound && self.interval.end == *bound {
return Err(anyhow!("Empty interval"));
}
if self.interval.start.checked_sgt(bound).unwrap() {
self.add_signed_greater_equal_bound(&(bound + &Bitvector::one(bound.width())))
} else if self.interval.start == *bound {
self.interval.start += &Bitvector::one(bound.width());
Ok(self)
} else if self.interval.end.checked_slt(bound).unwrap() {
self.add_signed_less_equal_bound(&(bound - &Bitvector::one(bound.width())))
} else if self.interval.end == *bound {
self.interval.end -= &Bitvector::one(bound.width());
Ok(self)
} else {
Ok(self)
}
}
} }
impl AbstractDomain for IntervalDomain { impl AbstractDomain for IntervalDomain {
......
...@@ -4,74 +4,66 @@ impl IntervalDomain { ...@@ -4,74 +4,66 @@ impl IntervalDomain {
/// Compute the interval of possible results /// Compute the interval of possible results
/// if one adds a value from `self` to a value from `rhs`. /// if one adds a value from `self` to a value from `rhs`.
pub fn add(&self, rhs: &Self) -> Self { pub fn add(&self, rhs: &Self) -> Self {
let interval = self.interval.add(&rhs.interval); let mut interval: IntervalDomain = self.interval.add(&rhs.interval).into();
if interval.is_top() { if interval.is_top() {
interval.into() interval
} else { } else {
let new_lower_bound = if let (Some(self_bound), Some(rhs_bound)) = interval.update_widening_lower_bound(
(&self.widening_lower_bound, &rhs.widening_lower_bound) &self
{ .widening_lower_bound
if self_bound.signed_add_overflow_check(rhs_bound) { .as_ref()
None .and_then(|bound| bound.signed_add_overflow_checked(&rhs.interval.start)),
} else { );
Some(self_bound.clone().into_checked_add(rhs_bound).unwrap()) interval.update_widening_lower_bound(
} &rhs.widening_lower_bound
} else { .as_ref()
None .and_then(|bound| bound.signed_add_overflow_checked(&self.interval.start)),
}; );
let new_upper_bound = if let (Some(self_bound), Some(rhs_bound)) = interval.update_widening_upper_bound(
(&self.widening_upper_bound, &rhs.widening_upper_bound) &self
{ .widening_upper_bound
if self_bound.signed_add_overflow_check(rhs_bound) { .as_ref()
None .and_then(|bound| bound.signed_add_overflow_checked(&rhs.interval.end)),
} else { );
Some(self_bound.clone().into_checked_add(rhs_bound).unwrap()) interval.update_widening_upper_bound(
} &rhs.widening_upper_bound
} else { .as_ref()
None .and_then(|bound| bound.signed_add_overflow_checked(&self.interval.end)),
}; );
IntervalDomain { interval
interval,
widening_upper_bound: new_upper_bound,
widening_lower_bound: new_lower_bound,
}
} }
} }
/// Compute the interval of possible results /// Compute the interval of possible results
/// if one subtracts a value in `rhs` from a value in `self`. /// if one subtracts a value in `rhs` from a value in `self`.
pub fn sub(&self, rhs: &Self) -> Self { pub fn sub(&self, rhs: &Self) -> Self {
let interval = self.interval.sub(&rhs.interval); let mut interval: IntervalDomain = self.interval.sub(&rhs.interval).into();
if interval.is_top() { if interval.is_top() {
interval.into() interval
} else { } else {
let new_lower_bound = if let (Some(self_bound), Some(rhs_bound)) = interval.update_widening_lower_bound(
(&self.widening_lower_bound, &rhs.widening_upper_bound) &self
{ .widening_lower_bound
if self_bound.signed_sub_overflow_check(rhs_bound) { .as_ref()
None .and_then(|bound| bound.signed_sub_overflow_checked(&rhs.interval.end)),
} else { );
Some(self_bound.clone().into_checked_sub(rhs_bound).unwrap()) interval.update_widening_lower_bound(
} &rhs.widening_upper_bound
} else { .as_ref()
None .and_then(|bound| self.interval.start.signed_sub_overflow_checked(bound)),
}; );
let new_upper_bound = if let (Some(self_bound), Some(rhs_bound)) = interval.update_widening_upper_bound(
(&self.widening_upper_bound, &rhs.widening_lower_bound) &self
{ .widening_upper_bound
if self_bound.signed_sub_overflow_check(rhs_bound) { .as_ref()
None .and_then(|bound| bound.signed_sub_overflow_checked(&rhs.interval.start)),
} else { );
Some(self_bound.clone().into_checked_sub(rhs_bound).unwrap()) interval.update_widening_upper_bound(
} &rhs.widening_lower_bound
} else { .as_ref()
None .and_then(|bound| self.interval.end.signed_sub_overflow_checked(bound)),
}; );
IntervalDomain { interval
interval,
widening_upper_bound: new_upper_bound,
widening_lower_bound: new_lower_bound,
}
} }
} }
......
...@@ -55,6 +55,18 @@ impl Interval { ...@@ -55,6 +55,18 @@ impl Interval {
Interval { start, end } Interval { start, end }
} }
/// Compute the intersection of two intervals as intervals of signed integers.
/// Return an error if the intersection is empty.
pub fn signed_intersect(&self, other: &Interval) -> Result<Interval, Error> {
let start = signed_max(&self.start, &other.start);
let end = signed_min(&self.end, &other.end);
if start.checked_sle(&end).unwrap() {
Ok(Interval { start, end })
} else {
Err(anyhow!("Empty interval"))
}
}
/// Return the number of contained values of the interval as an unsigned bitvector. /// Return the number of contained values of the interval as an unsigned bitvector.
/// If the interval is unconstrained, return zero /// If the interval is unconstrained, return zero
/// (since the maximal number of elements is not representable in a bitvector of the same byte size). /// (since the maximal number of elements is not representable in a bitvector of the same byte size).
...@@ -142,30 +154,26 @@ impl Interval { ...@@ -142,30 +154,26 @@ impl Interval {
/// Compute the interval of possible results /// Compute the interval of possible results
/// if one adds a value from `self` to a value from `rhs`. /// if one adds a value from `self` to a value from `rhs`.
pub fn add(&self, rhs: &Interval) -> Interval { pub fn add(&self, rhs: &Interval) -> Interval {
if self.start.signed_add_overflow_check(&rhs.start) if let (Some(start), Some(end)) = (
|| self.end.signed_add_overflow_check(&rhs.end) self.start.signed_add_overflow_checked(&rhs.start),
{ self.end.signed_add_overflow_checked(&rhs.end),
Interval::new_top(self.bytesize()) ) {
Interval { start, end }
} else { } else {
Interval { Interval::new_top(self.bytesize())
start: self.start.clone().into_checked_add(&rhs.start).unwrap(),
end: self.end.clone().into_checked_add(&rhs.end).unwrap(),
}
} }
} }
/// Compute the interval of possible results /// Compute the interval of possible results
/// if one subtracts a value in `rhs` from a value in `self`. /// if one subtracts a value in `rhs` from a value in `self`.
pub fn sub(&self, rhs: &Interval) -> Interval { pub fn sub(&self, rhs: &Interval) -> Interval {
if self.start.signed_sub_overflow_check(&rhs.end) if let (Some(start), Some(end)) = (
|| self.end.signed_sub_overflow_check(&rhs.start) self.start.signed_sub_overflow_checked(&rhs.end),
{ self.end.signed_sub_overflow_checked(&rhs.start),
Interval::new_top(self.bytesize()) ) {
Interval { start, end }
} else { } else {
Interval { Interval::new_top(self.bytesize())
start: self.start.clone().into_checked_sub(&rhs.end).unwrap(),
end: self.end.clone().into_checked_sub(&rhs.start).unwrap(),
}
} }
} }
...@@ -193,6 +201,12 @@ impl Interval { ...@@ -193,6 +201,12 @@ impl Interval {
end: max, end: max,
} }
} }
/// Return `true` if `bitvec` is contained in the interval.
/// Panics if the interval and `bitvec` have different bytesizes.
pub fn contains(&self, bitvec: &Bitvector) -> bool {
self.start.checked_sle(bitvec).unwrap() && self.end.checked_sge(bitvec).unwrap()
}
} }
impl From<Bitvector> for Interval { impl From<Bitvector> for Interval {
......
...@@ -228,18 +228,18 @@ fn add() { ...@@ -228,18 +228,18 @@ fn add() {
let result = lhs.bin_op(BinOpType::IntAdd, &rhs); let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!( assert_eq!(
result, result,
IntervalDomain::mock_with_bounds(None, 0, 7, Some(20)) IntervalDomain::mock_with_bounds(Some(-17), 0, 7, Some(10))
); );
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10)); let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 0, Some(10)); let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 0, Some(10));
let result = lhs.bin_op(BinOpType::IntAdd, &rhs); let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1))); assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1)));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 4, Some(100)); let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), -30, 40, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), 10, 20, Some(50)); let rhs = IntervalDomain::mock_i8_with_bounds(Some(-100), -30, 20, Some(50));
let result = lhs.bin_op(BinOpType::IntAdd, &rhs); let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!( assert_eq!(
result, result,
IntervalDomain::mock_i8_with_bounds(None, 12, 24, None) IntervalDomain::mock_i8_with_bounds(None, -60, 60, Some(90))
); );
} }
...@@ -250,18 +250,18 @@ fn sub() { ...@@ -250,18 +250,18 @@ fn sub() {
let result = lhs.bin_op(BinOpType::IntSub, &rhs); let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!( assert_eq!(
result, result,
IntervalDomain::mock_with_bounds(None, 3, 10, Some(30)) IntervalDomain::mock_with_bounds(Some(-7), 3, 10, Some(13))
); );
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10)); let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 9, Some(10)); let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 9, Some(10));
let result = lhs.bin_op(BinOpType::IntSub, &rhs); let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1))); assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1)));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 4, Some(100)); let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 40, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), 10, 20, Some(50)); let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), -30, 3, Some(100));
let result = lhs.bin_op(BinOpType::IntSub, &rhs); let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!( assert_eq!(
result, result,
IntervalDomain::mock_i8_with_bounds(None, -18, -6, None) IntervalDomain::mock_i8_with_bounds(Some(-98), -1, 70, Some(90))
); );
} }
...@@ -315,3 +315,191 @@ fn shift_left() { ...@@ -315,3 +315,191 @@ fn shift_left() {
let result = lhs.bin_op(BinOpType::IntLeft, &rhs); let result = lhs.bin_op(BinOpType::IntLeft, &rhs);
assert_eq!(result, IntervalDomain::mock(0, 0)); assert_eq!(result, IntervalDomain::mock(0, 0));
} }
#[test]
fn simple_interval_contains() {
let domain = IntervalDomain::mock(-10, 5);
assert!(!domain.interval.contains(&Bitvector::from_i64(-11)));
assert!(domain.interval.contains(&Bitvector::from_i64(-10)));
assert!(domain.interval.contains(&Bitvector::from_i64(-4)));
assert!(domain.interval.contains(&Bitvector::from_i64(5)));
assert!(!domain.interval.contains(&Bitvector::from_i64(6)));
}
#[test]
fn add_signed_bounds() {
let interval = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
// signed_less_equal
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(20))
);
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, -5, None)
);
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(-20));
assert!(x.is_err());
//signed_greater_equal
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(20));
assert!(x.is_err());
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -5, 10, Some(100))
);
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(-20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-20), -10, 10, Some(100))
);
}
#[test]
fn add_unsigned_bounds() {
let positive_interval = IntervalDomain::mock_with_bounds(Some(10), 20, 30, Some(40));
let wrapped_interval = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
let negative_interval = IntervalDomain::mock_with_bounds(Some(-40), -30, -20, Some(-10));
// unsigned_less_equal
let x = positive_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(10), 20, 30, Some(35))
);
let x = positive_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(15));
assert!(x.is_err());
let x = wrapped_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, 0, 10, Some(35))
);
let x = wrapped_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(x.unwrap(), wrapped_interval); // Cannot remove a subinterval from the domain
let x = negative_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-40), -30, -25, None)
);
let x = negative_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-35));
assert!(x.is_err());
// unsigned_greater_equal
let x = positive_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, 25, 30, Some(40))
);
let x = positive_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(35));
assert!(x.is_err());
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(5));
assert_eq!(x.unwrap(), wrapped_interval);
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, -1, None)
);
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(-50));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-50), -10, -1, None)
);
let x = negative_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(25));
assert_eq!(x.unwrap(), negative_interval);
let x = negative_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(-25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -25, -20, Some(-10))
);
}
#[test]
fn add_not_equal_bounds() {
let interval = IntervalDomain::mock_with_bounds(None, -10, 10, None);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(-20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-19), -10, 10, None)
);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(-0));
assert_eq!(x.unwrap(), interval);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -10, 10, Some(19))
);
let interval = IntervalDomain::mock(5, 5);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(5));
assert!(x.is_err());
let interval = IntervalDomain::mock(5, 6);
let x = interval.add_not_equal_bound(&Bitvector::from_i64(5));
assert_eq!(x.unwrap(), IntervalDomain::mock(6, 6));
}
#[test]
fn intersection() {
let interval1 = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
let interval2 = IntervalDomain::mock_with_bounds(Some(-20), 2, 30, None);
let intersection = interval1.intersect(&interval2).unwrap();
assert_eq!(
intersection,
IntervalDomain::mock_with_bounds(Some(-20), 2, 10, Some(100))
);
assert!(interval1.intersect(&IntervalDomain::mock(50, 55)).is_err());
}
...@@ -123,3 +123,34 @@ pub trait TryToInterval { ...@@ -123,3 +123,34 @@ pub trait TryToInterval {
Ok((interval.start.try_to_i64()?, interval.end.try_to_i64()?)) Ok((interval.start.try_to_i64()?, interval.end.try_to_i64()?))
} }
} }
/// A trait for domains whose values can be restricted by knowing the result of a comparison of it with a known bitvector.
/// The comparison may also be used to add widening hints to the domain.
///
/// Note that the value set represented by the domain after the restriction may be an upper bound,
/// i.e. it is possible that the result still contains values not satisfying the restricting comparison.
pub trait SpecializeByConditional: Sized {
/// Return the restriction of `self` to values satisfying `self <= bound`
/// with `self` and `bound` interpreted as signed integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_signed_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self <= bound`
/// with `self` and `bound` interpreted as unsigned integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_unsigned_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self >= bound`
/// with `self` and `bound` interpreted as signed integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_signed_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self >= bound`
/// with `self` and `bound` interpreted as unsigned integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_unsigned_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self != bound`
/// Returns an error if `self` only represents one value for which `self == bound` holds.
fn add_not_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
}
...@@ -165,13 +165,13 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> { ...@@ -165,13 +165,13 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
// The CallStub Edge value is added to the CallSourceCombinator // The CallStub Edge value is added to the CallSourceCombinator
// The user has the ability to split the node value at the BlkStart return to node // The user has the ability to split the node value at the BlkStart return to node
// to only send specific data along the CallStub Edge to the callsite // to only send specific data along the CallStub Edge to the callsite
Edge::CRCallStub => Some(NodeValue::CallFlowCombinator { Edge::CrCallStub => Some(NodeValue::CallFlowCombinator {
call_stub: self.context.split_call_stub(node_value.unwrap_value()), call_stub: self.context.split_call_stub(node_value.unwrap_value()),
interprocedural_flow: None, interprocedural_flow: None,
}), }),
// The user has the ability to split the node value at the BlkStart return node // The user has the ability to split the node value at the BlkStart return node
// to only send specific data along the ReturnStub Edge to the last BlkEnd node called subroutine // to only send specific data along the ReturnStub Edge to the last BlkEnd node called subroutine
Edge::CRReturnStub => { Edge::CrReturnStub => {
// The subroutine term from which the program returns // The subroutine term from which the program returns
let returned_from_sub = match graph.node_weight(end_node) { let returned_from_sub = match graph.node_weight(end_node) {
Some(Node::BlkEnd { 0: _, 1: sub_term }) => sub_term, Some(Node::BlkEnd { 0: _, 1: sub_term }) => sub_term,
......
...@@ -79,6 +79,7 @@ pub trait Context<'a> { ...@@ -79,6 +79,7 @@ pub trait Context<'a> {
&self, &self,
value: &Self::Value, value: &Self::Value,
condition: &Expression, condition: &Expression,
block_before_condition: &Term<Blk>,
is_true: bool, is_true: bool,
) -> Option<Self::Value>; ) -> Option<Self::Value>;
} }
...@@ -164,11 +165,11 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> { ...@@ -164,11 +165,11 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
.context .context
.update_call(node_value.unwrap_value(), call, &graph[end_node]) .update_call(node_value.unwrap_value(), call, &graph[end_node])
.map(NodeValue::Value), .map(NodeValue::Value),
Edge::CRCallStub => Some(NodeValue::CallFlowCombinator { Edge::CrCallStub => Some(NodeValue::CallFlowCombinator {
call_stub: Some(node_value.unwrap_value().clone()), call_stub: Some(node_value.unwrap_value().clone()),
interprocedural_flow: None, interprocedural_flow: None,
}), }),
Edge::CRReturnStub => Some(NodeValue::CallFlowCombinator { Edge::CrReturnStub => Some(NodeValue::CallFlowCombinator {
call_stub: None, call_stub: None,
interprocedural_flow: Some(node_value.unwrap_value().clone()), interprocedural_flow: Some(node_value.unwrap_value().clone()),
}), }),
...@@ -201,15 +202,51 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> { ...@@ -201,15 +202,51 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
.context .context
.update_call_stub(node_value.unwrap_value(), call) .update_call_stub(node_value.unwrap_value(), call)
.map(NodeValue::Value), .map(NodeValue::Value),
Edge::Jump(jump, untaken_conditional) => self Edge::Jump(jump, untaken_conditional) => {
.context let value_after_condition = if let Jmp::CBranch {
.update_jump( target: _,
node_value.unwrap_value(), condition,
jump, } = &jump.term
*untaken_conditional, {
graph[end_node].get_block(), let block = graph[start_node].get_block();
) self.context.specialize_conditional(
.map(NodeValue::Value), node_value.unwrap_value(),
condition,
block,
true,
)
} else if let Some(untaken_conditional_jump) = untaken_conditional {
if let Jmp::CBranch {
target: _,
condition,
} = &untaken_conditional_jump.term
{
let block = graph[start_node].get_block();
self.context.specialize_conditional(
node_value.unwrap_value(),
condition,
block,
false,
)
} else {
panic!("Malformed control flow graph");
}
} else {
Some(node_value.unwrap_value().clone())
};
if let Some(value) = value_after_condition {
self.context
.update_jump(
&value,
jump,
*untaken_conditional,
graph[end_node].get_block(),
)
.map(NodeValue::Value)
} else {
None
}
}
} }
} }
} }
......
...@@ -150,9 +150,9 @@ pub enum Edge<'a> { ...@@ -150,9 +150,9 @@ pub enum Edge<'a> {
/// The edge goes directly from the callsite to the return-to-site inside the caller. /// The edge goes directly from the callsite to the return-to-site inside the caller.
ExternCallStub(&'a Term<Jmp>), ExternCallStub(&'a Term<Jmp>),
/// An artificial edge. See the module-level documentation for more information. /// An artificial edge. See the module-level documentation for more information.
CRCallStub, CrCallStub,
/// An artificial edge. See the module-level documentation for more information. /// An artificial edge. See the module-level documentation for more information.
CRReturnStub, CrReturnStub,
/// An artificial edge to combine intra- and interprocedural data flows at the callsite of calls. /// An artificial edge to combine intra- and interprocedural data flows at the callsite of calls.
/// See the module-level documentation for more information. /// See the module-level documentation for more information.
CallCombine(&'a Term<Jmp>), CallCombine(&'a Term<Jmp>),
...@@ -425,9 +425,9 @@ impl<'a> GraphBuilder<'a> { ...@@ -425,9 +425,9 @@ impl<'a> GraphBuilder<'a> {
return_: (return_from_block, return_from_sub), return_: (return_from_block, return_from_sub),
}); });
self.graph self.graph
.add_edge(*call_node, return_combine_node, Edge::CRCallStub); .add_edge(*call_node, return_combine_node, Edge::CrCallStub);
self.graph self.graph
.add_edge(return_source, return_combine_node, Edge::CRReturnStub); .add_edge(return_source, return_combine_node, Edge::CrReturnStub);
self.graph.add_edge( self.graph.add_edge(
return_combine_node, return_combine_node,
*return_to_node, *return_to_node,
......
...@@ -90,7 +90,7 @@ impl<'a> Context<'a> { ...@@ -90,7 +90,7 @@ impl<'a> Context<'a> {
_ => Bitvector::zero(apint::BitWidth::from(self.project.get_pointer_bytesize())), _ => Bitvector::zero(apint::BitWidth::from(self.project.get_pointer_bytesize())),
}; };
match state_before_return.get_register(&self.project.stack_pointer_register) { match state_before_return.get_register(&self.project.stack_pointer_register) {
Ok(Data::Pointer(pointer)) => { Data::Pointer(pointer) => {
if pointer.targets().len() == 1 { if pointer.targets().len() == 1 {
let (id, offset) = pointer.targets().iter().next().unwrap(); let (id, offset) = pointer.targets().iter().next().unwrap();
if *id != state_before_return.stack_id if *id != state_before_return.stack_id
...@@ -106,11 +106,10 @@ impl<'a> Context<'a> { ...@@ -106,11 +106,10 @@ impl<'a> Context<'a> {
)) ))
} }
} }
Ok(Data::Top(_)) => Err(anyhow!( Data::Top(_) => Err(anyhow!(
"Stack register value lost during function execution" "Stack register value lost during function execution"
)), )),
Ok(Data::Value(_)) => Err(anyhow!("Unexpected stack register value on return")), Data::Value(_) => Err(anyhow!("Unexpected stack register value on return")),
Err(err) => Err(err),
} }
} }
...@@ -258,11 +257,7 @@ impl<'a> Context<'a> { ...@@ -258,11 +257,7 @@ impl<'a> Context<'a> {
/// Check whether the jump is an indirect call whose target evaluates to a *Top* value in the given state. /// Check whether the jump is an indirect call whose target evaluates to a *Top* value in the given state.
fn is_indirect_call_with_top_target(&self, state: &State, call: &Term<Jmp>) -> bool { fn is_indirect_call_with_top_target(&self, state: &State, call: &Term<Jmp>) -> bool {
match &call.term { match &call.term {
Jmp::CallInd { target, .. } Jmp::CallInd { target, .. } => state.eval(target).is_top(),
if state.eval(target).map_or(false, |value| value.is_top()) =>
{
true
}
_ => false, _ => false,
} }
} }
...@@ -280,7 +275,7 @@ impl<'a> Context<'a> { ...@@ -280,7 +275,7 @@ impl<'a> Context<'a> {
new_state: &mut State, new_state: &mut State,
) { ) {
let stack_register = &self.project.stack_pointer_register; let stack_register = &self.project.stack_pointer_register;
let stack_pointer = state_before_call.get_register(stack_register).unwrap(); let stack_pointer = state_before_call.get_register(stack_register);
match self.project.cpu_architecture.as_str() { match self.project.cpu_architecture.as_str() {
"x86" | "x86_64" => { "x86" | "x86_64" => {
let offset = Bitvector::from_u64(stack_register.size.into()) let offset = Bitvector::from_u64(stack_register.size.into())
...@@ -382,7 +377,7 @@ impl<'a> Context<'a> { ...@@ -382,7 +377,7 @@ impl<'a> Context<'a> {
/// Get the offset of the current stack pointer to the base of the current stack frame. /// Get the offset of the current stack pointer to the base of the current stack frame.
fn get_current_stack_offset(&self, state: &State) -> ValueDomain { fn get_current_stack_offset(&self, state: &State) -> ValueDomain {
if let Ok(Data::Pointer(ref stack_pointer)) = if let Data::Pointer(ref stack_pointer) =
state.get_register(&self.project.stack_pointer_register) state.get_register(&self.project.stack_pointer_register)
{ {
if stack_pointer.targets().len() == 1 { if stack_pointer.targets().len() == 1 {
......
...@@ -130,7 +130,7 @@ fn context_problem_implementation() { ...@@ -130,7 +130,7 @@ fn context_problem_implementation() {
// test update_def // test update_def
state = context.update_def(&state, &def).unwrap(); state = context.update_def(&state, &def).unwrap();
let stack_pointer = Data::Pointer(PointerDomain::new(new_id("main", "RSP"), bv(-16))); let stack_pointer = Data::Pointer(PointerDomain::new(new_id("main", "RSP"), bv(-16)));
assert_eq!(state.eval(&Var(register("RSP"))).unwrap(), stack_pointer); assert_eq!(state.eval(&Var(register("RSP"))), stack_pointer);
state = context.update_def(&state, &store_term).unwrap(); state = context.update_def(&state, &store_term).unwrap();
// Test update_call // Test update_call
...@@ -194,10 +194,9 @@ fn context_problem_implementation() { ...@@ -194,10 +194,9 @@ fn context_problem_implementation() {
assert_eq!(return_state.caller_stack_ids, BTreeSet::new()); assert_eq!(return_state.caller_stack_ids, BTreeSet::new());
assert_eq!(return_state.memory, state.memory); assert_eq!(return_state.memory, state.memory);
assert_eq!( assert_eq!(
return_state.get_register(&register("RSP")).unwrap(), return_state.get_register(&register("RSP")),
state state
.get_register(&register("RSP")) .get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Bitvector::from_i64(8).into()) .bin_op(BinOpType::IntAdd, &Bitvector::from_i64(8).into())
); );
...@@ -207,7 +206,7 @@ fn context_problem_implementation() { ...@@ -207,7 +206,7 @@ fn context_problem_implementation() {
let malloc = call_term("extern_malloc"); let malloc = call_term("extern_malloc");
let mut state_after_malloc = context.update_call_stub(&state, &malloc).unwrap(); let mut state_after_malloc = context.update_call_stub(&state, &malloc).unwrap();
assert_eq!( assert_eq!(
state_after_malloc.get_register(&register("RDX")).unwrap(), state_after_malloc.get_register(&register("RDX")),
Data::Pointer(PointerDomain::new( Data::Pointer(PointerDomain::new(
new_id("call_extern_malloc", "RDX"), new_id("call_extern_malloc", "RDX"),
bv(0) bv(0)
...@@ -215,21 +214,17 @@ fn context_problem_implementation() { ...@@ -215,21 +214,17 @@ fn context_problem_implementation() {
); );
assert_eq!(state_after_malloc.memory.get_num_objects(), 2); assert_eq!(state_after_malloc.memory.get_num_objects(), 2);
assert_eq!( assert_eq!(
state_after_malloc.get_register(&register("RSP")).unwrap(), state_after_malloc.get_register(&register("RSP")),
state state
.get_register(&register("RSP")) .get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Data::Value(bv(8))) .bin_op(BinOpType::IntAdd, &Data::Value(bv(8)))
); );
assert_eq!( assert_eq!(
state_after_malloc state_after_malloc.get_register(&register("callee_saved_reg")),
.get_register(&register("callee_saved_reg"))
.unwrap(),
Data::Value(bv(13)) Data::Value(bv(13))
); );
assert!(state_after_malloc assert!(state_after_malloc
.get_register(&register("other_reg")) .get_register(&register("other_reg"))
.unwrap()
.is_top()); .is_top());
state_after_malloc.set_register( state_after_malloc.set_register(
...@@ -243,15 +238,10 @@ fn context_problem_implementation() { ...@@ -243,15 +238,10 @@ fn context_problem_implementation() {
let state_after_free = context let state_after_free = context
.update_call_stub(&state_after_malloc, &free) .update_call_stub(&state_after_malloc, &free)
.unwrap(); .unwrap();
assert!(state_after_free assert!(state_after_free.get_register(&register("RDX")).is_top());
.get_register(&register("RDX"))
.unwrap()
.is_top());
assert_eq!(state_after_free.memory.get_num_objects(), 2); assert_eq!(state_after_free.memory.get_num_objects(), 2);
assert_eq!( assert_eq!(
state_after_free state_after_free.get_register(&register("callee_saved_reg")),
.get_register(&register("callee_saved_reg"))
.unwrap(),
Data::Pointer(PointerDomain::new( Data::Pointer(PointerDomain::new(
new_id("call_extern_malloc", "RDX"), new_id("call_extern_malloc", "RDX"),
bv(0) bv(0)
...@@ -262,21 +252,17 @@ fn context_problem_implementation() { ...@@ -262,21 +252,17 @@ fn context_problem_implementation() {
let state_after_other_fn = context.update_call_stub(&state, &other_extern_fn).unwrap(); let state_after_other_fn = context.update_call_stub(&state, &other_extern_fn).unwrap();
assert_eq!( assert_eq!(
state_after_other_fn.get_register(&register("RSP")).unwrap(), state_after_other_fn.get_register(&register("RSP")),
state state
.get_register(&register("RSP")) .get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Data::Value(bv(8))) .bin_op(BinOpType::IntAdd, &Data::Value(bv(8)))
); );
assert_eq!( assert_eq!(
state_after_other_fn state_after_other_fn.get_register(&register("callee_saved_reg")),
.get_register(&register("callee_saved_reg"))
.unwrap(),
Data::Value(bv(13)) Data::Value(bv(13))
); );
assert!(state_after_other_fn assert!(state_after_other_fn
.get_register(&register("other_reg")) .get_register(&register("other_reg"))
.unwrap()
.is_top()); .is_top());
} }
...@@ -380,10 +366,119 @@ fn update_return() { ...@@ -380,10 +366,119 @@ fn update_return() {
.get_all_object_ids() .get_all_object_ids()
.get(&new_id("caller_caller", "RSP")) .get(&new_id("caller_caller", "RSP"))
.is_some()); .is_some());
assert!(state.get_register(&register("RSP")).is_ok());
let expected_rsp = Data::Pointer(PointerDomain::new( let expected_rsp = Data::Pointer(PointerDomain::new(
new_id("original_caller_id", "RSP"), new_id("original_caller_id", "RSP"),
bv(-8), bv(-8),
)); ));
assert_eq!(state.get_register(&register("RSP")).unwrap(), expected_rsp); assert_eq!(state.get_register(&register("RSP")), expected_rsp);
}
#[test]
fn specialize_conditional() {
use crate::analysis::forward_interprocedural_fixpoint::Context as IpFpContext;
let (project, config) = mock_project();
let graph = crate::analysis::graph::get_program_cfg(&project.program, HashSet::new());
let runtime_memory_image = RuntimeMemoryImage::mock();
let (log_sender, _log_receiver) = crossbeam_channel::unbounded();
let context = Context::new(&project, &runtime_memory_image, &graph, config, log_sender);
let mut state = State::new(&register("RSP"), Tid::new("func"));
state.set_register(&register("RAX"), IntervalDomain::mock(-10, 20).into());
let condition = Expression::Var(Variable::mock("FLAG", 1));
// A complicated way of computing the result of `RAX <= 0`
// and assigning the result to the `FLAG` register.
let defs = vec![
Def::assign("def1", register("RAX"), Expression::Var(register("RAX"))),
Def::assign(
"def_that_should_be_ignored",
Variable::mock("FLAG", 1),
Expression::Const(Bitvector::from_u8(42)),
),
Def::assign(
"def2",
Variable::mock("FLAG_SLESS", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::Const(Bitvector::from_u64(0))),
},
),
Def::assign(
"def3",
Variable::mock("FLAG_EQUAL", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntEqual,
rhs: Box::new(Expression::Const(Bitvector::from_u64(0))),
},
),
Def::assign(
"def4",
Variable::mock("FLAG_NOTEQUAL", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(Variable::mock("FLAG_SLESS", 1))),
op: BinOpType::IntNotEqual,
rhs: Box::new(Expression::Const(Bitvector::from_u8(0))),
},
),
Def::assign(
"def5",
Variable::mock("FLAG", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(Variable::mock("FLAG_EQUAL", 1))),
op: BinOpType::BoolOr,
rhs: Box::new(Expression::Var(Variable::mock("FLAG_NOTEQUAL", 1))),
},
),
];
let block = Term {
tid: Tid::new("block"),
term: Blk {
defs,
jmps: Vec::new(),
indirect_jmp_targets: Vec::new(),
},
};
let result = context
.specialize_conditional(&state, &condition, &block, false)
.unwrap();
assert_eq!(
result.get_register(&Variable::mock("FLAG", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_NOTEQUAL", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_EQUAL", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_SLESS", 1)),
Bitvector::from_u8(0).into()
);
// The result is technically false, since RAX == 0 should be excluded.
// This impreciseness is due to the way that the result is calculated.
assert_eq!(
result.get_register(&register("RAX")),
IntervalDomain::mock(0, 20).into()
);
state.set_register(&register("RAX"), IntervalDomain::mock(0, 20).into());
let result = context
.specialize_conditional(&state, &condition, &block, false)
.unwrap();
assert_eq!(
result.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(Some(0), 1, 20, None).into()
);
state.set_register(&register("RAX"), IntervalDomain::mock(-20, 0).into());
let result = context.specialize_conditional(&state, &condition, &block, false);
assert!(result.is_none());
} }
use super::*; use super::*;
use std::collections::HashSet;
impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Context<'a> { impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Context<'a> {
type Value = State; type Value = State;
...@@ -43,7 +44,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont ...@@ -43,7 +44,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} }
Def::Assign { var, value } => { Def::Assign { var, value } => {
let mut new_state = state.clone(); let mut new_state = state.clone();
self.log_debug(new_state.handle_register_assign(var, value), Some(&def.tid)); new_state.handle_register_assign(var, value);
Some(new_state) Some(new_state)
} }
Def::Load { var, address } => { Def::Load { var, address } => {
...@@ -291,13 +292,55 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont ...@@ -291,13 +292,55 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} }
/// Update the state with the knowledge that some conditional evaluated to true or false. /// Update the state with the knowledge that some conditional evaluated to true or false.
/// Currently not implemented, this function just returns the state as it is.
fn specialize_conditional( fn specialize_conditional(
&self, &self,
value: &State, value: &State,
_condition: &Expression, condition: &Expression,
_is_true: bool, block_before_condition: &Term<Blk>,
is_true: bool,
) -> Option<State> { ) -> Option<State> {
Some(value.clone()) let mut specialized_state = value.clone();
if specialized_state
.specialize_by_expression_result(condition, Bitvector::from_u8(is_true as u8).into())
.is_err()
{
// State is unsatisfiable
return None;
}
let mut modified_vars: HashSet<Variable> = HashSet::new();
for def in block_before_condition.term.defs.iter().rev() {
match &def.term {
Def::Store { .. } => (),
Def::Load { var, .. } => {
modified_vars.insert(var.clone());
}
Def::Assign {
var,
value: input_expr,
} => {
if !modified_vars.contains(var) {
// Register is not modified again between the `Def` and the end of the block.
modified_vars.insert(var.clone());
if input_expr
.input_vars()
.into_iter()
.find(|input_var| modified_vars.contains(input_var))
.is_none()
{
// Values of input registers did not change between the `Def` and the end of the block.
let expr_result = specialized_state.get_register(var);
if specialized_state
.specialize_by_expression_result(input_expr, expr_result.clone())
.is_err()
{
// State is unsatisfiable
return None;
}
}
}
}
}
}
Some(specialized_state)
} }
} }
...@@ -338,33 +338,21 @@ impl<'a> PointerInference<'a> { ...@@ -338,33 +338,21 @@ impl<'a> PointerInference<'a> {
for jmp in block.term.jmps.iter() { for jmp in block.term.jmps.iter() {
match &jmp.term { match &jmp.term {
Jmp::BranchInd(target_expr) => { Jmp::BranchInd(target_expr) => {
if let Ok(address) = state.eval(&target_expr) { let address = state.eval(&target_expr);
println!( println!(
"{}: Indirect jump to {}", "{}: Indirect jump to {}",
jmp.tid, jmp.tid,
address.to_json_compact() address.to_json_compact()
); );
} else {
println!(
"{}: Indirect jump. Could not compute address",
jmp.tid
);
}
} }
Jmp::CallInd { target, return_ } => { Jmp::CallInd { target, return_ } => {
if let Ok(address) = state.eval(&target) { let address = state.eval(&target);
println!( println!(
"{}: Indirect call to {}. HasReturn: {}", "{}: Indirect call to {}. HasReturn: {}",
jmp.tid, jmp.tid,
address.to_json_compact(), address.to_json_compact(),
return_.is_some() return_.is_some()
); );
} else {
println!(
"{}: Indirect call. Could not compute address",
jmp.tid
);
}
} }
Jmp::Return(_) => { Jmp::Return(_) => {
if !state.caller_stack_ids.is_empty() { if !state.caller_stack_ids.is_empty() {
......
...@@ -258,19 +258,20 @@ impl AbstractObjectInfo { ...@@ -258,19 +258,20 @@ impl AbstractObjectInfo {
/// Get a more compact json-representation of the abstract object. /// Get a more compact json-representation of the abstract object.
/// Intended for pretty printing, not useable for serialization/deserialization. /// Intended for pretty printing, not useable for serialization/deserialization.
pub fn to_json_compact(&self) -> serde_json::Value { pub fn to_json_compact(&self) -> serde_json::Value {
let mut elements = Vec::new(); let mut elements = vec![
elements.push(( (
"is_unique".to_string(), "is_unique".to_string(),
serde_json::Value::String(format!("{}", self.is_unique)), serde_json::Value::String(format!("{}", self.is_unique)),
)); ),
elements.push(( (
"state".to_string(), "state".to_string(),
serde_json::Value::String(format!("{:?}", self.state)), serde_json::Value::String(format!("{:?}", self.state)),
)); ),
elements.push(( (
"type".to_string(), "type".to_string(),
serde_json::Value::String(format!("{:?}", self.type_)), serde_json::Value::String(format!("{:?}", self.type_)),
)); ),
];
let memory = self let memory = self
.memory .memory
.iter() .iter()
......
...@@ -4,13 +4,11 @@ use super::*; ...@@ -4,13 +4,11 @@ use super::*;
impl State { impl State {
/// Get the value of a register or Top() if no value is known. /// Get the value of a register or Top() if no value is known.
/// pub fn get_register(&self, variable: &Variable) -> Data {
/// Returns an error if the variable is not a register.
pub fn get_register(&self, variable: &Variable) -> Result<Data, Error> {
if let Some(data) = self.register.get(variable) { if let Some(data) = self.register.get(variable) {
Ok(data.clone()) data.clone()
} else { } else {
Ok(Data::new_top(variable.size)) Data::new_top(variable.size)
} }
} }
...@@ -28,8 +26,6 @@ impl State { ...@@ -28,8 +26,6 @@ impl State {
} }
/// Set the value of a register. /// Set the value of a register.
///
/// Returns an error if the variable is not a register.
pub fn set_register(&mut self, variable: &Variable, value: Data) { pub fn set_register(&mut self, variable: &Variable, value: Data) {
if !value.is_top() { if !value.is_top() {
self.register.insert(variable.clone(), value); self.register.insert(variable.clone(), value);
...@@ -39,21 +35,8 @@ impl State { ...@@ -39,21 +35,8 @@ impl State {
} }
/// Evaluate expression on the given state and write the result to the target register. /// Evaluate expression on the given state and write the result to the target register.
pub fn handle_register_assign( pub fn handle_register_assign(&mut self, target: &Variable, expression: &Expression) {
&mut self, self.set_register(target, self.eval(expression))
target: &Variable,
expression: &Expression,
) -> Result<(), Error> {
match self.eval(expression) {
Ok(new_value) => {
self.set_register(target, new_value);
Ok(())
}
Err(err) => {
self.set_register(target, Data::new_top(target.size));
Err(err)
}
}
} }
/// Store `value` at the given `address`. /// Store `value` at the given `address`.
...@@ -116,32 +99,19 @@ impl State { ...@@ -116,32 +99,19 @@ impl State {
value: &Data, value: &Data,
global_memory: &RuntimeMemoryImage, global_memory: &RuntimeMemoryImage,
) -> Result<(), Error> { ) -> Result<(), Error> {
match self.eval(address) { let address_data = self.eval(address);
Ok(address_data) => self.store_value(&address_data, value, global_memory), self.store_value(&address_data, value, global_memory)
Err(err) => Err(err),
}
} }
/// Evaluate the store instruction, given by its address and value expressions, /// Evaluate the store instruction, given by its address and value expressions,
/// and modify the state accordingly. /// and modify the state accordingly.
///
/// If an error occurs, the state is still modified before the error is returned.
/// E.g. if the value expression cannot be evaluated,
/// the value at the target address is overwritten with a `Top` value.
pub fn handle_store( pub fn handle_store(
&mut self, &mut self,
address: &Expression, address: &Expression,
value: &Expression, value: &Expression,
global_memory: &RuntimeMemoryImage, global_memory: &RuntimeMemoryImage,
) -> Result<(), Error> { ) -> Result<(), Error> {
match self.eval(value) { self.write_to_address(address, &self.eval(value), global_memory)
Ok(data) => self.write_to_address(address, &data, global_memory),
Err(err) => {
// we still need to write to the target location before reporting the error
self.write_to_address(address, &Data::new_top(value.bytesize()), global_memory)?;
Err(err)
}
}
} }
/// Evaluate the given load instruction and return the data read on success. /// Evaluate the given load instruction and return the data read on success.
...@@ -151,7 +121,7 @@ impl State { ...@@ -151,7 +121,7 @@ impl State {
size: ByteSize, size: ByteSize,
global_memory: &RuntimeMemoryImage, global_memory: &RuntimeMemoryImage,
) -> Result<Data, Error> { ) -> Result<Data, Error> {
let address = self.adjust_pointer_for_read(&self.eval(address)?); let address = self.adjust_pointer_for_read(&self.eval(address));
match address { match address {
Data::Value(global_address) => { Data::Value(global_address) => {
if let Ok(address_bitvector) = global_address.try_to_bitvec() { if let Ok(address_bitvector) = global_address.try_to_bitvec() {
...@@ -233,30 +203,30 @@ impl State { ...@@ -233,30 +203,30 @@ impl State {
} }
/// Evaluate the value of an expression in the current state /// Evaluate the value of an expression in the current state
pub fn eval(&self, expression: &Expression) -> Result<Data, Error> { pub fn eval(&self, expression: &Expression) -> Data {
use Expression::*; use Expression::*;
match expression { match expression {
Var(variable) => self.get_register(&variable), Var(variable) => self.get_register(&variable),
Const(bitvector) => Ok(bitvector.clone().into()), Const(bitvector) => bitvector.clone().into(),
BinOp { op, lhs, rhs } => { BinOp { op, lhs, rhs } => {
if *op == BinOpType::IntXOr && lhs == rhs { if *op == BinOpType::IntXOr && lhs == rhs {
// the result of `x XOR x` is always zero. // the result of `x XOR x` is always zero.
return Ok(Bitvector::zero(apint::BitWidth::from(lhs.bytesize())).into()); return Bitvector::zero(apint::BitWidth::from(lhs.bytesize())).into();
} }
let (left, right) = (self.eval(lhs)?, self.eval(rhs)?); let (left, right) = (self.eval(lhs), self.eval(rhs));
Ok(left.bin_op(*op, &right)) left.bin_op(*op, &right)
} }
UnOp { op, arg } => Ok(self.eval(arg)?.un_op(*op)), UnOp { op, arg } => self.eval(arg).un_op(*op),
Cast { op, size, arg } => Ok(self.eval(arg)?.cast(*op, *size)), Cast { op, size, arg } => self.eval(arg).cast(*op, *size),
Unknown { Unknown {
description: _, description: _,
size, size,
} => Ok(Data::new_top(*size)), } => Data::new_top(*size),
Subpiece { Subpiece {
low_byte, low_byte,
size, size,
arg, arg,
} => Ok(self.eval(arg)?.subpiece(*low_byte, *size)), } => self.eval(arg).subpiece(*low_byte, *size),
} }
} }
...@@ -268,7 +238,7 @@ impl State { ...@@ -268,7 +238,7 @@ impl State {
global_memory: &RuntimeMemoryImage, global_memory: &RuntimeMemoryImage,
) -> Result<Data, Error> { ) -> Result<Data, Error> {
match parameter { match parameter {
Arg::Register(var) => self.eval(&Expression::Var(var.clone())), Arg::Register(var) => Ok(self.eval(&Expression::Var(var.clone()))),
Arg::Stack { offset, size } => self.load_value( Arg::Stack { offset, size } => self.load_value(
&Expression::Var(stack_pointer.clone()).plus_const(*offset), &Expression::Var(stack_pointer.clone()).plus_const(*offset),
*size, *size,
...@@ -281,11 +251,7 @@ impl State { ...@@ -281,11 +251,7 @@ impl State {
pub fn contains_access_of_dangling_memory(&self, def: &Def) -> bool { pub fn contains_access_of_dangling_memory(&self, def: &Def) -> bool {
match def { match def {
Def::Load { address, .. } | Def::Store { address, .. } => { Def::Load { address, .. } | Def::Store { address, .. } => {
if let Ok(pointer) = self.eval(address) { self.memory.is_dangling_pointer(&self.eval(address), true)
self.memory.is_dangling_pointer(&pointer, true)
} else {
false
}
} }
_ => false, _ => false,
} }
......
...@@ -325,6 +325,299 @@ impl State { ...@@ -325,6 +325,299 @@ impl State {
self.register.remove(&register); self.register.remove(&register);
} }
} }
/// Try to restrict the input variables of `expression` on `self`
/// so that `expression` only evaluates to values represented by the given `result`.
///
/// If `expression` cannot evaluate to any value represented by `self`, return an error.
///
/// This function may restrict to upper bounds of possible values
/// if the restriction cannot be made exact,
/// i.e. after calling this function the state may still contain values
/// for which `expression` does not evaluate to values represented by `result`.
pub fn specialize_by_expression_result(
&mut self,
expression: &Expression,
result: Data,
) -> Result<(), Error> {
if let Expression::Var(var) = expression {
match (self.eval(expression), result) {
(Data::Value(old_value), Data::Value(result_value)) => {
self.set_register(var, old_value.intersect(&result_value)?.into())
}
(Data::Top(_), result) => self.set_register(var, result),
_ => (),
}
Ok(())
} else if let Expression::BinOp { op, lhs, rhs } = expression {
self.specialize_by_binop_expression_result(op, lhs, rhs, result)
} else if let Ok(result_bitvec) = result.try_to_bitvec() {
match expression {
Expression::Var(_) => panic!(),
Expression::Const(input_bitvec) => {
if *input_bitvec == result_bitvec {
Ok(())
} else {
Err(anyhow!("Unsatisfiable state"))
}
}
Expression::BinOp { .. } => {
panic!() // Already handled above
}
Expression::UnOp { op, arg } => {
use UnOpType::*;
match op {
IntNegate | BoolNegate | Int2Comp => {
let intermediate_result = result.un_op(*op);
self.specialize_by_expression_result(arg, intermediate_result)
}
_ => Ok(()),
}
}
Expression::Cast { op, size: _, arg } => match op {
CastOpType::IntZExt | CastOpType::IntSExt => {
let intermediate_result = result.subpiece(ByteSize::new(0), arg.bytesize());
self.specialize_by_expression_result(arg, intermediate_result)
}
_ => Ok(()),
},
Expression::Unknown {
description: _,
size: _,
} => Ok(()),
Expression::Subpiece { .. } => Ok(()),
}
} else {
Ok(())
}
}
/// Try to restrict the input variables of the given binary operation
/// so that it only evaluates to the given `result_bitvec`.
fn specialize_by_binop_expression_result(
&mut self,
op: &BinOpType,
lhs: &Expression,
rhs: &Expression,
result: Data,
) -> Result<(), Error> {
match op {
BinOpType::IntAdd => {
if let Ok(bitvec) = self.eval(lhs).try_to_bitvec() {
let intermediate_result = result - bitvec.into();
return self.specialize_by_expression_result(rhs, intermediate_result);
} else if let Ok(bitvec) = self.eval(rhs).try_to_bitvec() {
let intermediate_result = result - bitvec.into();
return self.specialize_by_expression_result(lhs, intermediate_result);
} else {
return Ok(());
}
}
BinOpType::IntSub => {
if let Ok(bitvec) = self.eval(lhs).try_to_bitvec() {
let intermediate_result: Data = Data::from(bitvec) - result;
return self.specialize_by_expression_result(rhs, intermediate_result);
} else if let Ok(bitvec) = self.eval(rhs).try_to_bitvec() {
let intermediate_result = result + bitvec.into();
return self.specialize_by_expression_result(lhs, intermediate_result);
} else {
return Ok(());
}
}
_ => (),
}
if let Ok(result_bitvec) = result.try_to_bitvec() {
match op {
BinOpType::IntXOr | BinOpType::BoolXOr => {
if let Ok(bitvec) = self.eval(lhs).try_to_bitvec() {
self.specialize_by_expression_result(rhs, (result_bitvec ^ &bitvec).into())
} else if let Ok(bitvec) = self.eval(rhs).try_to_bitvec() {
self.specialize_by_expression_result(lhs, (result_bitvec ^ &bitvec).into())
} else {
Ok(())
}
}
BinOpType::IntOr | BinOpType::BoolOr => {
if result_bitvec.is_zero() {
self.specialize_by_expression_result(lhs, result_bitvec.clone().into())?;
self.specialize_by_expression_result(rhs, result_bitvec.into())
} else if self
.eval(lhs)
.try_to_bitvec()
.map_or(false, |bitvec| bitvec.is_zero())
{
self.specialize_by_expression_result(rhs, result_bitvec.into())
} else if self
.eval(rhs)
.try_to_bitvec()
.map_or(false, |bitvec| bitvec.is_zero())
{
self.specialize_by_expression_result(lhs, result_bitvec.into())
} else {
Ok(())
}
}
BinOpType::BoolAnd => {
if !result_bitvec.is_zero() {
self.specialize_by_expression_result(lhs, result_bitvec.clone().into())?;
self.specialize_by_expression_result(rhs, result_bitvec.into())
} else if self
.eval(lhs)
.try_to_bitvec()
.map_or(false, |bitvec| !bitvec.is_zero())
{
self.specialize_by_expression_result(rhs, result_bitvec.into())
} else if self
.eval(rhs)
.try_to_bitvec()
.map_or(false, |bitvec| !bitvec.is_zero())
{
self.specialize_by_expression_result(lhs, result_bitvec.into())
} else {
Ok(())
}
}
BinOpType::IntEqual | BinOpType::IntNotEqual => {
match (op, !result_bitvec.is_zero()) {
(BinOpType::IntEqual, true) | (BinOpType::IntNotEqual, false) => {
// lhs == rhs
if let Ok(bitvec) = self.eval(lhs).try_to_bitvec() {
self.specialize_by_expression_result(rhs, bitvec.into())
} else if let Ok(bitvec) = self.eval(rhs).try_to_bitvec() {
self.specialize_by_expression_result(lhs, bitvec.into())
} else {
Ok(())
}
}
(BinOpType::IntEqual, false) | (BinOpType::IntNotEqual, true) => {
// lhs != rhs
if let Ok(bitvec) = self.eval(lhs).try_to_bitvec() {
let new_result = self.eval(rhs).add_not_equal_bound(&bitvec)?;
self.specialize_by_expression_result(rhs, new_result)
} else if let Ok(bitvec) = self.eval(rhs).try_to_bitvec() {
let new_result = self.eval(lhs).add_not_equal_bound(&bitvec)?;
self.specialize_by_expression_result(lhs, new_result)
} else {
Ok(())
}
}
_ => panic!(),
}
}
BinOpType::IntSLess
| BinOpType::IntLess
| BinOpType::IntLessEqual
| BinOpType::IntSLessEqual => {
use BinOpType::*;
let mut op = *op;
let (mut left_expr, mut right_expr) = (lhs, rhs);
if result_bitvec.is_zero() {
std::mem::swap(&mut left_expr, &mut right_expr);
op = match op {
IntSLess => IntSLessEqual,
IntSLessEqual => IntSLess,
IntLess => IntLessEqual,
IntLessEqual => IntLess,
_ => panic!(),
}
}
self.specialize_by_comparison_op(&op, left_expr, right_expr)
}
_ => {
let original_expression = Expression::BinOp {
lhs: Box::new(lhs.clone()),
op: *op,
rhs: Box::new(rhs.clone()),
};
if let Ok(interval) = self.eval(&original_expression).try_to_interval() {
if !interval.contains(&result_bitvec) {
Err(anyhow!("Unsatisfiable bound"))
} else {
Ok(())
}
} else {
Ok(())
}
}
}
} else {
Ok(())
}
}
/// Try to restrict the input variables of the given comparison operation
/// (signed and unsigned versions of `<` and `<=`)
/// so that the comparison evaluates to `true`.
fn specialize_by_comparison_op(
&mut self,
op: &BinOpType,
lhs: &Expression,
rhs: &Expression,
) -> Result<(), Error> {
use BinOpType::*;
if let Ok(mut lhs_bound) = self.eval(lhs).try_to_bitvec() {
match op {
IntSLess => {
if lhs_bound == Bitvector::signed_max_value(lhs_bound.width()) {
return Err(anyhow!("Unsatisfiable bound"));
}
lhs_bound += &Bitvector::one(lhs_bound.width());
let new_result = self.eval(rhs).add_signed_greater_equal_bound(&lhs_bound)?;
return self.specialize_by_expression_result(rhs, new_result);
}
IntSLessEqual => {
let new_result = self.eval(rhs).add_signed_greater_equal_bound(&lhs_bound)?;
return self.specialize_by_expression_result(rhs, new_result);
}
IntLess => {
if lhs_bound == Bitvector::unsigned_max_value(lhs_bound.width()) {
return Err(anyhow!("Unsatisfiable bound"));
}
lhs_bound += &Bitvector::one(lhs_bound.width());
let new_result = self
.eval(rhs)
.add_unsigned_greater_equal_bound(&lhs_bound)?;
return self.specialize_by_expression_result(rhs, new_result);
}
IntLessEqual => {
let new_result = self
.eval(rhs)
.add_unsigned_greater_equal_bound(&lhs_bound)?;
return self.specialize_by_expression_result(rhs, new_result);
}
_ => panic!(),
}
} else if let Ok(mut rhs_bound) = self.eval(rhs).try_to_bitvec() {
match op {
IntSLess => {
if rhs_bound == Bitvector::signed_min_value(rhs_bound.width()) {
return Err(anyhow!("Unsatisfiable bound"));
}
rhs_bound -= &Bitvector::one(rhs_bound.width());
let new_result = self.eval(lhs).add_signed_less_equal_bound(&rhs_bound)?;
return self.specialize_by_expression_result(lhs, new_result);
}
IntSLessEqual => {
let new_result = self.eval(lhs).add_signed_less_equal_bound(&rhs_bound)?;
return self.specialize_by_expression_result(lhs, new_result);
}
IntLess => {
if rhs_bound == Bitvector::zero(rhs_bound.width()) {
return Err(anyhow!("Unsatisfiable bound"));
}
rhs_bound -= &Bitvector::one(rhs_bound.width());
let new_result = self.eval(lhs).add_unsigned_less_equal_bound(&rhs_bound)?;
return self.specialize_by_expression_result(lhs, new_result);
}
IntLessEqual => {
let new_result = self.eval(lhs).add_unsigned_less_equal_bound(&rhs_bound)?;
return self.specialize_by_expression_result(lhs, new_result);
}
_ => panic!(),
}
}
Ok(())
}
} }
impl AbstractDomain for State { impl AbstractDomain for State {
......
...@@ -158,22 +158,18 @@ fn handle_store() { ...@@ -158,22 +158,18 @@ fn handle_store() {
let mut state = State::new(&register("RSP"), Tid::new("time0")); let mut state = State::new(&register("RSP"), Tid::new("time0"));
let stack_id = new_id("time0", "RSP"); let stack_id = new_id("time0", "RSP");
assert_eq!( assert_eq!(
state.eval(&Var(register("RSP"))).unwrap(), state.eval(&Var(register("RSP"))),
Data::Pointer(PointerDomain::new(stack_id.clone(), bv(0))) Data::Pointer(PointerDomain::new(stack_id.clone(), bv(0)))
); );
state state.handle_register_assign(&register("RSP"), &reg_sub("RSP", 32));
.handle_register_assign(&register("RSP"), &reg_sub("RSP", 32))
.unwrap();
assert_eq!( assert_eq!(
state.eval(&Var(register("RSP"))).unwrap(), state.eval(&Var(register("RSP"))),
Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-32))) Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-32)))
); );
state state.handle_register_assign(&register("RSP"), &reg_add("RSP", -8));
.handle_register_assign(&register("RSP"), &reg_add("RSP", -8))
.unwrap();
assert_eq!( assert_eq!(
state.eval(&Var(register("RSP"))).unwrap(), state.eval(&Var(register("RSP"))),
Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-40))) Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-40)))
); );
...@@ -198,9 +194,7 @@ fn handle_store() { ...@@ -198,9 +194,7 @@ fn handle_store() {
&global_memory, &global_memory,
) )
.unwrap(); .unwrap();
state state.handle_register_assign(&register("RSP"), &reg_sub("RSP", 4));
.handle_register_assign(&register("RSP"), &reg_sub("RSP", 4))
.unwrap();
assert_eq!( assert_eq!(
state state
...@@ -364,22 +358,16 @@ fn remove_and_restore_callee_saved_register() { ...@@ -364,22 +358,16 @@ fn remove_and_restore_callee_saved_register() {
let mut callee_state = state.clone(); let mut callee_state = state.clone();
callee_state.remove_callee_saved_register(&cconv); callee_state.remove_callee_saved_register(&cconv);
assert_eq!( assert_eq!(
callee_state.get_register(&register("RBP")).unwrap(), callee_state.get_register(&register("RBP")),
Data::new_top(ByteSize::new(8)) Data::new_top(ByteSize::new(8))
); );
assert_eq!( assert_eq!(callee_state.get_register(&register("RAX")), value.clone());
callee_state.get_register(&register("RAX")).unwrap(),
value.clone()
);
let other_value: Data = Bitvector::from_u64(13).into(); let other_value: Data = Bitvector::from_u64(13).into();
callee_state.set_register(&register("RAX"), other_value.clone()); callee_state.set_register(&register("RAX"), other_value.clone());
callee_state.restore_callee_saved_register(&state, &cconv, &register("RSP")); callee_state.restore_callee_saved_register(&state, &cconv, &register("RSP"));
assert_eq!(callee_state.get_register(&register("RBP")).unwrap(), value); assert_eq!(callee_state.get_register(&register("RBP")), value);
assert_eq!( assert_eq!(callee_state.get_register(&register("RAX")), other_value);
callee_state.get_register(&register("RAX")).unwrap(),
other_value
);
} }
#[test] #[test]
...@@ -484,3 +472,495 @@ fn global_mem_access() { ...@@ -484,3 +472,495 @@ fn global_mem_access() {
) )
.is_err()); .is_err());
} }
/// Test expression specialization except for binary operations.
#[test]
fn specialize_by_expression_results() {
let mut base_state = State::new(&register("RSP"), Tid::new("func_tid"));
base_state.set_register(
&register("RAX"),
IntervalDomain::new(Bitvector::from_i64(5), Bitvector::from_i64(10)).into(),
);
// Expr = Var(RAX)
let mut state = base_state.clone();
let x = state
.specialize_by_expression_result(&Expression::var("RAX"), Bitvector::from_i64(7).into());
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(7).into()
);
let mut state = base_state.clone();
let x = state
.specialize_by_expression_result(&Expression::var("RAX"), Bitvector::from_i64(-20).into());
assert!(x.is_err());
// Expr = Const
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::Const(Bitvector::from_i64(-20)),
Bitvector::from_i64(-20).into(),
);
assert!(x.is_ok());
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::Const(Bitvector::from_i64(5)),
Bitvector::from_i64(-20).into(),
);
assert!(x.is_err());
// Expr = -Var(RAX)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::var("RAX").un_op(UnOpType::Int2Comp),
Bitvector::from_i64(-7).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(7).into()
);
// Expr = IntSExt(Var(EAX))
let mut state = State::new(&register("RSP"), Tid::new("func_tid"));
let eax_register = Variable {
name: "EAX".to_string(),
size: ByteSize::new(4),
is_temp: false,
};
state.set_register(
&eax_register,
IntervalDomain::new(Bitvector::from_i32(-10), Bitvector::from_i32(-5)).into(),
);
let x = state.specialize_by_expression_result(
&Expression::Var(eax_register.clone()).cast(CastOpType::IntSExt),
Bitvector::from_i64(-7).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&eax_register),
Bitvector::from_i32(-7).into()
);
}
/// Test expression specialization for binary operations
/// except equality and inequality operations
#[test]
fn specialize_by_binop() {
let base_state = State::new(&register("RSP"), Tid::new("func_tid"));
// Expr = RAX + Const
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::var("RAX").plus_const(20),
IntervalDomain::mock(5, 7).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(-15, -13).into()
);
// Expr = RAX - Const
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::var("RAX").minus_const(20),
Bitvector::from_i64(5).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(25).into()
);
// Expr = RAX xor Const
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::var("RAX")),
op: BinOpType::IntXOr,
rhs: Box::new(Expression::const_from_i64(3)),
},
Bitvector::from_i64(-1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(-4).into()
);
// Expr = (RAX or RBX == 0)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::var("RAX")),
op: BinOpType::IntOr,
rhs: Box::new(Expression::var("RBX")),
},
Bitvector::from_i64(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(0).into()
);
assert_eq!(
state.get_register(&register("RBX")),
Bitvector::from_i64(0).into()
);
// Expr = (RAX or 0 == Const)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::var("RAX")),
op: BinOpType::IntOr,
rhs: Box::new(Expression::const_from_i64(0)),
},
Bitvector::from_i64(42).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(42).into()
);
// Expr = (FLAG1 bool_and FLAG2 == 1)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(Variable::mock("FLAG1", 1u64))),
op: BinOpType::BoolAnd,
rhs: Box::new(Expression::Var(Variable::mock("FLAG2", 1u64))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&Variable::mock("FLAG1", 1u64)),
Bitvector::from_u8(1).into()
);
assert_eq!(
state.get_register(&Variable::mock("FLAG2", 1u64)),
Bitvector::from_u8(1).into()
);
// Expr = (FLAG bool_and 1 = Const)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Const(Bitvector::from_u8(1))),
op: BinOpType::BoolAnd,
rhs: Box::new(Expression::Var(Variable::mock("FLAG", 1u64))),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&Variable::mock("FLAG", 1u64)),
Bitvector::from_u8(0).into()
);
}
/// Test expression specialization for comparison operations `==` and `!=`.
#[test]
fn specialize_by_equality_comparison() {
let mut base_state = State::new(&register("RSP"), Tid::new("func_tid"));
base_state.set_register(&register("RAX"), IntervalDomain::mock(0, 50).into());
// Expr = RAX == Const
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(23)),
op: BinOpType::IntEqual,
rhs: Box::new(Expression::var("RAX")),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(23).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(23)),
op: BinOpType::IntNotEqual,
rhs: Box::new(Expression::var("RAX")),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(23).into()
);
// Expr = RAX != Const
let mut state = base_state.clone();
state.set_register(&register("RAX"), Bitvector::from_i64(23).into());
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(23)),
op: BinOpType::IntNotEqual,
rhs: Box::new(Expression::var("RAX")),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_err());
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(100)),
op: BinOpType::IntEqual,
rhs: Box::new(Expression::var("RAX")),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(None, 0, 50, Some(99)).into()
);
}
/// Test expression specialization for signed comparison operations `<` and `<=`.
#[test]
fn specialize_by_signed_comparison_op() {
let mut base_state = State::new(&register("RSP"), Tid::new("func_tid"));
let interval = IntervalDomain::mock(5, 10);
base_state.set_register(&register("RAX"), interval.into());
// Expr = RAX < Const (signed)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(7)),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(8, 10).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(15)),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(None, 5, 10, Some(15)).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::Const(Bitvector::signed_min_value(
ByteSize::new(8).into(),
))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_err());
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::const_from_i64(7)),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(7, 10).into()
);
// Expr = RAX <= Const (signed)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(7)),
op: BinOpType::IntSLessEqual,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(7, 10).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(15)),
op: BinOpType::IntSLessEqual,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(None, 5, 10, Some(14)).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLessEqual,
rhs: Box::new(Expression::Const(Bitvector::signed_min_value(
ByteSize::new(8).into(),
))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_err());
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLessEqual,
rhs: Box::new(Expression::const_from_i64(7)),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(8, 10).into()
);
}
/// Test expression specialization for unsigned comparison operations `<` and `<=`.
#[test]
fn specialize_by_unsigned_comparison_op() {
let mut base_state = State::new(&register("RSP"), Tid::new("func_tid"));
let interval = IntervalDomain::mock(-5, 10);
base_state.set_register(&register("RAX"), interval.into());
// Expr = RAX < Const (unsigned)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(7)),
op: BinOpType::IntLess,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(-5, 10).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(15)),
op: BinOpType::IntLess,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(None, 0, 10, Some(15)).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntLess,
rhs: Box::new(Expression::const_from_i64(0)),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_err());
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntLess,
rhs: Box::new(Expression::const_from_i64(-20)),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(Some(-20), -5, -1, None).into()
);
// Expr = RAX <= Const (unsigned)
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(7)),
op: BinOpType::IntLessEqual,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock(-5, 10).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::const_from_i64(15)),
op: BinOpType::IntLessEqual,
rhs: Box::new(Expression::Var(register("RAX"))),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(None, 0, 10, Some(14)).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntLessEqual,
rhs: Box::new(Expression::const_from_i64(0)),
},
Bitvector::from_u8(1).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
Bitvector::from_i64(0).into()
);
let mut state = base_state.clone();
let x = state.specialize_by_expression_result(
&Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntLessEqual,
rhs: Box::new(Expression::const_from_i64(-20)),
},
Bitvector::from_u8(0).into(),
);
assert!(x.is_ok());
assert_eq!(
state.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(Some(-19), -5, -1, None).into()
);
}
...@@ -197,23 +197,21 @@ impl<'a> Context<'a> { ...@@ -197,23 +197,21 @@ impl<'a> Context<'a> {
for parameter in extern_symbol.parameters.iter() { for parameter in extern_symbol.parameters.iter() {
match parameter { match parameter {
Arg::Register(var) => { Arg::Register(var) => {
if let Ok(data) = pi_state.eval(&Expression::Var(var.clone())) { let data = pi_state.eval(&Expression::Var(var.clone()));
if state.check_if_address_points_to_taint(data, pi_state) { if state.check_if_address_points_to_taint(data, pi_state) {
return true; return true;
}
} }
} }
Arg::Stack { offset, size } => { Arg::Stack { offset, size } => {
if let Ok(stack_address) = pi_state.eval( let stack_address = pi_state.eval(
&Expression::Var(self.project.stack_pointer_register.clone()) &Expression::Var(self.project.stack_pointer_register.clone())
.plus_const(*offset), .plus_const(*offset),
) { );
if state if state
.load_taint_from_memory(&stack_address, *size) .load_taint_from_memory(&stack_address, *size)
.is_tainted() .is_tainted()
{ {
return true; return true;
}
} }
if let Ok(stack_param) = pi_state.eval_parameter_arg( if let Ok(stack_param) = pi_state.eval_parameter_arg(
parameter, parameter,
...@@ -266,6 +264,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont ...@@ -266,6 +264,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
&self, &self,
state: &State, state: &State,
_condition: &Expression, _condition: &Expression,
_block_before_condition: &Term<Blk>,
_is_true: bool, _is_true: bool,
) -> Option<State> { ) -> Option<State> {
Some(state.clone()) Some(state.clone())
...@@ -332,10 +331,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont ...@@ -332,10 +331,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} else if let Some(pi_state) = } else if let Some(pi_state) =
self.get_current_pointer_inference_state(state, &def.tid) self.get_current_pointer_inference_state(state, &def.tid)
{ {
if let Ok(address_data) = pi_state.eval(address) { let address_data = pi_state.eval(address);
let taint = state.load_taint_from_memory(&address_data, var.size); let taint = state.load_taint_from_memory(&address_data, var.size);
new_state.set_register_taint(var, taint); new_state.set_register_taint(var, taint);
}
} else { } else {
new_state.set_register_taint(var, Taint::Top(var.size)); new_state.set_register_taint(var, Taint::Top(var.size));
} }
...@@ -347,10 +345,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont ...@@ -347,10 +345,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} else if let Some(pi_state) = } else if let Some(pi_state) =
self.get_current_pointer_inference_state(state, &def.tid) self.get_current_pointer_inference_state(state, &def.tid)
{ {
if let Ok(address_data) = pi_state.eval(address) { let address_data = pi_state.eval(address);
let taint = state.eval(value); let taint = state.eval(value);
new_state.save_taint_to_memory(&address_data, taint); new_state.save_taint_to_memory(&address_data, taint);
}
} else { } else {
// We lost all knowledge about memory pointers. // We lost all knowledge about memory pointers.
// We delete all memory taint to reduce false positives. // We delete all memory taint to reduce false positives.
......
...@@ -101,9 +101,8 @@ impl State { ...@@ -101,9 +101,8 @@ impl State {
if let Some(pi_state) = pi_state { if let Some(pi_state) = pi_state {
let address_exp = let address_exp =
Expression::Var(stack_pointer_register.clone()).plus_const(*offset); Expression::Var(stack_pointer_register.clone()).plus_const(*offset);
if let Ok(address) = pi_state.eval(&address_exp) { let address = pi_state.eval(&address_exp);
state.save_taint_to_memory(&address, Taint::Tainted(*size)); state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
} }
} }
} }
...@@ -369,12 +368,13 @@ impl State { ...@@ -369,12 +368,13 @@ impl State {
} }
memory.push((format!("{}", tid), Value::Object(Map::from_iter(elements)))); memory.push((format!("{}", tid), Value::Object(Map::from_iter(elements))));
} }
let mut state_map = Vec::new(); let state_map = vec![
state_map.push(( (
"register".to_string(), "register".to_string(),
Value::Object(Map::from_iter(register)), Value::Object(Map::from_iter(register)),
)); ),
state_map.push(("memory".to_string(), Value::Object(Map::from_iter(memory)))); ("memory".to_string(), Value::Object(Map::from_iter(memory))),
];
Value::Object(Map::from_iter(state_map)) Value::Object(Map::from_iter(state_map))
} }
...@@ -477,7 +477,7 @@ mod tests { ...@@ -477,7 +477,7 @@ mod tests {
assert_eq!(state.register_taint.get(&register("RSP")), None); assert_eq!(state.register_taint.get(&register("RSP")), None);
let address = Expression::Var(register("RSP")); let address = Expression::Var(register("RSP"));
assert_eq!( assert_eq!(
state.load_taint_from_memory(&pi_state.eval(&address).unwrap(), ByteSize::new(8)), state.load_taint_from_memory(&pi_state.eval(&address), ByteSize::new(8)),
taint taint
); );
} }
......
...@@ -262,9 +262,8 @@ impl<'a> Context<'a> { ...@@ -262,9 +262,8 @@ impl<'a> Context<'a> {
{ {
let address_exp = let address_exp =
Expression::Var(self.project.stack_pointer_register.clone()).plus_const(offset); Expression::Var(self.project.stack_pointer_register.clone()).plus_const(offset);
if let Ok(address) = pi_state.eval(&address_exp) { let address = pi_state.eval(&address_exp);
new_state.save_taint_to_memory(&address, Taint::Tainted(size)); new_state.save_taint_to_memory(&address, Taint::Tainted(size));
}
} }
new_state new_state
} }
......
...@@ -467,13 +467,10 @@ fn creating_pi_def_map() { ...@@ -467,13 +467,10 @@ fn creating_pi_def_map() {
for (def_tid, pi_state) in pi_def_map.iter() { for (def_tid, pi_state) in pi_def_map.iter() {
if *def_tid == def1 { if *def_tid == def1 {
assert_eq!( assert_eq!(pi_state.get_register(&rdi_reg), Data::new_top(rdi_reg.size));
pi_state.get_register(&rdi_reg).unwrap(),
Data::new_top(rdi_reg.size)
);
} else if *def_tid == def2 { } else if *def_tid == def2 {
assert_eq!( assert_eq!(
pi_state.get_register(&rdi_reg).unwrap(), pi_state.get_register(&rdi_reg),
Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-8))) Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-8)))
); );
} }
......
...@@ -113,9 +113,8 @@ impl State { ...@@ -113,9 +113,8 @@ impl State {
if let Some(pi_state) = pi_state { if let Some(pi_state) = pi_state {
let address_exp = let address_exp =
Expression::Var(stack_pointer_register.clone()).plus_const(*offset); Expression::Var(stack_pointer_register.clone()).plus_const(*offset);
if let Ok(address) = pi_state.eval(&address_exp) { let address = pi_state.eval(&address_exp);
state.save_taint_to_memory(&address, Taint::Tainted(*size)); state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
} }
} }
} }
...@@ -208,9 +207,8 @@ impl State { ...@@ -208,9 +207,8 @@ impl State {
Expression::BinOp { .. } => { Expression::BinOp { .. } => {
if let Some(pid_map) = self.pi_def_map.as_ref() { if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) { if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.get_register(result) { let address = pi_state.get_register(result);
self.save_taint_to_memory(&address, Taint::Tainted(result.size)); self.save_taint_to_memory(&address, Taint::Tainted(result.size));
}
} }
} }
} }
...@@ -233,11 +231,10 @@ impl State { ...@@ -233,11 +231,10 @@ impl State {
) { ) {
if let Some(pid_map) = self.pi_def_map.as_ref() { if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) { if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.eval(target) { let address = pi_state.eval(target);
if self.check_if_address_points_to_taint(address.clone(), &pi_state) { if self.check_if_address_points_to_taint(address.clone(), &pi_state) {
self.taint_def_input_register(value, stack_pointer_register, def_tid); self.taint_def_input_register(value, stack_pointer_register, def_tid);
self.remove_mem_taint_at_target(&address); self.remove_mem_taint_at_target(&address);
}
} }
} }
} }
...@@ -277,12 +274,11 @@ impl State { ...@@ -277,12 +274,11 @@ impl State {
if var.name == stack_pointer_register.name { if var.name == stack_pointer_register.name {
if let Some(pid_map) = self.pi_def_map.as_ref() { if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) { if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.get_register(stack_pointer_register) { let address = pi_state.get_register(stack_pointer_register);
self.save_taint_to_memory( self.save_taint_to_memory(
&address, &address,
Taint::Tainted(stack_pointer_register.size), Taint::Tainted(stack_pointer_register.size),
); );
}
} }
} }
} else { } else {
......
...@@ -25,11 +25,11 @@ pub trait BitvectorExtended: Sized { ...@@ -25,11 +25,11 @@ pub trait BitvectorExtended: Sized {
/// Returns an error for non-implemented operations (currently all float-related operations). /// Returns an error for non-implemented operations (currently all float-related operations).
fn bin_op(&self, op: BinOpType, rhs: &Self) -> Result<Self, Error>; fn bin_op(&self, op: BinOpType, rhs: &Self) -> Result<Self, Error>;
/// Returns `true` if adding `self` to `rhs` would result in a signed integer overflow or underflow. /// Returns the result of `self + rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_add_overflow_check(&self, rhs: &Self) -> bool; fn signed_add_overflow_checked(&self, rhs: &Self) -> Option<Self>;
/// Returns `true` if subtracting `rhs` from `self` would result in a signed integer overflow or underflow. /// Returns the result of `self - rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_sub_overflow_check(&self, rhs: &Self) -> bool; fn signed_sub_overflow_checked(&self, rhs: &Self) -> Option<Self>;
/// Return the result of multiplying `self` with `rhs` /// Return the result of multiplying `self` with `rhs`
/// and a flag that is set to `true` if the multiplication resulted in a signed integer overflow or underflow. /// and a flag that is set to `true` if the multiplication resulted in a signed integer overflow or underflow.
...@@ -221,23 +221,21 @@ impl BitvectorExtended for Bitvector { ...@@ -221,23 +221,21 @@ impl BitvectorExtended for Bitvector {
} }
} }
/// Returns `true` if adding `self` to `rhs` would result in a signed integer overflow or underflow. /// Returns the result of `self + rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_add_overflow_check(&self, rhs: &Self) -> bool { fn signed_add_overflow_checked(&self, rhs: &Self) -> Option<Self> {
let result = self.clone().into_checked_add(rhs).unwrap(); let result = self.clone().into_checked_add(rhs).unwrap();
if rhs.sign_bit().to_bool() { match (rhs.sign_bit().to_bool(), self.checked_sle(&result).unwrap()) {
self.checked_sle(&result).unwrap() (true, true) | (false, false) => None,
} else { _ => Some(result),
self.checked_sgt(&result).unwrap()
} }
} }
/// Returns `true` if subtracting `rhs` from `self` would result in a signed integer overflow or underflow. /// Returns the result of `self - rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_sub_overflow_check(&self, rhs: &Self) -> bool { fn signed_sub_overflow_checked(&self, rhs: &Self) -> Option<Self> {
let result = self.clone().into_checked_sub(rhs).unwrap(); let result = self.clone().into_checked_sub(rhs).unwrap();
if rhs.sign_bit().to_bool() { match (rhs.sign_bit().to_bool(), self.checked_sge(&result).unwrap()) {
self.checked_sge(&result).unwrap() (true, true) | (false, false) => None,
} else { _ => Some(result),
self.checked_slt(&result).unwrap()
} }
} }
...@@ -264,3 +262,36 @@ impl BitvectorExtended for Bitvector { ...@@ -264,3 +262,36 @@ impl BitvectorExtended for Bitvector {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overflow_checked_add_and_sub() {
let max = Bitvector::signed_max_value(ByteSize::new(8).into());
let min = Bitvector::signed_min_value(ByteSize::new(8).into());
assert_eq!(min.signed_add_overflow_checked(&min), None);
assert_eq!(
min.signed_add_overflow_checked(&max),
Some(-Bitvector::one(ByteSize::new(8).into()))
);
assert_eq!(
max.signed_add_overflow_checked(&min),
Some(-Bitvector::one(ByteSize::new(8).into()))
);
assert_eq!(max.signed_add_overflow_checked(&max), None);
assert_eq!(
min.signed_sub_overflow_checked(&min),
Some(Bitvector::zero(ByteSize::new(8).into()))
);
assert_eq!(min.signed_sub_overflow_checked(&max), None);
assert_eq!(max.signed_sub_overflow_checked(&min), None);
assert_eq!(
max.signed_sub_overflow_checked(&max),
Some(Bitvector::zero(ByteSize::new(8).into()))
);
}
}
...@@ -100,6 +100,22 @@ impl Expression { ...@@ -100,6 +100,22 @@ impl Expression {
} }
} }
/// Return an array of all input variables of the given expression.
/// The array may contain duplicates.
pub fn input_vars(&self) -> Vec<&Variable> {
use Expression::*;
match self {
Var(var) => vec![var],
Const(_) | Unknown { .. } => Vec::new(),
BinOp { op: _, lhs, rhs } => {
let mut vars = lhs.input_vars();
vars.append(&mut rhs.input_vars());
vars
}
UnOp { arg, .. } | Cast { arg, .. } | Subpiece { arg, .. } => arg.input_vars(),
}
}
/// Substitute trivial BinOp-expressions with their results, /// Substitute trivial BinOp-expressions with their results,
/// e.g. substitute `a or a` with `a`. /// e.g. substitute `a or a` with `a`.
/// ///
......
...@@ -4,8 +4,7 @@ use apint::ApInt; ...@@ -4,8 +4,7 @@ use apint::ApInt;
#[cfg(test)] #[cfg(test)]
use super::{CastOpType, Variable}; use super::{CastOpType, Variable};
use super::{BinOpType, Expression}; use super::*;
use crate::prelude::*;
/// ## Helper functions for building expressions /// ## Helper functions for building expressions
impl Expression { impl Expression {
...@@ -51,6 +50,15 @@ impl Expression { ...@@ -51,6 +50,15 @@ impl Expression {
} }
} }
/// Shortcut for creating unary operation expressions.
#[cfg(test)]
pub fn un_op(self, op: UnOpType) -> Expression {
Expression::UnOp {
op,
arg: Box::new(self),
}
}
/// Shortcut for creating an `IntAdd`-expression /// Shortcut for creating an `IntAdd`-expression
pub fn plus(self, rhs: Expression) -> Expression { pub fn plus(self, rhs: Expression) -> Expression {
Expression::BinOp { Expression::BinOp {
...@@ -60,6 +68,16 @@ impl Expression { ...@@ -60,6 +68,16 @@ impl Expression {
} }
} }
/// Shortcut for creating an `IntSub`-expression
#[cfg(test)]
pub fn minus(self, rhs: Expression) -> Expression {
Expression::BinOp {
lhs: Box::new(self),
op: BinOpType::IntSub,
rhs: Box::new(rhs),
}
}
/// Construct an expression that adds a constant value to the given expression. /// Construct an expression that adds a constant value to the given expression.
/// ///
/// The bytesize of the value is automatically adjusted to the bytesize of the given expression. /// The bytesize of the value is automatically adjusted to the bytesize of the given expression.
...@@ -73,4 +91,19 @@ impl Expression { ...@@ -73,4 +91,19 @@ impl Expression {
} }
self.plus(Expression::Const(value)) self.plus(Expression::Const(value))
} }
/// Construct an expression that subtracts a constant value from the given expression.
///
/// The bytesize of the value is automatically adjusted to the bytesize of the given expression.
#[cfg(test)]
pub fn minus_const(self, value: i64) -> Expression {
let bytesize = self.bytesize();
let mut value = Bitvector::from_i64(value);
match u64::from(bytesize) {
size if size > 8 => value.sign_extend(bytesize).unwrap(),
size if size < 8 => value.truncate(bytesize).unwrap(),
_ => (),
}
self.minus(Expression::Const(value))
}
} }
...@@ -172,6 +172,7 @@ impl From<Expression> for IrExpression { ...@@ -172,6 +172,7 @@ impl From<Expression> for IrExpression {
/// Expression Opcodes as parsed from Ghidra /// Expression Opcodes as parsed from Ghidra
#[allow(missing_docs)] #[allow(missing_docs)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExpressionType { pub enum ExpressionType {
COPY, COPY,
......
...@@ -49,6 +49,7 @@ pub struct Jmp { ...@@ -49,6 +49,7 @@ pub struct Jmp {
/// A jump type mnemonic. /// A jump type mnemonic.
#[allow(missing_docs)] #[allow(missing_docs)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum JmpType { pub enum JmpType {
BRANCH, BRANCH,
...@@ -305,6 +306,7 @@ pub struct Arg { ...@@ -305,6 +306,7 @@ pub struct Arg {
/// The intent (input or output) of a function argument. /// The intent (input or output) of a function argument.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
#[allow(clippy::upper_case_acronyms)]
pub enum ArgIntent { pub enum ArgIntent {
/// The argument is an input parameter. /// The argument is an input parameter.
INPUT, INPUT,
......
...@@ -142,7 +142,8 @@ impl RuntimeMemoryImage { ...@@ -142,7 +142,8 @@ impl RuntimeMemoryImage {
let address = address.try_to_u64().unwrap(); let address = address.try_to_u64().unwrap();
for segment in self.memory_segments.iter() { for segment in self.memory_segments.iter() {
if address >= segment.base_address if address >= segment.base_address
&& address + u64::from(size) <= segment.base_address + segment.bytes.len() as u64 && u64::from(size) <= segment.base_address + segment.bytes.len() as u64
&& address <= segment.base_address + segment.bytes.len() as u64 - u64::from(size)
{ {
if segment.write_flag { if segment.write_flag {
// The segment is writeable, thus we do not know the content at runtime. // The segment is writeable, thus we do not know the content at runtime.
......
...@@ -43,7 +43,7 @@ pub fn is_sink_call_reachable_from_source_call( ...@@ -43,7 +43,7 @@ pub fn is_sink_call_reachable_from_source_call(
// and as long as the edge does not leave the function. // and as long as the edge does not leave the function.
match edge.weight() { match edge.weight() {
Edge::Block Edge::Block
| Edge::CRCallStub | Edge::CrCallStub
| Edge::CallCombine(_) | Edge::CallCombine(_)
| Edge::ReturnCombine(_) | Edge::ReturnCombine(_)
| Edge::Jump(_, _) | Edge::Jump(_, _)
...@@ -53,7 +53,7 @@ pub fn is_sink_call_reachable_from_source_call( ...@@ -53,7 +53,7 @@ pub fn is_sink_call_reachable_from_source_call(
worklist.push(edge.target()) worklist.push(edge.target())
} }
} }
Edge::Call(_) | Edge::CRReturnStub => (), // These edges would leave the function control flow graph. Edge::Call(_) | Edge::CrReturnStub => (), // These edges would leave the function control flow graph.
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment