Agent: Remove disused Plugin abstract class

This commit is contained in:
Mike Salvatore 2022-03-25 08:34:45 -04:00
parent 4316329384
commit 703dc315bc
9 changed files with 0 additions and 180 deletions

View File

@ -1,91 +0,0 @@
import glob
import importlib
import inspect
import logging
from abc import ABCMeta, abstractmethod
from os.path import basename, dirname, isfile, join
from typing import Callable, Sequence, Type, TypeVar
logger = logging.getLogger(__name__)
def _get_candidate_files(base_package_file):
files = glob.glob(join(dirname(base_package_file), "*.py"))
return [basename(f)[:-3] for f in files if isfile(f) and not f.endswith("__init__.py")]
PluginType = TypeVar("PluginType", bound="Plugin")
class Plugin(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def should_run(class_name: str) -> bool:
raise NotImplementedError()
@classmethod
def get_classes(cls) -> Sequence[Callable]:
"""
Returns the class objects from base_package_spec
base_package name and file must refer to the same package otherwise bad results
:return: A list of parent_class classes.
"""
objects = []
candidate_files = _get_candidate_files(cls.base_package_file())
logger.info(
"looking for classes of type {} in {}".format(cls.__name__, cls.base_package_name())
)
# Go through all of files
for file in candidate_files:
# Import module from that file
module = importlib.import_module("." + file, cls.base_package_name())
# Get all classes in a module
# m[1] because return object is (name,class)
classes = [
m[1]
for m in inspect.getmembers(module, inspect.isclass)
if ((m[1].__module__ == module.__name__) and issubclass(m[1], cls))
]
# Get object from class
for class_object in classes:
logger.debug("Checking if should run object {}".format(class_object.__name__))
try:
if class_object.should_run(class_object.__name__):
objects.append(class_object)
logger.debug("Added {} to list".format(class_object.__name__))
except Exception as e:
logger.warning(
"Exception {} when checking if {} should run".format(
str(e), class_object.__name__
)
)
return objects
@classmethod
def get_instances(cls) -> Sequence[Type[PluginType]]:
"""
Returns the type objects from base_package_spec.
base_package name and file must refer to the same package otherwise bad results
:return: A list of parent_class objects.
"""
class_objects = cls.get_classes()
instances = []
for class_object in class_objects:
try:
instance = class_object()
instances.append(instance)
except Exception as e:
logger.warning(
"Exception {} when initializing {}".format(str(e), class_object.__name__)
)
return instances
@staticmethod
@abstractmethod
def base_package_file():
pass
@staticmethod
@abstractmethod
def base_package_name():
pass

View File

@ -1,7 +0,0 @@
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import ( # noqa: F401, E501
PluginTester,
)
class SomeDummyPlugin:
pass

View File

@ -1,6 +0,0 @@
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
class BadPluginInit(PluginTester):
def __init__(self):
raise Exception("TestException")

View File

@ -1,10 +0,0 @@
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
class BadInit(PluginTester):
def __init__(self):
raise Exception("TestException")
class ProperClass(PluginTester):
pass

View File

@ -1,23 +0,0 @@
import tests.unit_tests.infection_monkey.utils.plugins.pluginTests
from infection_monkey.utils.plugins.plugin import Plugin
class PluginTester(Plugin):
classes_to_load = []
@staticmethod
def should_run(class_name):
"""
Decides if post breach action is enabled in config
:return: True if it needs to be ran, false otherwise
"""
return class_name in PluginTester.classes_to_load
@staticmethod
def base_package_file():
return tests.unit_tests.infection_monkey.utils.plugins.pluginTests.__file__
@staticmethod
def base_package_name():
return tests.unit_tests.infection_monkey.utils.plugins.pluginTests.__package__

View File

@ -1,5 +0,0 @@
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
class PluginWorking(PluginTester):
pass

View File

@ -1,38 +0,0 @@
from unittest import TestCase
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.BadImport import SomeDummyPlugin
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.BadInit import BadPluginInit
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.ComboFile import (
BadInit,
ProperClass,
)
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginTestClass import PluginTester
from tests.unit_tests.infection_monkey.utils.plugins.pluginTests.PluginWorking import PluginWorking
class TestPlugin(TestCase):
def test_combo_file(self):
PluginTester.classes_to_load = [BadInit.__name__, ProperClass.__name__]
to_init = PluginTester.get_classes()
self.assertEqual(len(to_init), 2)
objects = PluginTester.get_instances()
self.assertEqual(len(objects), 1)
def test_bad_init(self):
PluginTester.classes_to_load = [BadPluginInit.__name__]
to_init = PluginTester.get_classes()
self.assertEqual(len(to_init), 1)
objects = PluginTester.get_instances()
self.assertEqual(len(objects), 0)
def test_bad_import(self):
PluginTester.classes_to_load = [SomeDummyPlugin.__name__]
to_init = PluginTester.get_classes()
self.assertEqual(len(to_init), 0)
def test_flow(self):
PluginTester.classes_to_load = [PluginWorking.__name__]
to_init = PluginTester.get_classes()
self.assertEqual(len(to_init), 1)
objects = PluginTester.get_instances()
self.assertEqual(len(objects), 1)