From 7852904f995dfdaf13f66948ed4546953557c226 Mon Sep 17 00:00:00 2001
From: Craig Heffner <heffnercj@gmail.com>
Date: Sun, 22 Nov 2015 23:23:55 -0500
Subject: [PATCH] Re-implemented the status information to be served via a network socket.

---
 setup.py                        |  3 ++-
 src/binwalk/__init__.py         |  8 +++++---
 src/binwalk/core/module.py      | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------
 src/binwalk/core/statuserver.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++
 src/binwalk/modules/general.py  | 10 ++++++++++
 src/scripts/binwalk             | 64 ++++++++++++++++++++--------------------------------------------
 6 files changed, 159 insertions(+), 61 deletions(-)
 create mode 100644 src/binwalk/core/statuserver.py

diff --git a/setup.py b/setup.py
index b9785cc..29d8ea3 100755
--- a/setup.py
+++ b/setup.py
@@ -8,6 +8,7 @@ from distutils.core import setup, Command
 from distutils.dir_util import remove_tree
 
 MODULE_NAME = "binwalk"
+SCRIPT_NAME = MODULE_NAME
 
 # Python2/3 compliance
 try:
@@ -208,7 +209,7 @@ setup(name = MODULE_NAME,
       requires = [],
       packages = [MODULE_NAME],
       package_data = {MODULE_NAME : install_data_files},
-      scripts = [os.path.join("scripts", MODULE_NAME)],
+      scripts = [os.path.join("scripts", SCRIPT_NAME)],
 
       cmdclass = {'clean' : CleanCommand, 'uninstall' : UninstallCommand, 'idainstall' : IDAInstallCommand, 'idauninstall' : IDAUnInstallCommand}
 )
diff --git a/src/binwalk/__init__.py b/src/binwalk/__init__.py
index 76bbe25..e4443cf 100644
--- a/src/binwalk/__init__.py
+++ b/src/binwalk/__init__.py
@@ -1,9 +1,11 @@
-__all__ = ['scan', 'execute', 'Modules', 'ModuleException']
+__all__ = ['scan', 'execute', 'ModuleException']
 
 from binwalk.core.module import Modules, ModuleException
 
 # Convenience functions
 def scan(*args, **kwargs):
-    return Modules(*args, **kwargs).execute()
+    with Modules(*args, **kwargs) as m:
+        objs = m.execute()
+    return objs
 def execute(*args, **kwargs):
-    return Modules(*args, **kwargs).execute()
+    return scan(*args, **kwargs)
diff --git a/src/binwalk/core/module.py b/src/binwalk/core/module.py
index b5c0f4e..d70ef5c 100644
--- a/src/binwalk/core/module.py
+++ b/src/binwalk/core/module.py
@@ -9,9 +9,11 @@ import sys
 import inspect
 import argparse
 import traceback
+import binwalk.core.statuserver
 import binwalk.core.common
 import binwalk.core.settings
 import binwalk.core.plugin
+from threading import Thread
 from binwalk.core.compat import *
 
 class Option(object):
@@ -216,10 +218,11 @@ class Module(object):
     # Set to False if this is not a primary module (e.g., General, Extractor modules)
     PRIMARY = True
 
-    def __init__(self, **kwargs):
+    def __init__(self, parent, **kwargs):
         self.errors = []
         self.results = []
 
+        self.parent = parent
         self.target_file_list = []
         self.status = None
         self.enabled = False
@@ -258,6 +261,13 @@ class Module(object):
         '''
         return None
 
+    def unload(self):
+        '''
+        Invoked at module load time.
+        May be overridden by the module sub-class.
+        '''
+        return None
+
     def reset(self):
         '''
         Invoked only for dependency modules immediately prior to starting a new primary module.
@@ -336,6 +346,17 @@ class Module(object):
 
         return args
 
+    def _unload_dependencies(self):
+        # Calls the unload method for all dependency modules.
+        # These modules cannot be unloaded immediately after being run, as
+        # they must persist until the module that depends on them is finished.
+        # As such, this must be done separately from the Modules.run 'unload' call.
+        for dependency in self.dependencies:
+            try:
+                getattr(self, dependency.attribute).unload()
+            except AttributeError:
+                continue
+
     def next_file(self, close_previous=True):
         '''
         Gets the next file to be scanned (including pending extracted files, if applicable).
@@ -386,8 +407,10 @@ class Module(object):
 
         if fp is not None:
             self.current_target_file_name = fp.path
+            self.status.fp = fp
         else:
             self.current_target_file_name = None
+            self.status.fp = fp
 
         self.previous_next_file_fp = fp
 
@@ -499,14 +522,14 @@ class Module(object):
             if hasattr(self, dependency.attribute):
                 getattr(self, dependency.attribute).reset()
 
-    def main(self, parent):
+    def main(self):
         '''
         Responsible for calling self.init, initializing self.config.display, and calling self.run.
 
         Returns the value returned from self.run.
         '''
-        self.status = parent.status
-        self.modules = parent.loaded_modules
+        self.status = self.parent.status
+        self.modules = self.parent.executed_modules
 
         # A special exception for the extractor module, which should be allowed to
         # override the verbose setting, e.g., if --matryoshka has been specified
@@ -584,12 +607,25 @@ class Modules(object):
         Returns None.
         '''
         self.arguments = []
-        self.loaded_modules = {}
+        self.executed_modules = {}
         self.default_dependency_modules = {}
-        self.status = Status(completed=0, total=0)
+        self.status = Status(completed=0, total=0, file=None)
+        self.status_server_started = False
+        self.status_service = None
 
         self._set_arguments(list(argv), kargv)
 
+    def cleanup(self):
+        if self.status_service:
+            self.status_service.server.socket.shutdown(1)
+            self.status_service.server.socket.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, t, v, b):
+        self.cleanup()
+
     def _set_arguments(self, argv=[], kargv={}):
         for (k,v) in iterator(kargv):
             k = self._parse_api_opt(k)
@@ -691,7 +727,7 @@ class Modules(object):
             obj = self.run(module)
 
         # Add all loaded modules that marked themselves as enabled to the run_modules list
-        for (module, obj) in iterator(self.loaded_modules):
+        for (module, obj) in iterator(self.executed_modules):
             # Report the results if the module is enabled and if it is a primary module or if it reported any results/errors
             if obj.enabled and (obj.PRIMARY or obj.results or obj.errors):
                 run_modules.append(obj)
@@ -707,12 +743,18 @@ class Modules(object):
         obj = self.load(module, kwargs)
 
         if isinstance(obj, binwalk.core.module.Module) and obj.enabled:
-            obj.main(parent=self)
+            obj.main()
             self.status.clear()
 
-        # If the module is not being loaded as a dependency, add it to the loaded modules dictionary
+        # If the module is not being loaded as a dependency, add it to the executed modules dictionary.
+        # This is used later in self.execute to determine which objects should be returned.
         if not dependency:
-            self.loaded_modules[module] = obj
+            self.executed_modules[module] = obj
+
+            # The unload method tells the module that we're done with it, and gives it a chance to do
+            # any cleanup operations that may be necessary. We still retain the object instance in self.executed_modules.
+            obj._unload_dependencies()
+            obj.unload()
 
         return obj
 
@@ -720,7 +762,7 @@ class Modules(object):
         argv = self.argv(module, argv=self.arguments)
         argv.update(kwargs)
         argv.update(self.dependencies(module, argv['enabled']))
-        return module(**argv)
+        return module(self, **argv)
 
     def dependencies(self, module, module_enabled):
         import binwalk.modules
@@ -859,6 +901,21 @@ class Modules(object):
         else:
             raise Exception("binwalk.core.module.Modules.process_kwargs: %s has no attribute 'KWARGS'" % str(obj))
 
+    def status_server(self, port):
+        '''
+        Starts the progress bar TCP service on the specified port.
+        This service will only be started once per instance, regardless of the
+        number of times this method is invoked.
+
+        Failure to start the status service is considered non-critical; that is,
+        a warning will be displayed to the user, but normal operation will proceed.
+        '''
+        if self.status_server_started == False:
+            self.status_server_started = True
+            try:
+                self.status_service = binwalk.core.statuserver.StatusServer(port, self)
+            except Exception as e:
+                binwalk.core.common.warning("Failed to start status server on port %d: %s" % (port, str(e)))
 
 def process_kwargs(obj, kwargs):
     '''
@@ -869,7 +926,9 @@ def process_kwargs(obj, kwargs):
 
     Returns None.
     '''
-    return Modules().kwargs(obj, kwargs)
+    with Modules() as m:
+        kwargs = m.kwargs(obj, kwargs)
+    return kwargs
 
 def show_help(fd=sys.stdout):
     '''
@@ -879,6 +938,7 @@ def show_help(fd=sys.stdout):
 
     Returns None.
     '''
-    fd.write(Modules().help())
+    with Modules() as m:
+        fd.write(m.help())
 
 
diff --git a/src/binwalk/core/statuserver.py b/src/binwalk/core/statuserver.py
new file mode 100644
index 0000000..5d3b426
--- /dev/null
+++ b/src/binwalk/core/statuserver.py
@@ -0,0 +1,49 @@
+# Provides scan status information via a TCP socket service.
+
+import time
+import threading
+import SocketServer
+
+class StatusRequestHandler(SocketServer.BaseRequestHandler):
+
+    def handle(self):
+        message_format = 'Binwalk scan progress: %3d%%   Currently at byte %d of %d total bytes in file %s'
+        last_status_message_len = 0
+        status_message = ''
+
+        while True:
+            time.sleep(0.1)
+
+            try:
+                self.request.send('\b' * last_status_message_len)
+                self.request.send(' ' * last_status_message_len)
+                self.request.send('\b' * last_status_message_len)
+
+                percentage = ((float(self.server.binwalk.status.completed) / float(self.server.binwalk.status.total)) * 100)
+                status_message = message_format % (percentage,
+                                                   self.server.binwalk.status.completed,
+                                                   self.server.binwalk.status.total,
+                                                   self.server.binwalk.status.fp.path)
+                last_status_message_len = len(status_message)
+
+                self.request.send(status_message)
+            except KeyboardInterrupt as e:
+                raise e
+            except Exception as e:
+                pass
+
+        return
+
+class ThreadedStatusServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
+    daemon_threads = True
+    allow_reuse_address = True
+
+class StatusServer(object):
+
+    def __init__(self, port, binwalk):
+        self.server = ThreadedStatusServer(('127.0.0.1', port), StatusRequestHandler)
+        self.server.binwalk = binwalk
+
+        t = threading.Thread(target=self.server.serve_forever)
+        t.setDaemon(True)
+        t.start()
diff --git a/src/binwalk/modules/general.py b/src/binwalk/modules/general.py
index 6434007..9d25758 100644
--- a/src/binwalk/modules/general.py
+++ b/src/binwalk/modules/general.py
@@ -80,6 +80,11 @@ class General(Module):
                type=str,
                kwargs={'file_name_exclude_regex' : ""},
                description='Do not scan files whose names match this regex'),
+        Option(short='s',
+               long='status',
+               type=int,
+               kwargs={'status_server_port' : 0},
+               description='Enable the status server on the specified port'),
         Option(long=None,
                short=None,
                type=binwalk.core.common.BlockFile,
@@ -96,6 +101,7 @@ class General(Module):
         Kwarg(name='offset', default=0),
         Kwarg(name='base', default=0),
         Kwarg(name='block', default=0),
+        Kwarg(name='status_server_port', default=0),
         Kwarg(name='swap_size', default=0),
         Kwarg(name='log_file', default=None),
         Kwarg(name='csv', default=False),
@@ -113,6 +119,7 @@ class General(Module):
     PRIMARY = False
 
     def load(self):
+        self.threads_active = False
         self.target_files = []
 
         # A special case for when we're loaded into IDA
@@ -141,6 +148,9 @@ class General(Module):
             if not binwalk.core.idb.LOADED_IN_IDA:
                 sys.exit(0)
 
+        if self.status_server_port > 0:
+            self.parent.status_server(self.status_server_port)
+
     def reset(self):
         pass
 
diff --git a/src/scripts/binwalk b/src/scripts/binwalk
index 1fe1a4a..e9ebdd7 100755
--- a/src/scripts/binwalk
+++ b/src/scripts/binwalk
@@ -2,7 +2,6 @@
 
 import os
 import sys
-from threading import Thread
 
 # If installed to a custom prefix directory, binwalk may not be in
 # the default module search path(s). Try to resolve the prefix module
@@ -15,7 +14,8 @@ for _module_path in [
     # from build dir: build/scripts-3.4/ -> build/lib/
     os.path.join(_parent_dir, "lib"),
     # installed in non-default path: bin/ -> lib/python3.4/site-packages/
-    os.path.join(_parent_dir, "lib",
+    os.path.join(_parent_dir,
+                 "lib",
                  "python%d.%d" % (sys.version_info[0], sys.version_info[1]),
                  "site-packages")
 ]:
@@ -24,51 +24,27 @@ for _module_path in [
 
 import binwalk
 import binwalk.modules
-from binwalk.core.compat import user_input
-
-def display_status(m):
-    # Display the current scan progress when the enter key is pressed.
-    while True:
-        try:
-            user_input()
-            percentage = ((float(m.status.completed) / float(m.status.total)) * 100)
-            sys.stderr.write("Progress: %.2f%% (%d / %d)\n\n" % (percentage,
-                                                                 m.status.completed,
-                                                                 m.status.total))
-        except KeyboardInterrupt as e:
-            raise e
-        except Exception:
-            pass
-
-def usage(modules):
-    sys.stderr.write(modules.help())
-    sys.exit(1)
 
 def main():
-    modules = binwalk.Modules()
-
-    # Start the display_status function as a daemon thread.
-    t = Thread(target=display_status, args=(modules,))
-    t.setDaemon(True)
-    t.start()
-
-    try:
-        if len(sys.argv) == 1:
-            usage(modules)
-        # If no explicit module was enabled in the command line arguments,
-        # run again with the default signature scan explicitly enabled.
-        elif not modules.execute():
-            # Make sure the Signature module is loaded before attempting 
-            # an implicit signature scan; else, the error message received
-            # by the end user is not very helpful.
-            if hasattr(binwalk.modules, "Signature"):
-                modules.execute(*sys.argv[1:], signature=True)
-            else:
-                sys.stderr.write("Error: Signature scans not supported; ")
-                sys.stderr.write("make sure you have python-lzma installed and try again.\n")
+    with binwalk.Modules() as modules:
+        try:
+            if len(sys.argv) == 1:
+                sys.stderr.write(modules.help())
                 sys.exit(1)
-    except binwalk.ModuleException as e:
-        sys.exit(1)
+            # If no explicit module was enabled in the command line arguments,
+            # run again with the default signature scan explicitly enabled.
+            elif not modules.execute():
+                # Make sure the Signature module is loaded before attempting 
+                # an implicit signature scan; else, the error message received
+                # by the end user is not very helpful.
+                if hasattr(binwalk.modules, "Signature"):
+                    modules.execute(*sys.argv[1:], signature=True)
+                else:
+                    sys.stderr.write("Error: Signature scans not supported; ")
+                    sys.stderr.write("make sure you have python-lzma installed and try again.\n")
+                    sys.exit(2)
+        except binwalk.ModuleException as e:
+            sys.exit(3)
 
 if __name__ == '__main__':
     try:
--
libgit2 0.26.0