Commit 6fc8d05c by fwkz

Refactoring implementation.

parent 4b3231ab
from inspect import getmodule import unittest
from unittest import main, TestCase, TestSuite
from routersploit.exploits import Exploit
from routersploit.utils import iter_modules from routersploit.utils import iter_modules
class ModuleTest(TestCase): class ModuleTest(unittest.TestCase):
"""A test case that every module must pass. """A test case that every module must pass.
Attributes: Attributes:
...@@ -13,48 +11,37 @@ class ModuleTest(TestCase): ...@@ -13,48 +11,37 @@ class ModuleTest(TestCase):
metadata (Dict): The info associated with the module. metadata (Dict): The info associated with the module.
""" """
def test_has_exploit(self): def __init__(self, methodName='runTest', module=None):
self.assertIsInstance(self.module, Exploit) super(ModuleTest, self).__init__(methodName)
self.module = module
def test_has_metadata(self): def __str__(self):
self.assertIsInstance(self.metadata, dict) return " ".join([super(ModuleTest, self).__str__(), self.module.__module__])
def test_legal_metadata_keys(self): @property
def module_metadata(self):
return getattr(self.module, "_{}__info__".format(self.module.__name__))
legal_keys = set([ def test_required_metadata(self):
required_metadata = (
"name", "name",
"description", "description",
"devices", "devices",
"authors", "authors",
"references"]) "references"
)
self.assertTrue(set(self.metadata.keys()).issubset(legal_keys)) self.assertItemsEqual(required_metadata, self.module_metadata.keys())
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
"""Map every module to a test case, and group them into a suite.""" """ Map every module to a test case, and group them into a suite. """
suite = TestSuite()
for m in iter_modules():
class ParametrizedModuleTest(ModuleTest):
# bind module
module = m()
@property
def metadata(self):
return getattr(self.module, "_{}__info__".format(self.module.__class__.__name__))
def shortDescription(self):
# provide the module name in the test description
return getmodule(self.module).__name__
# add the tests from this test case
suite.addTests(loader.loadTestsFromTestCase(ParametrizedModuleTest))
suite = unittest.TestSuite()
test_names = loader.getTestCaseNames(ModuleTest)
for module in iter_modules():
suite.addTests([ModuleTest(name, module) for name in test_names])
return suite return suite
if __name__ == '__main__': if __name__ == '__main__':
main() unittest.main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment