Refactor PBA to use generic plugin

This commit is contained in:
Daniel Goldberg 2019-10-30 15:14:27 +02:00
parent 1df3e30003
commit 9c40b4a022
2 changed files with 6 additions and 32 deletions

View File

@ -1,11 +0,0 @@
from os.path import dirname, basename, isfile, join
import glob
def get_pba_files():
"""
Gets all files under current directory(/actions)
:return: list of all files without .py ending
"""
files = glob.glob(join(dirname(__file__), "*.py"))
return [basename(f)[:-3] for f in files if isfile(f) and not f.endswith('__init__.py')]

View File

@ -1,9 +1,8 @@
import logging import logging
import inspect
import importlib
from infection_monkey.post_breach.pba import PBA
from infection_monkey.post_breach.actions import get_pba_files
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.load_plugins import get_instances
from infection_monkey.post_breach.pba import PBA
import infection_monkey.post_breach.actions
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -16,6 +15,7 @@ class PostBreach(object):
""" """
This class handles post breach actions execution This class handles post breach actions execution
""" """
def __init__(self): def __init__(self):
self.os_is_linux = not is_windows_os() self.os_is_linux = not is_windows_os()
self.pba_list = self.config_to_pba_list() self.pba_list = self.config_to_pba_list()
@ -38,20 +38,5 @@ class PostBreach(object):
Passes config to each post breach action class and aggregates results into a list. Passes config to each post breach action class and aggregates results into a list.
:return: A list of PBA objects. :return: A list of PBA objects.
""" """
pba_list = [] return get_instances(infection_monkey.post_breach.actions.__package__,
pba_files = get_pba_files() infection_monkey.post_breach.actions.__file__, PBA)
# Go through all of files in ./actions
for pba_file in pba_files:
# Import module from that file
module = importlib.import_module(PATH_TO_ACTIONS + pba_file)
# Get all classes in a module
pba_classes = [m[1] for m in inspect.getmembers(module, inspect.isclass)
if ((m[1].__module__ == module.__name__) and issubclass(m[1], PBA))]
# Get post breach action object from class
for pba_class in pba_classes:
LOG.debug("Checking if should run PBA {}".format(pba_class.__name__))
if pba_class.should_run(pba_class.__name__):
pba = pba_class()
pba_list.append(pba)
LOG.debug("Added PBA {} to PBA list".format(pba_class.__name__))
return pba_list