open Core_kernel.Std
open Bap.Std

type symbol =
  {
    address : tid option;
    name : string;
  }


let find_symbol program name =
  Term.enum sub_t program |>
    Seq.find_map ~f:(fun s -> Option.some_if (Sub.name s = name) (Term.tid s))

let build_symbols symbol_names prog =
  List.map symbol_names ~f:(fun symbol -> let symbol_address = find_symbol prog symbol in
                             {address = symbol_address; name = symbol;})
  |> List.filter ~f:(fun symbol -> match symbol.address with
      | Some _ -> true
      | _ -> false)

let get_symbol tid symbols =
  List.find symbols ~f:(
    fun symbol -> match symbol.address with
      | Some address -> tid = address
      | None -> false)

let get_symbol_name_from_jmp jmp symbols =
    match Jmp.kind jmp with
    | Goto _ | Ret _ | Int (_,_) -> assert(false)
    | Call destination -> begin
        match Call.target destination with
        | Direct addr -> 
          begin
            let symbol = List.find symbols ~f:(fun symbol -> match symbol.address with
                | Some address -> addr = address
                | _ -> assert(false)) in match symbol with
            | Some s -> s.name
            | _ -> assert(false)
          end
        | _ -> assert(false)
      end

let get_direct_callsites_of_sub sub =
Term.enum blk_t sub |>
  Seq.concat_map ~f:(fun blk ->
      Term.enum jmp_t blk |> Seq.filter_map ~f:(fun j ->
          match Jmp.kind j with
          | Goto _ | Ret _ | Int (_,_) -> None
          | Call destination -> begin match Call.target destination with
            | Direct tid -> Some j
            | _ -> None
            end))

let sub_calls_symbol prog sub symbol_name = 
  let symbol_struct = find_symbol prog symbol_name in
  match symbol_struct with
  | Some s -> begin
    let callsites = get_direct_callsites_of_sub sub in
    Seq.exists callsites ~f:(fun callsite -> match Jmp.kind callsite with
            | Goto _ | Ret _ | Int (_,_) -> false
            | Call destination -> match Call.target destination with
              | Direct addr -> addr = s
              | _ -> false)
  end
  | _ -> false

type concrete_call =
  {
    call_site : tid;
    symbol_address : tid;
    name : string;
  }

let call_finder = object
  inherit [(tid * tid) list] Term.visitor
  method! enter_jmp jmp tid_list = match Jmp.kind jmp with
    | Goto _ | Ret _ | Int (_,_) -> tid_list
    | Call destination -> begin
        match Call.target destination with
        | Direct addr -> (Term.tid jmp, addr) :: tid_list
        | _ -> tid_list
      end
end


let transform_call_to_concrete_call (src_tid, dst_tid) symbols = 
  match (get_symbol dst_tid symbols) with
  | Some symbol -> {call_site = src_tid; symbol_address = dst_tid; name = symbol.name}
  | None -> assert(false)

let filter_calls_to_symbols calls symbols =
  List.filter calls ~f:(
    fun (_, dst) -> List.exists symbols ~f:(
        fun symbol -> match symbol.address with
          | Some address -> address = dst
          | None -> false))
|> List.map ~f:(fun call -> transform_call_to_concrete_call call symbols)

let is_interesting_callsite jmp relevant_calls =
  match Jmp.kind jmp with
          | Goto _ | Ret _ | Int (_,_) -> false
          | Call dst -> match Call.target dst with
            | Direct tid -> List.exists relevant_calls ~f:(fun c -> c.symbol_address = tid)
            | _ -> false


let check_calls relevant_calls prog proj tid_map symbols check_func =
  Seq.iter (Term.enum sub_t prog)
    ~f:(fun sub ->
        begin
          Seq.iter (Term.enum blk_t sub)
           ~f:(fun blk -> Seq.iter (Term.enum jmp_t blk)
                  ~f:(fun jmp -> if is_interesting_callsite jmp relevant_calls then
                     check_func proj prog sub blk jmp tid_map symbols))
        end)
  
let get_symbol_call_count_of_sub symbol_name sub prog =
  match find_symbol prog symbol_name with
  | Some s -> begin
                Seq.to_list (get_direct_callsites_of_sub sub) 
                |> List.filter ~f:(fun callsite ->
                    match Jmp.kind callsite with
                    | Goto _ | Ret _ | Int (_,_) -> false
                    | Call destination -> match Call.target destination with
                      | Direct addr -> addr = s
                      | _ -> false)
                |> List.length
              end
  | _ -> 0