import sys
import subprocess
import os
import re
# pymongo==3.13.0
import pymongo
from bson.objectid import ObjectId
from gridfs import GridFS

# 二进制程序保存的路径
binary_file_path = "/home/fuzz_dir/validate_script/"
# asan编译的二进制程序保存的路径
asan_file_path = "/home/fuzz_dir/validate_script/asan_software"
# crash的种子保存的路径
crashes_file_path = "/home/fuzz_dir/validate_script/crashes/"

pattern_valgrind_head = re.compile(r'==\d+==')
pattern_valgrind_tail = re.compile(r'==\d+== ERROR SUMMARY: [1-9]+')
pattern_valgrind_at = re.compile(r'==\d+== {4}at 0x\w+: ')
pattern_valgrind_by = re.compile(r'==\d+== {4}by 0x\w+: ')

pattern_asan_head = re.compile(r'==\d+==ERROR: AddressSanitizer:')
pattern_asan = re.compile(r' {4}#\d+ 0x\w+ in ')
pattern_asan_0 = re.compile(r' {4}#0 0x\w+ in ')

invalid_cause_dict = dict()


def search_file(dirname):
    paths = []
    for root, dirs, files in os.walk(dirname):
        for file in files:
            if file.startswith("README"):
                continue
            else:
                path = os.path.join(root, file)
                paths.append(path)
    return paths


def generation_command(crashes_collection, target, parameter, seeds, usage="valgrind"):
    for seed in seeds:
        command = usage + " " + target + " " + parameter.replace("@@", seed, 1) + " "
        seed_id = seed.split("/")[-1]
        exec_command(crashes_collection, usage, command, ObjectId(seed_id))


def exec_command(crashes_collection, usage, command, seed_id: ObjectId):
    search_by_count = 0
    error_cause = ''
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    _, errs = process.communicate()
    errs = str(errs, "utf-8", errors='ignore')
    if usage == "valgrind":
        err_data = ''
        search_at_count = 0
        if pattern_valgrind_tail.search(errs):
            for i in errs.splitlines():
                if pattern_valgrind_at.match(i) is not None and search_at_count < 2:
                    search_at_count += 1
                    if search_at_count <= 1:
                        err_data = err_data + i + "\n"
                        _, end = pattern_valgrind_at.search(i).span()
                        error_cause += i[end:]
                elif pattern_valgrind_by.match(i) is not None and search_by_count <= 10 and search_at_count <= 1:
                    search_by_count += 1
                    err_data = err_data + i + "\n"
                    _, end = pattern_valgrind_by.search(i).span()
                    error_cause += i[end:]
                elif pattern_valgrind_tail.match(i) is not None:
                    err_data = err_data + i + "\n"
                    if not invalid_cause_dict.get(error_cause, None):
                        crashes_collection.update_one({"_id": seed_id},
                                                      {"$set": {"valgrind_stderr": err_data}})
                        invalid_cause_dict[error_cause] = 1
                elif pattern_valgrind_head.match(i) is not None:
                    err_data = err_data + i + "\n"
                else:
                    pass
    else:
        is_search_0 = False
        if pattern_asan_head.search(errs) is not None:
            for i in errs.splitlines():
                if pattern_asan_0.match(i) is not None and not is_search_0:
                    is_search_0 = True
                    _, end = pattern_asan_0.search(i).span()
                    error_cause += i[end:]
                elif pattern_asan.match(i) is not None and search_by_count <= 10:
                    search_by_count += 1
                    _, end = pattern_asan.search(i).span()
                    error_cause += i[end:]
            if not invalid_cause_dict.get(error_cause, None):
                crashes_collection.update_one({"_id": seed_id},
                                              {"$set": {"asan_stderr": errs}})
                invalid_cause_dict[error_cause] = 1


def main(argv):
    mongo_address = argv[0]
    mongo_port = argv[1]
    db_name = argv[2]
    default_db_name = argv[3]
    parameter = argv[4]
    software_id = argv[5]
    task_id = argv[6]
    host_node_id = argv[7]
    fs_name = "fs"

    mongo_client = pymongo.MongoClient(f"mongodb://{mongo_address}:{mongo_port}/")
    db = mongo_client[db_name]
    default_db = mongo_client[default_db_name]
    task_collection = default_db["task"]
    host_node_collection = default_db["host_node"]

    # 获取产生crash的种子
    crashes_collection = db["crashes"]
    os.mkdir(crashes_file_path)
    for crash in crashes_collection.find():
        crash_file_path = crashes_file_path + str(crash["_id"])
        crash_file = open(crash_file_path, 'wb')
        crash_file.write(crash["seed"])
        crash_file.close()

    crashes_file_paths = search_file(crashes_file_path)
    try:
        task_collection.update_one({"_id": ObjectId(task_id)}, {"$set": {"verification": "VERIFYING"}})
        host_node = host_node_collection.find_one({"_id": ObjectId(host_node_id)})
        used_cpu = host_node["used_cpu"]
        host_node_collection.update_one({"_id": ObjectId(host_node_id)}, {"$set": {"used_cpu": int(used_cpu) + 1}})

        # 获取二进制程序
        binary_collection = db["binary"]
        if binary_collection.count_documents({}) == 1:
            binary_ret = binary_collection.find_one()
            binary_software = binary_file_path + db_name
            binary_software_file = open(binary_software, 'wb')
            binary_software_file.write(binary_ret["code"])
        else:
            binary_ret = binary_collection.find()
            binary_software = binary_file_path + db_name
            binary_software_file = open(binary_software, 'ab')
            for i in binary_ret:
                binary_software_file.write(i["code"])
        binary_software_file.close()
        os.system(f"chmod +x {binary_software}")
        print("binary download successfully")
        generation_command(crashes_collection, binary_software, parameter, crashes_file_paths, "valgrind")

        # 获取asan编译的二进制程序
        software_collection = default_db["software"]
        software = software_collection.find_one({"_id": ObjectId(software_id)})
        fs = GridFS(default_db, fs_name)
        # 根据ObjectId查找文件
        print(software["asan_file"])
        if software["asan_file"] is not None:
            asan_file_data = fs.get(software["asan_file"])
            if asan_file_data:
                # 确保下载目录存在
                os.makedirs(os.path.dirname(asan_file_path), exist_ok=True)
                # 读取文件内容并保存到本地
                with open(asan_file_path, "wb") as f:
                    f.write(asan_file_data.read())
                os.system(f"chmod +x {asan_file_path}")
                print("asan_file download successfully")
                generation_command(crashes_collection, asan_file_path, parameter, crashes_file_paths, "")
    finally:
        task_collection.update_one({"_id": ObjectId(task_id)}, {"$set": {"verification": "VERIFIED"}})
        host_node = host_node_collection.find_one({"_id": ObjectId(host_node_id)})
        used_cpu = host_node["used_cpu"]
        if int(used_cpu) >= 1:
            host_node_collection.update_one({"_id": ObjectId(host_node_id)}, {"$set": {"used_cpu": int(used_cpu) - 1}})


if __name__ == "__main__":
    main(sys.argv[1:])
