Unverified Commit a6ef477d by Enkelmann Committed by GitHub

Use conditional branches to restrict value sets of variables (#159)

parent 83593fac
use super::{
AbstractDomain, AbstractIdentifier, HasTop, Interval, PointerDomain, RegisterDomain,
SizedDomain, TryToBitvec, TryToInterval,
SizedDomain, SpecializeByConditional, TryToBitvec, TryToInterval,
};
use crate::intermediate_representation::*;
use crate::prelude::*;
......@@ -68,6 +68,48 @@ impl<T: RegisterDomain> DataDomain<T> {
}
}
impl<T: SpecializeByConditional + RegisterDomain> SpecializeByConditional for DataDomain<T> {
fn add_signed_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_signed_less_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_unsigned_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_unsigned_less_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_signed_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_signed_greater_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_unsigned_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_unsigned_greater_equal_bound(bound)?))
} else {
Ok(self)
}
}
fn add_not_equal_bound(self, bound: &Bitvector) -> Result<Self, Error> {
if let Self::Value(value) = self {
Ok(Self::Value(value.add_not_equal_bound(bound)?))
} else {
Ok(self)
}
}
}
impl<T: RegisterDomain> SizedDomain for DataDomain<T> {
// Return the bitsize of `self`.
fn bytesize(&self) -> ByteSize {
......@@ -244,6 +286,22 @@ impl<T: RegisterDomain + TryToInterval> TryToInterval for DataDomain<T> {
}
}
impl<T: RegisterDomain> std::ops::Add for DataDomain<T> {
type Output = DataDomain<T>;
fn add(self, rhs: Self) -> Self {
self.bin_op(BinOpType::IntAdd, &rhs)
}
}
impl<T: RegisterDomain> std::ops::Sub for DataDomain<T> {
type Output = DataDomain<T>;
fn sub(self, rhs: Self) -> Self {
self.bin_op(BinOpType::IntSub, &rhs)
}
}
impl<T: RegisterDomain + Display> DataDomain<T> {
/// Get a more compact json-representation of the data domain.
/// Intended for pretty printing, not useable for serialization/deserialization.
......
......@@ -3,7 +3,7 @@ use std::fmt::Display;
use crate::intermediate_representation::*;
use crate::prelude::*;
use super::{AbstractDomain, HasTop, RegisterDomain, SizedDomain};
use super::{AbstractDomain, HasTop, RegisterDomain, SizedDomain, SpecializeByConditional};
use super::{TryToBitvec, TryToInterval};
mod simple_interval;
......@@ -217,6 +217,113 @@ impl IntervalDomain {
.map(|bitvec| bitvec.into_sign_extend(width).unwrap()),
}
}
/// Compute the intersection of two intervals.
/// Return an error if the intersection is empty.
pub fn intersect(&self, other: &Self) -> Result<Self, Error> {
let mut intersected_domain: IntervalDomain =
self.interval.signed_intersect(&other.interval)?.into();
intersected_domain.update_widening_lower_bound(&self.widening_lower_bound);
intersected_domain.update_widening_lower_bound(&other.widening_lower_bound);
intersected_domain.update_widening_upper_bound(&self.widening_upper_bound);
intersected_domain.update_widening_upper_bound(&other.widening_upper_bound);
Ok(intersected_domain)
}
}
impl SpecializeByConditional for IntervalDomain {
fn add_signed_less_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if let Some(old_upper_bound) = &self.widening_upper_bound {
if old_upper_bound.checked_sle(bound).unwrap() {
return Ok(self);
} else if self.interval.end.checked_slt(bound).unwrap() {
self.widening_upper_bound = Some(bound.clone());
return Ok(self);
} else {
self.widening_upper_bound = None;
}
} else if self.interval.end.checked_slt(bound).unwrap() {
self.widening_upper_bound = Some(bound.clone());
return Ok(self);
}
// we already know that the bound is less equal to `self.interval.end`
if self.interval.start.checked_sle(bound).unwrap() {
self.interval.end = bound.clone();
Ok(self)
} else {
Err(anyhow!("Empty interval"))
}
}
fn add_signed_greater_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if let Some(old_lower_bound) = &self.widening_lower_bound {
if old_lower_bound.checked_sge(bound).unwrap() {
return Ok(self);
} else if self.interval.start.checked_sgt(bound).unwrap() {
self.widening_lower_bound = Some(bound.clone());
return Ok(self);
} else {
self.widening_lower_bound = None;
}
} else if self.interval.start.checked_sgt(bound).unwrap() {
self.widening_lower_bound = Some(bound.clone());
return Ok(self);
}
// we already know that the bound is greater equal to `self.interval.start`
if self.interval.end.checked_sge(bound).unwrap() {
self.interval.start = bound.clone();
Ok(self)
} else {
Err(anyhow!("Empty interval"))
}
}
fn add_unsigned_less_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if bound.sign_bit().to_bool() {
if self.interval.end.sign_bit().to_bool() {
self.add_signed_less_equal_bound(bound)
} else if self.interval.start.sign_bit().to_bool() {
Ok(self)
} else {
self.add_signed_greater_equal_bound(&Bitvector::zero(bound.width()))
}
} else {
self = self.add_signed_greater_equal_bound(&Bitvector::zero(bound.width()))?;
self.add_signed_less_equal_bound(bound)
}
}
fn add_unsigned_greater_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if bound.sign_bit().to_bool() {
self = self.add_signed_less_equal_bound(&(-Bitvector::one(bound.width())))?;
self.add_signed_greater_equal_bound(bound)
} else if self.interval.end.checked_slt(bound).unwrap() {
self.add_signed_less_equal_bound(&(-Bitvector::one(bound.width())))
} else if self.interval.start.sign_bit().to_bool() {
Ok(self)
} else {
self.add_signed_greater_equal_bound(bound)
}
}
fn add_not_equal_bound(mut self, bound: &Bitvector) -> Result<Self, Error> {
if self.interval.start == *bound && self.interval.end == *bound {
return Err(anyhow!("Empty interval"));
}
if self.interval.start.checked_sgt(bound).unwrap() {
self.add_signed_greater_equal_bound(&(bound + &Bitvector::one(bound.width())))
} else if self.interval.start == *bound {
self.interval.start += &Bitvector::one(bound.width());
Ok(self)
} else if self.interval.end.checked_slt(bound).unwrap() {
self.add_signed_less_equal_bound(&(bound - &Bitvector::one(bound.width())))
} else if self.interval.end == *bound {
self.interval.end -= &Bitvector::one(bound.width());
Ok(self)
} else {
Ok(self)
}
}
}
impl AbstractDomain for IntervalDomain {
......
......@@ -4,74 +4,66 @@ impl IntervalDomain {
/// Compute the interval of possible results
/// if one adds a value from `self` to a value from `rhs`.
pub fn add(&self, rhs: &Self) -> Self {
let interval = self.interval.add(&rhs.interval);
let mut interval: IntervalDomain = self.interval.add(&rhs.interval).into();
if interval.is_top() {
interval.into()
interval
} else {
let new_lower_bound = if let (Some(self_bound), Some(rhs_bound)) =
(&self.widening_lower_bound, &rhs.widening_lower_bound)
{
if self_bound.signed_add_overflow_check(rhs_bound) {
None
} else {
Some(self_bound.clone().into_checked_add(rhs_bound).unwrap())
}
} else {
None
};
let new_upper_bound = if let (Some(self_bound), Some(rhs_bound)) =
(&self.widening_upper_bound, &rhs.widening_upper_bound)
{
if self_bound.signed_add_overflow_check(rhs_bound) {
None
} else {
Some(self_bound.clone().into_checked_add(rhs_bound).unwrap())
}
} else {
None
};
IntervalDomain {
interval,
widening_upper_bound: new_upper_bound,
widening_lower_bound: new_lower_bound,
}
interval.update_widening_lower_bound(
&self
.widening_lower_bound
.as_ref()
.and_then(|bound| bound.signed_add_overflow_checked(&rhs.interval.start)),
);
interval.update_widening_lower_bound(
&rhs.widening_lower_bound
.as_ref()
.and_then(|bound| bound.signed_add_overflow_checked(&self.interval.start)),
);
interval.update_widening_upper_bound(
&self
.widening_upper_bound
.as_ref()
.and_then(|bound| bound.signed_add_overflow_checked(&rhs.interval.end)),
);
interval.update_widening_upper_bound(
&rhs.widening_upper_bound
.as_ref()
.and_then(|bound| bound.signed_add_overflow_checked(&self.interval.end)),
);
interval
}
}
/// Compute the interval of possible results
/// if one subtracts a value in `rhs` from a value in `self`.
pub fn sub(&self, rhs: &Self) -> Self {
let interval = self.interval.sub(&rhs.interval);
let mut interval: IntervalDomain = self.interval.sub(&rhs.interval).into();
if interval.is_top() {
interval.into()
interval
} else {
let new_lower_bound = if let (Some(self_bound), Some(rhs_bound)) =
(&self.widening_lower_bound, &rhs.widening_upper_bound)
{
if self_bound.signed_sub_overflow_check(rhs_bound) {
None
} else {
Some(self_bound.clone().into_checked_sub(rhs_bound).unwrap())
}
} else {
None
};
let new_upper_bound = if let (Some(self_bound), Some(rhs_bound)) =
(&self.widening_upper_bound, &rhs.widening_lower_bound)
{
if self_bound.signed_sub_overflow_check(rhs_bound) {
None
} else {
Some(self_bound.clone().into_checked_sub(rhs_bound).unwrap())
}
} else {
None
};
IntervalDomain {
interval,
widening_upper_bound: new_upper_bound,
widening_lower_bound: new_lower_bound,
}
interval.update_widening_lower_bound(
&self
.widening_lower_bound
.as_ref()
.and_then(|bound| bound.signed_sub_overflow_checked(&rhs.interval.end)),
);
interval.update_widening_lower_bound(
&rhs.widening_upper_bound
.as_ref()
.and_then(|bound| self.interval.start.signed_sub_overflow_checked(bound)),
);
interval.update_widening_upper_bound(
&self
.widening_upper_bound
.as_ref()
.and_then(|bound| bound.signed_sub_overflow_checked(&rhs.interval.start)),
);
interval.update_widening_upper_bound(
&rhs.widening_lower_bound
.as_ref()
.and_then(|bound| self.interval.end.signed_sub_overflow_checked(bound)),
);
interval
}
}
......
......@@ -55,6 +55,18 @@ impl Interval {
Interval { start, end }
}
/// Compute the intersection of two intervals as intervals of signed integers.
/// Return an error if the intersection is empty.
pub fn signed_intersect(&self, other: &Interval) -> Result<Interval, Error> {
let start = signed_max(&self.start, &other.start);
let end = signed_min(&self.end, &other.end);
if start.checked_sle(&end).unwrap() {
Ok(Interval { start, end })
} else {
Err(anyhow!("Empty interval"))
}
}
/// Return the number of contained values of the interval as an unsigned bitvector.
/// If the interval is unconstrained, return zero
/// (since the maximal number of elements is not representable in a bitvector of the same byte size).
......@@ -142,30 +154,26 @@ impl Interval {
/// Compute the interval of possible results
/// if one adds a value from `self` to a value from `rhs`.
pub fn add(&self, rhs: &Interval) -> Interval {
if self.start.signed_add_overflow_check(&rhs.start)
|| self.end.signed_add_overflow_check(&rhs.end)
{
Interval::new_top(self.bytesize())
if let (Some(start), Some(end)) = (
self.start.signed_add_overflow_checked(&rhs.start),
self.end.signed_add_overflow_checked(&rhs.end),
) {
Interval { start, end }
} else {
Interval {
start: self.start.clone().into_checked_add(&rhs.start).unwrap(),
end: self.end.clone().into_checked_add(&rhs.end).unwrap(),
}
Interval::new_top(self.bytesize())
}
}
/// Compute the interval of possible results
/// if one subtracts a value in `rhs` from a value in `self`.
pub fn sub(&self, rhs: &Interval) -> Interval {
if self.start.signed_sub_overflow_check(&rhs.end)
|| self.end.signed_sub_overflow_check(&rhs.start)
{
Interval::new_top(self.bytesize())
if let (Some(start), Some(end)) = (
self.start.signed_sub_overflow_checked(&rhs.end),
self.end.signed_sub_overflow_checked(&rhs.start),
) {
Interval { start, end }
} else {
Interval {
start: self.start.clone().into_checked_sub(&rhs.end).unwrap(),
end: self.end.clone().into_checked_sub(&rhs.start).unwrap(),
}
Interval::new_top(self.bytesize())
}
}
......@@ -193,6 +201,12 @@ impl Interval {
end: max,
}
}
/// Return `true` if `bitvec` is contained in the interval.
/// Panics if the interval and `bitvec` have different bytesizes.
pub fn contains(&self, bitvec: &Bitvector) -> bool {
self.start.checked_sle(bitvec).unwrap() && self.end.checked_sge(bitvec).unwrap()
}
}
impl From<Bitvector> for Interval {
......
......@@ -228,18 +228,18 @@ fn add() {
let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!(
result,
IntervalDomain::mock_with_bounds(None, 0, 7, Some(20))
IntervalDomain::mock_with_bounds(Some(-17), 0, 7, Some(10))
);
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 0, Some(10));
let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1)));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 4, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), 10, 20, Some(50));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), -30, 40, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-100), -30, 20, Some(50));
let result = lhs.bin_op(BinOpType::IntAdd, &rhs);
assert_eq!(
result,
IntervalDomain::mock_i8_with_bounds(None, 12, 24, None)
IntervalDomain::mock_i8_with_bounds(None, -60, 60, Some(90))
);
}
......@@ -250,18 +250,18 @@ fn sub() {
let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!(
result,
IntervalDomain::mock_with_bounds(None, 3, 10, Some(30))
IntervalDomain::mock_with_bounds(Some(-7), 3, 10, Some(13))
);
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-121), -120, -120, Some(10));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-10), -9, 9, Some(10));
let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!(result, IntervalDomain::new_top(ByteSize::new(1)));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 4, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), 10, 20, Some(50));
let lhs = IntervalDomain::mock_i8_with_bounds(Some(-100), 2, 40, Some(100));
let rhs = IntervalDomain::mock_i8_with_bounds(Some(-50), -30, 3, Some(100));
let result = lhs.bin_op(BinOpType::IntSub, &rhs);
assert_eq!(
result,
IntervalDomain::mock_i8_with_bounds(None, -18, -6, None)
IntervalDomain::mock_i8_with_bounds(Some(-98), -1, 70, Some(90))
);
}
......@@ -315,3 +315,191 @@ fn shift_left() {
let result = lhs.bin_op(BinOpType::IntLeft, &rhs);
assert_eq!(result, IntervalDomain::mock(0, 0));
}
#[test]
fn simple_interval_contains() {
let domain = IntervalDomain::mock(-10, 5);
assert!(!domain.interval.contains(&Bitvector::from_i64(-11)));
assert!(domain.interval.contains(&Bitvector::from_i64(-10)));
assert!(domain.interval.contains(&Bitvector::from_i64(-4)));
assert!(domain.interval.contains(&Bitvector::from_i64(5)));
assert!(!domain.interval.contains(&Bitvector::from_i64(6)));
}
#[test]
fn add_signed_bounds() {
let interval = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
// signed_less_equal
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(20))
);
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, -5, None)
);
let x = interval
.clone()
.add_signed_less_equal_bound(&Bitvector::from_i64(-20));
assert!(x.is_err());
//signed_greater_equal
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(20));
assert!(x.is_err());
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -5, 10, Some(100))
);
let x = interval
.clone()
.add_signed_greater_equal_bound(&Bitvector::from_i64(-20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-20), -10, 10, Some(100))
);
}
#[test]
fn add_unsigned_bounds() {
let positive_interval = IntervalDomain::mock_with_bounds(Some(10), 20, 30, Some(40));
let wrapped_interval = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
let negative_interval = IntervalDomain::mock_with_bounds(Some(-40), -30, -20, Some(-10));
// unsigned_less_equal
let x = positive_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(10), 20, 30, Some(35))
);
let x = positive_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(15));
assert!(x.is_err());
let x = wrapped_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, 0, 10, Some(35))
);
let x = wrapped_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-5));
assert_eq!(x.unwrap(), wrapped_interval); // Cannot remove a subinterval from the domain
let x = negative_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-40), -30, -25, None)
);
let x = negative_interval
.clone()
.add_unsigned_less_equal_bound(&Bitvector::from_i64(-35));
assert!(x.is_err());
// unsigned_greater_equal
let x = positive_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, 25, 30, Some(40))
);
let x = positive_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(35));
assert!(x.is_err());
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(5));
assert_eq!(x.unwrap(), wrapped_interval);
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(35));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-100), -10, -1, None)
);
let x = wrapped_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(-50));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-50), -10, -1, None)
);
let x = negative_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(25));
assert_eq!(x.unwrap(), negative_interval);
let x = negative_interval
.clone()
.add_unsigned_greater_equal_bound(&Bitvector::from_i64(-25));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -25, -20, Some(-10))
);
}
#[test]
fn add_not_equal_bounds() {
let interval = IntervalDomain::mock_with_bounds(None, -10, 10, None);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(-20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(Some(-19), -10, 10, None)
);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(-0));
assert_eq!(x.unwrap(), interval);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(20));
assert_eq!(
x.unwrap(),
IntervalDomain::mock_with_bounds(None, -10, 10, Some(19))
);
let interval = IntervalDomain::mock(5, 5);
let x = interval
.clone()
.add_not_equal_bound(&Bitvector::from_i64(5));
assert!(x.is_err());
let interval = IntervalDomain::mock(5, 6);
let x = interval.add_not_equal_bound(&Bitvector::from_i64(5));
assert_eq!(x.unwrap(), IntervalDomain::mock(6, 6));
}
#[test]
fn intersection() {
let interval1 = IntervalDomain::mock_with_bounds(Some(-100), -10, 10, Some(100));
let interval2 = IntervalDomain::mock_with_bounds(Some(-20), 2, 30, None);
let intersection = interval1.intersect(&interval2).unwrap();
assert_eq!(
intersection,
IntervalDomain::mock_with_bounds(Some(-20), 2, 10, Some(100))
);
assert!(interval1.intersect(&IntervalDomain::mock(50, 55)).is_err());
}
......@@ -123,3 +123,34 @@ pub trait TryToInterval {
Ok((interval.start.try_to_i64()?, interval.end.try_to_i64()?))
}
}
/// A trait for domains whose values can be restricted by knowing the result of a comparison of it with a known bitvector.
/// The comparison may also be used to add widening hints to the domain.
///
/// Note that the value set represented by the domain after the restriction may be an upper bound,
/// i.e. it is possible that the result still contains values not satisfying the restricting comparison.
pub trait SpecializeByConditional: Sized {
/// Return the restriction of `self` to values satisfying `self <= bound`
/// with `self` and `bound` interpreted as signed integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_signed_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self <= bound`
/// with `self` and `bound` interpreted as unsigned integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_unsigned_less_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self >= bound`
/// with `self` and `bound` interpreted as signed integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_signed_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self >= bound`
/// with `self` and `bound` interpreted as unsigned integers.
/// Returns an error if no value represented by `self` can satisfy the comparison.
fn add_unsigned_greater_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
/// Return the restriction of `self` to values satisfying `self != bound`
/// Returns an error if `self` only represents one value for which `self == bound` holds.
fn add_not_equal_bound(self, bound: &Bitvector) -> Result<Self, Error>;
}
......@@ -165,13 +165,13 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
// The CallStub Edge value is added to the CallSourceCombinator
// The user has the ability to split the node value at the BlkStart return to node
// to only send specific data along the CallStub Edge to the callsite
Edge::CRCallStub => Some(NodeValue::CallFlowCombinator {
Edge::CrCallStub => Some(NodeValue::CallFlowCombinator {
call_stub: self.context.split_call_stub(node_value.unwrap_value()),
interprocedural_flow: None,
}),
// The user has the ability to split the node value at the BlkStart return node
// to only send specific data along the ReturnStub Edge to the last BlkEnd node called subroutine
Edge::CRReturnStub => {
Edge::CrReturnStub => {
// The subroutine term from which the program returns
let returned_from_sub = match graph.node_weight(end_node) {
Some(Node::BlkEnd { 0: _, 1: sub_term }) => sub_term,
......
......@@ -79,6 +79,7 @@ pub trait Context<'a> {
&self,
value: &Self::Value,
condition: &Expression,
block_before_condition: &Term<Blk>,
is_true: bool,
) -> Option<Self::Value>;
}
......@@ -164,11 +165,11 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
.context
.update_call(node_value.unwrap_value(), call, &graph[end_node])
.map(NodeValue::Value),
Edge::CRCallStub => Some(NodeValue::CallFlowCombinator {
Edge::CrCallStub => Some(NodeValue::CallFlowCombinator {
call_stub: Some(node_value.unwrap_value().clone()),
interprocedural_flow: None,
}),
Edge::CRReturnStub => Some(NodeValue::CallFlowCombinator {
Edge::CrReturnStub => Some(NodeValue::CallFlowCombinator {
call_stub: None,
interprocedural_flow: Some(node_value.unwrap_value().clone()),
}),
......@@ -201,15 +202,51 @@ impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
.context
.update_call_stub(node_value.unwrap_value(), call)
.map(NodeValue::Value),
Edge::Jump(jump, untaken_conditional) => self
.context
.update_jump(
node_value.unwrap_value(),
jump,
*untaken_conditional,
graph[end_node].get_block(),
)
.map(NodeValue::Value),
Edge::Jump(jump, untaken_conditional) => {
let value_after_condition = if let Jmp::CBranch {
target: _,
condition,
} = &jump.term
{
let block = graph[start_node].get_block();
self.context.specialize_conditional(
node_value.unwrap_value(),
condition,
block,
true,
)
} else if let Some(untaken_conditional_jump) = untaken_conditional {
if let Jmp::CBranch {
target: _,
condition,
} = &untaken_conditional_jump.term
{
let block = graph[start_node].get_block();
self.context.specialize_conditional(
node_value.unwrap_value(),
condition,
block,
false,
)
} else {
panic!("Malformed control flow graph");
}
} else {
Some(node_value.unwrap_value().clone())
};
if let Some(value) = value_after_condition {
self.context
.update_jump(
&value,
jump,
*untaken_conditional,
graph[end_node].get_block(),
)
.map(NodeValue::Value)
} else {
None
}
}
}
}
}
......
......@@ -150,9 +150,9 @@ pub enum Edge<'a> {
/// The edge goes directly from the callsite to the return-to-site inside the caller.
ExternCallStub(&'a Term<Jmp>),
/// An artificial edge. See the module-level documentation for more information.
CRCallStub,
CrCallStub,
/// An artificial edge. See the module-level documentation for more information.
CRReturnStub,
CrReturnStub,
/// An artificial edge to combine intra- and interprocedural data flows at the callsite of calls.
/// See the module-level documentation for more information.
CallCombine(&'a Term<Jmp>),
......@@ -425,9 +425,9 @@ impl<'a> GraphBuilder<'a> {
return_: (return_from_block, return_from_sub),
});
self.graph
.add_edge(*call_node, return_combine_node, Edge::CRCallStub);
.add_edge(*call_node, return_combine_node, Edge::CrCallStub);
self.graph
.add_edge(return_source, return_combine_node, Edge::CRReturnStub);
.add_edge(return_source, return_combine_node, Edge::CrReturnStub);
self.graph.add_edge(
return_combine_node,
*return_to_node,
......
......@@ -90,7 +90,7 @@ impl<'a> Context<'a> {
_ => Bitvector::zero(apint::BitWidth::from(self.project.get_pointer_bytesize())),
};
match state_before_return.get_register(&self.project.stack_pointer_register) {
Ok(Data::Pointer(pointer)) => {
Data::Pointer(pointer) => {
if pointer.targets().len() == 1 {
let (id, offset) = pointer.targets().iter().next().unwrap();
if *id != state_before_return.stack_id
......@@ -106,11 +106,10 @@ impl<'a> Context<'a> {
))
}
}
Ok(Data::Top(_)) => Err(anyhow!(
Data::Top(_) => Err(anyhow!(
"Stack register value lost during function execution"
)),
Ok(Data::Value(_)) => Err(anyhow!("Unexpected stack register value on return")),
Err(err) => Err(err),
Data::Value(_) => Err(anyhow!("Unexpected stack register value on return")),
}
}
......@@ -258,11 +257,7 @@ impl<'a> Context<'a> {
/// Check whether the jump is an indirect call whose target evaluates to a *Top* value in the given state.
fn is_indirect_call_with_top_target(&self, state: &State, call: &Term<Jmp>) -> bool {
match &call.term {
Jmp::CallInd { target, .. }
if state.eval(target).map_or(false, |value| value.is_top()) =>
{
true
}
Jmp::CallInd { target, .. } => state.eval(target).is_top(),
_ => false,
}
}
......@@ -280,7 +275,7 @@ impl<'a> Context<'a> {
new_state: &mut State,
) {
let stack_register = &self.project.stack_pointer_register;
let stack_pointer = state_before_call.get_register(stack_register).unwrap();
let stack_pointer = state_before_call.get_register(stack_register);
match self.project.cpu_architecture.as_str() {
"x86" | "x86_64" => {
let offset = Bitvector::from_u64(stack_register.size.into())
......@@ -382,7 +377,7 @@ impl<'a> Context<'a> {
/// Get the offset of the current stack pointer to the base of the current stack frame.
fn get_current_stack_offset(&self, state: &State) -> ValueDomain {
if let Ok(Data::Pointer(ref stack_pointer)) =
if let Data::Pointer(ref stack_pointer) =
state.get_register(&self.project.stack_pointer_register)
{
if stack_pointer.targets().len() == 1 {
......
......@@ -130,7 +130,7 @@ fn context_problem_implementation() {
// test update_def
state = context.update_def(&state, &def).unwrap();
let stack_pointer = Data::Pointer(PointerDomain::new(new_id("main", "RSP"), bv(-16)));
assert_eq!(state.eval(&Var(register("RSP"))).unwrap(), stack_pointer);
assert_eq!(state.eval(&Var(register("RSP"))), stack_pointer);
state = context.update_def(&state, &store_term).unwrap();
// Test update_call
......@@ -194,10 +194,9 @@ fn context_problem_implementation() {
assert_eq!(return_state.caller_stack_ids, BTreeSet::new());
assert_eq!(return_state.memory, state.memory);
assert_eq!(
return_state.get_register(&register("RSP")).unwrap(),
return_state.get_register(&register("RSP")),
state
.get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Bitvector::from_i64(8).into())
);
......@@ -207,7 +206,7 @@ fn context_problem_implementation() {
let malloc = call_term("extern_malloc");
let mut state_after_malloc = context.update_call_stub(&state, &malloc).unwrap();
assert_eq!(
state_after_malloc.get_register(&register("RDX")).unwrap(),
state_after_malloc.get_register(&register("RDX")),
Data::Pointer(PointerDomain::new(
new_id("call_extern_malloc", "RDX"),
bv(0)
......@@ -215,21 +214,17 @@ fn context_problem_implementation() {
);
assert_eq!(state_after_malloc.memory.get_num_objects(), 2);
assert_eq!(
state_after_malloc.get_register(&register("RSP")).unwrap(),
state_after_malloc.get_register(&register("RSP")),
state
.get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Data::Value(bv(8)))
);
assert_eq!(
state_after_malloc
.get_register(&register("callee_saved_reg"))
.unwrap(),
state_after_malloc.get_register(&register("callee_saved_reg")),
Data::Value(bv(13))
);
assert!(state_after_malloc
.get_register(&register("other_reg"))
.unwrap()
.is_top());
state_after_malloc.set_register(
......@@ -243,15 +238,10 @@ fn context_problem_implementation() {
let state_after_free = context
.update_call_stub(&state_after_malloc, &free)
.unwrap();
assert!(state_after_free
.get_register(&register("RDX"))
.unwrap()
.is_top());
assert!(state_after_free.get_register(&register("RDX")).is_top());
assert_eq!(state_after_free.memory.get_num_objects(), 2);
assert_eq!(
state_after_free
.get_register(&register("callee_saved_reg"))
.unwrap(),
state_after_free.get_register(&register("callee_saved_reg")),
Data::Pointer(PointerDomain::new(
new_id("call_extern_malloc", "RDX"),
bv(0)
......@@ -262,21 +252,17 @@ fn context_problem_implementation() {
let state_after_other_fn = context.update_call_stub(&state, &other_extern_fn).unwrap();
assert_eq!(
state_after_other_fn.get_register(&register("RSP")).unwrap(),
state_after_other_fn.get_register(&register("RSP")),
state
.get_register(&register("RSP"))
.unwrap()
.bin_op(BinOpType::IntAdd, &Data::Value(bv(8)))
);
assert_eq!(
state_after_other_fn
.get_register(&register("callee_saved_reg"))
.unwrap(),
state_after_other_fn.get_register(&register("callee_saved_reg")),
Data::Value(bv(13))
);
assert!(state_after_other_fn
.get_register(&register("other_reg"))
.unwrap()
.is_top());
}
......@@ -380,10 +366,119 @@ fn update_return() {
.get_all_object_ids()
.get(&new_id("caller_caller", "RSP"))
.is_some());
assert!(state.get_register(&register("RSP")).is_ok());
let expected_rsp = Data::Pointer(PointerDomain::new(
new_id("original_caller_id", "RSP"),
bv(-8),
));
assert_eq!(state.get_register(&register("RSP")).unwrap(), expected_rsp);
assert_eq!(state.get_register(&register("RSP")), expected_rsp);
}
#[test]
fn specialize_conditional() {
use crate::analysis::forward_interprocedural_fixpoint::Context as IpFpContext;
let (project, config) = mock_project();
let graph = crate::analysis::graph::get_program_cfg(&project.program, HashSet::new());
let runtime_memory_image = RuntimeMemoryImage::mock();
let (log_sender, _log_receiver) = crossbeam_channel::unbounded();
let context = Context::new(&project, &runtime_memory_image, &graph, config, log_sender);
let mut state = State::new(&register("RSP"), Tid::new("func"));
state.set_register(&register("RAX"), IntervalDomain::mock(-10, 20).into());
let condition = Expression::Var(Variable::mock("FLAG", 1));
// A complicated way of computing the result of `RAX <= 0`
// and assigning the result to the `FLAG` register.
let defs = vec![
Def::assign("def1", register("RAX"), Expression::Var(register("RAX"))),
Def::assign(
"def_that_should_be_ignored",
Variable::mock("FLAG", 1),
Expression::Const(Bitvector::from_u8(42)),
),
Def::assign(
"def2",
Variable::mock("FLAG_SLESS", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntSLess,
rhs: Box::new(Expression::Const(Bitvector::from_u64(0))),
},
),
Def::assign(
"def3",
Variable::mock("FLAG_EQUAL", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(register("RAX"))),
op: BinOpType::IntEqual,
rhs: Box::new(Expression::Const(Bitvector::from_u64(0))),
},
),
Def::assign(
"def4",
Variable::mock("FLAG_NOTEQUAL", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(Variable::mock("FLAG_SLESS", 1))),
op: BinOpType::IntNotEqual,
rhs: Box::new(Expression::Const(Bitvector::from_u8(0))),
},
),
Def::assign(
"def5",
Variable::mock("FLAG", 1),
Expression::BinOp {
lhs: Box::new(Expression::Var(Variable::mock("FLAG_EQUAL", 1))),
op: BinOpType::BoolOr,
rhs: Box::new(Expression::Var(Variable::mock("FLAG_NOTEQUAL", 1))),
},
),
];
let block = Term {
tid: Tid::new("block"),
term: Blk {
defs,
jmps: Vec::new(),
indirect_jmp_targets: Vec::new(),
},
};
let result = context
.specialize_conditional(&state, &condition, &block, false)
.unwrap();
assert_eq!(
result.get_register(&Variable::mock("FLAG", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_NOTEQUAL", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_EQUAL", 1)),
Bitvector::from_u8(0).into()
);
assert_eq!(
result.get_register(&Variable::mock("FLAG_SLESS", 1)),
Bitvector::from_u8(0).into()
);
// The result is technically false, since RAX == 0 should be excluded.
// This impreciseness is due to the way that the result is calculated.
assert_eq!(
result.get_register(&register("RAX")),
IntervalDomain::mock(0, 20).into()
);
state.set_register(&register("RAX"), IntervalDomain::mock(0, 20).into());
let result = context
.specialize_conditional(&state, &condition, &block, false)
.unwrap();
assert_eq!(
result.get_register(&register("RAX")),
IntervalDomain::mock_with_bounds(Some(0), 1, 20, None).into()
);
state.set_register(&register("RAX"), IntervalDomain::mock(-20, 0).into());
let result = context.specialize_conditional(&state, &condition, &block, false);
assert!(result.is_none());
}
use super::*;
use std::collections::HashSet;
impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Context<'a> {
type Value = State;
......@@ -43,7 +44,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
}
Def::Assign { var, value } => {
let mut new_state = state.clone();
self.log_debug(new_state.handle_register_assign(var, value), Some(&def.tid));
new_state.handle_register_assign(var, value);
Some(new_state)
}
Def::Load { var, address } => {
......@@ -291,13 +292,55 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
}
/// Update the state with the knowledge that some conditional evaluated to true or false.
/// Currently not implemented, this function just returns the state as it is.
fn specialize_conditional(
&self,
value: &State,
_condition: &Expression,
_is_true: bool,
condition: &Expression,
block_before_condition: &Term<Blk>,
is_true: bool,
) -> Option<State> {
Some(value.clone())
let mut specialized_state = value.clone();
if specialized_state
.specialize_by_expression_result(condition, Bitvector::from_u8(is_true as u8).into())
.is_err()
{
// State is unsatisfiable
return None;
}
let mut modified_vars: HashSet<Variable> = HashSet::new();
for def in block_before_condition.term.defs.iter().rev() {
match &def.term {
Def::Store { .. } => (),
Def::Load { var, .. } => {
modified_vars.insert(var.clone());
}
Def::Assign {
var,
value: input_expr,
} => {
if !modified_vars.contains(var) {
// Register is not modified again between the `Def` and the end of the block.
modified_vars.insert(var.clone());
if input_expr
.input_vars()
.into_iter()
.find(|input_var| modified_vars.contains(input_var))
.is_none()
{
// Values of input registers did not change between the `Def` and the end of the block.
let expr_result = specialized_state.get_register(var);
if specialized_state
.specialize_by_expression_result(input_expr, expr_result.clone())
.is_err()
{
// State is unsatisfiable
return None;
}
}
}
}
}
}
Some(specialized_state)
}
}
......@@ -338,33 +338,21 @@ impl<'a> PointerInference<'a> {
for jmp in block.term.jmps.iter() {
match &jmp.term {
Jmp::BranchInd(target_expr) => {
if let Ok(address) = state.eval(&target_expr) {
println!(
"{}: Indirect jump to {}",
jmp.tid,
address.to_json_compact()
);
} else {
println!(
"{}: Indirect jump. Could not compute address",
jmp.tid
);
}
let address = state.eval(&target_expr);
println!(
"{}: Indirect jump to {}",
jmp.tid,
address.to_json_compact()
);
}
Jmp::CallInd { target, return_ } => {
if let Ok(address) = state.eval(&target) {
println!(
"{}: Indirect call to {}. HasReturn: {}",
jmp.tid,
address.to_json_compact(),
return_.is_some()
);
} else {
println!(
"{}: Indirect call. Could not compute address",
jmp.tid
);
}
let address = state.eval(&target);
println!(
"{}: Indirect call to {}. HasReturn: {}",
jmp.tid,
address.to_json_compact(),
return_.is_some()
);
}
Jmp::Return(_) => {
if !state.caller_stack_ids.is_empty() {
......
......@@ -258,19 +258,20 @@ impl AbstractObjectInfo {
/// Get a more compact json-representation of the abstract object.
/// Intended for pretty printing, not useable for serialization/deserialization.
pub fn to_json_compact(&self) -> serde_json::Value {
let mut elements = Vec::new();
elements.push((
"is_unique".to_string(),
serde_json::Value::String(format!("{}", self.is_unique)),
));
elements.push((
"state".to_string(),
serde_json::Value::String(format!("{:?}", self.state)),
));
elements.push((
"type".to_string(),
serde_json::Value::String(format!("{:?}", self.type_)),
));
let mut elements = vec![
(
"is_unique".to_string(),
serde_json::Value::String(format!("{}", self.is_unique)),
),
(
"state".to_string(),
serde_json::Value::String(format!("{:?}", self.state)),
),
(
"type".to_string(),
serde_json::Value::String(format!("{:?}", self.type_)),
),
];
let memory = self
.memory
.iter()
......
......@@ -4,13 +4,11 @@ use super::*;
impl State {
/// Get the value of a register or Top() if no value is known.
///
/// Returns an error if the variable is not a register.
pub fn get_register(&self, variable: &Variable) -> Result<Data, Error> {
pub fn get_register(&self, variable: &Variable) -> Data {
if let Some(data) = self.register.get(variable) {
Ok(data.clone())
data.clone()
} else {
Ok(Data::new_top(variable.size))
Data::new_top(variable.size)
}
}
......@@ -28,8 +26,6 @@ impl State {
}
/// Set the value of a register.
///
/// Returns an error if the variable is not a register.
pub fn set_register(&mut self, variable: &Variable, value: Data) {
if !value.is_top() {
self.register.insert(variable.clone(), value);
......@@ -39,21 +35,8 @@ impl State {
}
/// Evaluate expression on the given state and write the result to the target register.
pub fn handle_register_assign(
&mut self,
target: &Variable,
expression: &Expression,
) -> Result<(), Error> {
match self.eval(expression) {
Ok(new_value) => {
self.set_register(target, new_value);
Ok(())
}
Err(err) => {
self.set_register(target, Data::new_top(target.size));
Err(err)
}
}
pub fn handle_register_assign(&mut self, target: &Variable, expression: &Expression) {
self.set_register(target, self.eval(expression))
}
/// Store `value` at the given `address`.
......@@ -116,32 +99,19 @@ impl State {
value: &Data,
global_memory: &RuntimeMemoryImage,
) -> Result<(), Error> {
match self.eval(address) {
Ok(address_data) => self.store_value(&address_data, value, global_memory),
Err(err) => Err(err),
}
let address_data = self.eval(address);
self.store_value(&address_data, value, global_memory)
}
/// Evaluate the store instruction, given by its address and value expressions,
/// and modify the state accordingly.
///
/// If an error occurs, the state is still modified before the error is returned.
/// E.g. if the value expression cannot be evaluated,
/// the value at the target address is overwritten with a `Top` value.
pub fn handle_store(
&mut self,
address: &Expression,
value: &Expression,
global_memory: &RuntimeMemoryImage,
) -> Result<(), Error> {
match self.eval(value) {
Ok(data) => self.write_to_address(address, &data, global_memory),
Err(err) => {
// we still need to write to the target location before reporting the error
self.write_to_address(address, &Data::new_top(value.bytesize()), global_memory)?;
Err(err)
}
}
self.write_to_address(address, &self.eval(value), global_memory)
}
/// Evaluate the given load instruction and return the data read on success.
......@@ -151,7 +121,7 @@ impl State {
size: ByteSize,
global_memory: &RuntimeMemoryImage,
) -> Result<Data, Error> {
let address = self.adjust_pointer_for_read(&self.eval(address)?);
let address = self.adjust_pointer_for_read(&self.eval(address));
match address {
Data::Value(global_address) => {
if let Ok(address_bitvector) = global_address.try_to_bitvec() {
......@@ -233,30 +203,30 @@ impl State {
}
/// Evaluate the value of an expression in the current state
pub fn eval(&self, expression: &Expression) -> Result<Data, Error> {
pub fn eval(&self, expression: &Expression) -> Data {
use Expression::*;
match expression {
Var(variable) => self.get_register(&variable),
Const(bitvector) => Ok(bitvector.clone().into()),
Const(bitvector) => bitvector.clone().into(),
BinOp { op, lhs, rhs } => {
if *op == BinOpType::IntXOr && lhs == rhs {
// the result of `x XOR x` is always zero.
return Ok(Bitvector::zero(apint::BitWidth::from(lhs.bytesize())).into());
return Bitvector::zero(apint::BitWidth::from(lhs.bytesize())).into();
}
let (left, right) = (self.eval(lhs)?, self.eval(rhs)?);
Ok(left.bin_op(*op, &right))
let (left, right) = (self.eval(lhs), self.eval(rhs));
left.bin_op(*op, &right)
}
UnOp { op, arg } => Ok(self.eval(arg)?.un_op(*op)),
Cast { op, size, arg } => Ok(self.eval(arg)?.cast(*op, *size)),
UnOp { op, arg } => self.eval(arg).un_op(*op),
Cast { op, size, arg } => self.eval(arg).cast(*op, *size),
Unknown {
description: _,
size,
} => Ok(Data::new_top(*size)),
} => Data::new_top(*size),
Subpiece {
low_byte,
size,
arg,
} => Ok(self.eval(arg)?.subpiece(*low_byte, *size)),
} => self.eval(arg).subpiece(*low_byte, *size),
}
}
......@@ -268,7 +238,7 @@ impl State {
global_memory: &RuntimeMemoryImage,
) -> Result<Data, Error> {
match parameter {
Arg::Register(var) => self.eval(&Expression::Var(var.clone())),
Arg::Register(var) => Ok(self.eval(&Expression::Var(var.clone()))),
Arg::Stack { offset, size } => self.load_value(
&Expression::Var(stack_pointer.clone()).plus_const(*offset),
*size,
......@@ -281,11 +251,7 @@ impl State {
pub fn contains_access_of_dangling_memory(&self, def: &Def) -> bool {
match def {
Def::Load { address, .. } | Def::Store { address, .. } => {
if let Ok(pointer) = self.eval(address) {
self.memory.is_dangling_pointer(&pointer, true)
} else {
false
}
self.memory.is_dangling_pointer(&self.eval(address), true)
}
_ => false,
}
......
......@@ -197,23 +197,21 @@ impl<'a> Context<'a> {
for parameter in extern_symbol.parameters.iter() {
match parameter {
Arg::Register(var) => {
if let Ok(data) = pi_state.eval(&Expression::Var(var.clone())) {
if state.check_if_address_points_to_taint(data, pi_state) {
return true;
}
let data = pi_state.eval(&Expression::Var(var.clone()));
if state.check_if_address_points_to_taint(data, pi_state) {
return true;
}
}
Arg::Stack { offset, size } => {
if let Ok(stack_address) = pi_state.eval(
let stack_address = pi_state.eval(
&Expression::Var(self.project.stack_pointer_register.clone())
.plus_const(*offset),
) {
if state
.load_taint_from_memory(&stack_address, *size)
.is_tainted()
{
return true;
}
);
if state
.load_taint_from_memory(&stack_address, *size)
.is_tainted()
{
return true;
}
if let Ok(stack_param) = pi_state.eval_parameter_arg(
parameter,
......@@ -266,6 +264,7 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
&self,
state: &State,
_condition: &Expression,
_block_before_condition: &Term<Blk>,
_is_true: bool,
) -> Option<State> {
Some(state.clone())
......@@ -332,10 +331,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} else if let Some(pi_state) =
self.get_current_pointer_inference_state(state, &def.tid)
{
if let Ok(address_data) = pi_state.eval(address) {
let taint = state.load_taint_from_memory(&address_data, var.size);
new_state.set_register_taint(var, taint);
}
let address_data = pi_state.eval(address);
let taint = state.load_taint_from_memory(&address_data, var.size);
new_state.set_register_taint(var, taint);
} else {
new_state.set_register_taint(var, Taint::Top(var.size));
}
......@@ -347,10 +345,9 @@ impl<'a> crate::analysis::forward_interprocedural_fixpoint::Context<'a> for Cont
} else if let Some(pi_state) =
self.get_current_pointer_inference_state(state, &def.tid)
{
if let Ok(address_data) = pi_state.eval(address) {
let taint = state.eval(value);
new_state.save_taint_to_memory(&address_data, taint);
}
let address_data = pi_state.eval(address);
let taint = state.eval(value);
new_state.save_taint_to_memory(&address_data, taint);
} else {
// We lost all knowledge about memory pointers.
// We delete all memory taint to reduce false positives.
......
......@@ -101,9 +101,8 @@ impl State {
if let Some(pi_state) = pi_state {
let address_exp =
Expression::Var(stack_pointer_register.clone()).plus_const(*offset);
if let Ok(address) = pi_state.eval(&address_exp) {
state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
let address = pi_state.eval(&address_exp);
state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
}
}
......@@ -369,12 +368,13 @@ impl State {
}
memory.push((format!("{}", tid), Value::Object(Map::from_iter(elements))));
}
let mut state_map = Vec::new();
state_map.push((
"register".to_string(),
Value::Object(Map::from_iter(register)),
));
state_map.push(("memory".to_string(), Value::Object(Map::from_iter(memory))));
let state_map = vec![
(
"register".to_string(),
Value::Object(Map::from_iter(register)),
),
("memory".to_string(), Value::Object(Map::from_iter(memory))),
];
Value::Object(Map::from_iter(state_map))
}
......@@ -477,7 +477,7 @@ mod tests {
assert_eq!(state.register_taint.get(&register("RSP")), None);
let address = Expression::Var(register("RSP"));
assert_eq!(
state.load_taint_from_memory(&pi_state.eval(&address).unwrap(), ByteSize::new(8)),
state.load_taint_from_memory(&pi_state.eval(&address), ByteSize::new(8)),
taint
);
}
......
......@@ -262,9 +262,8 @@ impl<'a> Context<'a> {
{
let address_exp =
Expression::Var(self.project.stack_pointer_register.clone()).plus_const(offset);
if let Ok(address) = pi_state.eval(&address_exp) {
new_state.save_taint_to_memory(&address, Taint::Tainted(size));
}
let address = pi_state.eval(&address_exp);
new_state.save_taint_to_memory(&address, Taint::Tainted(size));
}
new_state
}
......
......@@ -467,13 +467,10 @@ fn creating_pi_def_map() {
for (def_tid, pi_state) in pi_def_map.iter() {
if *def_tid == def1 {
assert_eq!(
pi_state.get_register(&rdi_reg).unwrap(),
Data::new_top(rdi_reg.size)
);
assert_eq!(pi_state.get_register(&rdi_reg), Data::new_top(rdi_reg.size));
} else if *def_tid == def2 {
assert_eq!(
pi_state.get_register(&rdi_reg).unwrap(),
pi_state.get_register(&rdi_reg),
Data::Pointer(PointerDomain::new(stack_id.clone(), bv(-8)))
);
}
......
......@@ -113,9 +113,8 @@ impl State {
if let Some(pi_state) = pi_state {
let address_exp =
Expression::Var(stack_pointer_register.clone()).plus_const(*offset);
if let Ok(address) = pi_state.eval(&address_exp) {
state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
let address = pi_state.eval(&address_exp);
state.save_taint_to_memory(&address, Taint::Tainted(*size));
}
}
}
......@@ -208,9 +207,8 @@ impl State {
Expression::BinOp { .. } => {
if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.get_register(result) {
self.save_taint_to_memory(&address, Taint::Tainted(result.size));
}
let address = pi_state.get_register(result);
self.save_taint_to_memory(&address, Taint::Tainted(result.size));
}
}
}
......@@ -233,11 +231,10 @@ impl State {
) {
if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.eval(target) {
if self.check_if_address_points_to_taint(address.clone(), &pi_state) {
self.taint_def_input_register(value, stack_pointer_register, def_tid);
self.remove_mem_taint_at_target(&address);
}
let address = pi_state.eval(target);
if self.check_if_address_points_to_taint(address.clone(), &pi_state) {
self.taint_def_input_register(value, stack_pointer_register, def_tid);
self.remove_mem_taint_at_target(&address);
}
}
}
......@@ -277,12 +274,11 @@ impl State {
if var.name == stack_pointer_register.name {
if let Some(pid_map) = self.pi_def_map.as_ref() {
if let Some(pi_state) = pid_map.get(def_tid) {
if let Ok(address) = pi_state.get_register(stack_pointer_register) {
self.save_taint_to_memory(
&address,
Taint::Tainted(stack_pointer_register.size),
);
}
let address = pi_state.get_register(stack_pointer_register);
self.save_taint_to_memory(
&address,
Taint::Tainted(stack_pointer_register.size),
);
}
}
} else {
......
......@@ -25,11 +25,11 @@ pub trait BitvectorExtended: Sized {
/// Returns an error for non-implemented operations (currently all float-related operations).
fn bin_op(&self, op: BinOpType, rhs: &Self) -> Result<Self, Error>;
/// Returns `true` if adding `self` to `rhs` would result in a signed integer overflow or underflow.
fn signed_add_overflow_check(&self, rhs: &Self) -> bool;
/// Returns the result of `self + rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_add_overflow_checked(&self, rhs: &Self) -> Option<Self>;
/// Returns `true` if subtracting `rhs` from `self` would result in a signed integer overflow or underflow.
fn signed_sub_overflow_check(&self, rhs: &Self) -> bool;
/// Returns the result of `self - rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_sub_overflow_checked(&self, rhs: &Self) -> Option<Self>;
/// Return the result of multiplying `self` with `rhs`
/// and a flag that is set to `true` if the multiplication resulted in a signed integer overflow or underflow.
......@@ -221,23 +221,21 @@ impl BitvectorExtended for Bitvector {
}
}
/// Returns `true` if adding `self` to `rhs` would result in a signed integer overflow or underflow.
fn signed_add_overflow_check(&self, rhs: &Self) -> bool {
/// Returns the result of `self + rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_add_overflow_checked(&self, rhs: &Self) -> Option<Self> {
let result = self.clone().into_checked_add(rhs).unwrap();
if rhs.sign_bit().to_bool() {
self.checked_sle(&result).unwrap()
} else {
self.checked_sgt(&result).unwrap()
match (rhs.sign_bit().to_bool(), self.checked_sle(&result).unwrap()) {
(true, true) | (false, false) => None,
_ => Some(result),
}
}
/// Returns `true` if subtracting `rhs` from `self` would result in a signed integer overflow or underflow.
fn signed_sub_overflow_check(&self, rhs: &Self) -> bool {
/// Returns the result of `self - rhs` if the computation does not result in a signed integer overflow or underflow.
fn signed_sub_overflow_checked(&self, rhs: &Self) -> Option<Self> {
let result = self.clone().into_checked_sub(rhs).unwrap();
if rhs.sign_bit().to_bool() {
self.checked_sge(&result).unwrap()
} else {
self.checked_slt(&result).unwrap()
match (rhs.sign_bit().to_bool(), self.checked_sge(&result).unwrap()) {
(true, true) | (false, false) => None,
_ => Some(result),
}
}
......@@ -264,3 +262,36 @@ impl BitvectorExtended for Bitvector {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overflow_checked_add_and_sub() {
let max = Bitvector::signed_max_value(ByteSize::new(8).into());
let min = Bitvector::signed_min_value(ByteSize::new(8).into());
assert_eq!(min.signed_add_overflow_checked(&min), None);
assert_eq!(
min.signed_add_overflow_checked(&max),
Some(-Bitvector::one(ByteSize::new(8).into()))
);
assert_eq!(
max.signed_add_overflow_checked(&min),
Some(-Bitvector::one(ByteSize::new(8).into()))
);
assert_eq!(max.signed_add_overflow_checked(&max), None);
assert_eq!(
min.signed_sub_overflow_checked(&min),
Some(Bitvector::zero(ByteSize::new(8).into()))
);
assert_eq!(min.signed_sub_overflow_checked(&max), None);
assert_eq!(max.signed_sub_overflow_checked(&min), None);
assert_eq!(
max.signed_sub_overflow_checked(&max),
Some(Bitvector::zero(ByteSize::new(8).into()))
);
}
}
......@@ -100,6 +100,22 @@ impl Expression {
}
}
/// Return an array of all input variables of the given expression.
/// The array may contain duplicates.
pub fn input_vars(&self) -> Vec<&Variable> {
use Expression::*;
match self {
Var(var) => vec![var],
Const(_) | Unknown { .. } => Vec::new(),
BinOp { op: _, lhs, rhs } => {
let mut vars = lhs.input_vars();
vars.append(&mut rhs.input_vars());
vars
}
UnOp { arg, .. } | Cast { arg, .. } | Subpiece { arg, .. } => arg.input_vars(),
}
}
/// Substitute trivial BinOp-expressions with their results,
/// e.g. substitute `a or a` with `a`.
///
......
......@@ -4,8 +4,7 @@ use apint::ApInt;
#[cfg(test)]
use super::{CastOpType, Variable};
use super::{BinOpType, Expression};
use crate::prelude::*;
use super::*;
/// ## Helper functions for building expressions
impl Expression {
......@@ -51,6 +50,15 @@ impl Expression {
}
}
/// Shortcut for creating unary operation expressions.
#[cfg(test)]
pub fn un_op(self, op: UnOpType) -> Expression {
Expression::UnOp {
op,
arg: Box::new(self),
}
}
/// Shortcut for creating an `IntAdd`-expression
pub fn plus(self, rhs: Expression) -> Expression {
Expression::BinOp {
......@@ -60,6 +68,16 @@ impl Expression {
}
}
/// Shortcut for creating an `IntSub`-expression
#[cfg(test)]
pub fn minus(self, rhs: Expression) -> Expression {
Expression::BinOp {
lhs: Box::new(self),
op: BinOpType::IntSub,
rhs: Box::new(rhs),
}
}
/// Construct an expression that adds a constant value to the given expression.
///
/// The bytesize of the value is automatically adjusted to the bytesize of the given expression.
......@@ -73,4 +91,19 @@ impl Expression {
}
self.plus(Expression::Const(value))
}
/// Construct an expression that subtracts a constant value from the given expression.
///
/// The bytesize of the value is automatically adjusted to the bytesize of the given expression.
#[cfg(test)]
pub fn minus_const(self, value: i64) -> Expression {
let bytesize = self.bytesize();
let mut value = Bitvector::from_i64(value);
match u64::from(bytesize) {
size if size > 8 => value.sign_extend(bytesize).unwrap(),
size if size < 8 => value.truncate(bytesize).unwrap(),
_ => (),
}
self.minus(Expression::Const(value))
}
}
......@@ -172,6 +172,7 @@ impl From<Expression> for IrExpression {
/// Expression Opcodes as parsed from Ghidra
#[allow(missing_docs)]
#[allow(non_camel_case_types)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExpressionType {
COPY,
......
......@@ -49,6 +49,7 @@ pub struct Jmp {
/// A jump type mnemonic.
#[allow(missing_docs)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum JmpType {
BRANCH,
......@@ -305,6 +306,7 @@ pub struct Arg {
/// The intent (input or output) of a function argument.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
#[allow(clippy::upper_case_acronyms)]
pub enum ArgIntent {
/// The argument is an input parameter.
INPUT,
......
......@@ -142,7 +142,8 @@ impl RuntimeMemoryImage {
let address = address.try_to_u64().unwrap();
for segment in self.memory_segments.iter() {
if address >= segment.base_address
&& address + u64::from(size) <= segment.base_address + segment.bytes.len() as u64
&& u64::from(size) <= segment.base_address + segment.bytes.len() as u64
&& address <= segment.base_address + segment.bytes.len() as u64 - u64::from(size)
{
if segment.write_flag {
// The segment is writeable, thus we do not know the content at runtime.
......
......@@ -43,7 +43,7 @@ pub fn is_sink_call_reachable_from_source_call(
// and as long as the edge does not leave the function.
match edge.weight() {
Edge::Block
| Edge::CRCallStub
| Edge::CrCallStub
| Edge::CallCombine(_)
| Edge::ReturnCombine(_)
| Edge::Jump(_, _)
......@@ -53,7 +53,7 @@ pub fn is_sink_call_reachable_from_source_call(
worklist.push(edge.target())
}
}
Edge::Call(_) | Edge::CRReturnStub => (), // These edges would leave the function control flow graph.
Edge::Call(_) | Edge::CrReturnStub => (), // These edges would leave the function control flow graph.
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment