Agent: Remove disused Plugin abstract class
This commit is contained in:
parent
4316329384
commit
703dc315bc
|
@ -1,91 +0,0 @@
|
|||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from os.path import basename, dirname, isfile, join
|
||||
from typing import Callable, Sequence, Type, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_candidate_files(base_package_file):
|
||||
files = glob.glob(join(dirname(base_package_file), "*.py"))
|
||||
return [basename(f)[:-3] for f in files if isfile(f) and not f.endswith("__init__.py")]
|
||||
|
||||
|
||||
PluginType = TypeVar("PluginType", bound="Plugin")
|
||||
|
||||
|
||||
class Plugin(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def should_run(class_name: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_classes(cls) -> Sequence[Callable]:
|
||||
"""
|
||||
Returns the class objects from base_package_spec
|
||||
base_package name and file must refer to the same package otherwise bad results
|
||||
:return: A list of parent_class classes.
|
||||
"""
|
||||
objects = []
|
||||
candidate_files = _get_candidate_files(cls.base_package_file())
|
||||
logger.info(
|
||||
"looking for classes of type {} in {}".format(cls.__name__, cls.base_package_name())
|
||||
)
|
||||
# Go through all of files
|
||||
for file in candidate_files:
|
||||
# Import module from that file
|
||||
module = importlib.import_module("." + file, cls.base_package_name())
|
||||
# Get all classes in a module
|
||||
# m[1] because return object is (name,class)
|
||||
classes = [
|
||||
m[1]
|
||||
for m in inspect.getmembers(module, inspect.isclass)
|
||||
if ((m[1].__module__ == module.__name__) and issubclass(m[1], cls))
|
||||
]
|
||||
# Get object from class
|
||||
for class_object in classes:
|
||||
logger.debug("Checking if should run object {}".format(class_object.__name__))
|
||||
try:
|
||||
if class_object.should_run(class_object.__name__):
|
||||
objects.append(class_object)
|
||||
logger.debug("Added {} to list".format(class_object.__name__))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Exception {} when checking if {} should run".format(
|
||||
str(e), class_object.__name__
|
||||
)
|
||||
)
|
||||
return objects
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> Sequence[Type[PluginType]]:
|
||||
"""
|
||||
Returns the type objects from base_package_spec.
|
||||
base_package name and file must refer to the same package otherwise bad results
|
||||
:return: A list of parent_class objects.
|
||||
"""
|
||||
class_objects = cls.get_classes()
|
||||
instances = []
|
||||
for class_object in class_objects:
|
||||
try:
|
||||
instance = class_object()
|
||||
instances.append(instance)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Exception {} when initializing {}".format(str(e), class_object.__name__)
|
||||
)
|
||||
return instances
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def base_package_file():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def base_package_name():
|
||||
pass
|
|
@ -1,7 +0,0 @@
|
|||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import ( # noqa: F401, E501
|
||||
PluginTester,
|
||||
)
|
||||
|
||||
|
||||
class SomeDummyPlugin:
|
||||
pass
|
|
@ -1,6 +0,0 @@
|
|||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
|
||||
|
||||
|
||||
class BadPluginInit(PluginTester):
|
||||
def __init__(self):
|
||||
raise Exception("TestException")
|
|
@ -1,10 +0,0 @@
|
|||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
|
||||
|
||||
|
||||
class BadInit(PluginTester):
|
||||
def __init__(self):
|
||||
raise Exception("TestException")
|
||||
|
||||
|
||||
class ProperClass(PluginTester):
|
||||
pass
|
|
@ -1,23 +0,0 @@
|
|||
import tests.unit_tests.infection_monkey.utils.plugins.pluginTests
|
||||
|
||||
from infection_monkey.utils.plugins.plugin import Plugin
|
||||
|
||||
|
||||
class PluginTester(Plugin):
|
||||
classes_to_load = []
|
||||
|
||||
@staticmethod
|
||||
def should_run(class_name):
|
||||
"""
|
||||
Decides if post breach action is enabled in config
|
||||
:return: True if it needs to be ran, false otherwise
|
||||
"""
|
||||
return class_name in PluginTester.classes_to_load
|
||||
|
||||
@staticmethod
|
||||
def base_package_file():
|
||||
return tests.unit_tests.infection_monkey.utils.plugins.pluginTests.__file__
|
||||
|
||||
@staticmethod
|
||||
def base_package_name():
|
||||
return tests.unit_tests.infection_monkey.utils.plugins.pluginTests.__package__
|
|
@ -1,5 +0,0 @@
|
|||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
|
||||
|
||||
|
||||
class PluginWorking(PluginTester):
|
||||
pass
|
|
@ -1,38 +0,0 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.BadImport import SomeDummyPlugin
|
||||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.BadInit import BadPluginInit
|
||||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.ComboFile import (
|
||||
BadInit,
|
||||
ProperClass,
|
||||
)
|
||||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
|
||||
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginWorking import PluginWorking
|
||||
|
||||
|
||||
class TestPlugin(TestCase):
|
||||
def test_combo_file(self):
|
||||
PluginTester.classes_to_load = [BadInit.__name__, ProperClass.__name__]
|
||||
to_init = PluginTester.get_classes()
|
||||
self.assertEqual(len(to_init), 2)
|
||||
objects = PluginTester.get_instances()
|
||||
self.assertEqual(len(objects), 1)
|
||||
|
||||
def test_bad_init(self):
|
||||
PluginTester.classes_to_load = [BadPluginInit.__name__]
|
||||
to_init = PluginTester.get_classes()
|
||||
self.assertEqual(len(to_init), 1)
|
||||
objects = PluginTester.get_instances()
|
||||
self.assertEqual(len(objects), 0)
|
||||
|
||||
def test_bad_import(self):
|
||||
PluginTester.classes_to_load = [SomeDummyPlugin.__name__]
|
||||
to_init = PluginTester.get_classes()
|
||||
self.assertEqual(len(to_init), 0)
|
||||
|
||||
def test_flow(self):
|
||||
PluginTester.classes_to_load = [PluginWorking.__name__]
|
||||
to_init = PluginTester.get_classes()
|
||||
self.assertEqual(len(to_init), 1)
|
||||
objects = PluginTester.get_instances()
|
||||
self.assertEqual(len(objects), 1)
|
Loading…
Reference in New Issue