Commit 2db57c45 by Enkelmann Committed by Thomas Barabosch

Type inference (#14)

* Initial version of Type inference
parent c0716c2c
......@@ -204,16 +204,20 @@ dmypy.json
pyvenv.cfg
pip-selfcheck.json
# End of https://www.gitignore.io/api/c,ocaml,python
# dont upload our real life zoo
test/real_world_samples
test/run_real_world_samples.sh
# End of https://www.gitignore.io/api/c,ocaml,python
.project
.pydevproject
src/cwe_checker.plugin
# Plugin files (generated by bapbuild)
*.plugin
# install files for opam packages (generated by dune)
*.install
test/artificial_samples/dockcross*
......
......@@ -9,6 +9,8 @@
- Added BAP recipe for standard cwe_checker run (PR #9)
- Improved check for CWE-476 (NULL Pointer Dereference) using data flow analysis (PR #11)
- Switched C build system from make to scons (PR #16)
- Added type inference pass (PR #14)
- Added unit tests to test suite (PR #14)
0.1 (2018-10-08)
=====
......
.PHONY: all clean test uninstall
all:
cd src; bapbuild -r -Is checkers,utils -pkgs yojson,unix cwe_checker.plugin; bapbundle install cwe_checker.plugin; cd ..
dune build --profile release
dune install
cd plugins/cwe_checker; make all; cd ../..
cd plugins/cwe_checker_type_inference; make all; cd ../..
cd plugins/cwe_checker_type_inference_print; make all; cd ../..
test:
dune runtest --profile release # TODO: correct all dune linter warnings so that we can remove --profile release
pytest -v
clean:
dune clean
bapbuild -clean
cd test/unit; make clean; cd ../..
cd plugins/cwe_checker; make clean; cd ../..
cd plugins/cwe_checker_type_inference; make clean; cd ../..
cd plugins/cwe_checker_type_inference_print; make clean; cd ../..
uninstall:
bapbundle remove cwe_checker.plugin
dune uninstall
cd plugins/cwe_checker; make uninstall; cd ../..
cd plugins/cwe_checker_type_inference; make uninstall; cd ../..
cd plugins/cwe_checker_type_inference_print; make uninstall; cd ../..
(lang dune 1.6)
(name cwe_checker)
all:
bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker.plugin
bapbundle install cwe_checker.plugin
clean:
bapbuild -clean
uninstall:
bapbundle remove cwe_checker.plugin
open Core_kernel.Std
open Core_kernel
open Bap.Std
open Graphlib.Std
open Format
open Yojson.Basic.Util
open Cwe_checker_core
include Self()
......
all:
bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker_type_inference.plugin
bapbundle install cwe_checker_type_inference.plugin
clean:
bapbuild -clean
uninstall:
bapbundle remove cwe_checker_type_inference.plugin
open Bap.Std
open Core_kernel
open Cwe_checker_core
let () = Project.register_pass Type_inference.compute_pointer_register
all:
bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker_type_inference_print.plugin
bapbundle install cwe_checker_type_inference_print.plugin
clean:
bapbuild -clean
uninstall:
bapbundle remove cwe_checker_type_inference_print.plugin
open Bap.Std
open Core_kernel
open Cwe_checker_core
let main project =
Log_utils.set_log_level Log_utils.DEBUG;
Log_utils.set_output stdout;
Log_utils.color_on ();
let program = Project.program project in
let tid_map = Address_translation.generate_tid_map program in
Type_inference.print_type_info_tags project tid_map
let () = Project.register_pass' main ~deps:["cwe-checker-type-inference"]
Prints the results of the type inference pass of each block.
(option pass cwe-type-inference-print)
open Bap.Std
open Core_kernel
let (+), (-) = Bitvector.(+), Bitvector.(-)
let (>) x y = Bitvector.(>) (Bitvector.signed x) (Bitvector.signed y)
let (<) x y = Bitvector.(<) (Bitvector.signed x) (Bitvector.signed y)
let (>=) x y = Bitvector.(>=) (Bitvector.signed x) (Bitvector.signed y)
let (<=) x y = Bitvector.(<=) (Bitvector.signed x) (Bitvector.signed y)
let (=) x y = Bitvector.(=) x y
type 'a mem_node = {
pos: Bitvector.t; (* address of the element *)
size: Bitvector.t; (* size (in bytes) of the element *)
data: ('a, unit) Result.t;
} [@@deriving bin_io, compare, sexp]
type 'a t = 'a mem_node list [@@deriving bin_io, compare, sexp]
let empty () : 'a t =
[]
(** Return an error mem_node at the given position with the given size. *)
let error_elem ~pos ~size =
{ pos = pos;
size = size;
data = Error ();}
let rec add mem_region elem ~pos ~size =
let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in
let new_node = {
pos=pos;
size=size;
data=Ok(elem);
} in
match mem_region with
| [] -> new_node :: []
| head :: tail ->
if head.pos + head.size <= pos then
head :: (add tail elem ~pos ~size)
else if pos + size <= head.pos then
new_node :: mem_region
else begin (* head and new node intersect => at the intersection, head gets overwritten and the rest of head gets marked as error. *)
let tail = if head.pos + head.size > pos + size then (* mark the right end of head as error *)
let err = error_elem ~pos:(pos + size) ~size:(head.pos + head.size - (pos + size)) in
err :: tail
else
tail in
let tail = add tail elem ~pos ~size in (* add the new element*)
let tail = if head.pos < pos then (* mark the left end of head as error *)
let err = error_elem ~pos:(head.pos) ~size:(pos - head.pos) in
err :: tail
else
tail in
tail
end
let rec get mem_region pos =
match mem_region with
| [] -> None
| head :: tail ->
if head.pos > pos then
None
else if head.pos = pos then
match head.data with
| Ok(x) -> Some(Ok(x, head.size))
| Error(_) -> Some(Error(()))
else if head.pos + head.size <= pos then
get tail pos
else
Some(Error(())) (* pos intersects some data, but does not equal its starting address*)
(* Helper function. Removes all elements with position <= pos. *)
let rec remove_until mem_region pos =
match mem_region with
| [] -> []
| hd :: tl ->
if hd.pos <= pos then
remove_until tl pos
else
mem_region
let rec remove mem_region ~pos ~size =
let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in
match mem_region with
| [] -> []
| hd :: tl ->
if hd.pos + hd.size <= pos then
hd :: remove tl pos size
else if pos + size <= hd.pos then
mem_region
else
let mem_region = remove tl pos size in
let mem_region =
if hd.pos + hd.size > pos + size then
error_elem ~pos:(pos + size) ~size:(hd.pos + hd.size - (pos + size)) :: mem_region
else
mem_region in
let mem_region =
if hd.pos < pos then
error_elem ~pos:hd.pos ~size:(pos - hd.pos) :: mem_region
else
mem_region in
mem_region
let rec mark_error mem_region ~pos ~size =
let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in
match mem_region with
| [] -> (error_elem pos size) :: []
| hd :: tl ->
if hd.pos + hd.size <= pos then
hd :: (mark_error tl pos size)
else if pos + size <= hd.pos then
(error_elem pos size) :: mem_region
else
let start_pos = min pos hd.pos in
let end_pos_plus_one = max (pos + size) (hd.pos + hd.size) in
mark_error tl ~pos:start_pos ~size:(end_pos_plus_one - start_pos)
(* TODO: This is probably a very inefficient implementation in some cases. Write a faster implementation if necessary. *)
let rec merge mem_region1 mem_region2 ~data_merge =
match (mem_region1, mem_region2) with
| (value, [])
| ([], value) -> value
| (hd1 :: tl1, hd2 :: tl2) ->
if hd1.pos + hd1.size <= hd2.pos then
hd1 :: merge tl1 mem_region2 data_merge
else if hd2.pos + hd2.size <= hd1.pos then
hd2 :: merge mem_region1 tl2 data_merge
else if hd1.pos = hd2.pos && hd1.size = hd2.size then
match (hd1.data, hd2.data) with
| (Ok(data1), Ok(data2)) -> begin
match data_merge data1 data2 with
| Some(Ok(value)) -> { hd1 with data = Ok(value) } :: merge tl1 tl2 ~data_merge
| Some(Error(_)) -> {hd1 with data = Error(())} :: merge tl1 tl2 ~data_merge
| None -> merge tl1 tl2 data_merge
end
| _ -> { hd1 with data = Error(()) } :: merge tl1 tl2 ~data_merge
else
let start_pos = min hd1.pos hd2.pos in
let end_pos_plus_one = max (hd1.pos + hd1.size) (hd2.pos + hd2.size) in
let mem_region = merge tl1 tl2 data_merge in
mark_error mem_region ~pos:start_pos ~size:(end_pos_plus_one - start_pos)
let rec equal (mem_region1:'a t) (mem_region2:'a t) ~data_equal : bool =
match (mem_region1, mem_region2) with
| ([], []) -> true
| (hd1 :: tl1, hd2 :: tl2) ->
if hd1.pos = hd2.pos && hd1.size = hd2.size then
match (hd1.data, hd2.data) with
| (Ok(data1), Ok(data2)) when data_equal data1 data2 ->
equal tl1 tl2 data_equal
| (Error(()), Error(())) -> equal tl1 tl2 data_equal
| _ -> false
else
false
| _ -> false
(** contains an abstract memory region data type where you can assign arbitrary data to locations
inside the memory regions. A memory region has no fixed size, so it can be used
for memory regions of variable size like arrays or stacks.
TODO: Right now this data structure is unsuited for elements that get only partially loaded. *)
open Bap.Std
open Core_kernel
type 'a t [@@deriving bin_io, compare, sexp]
(** Get an empty memory region- *)
val empty: unit -> 'a t
(** Add an element to the memory region. If the element intersects existing elements,
the non-overwritten part gets marked as Error *)
val add: 'a t -> 'a -> pos:Bitvector.t -> size:Bitvector.t -> 'a t
(** Mark the memory region between pos (included) and pos+size (excluded) as empty.
If elements get partially removed, mark the non-removed parts as Error *)
val remove: 'a t -> pos:Bitvector.t -> size:Bitvector.t -> 'a t
(** Returns the element and its size at position pos or None, when there is no element at that position.
If pos intersects an element but does not match its starting position, it returns Some(Error(())). *)
val get: 'a t -> Bitvector.t -> (('a * Bitvector.t), unit) Result.t Option.t
(** Merge two memory regions. Elements with the same position and size get merged using
data_merge, other intersecting elements get marked as Error. Note that data_merge
may return None (to remove the elements from the memory region) or Some(Error(_)) to
mark the merged element as error. *)
val merge: 'a t -> 'a t -> data_merge:('a -> 'a -> ('a, 'b) result option) -> 'a t
(** Check whether two memory regions are equal. *)
val equal: 'a t -> 'a t -> data_equal:('a -> 'a -> bool) -> bool
(** Mark an area in the mem_region as containing errors. *)
val mark_error: 'a t -> pos:Bitvector.t -> size:Bitvector.t -> 'a t
(* This file contains analysis passes for type recognition *)
open Bap.Std
open Core_kernel
(** The register type. *)
module Register : sig
type t =
| Pointer
| Data
[@@deriving bin_io, compare, sexp]
end
module TypeInfo : sig
type reg_state = (Register.t, unit) Result.t Var.Map.t [@@deriving bin_io, compare, sexp]
type t = {
stack: Register.t Mem_region.t;
stack_offset: (Bitvector.t, unit) Result.t Option.t;
reg: reg_state;
} [@@deriving bin_io, compare, sexp]
(* Pretty Printer. At the moment, the output is not pretty at all. *)
val pp: Format.formatter -> t -> unit
end
val type_info_tag: TypeInfo.t Value.tag
(** Computes TypeInfo for the given project. Adds tags to each block containing the
TypeInfo at the start of the block. *)
val compute_pointer_register: Project.t -> Project.t
(** Print type info tags. TODO: If this should be used for more than debug purposes,
then the output format should be refactored accordingly. *)
val print_type_info_tags: project:Project.t -> tid_map:word Tid.Map.t -> unit
(** Updates the type info for a single element (Phi/Def/Jmp) of a block. Input
is the type info before execution of the element, output is the type info
after execution of the element. *)
val update_type_info: Blk.elt -> TypeInfo.t -> project:Project.t -> TypeInfo.t
(* functions made public for unit tests: *)
module Test : sig
val update_block_analysis: Blk.t -> TypeInfo.t -> project:Project.t -> TypeInfo.t
end
......@@ -239,13 +239,13 @@ let print_hit tid ~sub ~function_names ~tid_map =
match Jmp.kind jmp with
| Call(call) -> begin
match Call.target call with
| Direct(call_tid) -> Option.is_some (List.find function_names ~f:(fun name ->
if name = (Tid.name call_tid) then begin
| Direct(call_tid) -> Option.is_some (List.find function_names ~f:(fun fn_name ->
if fn_name = (Tid.name call_tid) then begin
Log_utils.warn "[%s] {%s} (NULL Pointer Dereference) There is no check if the return value is NULL at %s (%s)."
name
version
(Address_translation.translate_tid_to_assembler_address_string tid tid_map)
name;
fn_name;
true
end else
false
......@@ -273,7 +273,7 @@ let check_cwe prog proj tid_map symbol_names parameters =
Seq.iter subfunctions ~f:(fun subfn ->
let cfg = Sub.to_cfg subfn in
let cwe_hits = ref [] in
let empty = Map.empty Graphs.Ir.Node.comparator in
let empty = Map.empty (module Graphs.Ir.Node) in
let init = Graphlib.Std.Solution.create empty [] in
let equal = State.equal in
let merge = State.union in
......
(library
(name cwe_checker_core)
(public_name cwe_checker_core)
(libraries
yojson
bap
core_kernel
core)
(preprocess (pps ppx_jane))
)
(include_subdirs unqualified) ; Include all subdirs when looking for source files
(lang dune 1.6)
......@@ -11,5 +11,5 @@ let generate_tid_map prog =
inherit [addr Tid.Map.t] Term.visitor
method enter_term _ t addrs = match Term.get_attr t address with
| None -> addrs
| Some addr -> Map.add addrs ~key:(Term.tid t) ~data:addr
| Some addr -> Map.add_exn addrs ~key:(Term.tid t) ~data:addr
end)#run prog Tid.Map.empty
open Bap.Std
open Core_kernel
let dyn_syms = ref None
let callee_saved_registers = ref None
(** Return a list of registers that are callee-saved.
TODO: At least ARMv7 and PPC have floating point registers that are callee saved. Check their names in bap and then add them. *)
let callee_saved_register_list project =
let arch = Project.arch project in
match arch with
| `x86_64 -> (* System V ABI *)
"RBX" :: "RSP" :: "RBP" :: "R12" :: "R13" :: "R14" :: "R15" :: []
| `x86_64 -> (* Microsoft x64 calling convention *) (* TODO: How to distinguish from System V? For the time being, only use the System V ABI, since it saves less registers. *)
"RBX" :: "RBP" :: "RDI" :: "RSI" :: "RSP" :: "R12" :: "R13" :: "R14" :: "R15" :: []
| `x86 -> (* Both Windows and Linux save the same registers *)
"EBX" :: "ESI" :: "EDI" :: "EBP" :: []
| `armv4 | `armv5 | `armv6 | `armv7
| `armv4eb | `armv5eb | `armv6eb | `armv7eb
| `thumbv4 | `thumbv5 | `thumbv6 | `thumbv7
| `thumbv4eb | `thumbv5eb | `thumbv6eb | `thumbv7eb -> (* ARM 32bit. R13 and SP are both names for the stack pointer. *)
"R4" :: "R5" :: "R6" :: "R7" :: "R8" :: "R9" :: "R10" :: "R11" :: "R13" :: "SP" :: []
| `aarch64 | `aarch64_be -> (* ARM 64bit *) (* TODO: This architecture is not contained in the acceptance tests yet? *)
"X19" :: "X20" :: "X21" :: "X22" :: "X23" :: "X24" :: "X25" :: "X26" :: "X27" :: "X28" :: "X29" :: "SP" :: []
| `ppc (* 32bit PowerPC *) (* TODO: add floating point registers. *) (* TODO: add CR2, CR3, CR4. Test their representation in bap first. *)
| `ppc64 | `ppc64le -> (* 64bit PowerPC *)
"R14" :: "R15" :: "R16" :: "R17" :: "R18" :: "R19" :: "R20" :: "R21" :: "R22" :: "R23" ::
"R24" :: "R25" :: "R26" :: "R27" :: "R28" :: "R29" :: "R30" :: "R31" :: "R1" :: "R2" :: []
| `mips | `mips64 | `mips64el | `mipsel -> (* S8 and FP are the same register. bap uses FP, S8 is left there just in case. *)
"S0" :: "S1" :: "S2" :: "S3" :: "S4" :: "S5" :: "S6" :: "S7" :: "S8" :: "GP" :: "SP" :: "FP" :: []
| _ -> failwith "No calling convention implemented for the given architecture."
let is_callee_saved var project =
match !callee_saved_registers with
| Some(register_set) -> String.Set.mem register_set (Var.name var)
| None ->
callee_saved_registers := Some(String.Set.of_list (callee_saved_register_list project));
String.Set.mem (Option.value_exn !callee_saved_registers) (Var.name var)
(** Parse a line from the dyn-syms output table of readelf. Return the name of a symbol if the symbol is an extern function name. *)
let parse_dyn_sym_line line =
let line = ref (String.strip line) in
let str_list = ref [] in
while Option.is_some (String.rsplit2 !line ~on:' ') do
let (left, right) = Option.value_exn (String.rsplit2 !line ~on:' ') in
line := String.strip left;
str_list := right :: !str_list;
done;
match !str_list with
| _ :: value :: _ :: "FUNC" :: _ :: _ :: _ :: name :: [] -> begin
match ( String.strip ~drop:(fun x -> x = '0') value, String.lsplit2 name ~on:'@') with
| ("", Some(left, _)) -> Some(left)
| ("", None) -> Some(name)
| _ -> None (* The symbol has a nonzero value, so we assume that it is not an extern function symbol. *)
end
| _ -> None
let parse_dyn_syms project =
match !dyn_syms with
| Some(symbol_set) -> symbol_set
| None ->
match Project.get project filename with
| None -> failwith "[CWE-checker] Project has no file name."
| Some(fname) -> begin
let cmd = Format.sprintf "readelf --dyn-syms %s" fname in
try
let in_chan = Unix.open_process_in cmd in
let lines = In_channel.input_lines in_chan in
let () = In_channel.close in_chan in begin
match lines with
| _ :: _ :: _ :: tail -> (* The first three lines are not part of the table *)
let symbol_set = String.Set.of_list (List.filter_map tail ~f:parse_dyn_sym_line) in
dyn_syms := Some(symbol_set);
symbol_set
| _ ->
dyn_syms := Some(String.Set.empty);
String.Set.empty (* *)
end
with
Unix.Unix_error (e,fm,argm) ->
failwith (Format.sprintf "[CWE-checker] Parsing of dynamic symbols failed: %s %s %s" (Unix.error_message e) fm argm)
end
open Bap.Std
open Core_kernel
(** Returns whether a variable is callee saved according to the calling convention
of the target architecture. Should only used for calls to functions outside
of the program, not for calls between functions inside the program. *)
val is_callee_saved: Var.t -> Project.t -> bool
(** Returns a list of those function names that are extern symbols.
TODO: Since we do not do name demangling here, check whether bap name demangling
yields different function names for the symbols. *)
val parse_dyn_syms: Project.t -> String.Set.t
......@@ -147,3 +147,17 @@ let get_program_entry_points program =
let entry_points = Seq.filter subfunctions ~f:(fun subfn -> Term.has_attr subfn Sub.entry_point) in
let main_fn = Seq.filter subfunctions ~f:(fun subfn -> "@main" = Tid.name (Term.tid subfn)) in
Seq.append main_fn entry_points
let stack_register project =
let arch = Project.arch project in
let module Target = (val target_of_arch arch) in
Target.CPU.sp
let flag_register_list project =
let arch = Project.arch project in
let module Target = (val target_of_arch arch) in
Target.CPU.zf :: Target.CPU.cf :: Target.CPU.vf :: Target.CPU.nf :: []
let arch_pointer_size_in_bytes project : int =
let arch = Project.arch project in
Size.in_bytes (Arch.addr_size arch)
......@@ -68,3 +68,15 @@ val extract_direct_call_tid_from_block : Bap.Std.blk Bap.Std.term -> Bap.Std.tid
TODO: The _start entry point usually calls a libc-function which then calls the main function. Since right now only direct
calls are tracked, our graph traversal may never find the main function. For now, we add it by hand to the entry points. *)
val get_program_entry_points : Bap.Std.program Bap.Std.term -> Bap.Std.sub Bap.Std.term Bap.Std.Seq.t
(** Returns the stack register on the architecture of the given project. *)
val stack_register: Bap.Std.Project.t -> Bap.Std.Var.t
(** Returns a list of the known flag registers on the architecture of the given project.
TODO: Right now it only returns flag registers that exist on all architectures.
We should add known architecture dependend flag registers, too. *)
val flag_register_list: Bap.Std.Project.t -> Bap.Std.Var.t list
(** Returns the pointer size in bytes on the architecture of the given project. *)
val arch_pointer_size_in_bytes: Bap.Std.Project.t -> int
all:
bapbundle remove cwe_checker_unit_tests.plugin
bapbuild -r -Is analysis cwe_checker_unit_tests.plugin -pkgs core,alcotest,yojson,unix,ppx_jane,cwe_checker_core
bapbundle install cwe_checker_unit_tests.plugin
bap ../artificial_samples/build/arrays_x64.out --pass=cwe-checker-unit-tests
bapbundle remove cwe_checker_unit_tests.plugin
clean:
bapbuild -clean
open Bap.Std
open Core_kernel
open Cwe_checker_core
let check msg x = Alcotest.(check bool) msg true x
let test_add () : unit =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "Five" ~pos:(bv 3) ~size:(bv 5) in
let x = Mem_region.add x "Seven" ~pos:(bv 9) ~size:(bv 7) in
let x = Mem_region.add x "Three" ~pos:(bv 0) ~size:(bv 3) in
check "add_ok" (Some(Ok("Five", bv 5)) = (Mem_region.get x (bv 3)));
check "add_err" (Some(Error(())) = (Mem_region.get x (bv 1)));
check "add_none" (None = (Mem_region.get x (bv 8)))
let test_minus () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv (-8)) ~size:(bv 8) in
check "negative_index" (Some(Ok("One", bv 8)) = Mem_region.get x (Bitvector.unsigned (bv (-8))))
let test_remove () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 11) in
let x = Mem_region.remove x ~pos:(bv 5) ~size:(bv 20) in
check "remove_error_before" (Some(Error()) = Mem_region.get x (bv 4));
check "remove_none1" (None = Mem_region.get x (bv 5));
check "remove_none2" (None = Mem_region.get x (bv 24));
check "remove_error_after1" (Some(Error()) = Mem_region.get x (bv 25));
check "remove_error_after2" (None = Mem_region.get x (bv 26))
let test_mark_error () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.mark_error x ~pos:(bv 5) ~size:(bv 10) in
check "mark_error1" (Some(Error()) = Mem_region.get x (bv 0));
check "mark_error2" (Some(Error()) = Mem_region.get x (bv 14));
check "mark_error3" (None = Mem_region.get x (bv 15))
let test_merge () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 5) in
let x = Mem_region.add x "Three" ~pos:(bv 25) ~size:(bv 5) in
let y = Mem_region.empty () in
let y = Mem_region.add y "One" ~pos:(bv 1) ~size:(bv 10) in
let y = Mem_region.add y "Two" ~pos:(bv 15) ~size:(bv 5) in
let y = Mem_region.add y "Four" ~pos:(bv 25) ~size:(bv 5) in
let merge_fn a b = if a = b then Some(Ok(a)) else Some(Error()) in
let z = Mem_region.merge x y ~data_merge:merge_fn in
check "merge_intersect" (Some(Error()) = Mem_region.get z (bv 0));
check "merge_match_ok" (Some(Ok("Two", bv 5)) = Mem_region.get z (bv 15));
check "merge_match_error" (Some(Error()) = Mem_region.get z (bv 25))
let test_equal () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 5) in
let y = Mem_region.empty () in
let y = Mem_region.add y "Two" ~pos:(bv 15) ~size:(bv 5) in
check "equal_no" (false = (Mem_region.equal x y ~data_equal:(fun x y -> x = y)));
let y = Mem_region.add y "One" ~pos:(bv 0) ~size:(bv 10) in
check "equal_yes" (Mem_region.equal x y ~data_equal:(fun x y -> x = y))
let test_around_zero () =
let bv num = Bitvector.of_int num ~width:32 in
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv (-5)) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv 0) ~size:(bv 10) in
check "around_zero1" (Some(Error()) = Mem_region.get x (bv (-5)));
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv (-5)) ~size:(bv 10) in
check "around_zero2" (Some(Error()) = Mem_region.get x (bv 0));
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv (-5)) ~size:(bv 20) in
let x = Mem_region.add x "Two" ~pos:(bv 0) ~size:(bv 10) in
check "around_zero3" (Some(Error()) = Mem_region.get x (bv (-5)));
let x = Mem_region.empty () in
let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in
let x = Mem_region.add x "Two" ~pos:(bv (-5)) ~size:(bv 20) in
check "around_zero2" (Some(Error()) = Mem_region.get x (bv 0))
let tests = [
"Add", `Quick, test_add;
"Negative Indices", `Quick, test_minus;
"Remove", `Quick, test_remove;
"Mark_error", `Quick, test_mark_error;
"Merge", `Quick, test_merge;
"Equal", `Quick, test_equal;
"Around Zero", `Quick, test_around_zero;
]
open Bap.Std
open Core_kernel
val tests: unit Alcotest.test_case list
open Bap.Std
open Core_kernel
open Cwe_checker_core
open Type_inference
open Type_inference.Test
let check msg x = Alcotest.(check bool) msg true x
let example_project = ref None
(* TODO: As soon as more pointers than stack pointer are tracked, add more tests! *)
let create_block_from_defs def_list =
let block = Blk.Builder.create () in
let () = List.iter def_list ~f:(fun def -> Blk.Builder.add_def block def) in
Blk.Builder.result block
let start_state stack_register project =
let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in
let start_reg = Var.Map.empty in
let start_reg = Map.add_exn start_reg ~key:stack_register ~data:(Ok(Register.Pointer)) in
{ TypeInfo.stack = Mem_region.empty ();
TypeInfo.stack_offset = Some (Ok(bv 0));
TypeInfo.reg = start_reg;
}
let test_update_stack_offset () =
let project = Option.value_exn !example_project in
let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in
let stack_register = Symbol_utils.stack_register project in
let fn_start_state = start_state stack_register project in
let def1 = Def.create stack_register (Bil.binop Bil.plus (Bil.var stack_register) (Bil.int (bv 8))) in
let def2 = Def.create stack_register (Bil.binop Bil.minus (Bil.var stack_register) (Bil.int (bv 16))) in
let block = create_block_from_defs [def1; def2] in
let state = update_block_analysis block fn_start_state project in
let () = check "update_stack_offset" (state.TypeInfo.stack_offset = Some(Ok(Bitvector.unsigned (bv (-8))))) in
()
let test_update_reg () =
let project = Option.value_exn !example_project in
let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in
let stack_register = Symbol_utils.stack_register project in
let fn_start_state = start_state stack_register project in
let register1 = Var.create "Register1" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in
let register2 = Var.create "Register2" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in
let def1 = Def.create register1 (Bil.binop Bil.AND (Bil.var stack_register) (Bil.int (bv 8))) in
let def2 = Def.create register2 (Bil.binop Bil.XOR (Bil.var register1) (Bil.var stack_register)) in
let block = create_block_from_defs [def1; def2] in
let state = update_block_analysis block fn_start_state project in
let () = check "update_pointer_register" (Var.Map.find state.TypeInfo.reg register1 = Some(Ok(Pointer))) in
let () = check "update_data_register" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Data))) in
let def1 = Def.create register1 (Bil.Load (Bil.var register1, Bil.var register2, Bitvector.LittleEndian, `r64) ) in
let block = create_block_from_defs [def1;] in
let state = update_block_analysis block fn_start_state project in
let () = check "add_mem_address_registers" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Pointer))) in
()
let test_update_stack () =
let project = Option.value_exn !example_project in
let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in
let stack_register = Symbol_utils.stack_register project in
let fn_start_state = start_state stack_register project in
let register1 = Var.create "Register1" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in
let register2 = Var.create "Register2" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in
let mem_reg = Var.create "Mem_reg" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in
let def1 = Def.create register1 (Bil.binop Bil.AND (Bil.var stack_register) (Bil.int (bv 8))) in
let def2 = Def.create mem_reg (Bil.Store ((Bil.var mem_reg), (Bil.binop Bil.PLUS (Bil.var stack_register) (Bil.int (bv (-8)))), (Bil.var stack_register), Bitvector.LittleEndian, `r64)) in
let def3 = Def.create register2 (Bil.Load (Bil.var register2, (Bil.binop Bil.MINUS (Bil.var stack_register) (Bil.int (bv 8))), Bitvector.LittleEndian, `r64) ) in
let block = create_block_from_defs [def1; def2; def3;] in
let state = update_block_analysis block fn_start_state project in
let () = check "write_to_stack" ((Mem_region.get state.TypeInfo.stack (bv (-8))) = Some(Ok(Pointer, bv (Symbol_utils.arch_pointer_size_in_bytes project)))) in
let () = check "load_from_stack" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Pointer))) in
()
let tests = [
"Update Stack Offset", `Quick, test_update_stack_offset;
"Update Register", `Quick, test_update_reg;
"Update Stack", `Quick, test_update_stack;
]
open Bap.Std
open Core_kernel
val example_project: Project.t option ref
val tests: unit Alcotest.test_case list
open Bap.Std
open Core_kernel
open Cwe_checker_core
let run_tests project =
Type_inference_test.example_project := Some(project);
Alcotest.run "Unit tests" ~argv:[|"DoNotComplainWhenRunAsABapPlugin";"--color=always";|] [
"Mem_region_tests", Mem_region_test.tests;
"Type_inference_tests", Type_inference_test.tests;
]
let () =
(* Check whether this file is run as an executable (via dune runtest) or
as a bap plugin *)
if Sys.argv.(0) = "bap" then
(* The file was run as a bap plugin. *)
Project.register_pass' run_tests
else
(* The file was run as a standalone executable. Use make to build and run the unit test plugin *)
let () = Sys.chdir (Sys.getenv "PWD" ^ "/test/unit") in
exit (Sys.command "make all")
(executable
(name cwe_checker_unit_tests)
(libraries
alcotest
yojson
bap
cwe_checker_core
core_kernel
core)
(preprocess (pps ppx_jane))
)
(include_subdirs unqualified) ; Include all subdirs when looking for source files
(alias
(name runtest)
(deps cwe_checker_unit_tests.exe)
(action (run %{deps} --color=always)))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment