open Bap.Std open Core_kernel let (+), (-) = Bitvector.(+), Bitvector.(-) let (>) x y = Bitvector.(>) (Bitvector.signed x) (Bitvector.signed y) let (<) x y = Bitvector.(<) (Bitvector.signed x) (Bitvector.signed y) (* let (>=) x y = Bitvector.(>=) (Bitvector.signed x) (Bitvector.signed y) *) let (<=) x y = Bitvector.(<=) (Bitvector.signed x) (Bitvector.signed y) let (=) x y = Bitvector.(=) x y type 'a mem_node = { pos: Bitvector.t; (* address of the element *) size: Bitvector.t; (* size (in bytes) of the element *) data: ('a, unit) Result.t; } [@@deriving bin_io, compare, sexp] type 'a t = 'a mem_node list [@@deriving bin_io, compare, sexp] let empty () : 'a t = [] (** Return an error mem_node at the given position with the given size. *) let error_elem ~pos ~size = { pos = pos; size = size; data = Error ();} let rec add mem_region elem ~pos ~size = let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in let new_node = { pos=pos; size=size; data=Ok(elem); } in match mem_region with | [] -> new_node :: [] | head :: tail -> if head.pos + head.size <= pos then head :: (add tail elem ~pos ~size) else if pos + size <= head.pos then new_node :: mem_region else begin (* head and new node intersect => at the intersection, head gets overwritten and the rest of head gets marked as error. *) let tail = if head.pos + head.size > pos + size then (* mark the right end of head as error *) let err = error_elem ~pos:(pos + size) ~size:(head.pos + head.size - (pos + size)) in err :: tail else tail in let tail = add tail elem ~pos ~size in (* add the new element*) let tail = if head.pos < pos then (* mark the left end of head as error *) let err = error_elem ~pos:(head.pos) ~size:(pos - head.pos) in err :: tail else tail in tail end let rec get mem_region pos = match mem_region with | [] -> None | head :: tail -> if head.pos > pos then None else if head.pos = pos then match head.data with | Ok(x) -> Some(Ok(x, head.size)) | Error(_) -> Some(Error(())) else if head.pos + head.size <= pos then get tail pos else Some(Error(())) (* pos intersects some data, but does not equal its starting address*) let rec remove mem_region ~pos ~size = let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in match mem_region with | [] -> [] | hd :: tl -> if hd.pos + hd.size <= pos then hd :: remove tl ~pos ~size else if pos + size <= hd.pos then mem_region else let mem_region = remove tl ~pos ~size in let mem_region = if hd.pos + hd.size > pos + size then error_elem ~pos:(pos + size) ~size:(hd.pos + hd.size - (pos + size)) :: mem_region else mem_region in let mem_region = if hd.pos < pos then error_elem ~pos:hd.pos ~size:(pos - hd.pos) :: mem_region else mem_region in mem_region let rec mark_error mem_region ~pos ~size = let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in match mem_region with | [] -> (error_elem ~pos ~size) :: [] | hd :: tl -> if hd.pos + hd.size <= pos then hd :: (mark_error tl ~pos ~size) else if pos + size <= hd.pos then (error_elem ~pos ~size) :: mem_region else let start_pos = Word.min pos hd.pos in let end_pos_plus_one = Word.max (pos + size) (hd.pos + hd.size) in mark_error tl ~pos:start_pos ~size:(end_pos_plus_one - start_pos) (* TODO: This is probably a very inefficient implementation in some cases. Write a faster implementation if necessary. *) let rec merge mem_region1 mem_region2 ~data_merge = match (mem_region1, mem_region2) with | (value, []) | ([], value) -> value | (hd1 :: tl1, hd2 :: tl2) -> if hd1.pos + hd1.size <= hd2.pos then hd1 :: merge tl1 mem_region2 ~data_merge else if hd2.pos + hd2.size <= hd1.pos then hd2 :: merge mem_region1 tl2 ~data_merge else if hd1.pos = hd2.pos && hd1.size = hd2.size then match (hd1.data, hd2.data) with | (Ok(data1), Ok(data2)) -> begin match data_merge data1 data2 with | Some(Ok(value)) -> { hd1 with data = Ok(value) } :: merge tl1 tl2 ~data_merge | Some(Error(_)) -> {hd1 with data = Error(())} :: merge tl1 tl2 ~data_merge | None -> merge tl1 tl2 ~data_merge end | _ -> { hd1 with data = Error(()) } :: merge tl1 tl2 ~data_merge else let start_pos = Word.min hd1.pos hd2.pos in let end_pos_plus_one = Word.max (hd1.pos + hd1.size) (hd2.pos + hd2.size) in let mem_region = merge tl1 tl2 ~data_merge in mark_error mem_region ~pos:start_pos ~size:(end_pos_plus_one - start_pos) let rec equal (mem_region1:'a t) (mem_region2:'a t) ~data_equal : bool = match (mem_region1, mem_region2) with | ([], []) -> true | (hd1 :: tl1, hd2 :: tl2) -> if hd1.pos = hd2.pos && hd1.size = hd2.size then match (hd1.data, hd2.data) with | (Ok(data1), Ok(data2)) when data_equal data1 data2 -> equal tl1 tl2 ~data_equal | (Error(()), Error(())) -> equal tl1 tl2 ~data_equal | _ -> false else false | _ -> false let map_data (mem_region: 'a t) ~(f: 'a -> 'b) : 'b t = List.map mem_region ~f:(fun mem_node -> { pos = mem_node.pos; size = mem_node.size; data = Result.map mem_node.data ~f } ) let list_data (mem_region: 'a t) : 'a List.t = List.filter_map mem_region ~f:(fun mem_node -> match mem_node.data with | Ok(value) -> Some(value) | Error(_) -> None ) let list_data_pos (mem_region: 'a t) : (Bitvector.t * 'a) List.t = List.filter_map mem_region ~f:(fun mem_node -> match mem_node.data with | Ok(value) -> Some( mem_node.pos, value ) | Error(_) -> None )