Commit a1153f67 by fwkz

ThreadPoolExecutor as context manager

parent b9248c7a
...@@ -33,15 +33,17 @@ class Exploit(exploits.Exploit): ...@@ -33,15 +33,17 @@ class Exploit(exploits.Exploit):
def run(self): def run(self):
self.vulnerabilities = [] self.vulnerabilities = []
executor = threads.ThreadPoolExecutor(self.threads)
executor.feed(utils.iter_modules(utils.EXPLOITS_DIR))
executor.run(self.target_function)
with threads.ThreadPoolExecutor(self.threads) as executor:
for exploit in utils.iter_modules(utils.EXPLOITS_DIR):
executor.submit(self.target_function, exploit)
print_info()
if self.vulnerabilities: if self.vulnerabilities:
print_info()
print_success("Device is vulnerable!") print_success("Device is vulnerable!")
for v in self.vulnerabilities: for v in self.vulnerabilities:
print_info(" - {}".format(v)) print_info(" - {}".format(v))
print_info()
else: else:
print_error("Device is not vulnerable to any exploits!\n") print_error("Device is not vulnerable to any exploits!\n")
......
...@@ -15,33 +15,18 @@ from . import utils ...@@ -15,33 +15,18 @@ from . import utils
data_queue = queue.Queue() data_queue = queue.Queue()
class DataProducerThread(threading.Thread):
def __init__(self, data):
super(DataProducerThread, self).__init__(name=self.__class__.__name__)
self.data = data
def run(self):
for record in self.data:
data_queue.put(record)
def stop(self):
data_queue.queue.clear()
def join_queue(self):
data_queue.join()
class WorkerThread(threading.Thread): class WorkerThread(threading.Thread):
def __init__(self, target, name): def __init__(self, name):
super(WorkerThread, self).__init__(target=target, name=name) super(WorkerThread, self).__init__(name=name)
self.target = target
self.name = name self.name = name
def run(self): def run(self):
while not data_queue.empty(): while not data_queue.empty():
record = data_queue.get() record = data_queue.get()
target = record[0]
args = record[1:]
try: try:
self.target(record) target(*args)
finally: finally:
data_queue.task_done() data_queue.task_done()
...@@ -49,21 +34,19 @@ class WorkerThread(threading.Thread): ...@@ -49,21 +34,19 @@ class WorkerThread(threading.Thread):
class ThreadPoolExecutor(object): class ThreadPoolExecutor(object):
def __init__(self, threads): def __init__(self, threads):
self.threads = threads self.threads = threads
self.data_producer = None self.workers = []
def feed(self, dataset):
self.data_producer = DataProducerThread(dataset)
self.data_producer.start()
time.sleep(0.1)
def run(self, target): def __enter__(self):
workers = [] self.workers = []
for worker_id in xrange(int(self.threads)): for worker_id in xrange(int(self.threads)):
worker = WorkerThread( worker = WorkerThread(
target=target,
name='worker-{}'.format(worker_id), name='worker-{}'.format(worker_id),
) )
workers.append(worker) self.workers.append(worker)
return self
def __exit__(self, *args):
for worker in self.workers:
worker.start() worker.start()
start = time.time() start = time.time()
...@@ -73,10 +56,13 @@ class ThreadPoolExecutor(object): ...@@ -73,10 +56,13 @@ class ThreadPoolExecutor(object):
except KeyboardInterrupt: except KeyboardInterrupt:
utils.print_info() utils.print_info()
utils.print_status("Waiting for already scheduled jobs to finish...") utils.print_status("Waiting for already scheduled jobs to finish...")
self.data_producer.stop() data_queue.queue.clear()
for worker in workers: for worker in self.workers:
worker.join() worker.join()
else: else:
self.data_producer.join_queue() data_queue.join()
utils.print_status('Elapsed time: ', time.time() - start, 'seconds') utils.print_status('Elapsed time: ', time.time() - start, 'seconds')
def submit(self, *args):
data_queue.put(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment