Basic tests for plugins

This commit is contained in:
Daniel Goldberg 2019-11-21 19:06:20 +02:00
parent 55d7eba2d8
commit 0f8e8925b3
6 changed files with 70 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from infection_monkey.utils.plugins.pluginTests.PluginTestClass import TestPlugin
class SomeDummyPlugin:
pass

View File

@ -0,0 +1,7 @@
from infection_monkey.utils.plugins.pluginTests.PluginTestClass import TestPlugin
class badPluginInit(TestPlugin):
def __init__(self):
raise Exception("TestException")

View File

@ -0,0 +1,22 @@
from infection_monkey.utils.plugins.plugin import Plugin
import infection_monkey.utils.plugins.pluginTests
class TestPlugin(Plugin):
classes_to_load = []
@staticmethod
def should_run(class_name):
"""
Decides if post breach action is enabled in config
:return: True if it needs to be ran, false otherwise
"""
return class_name in TestPlugin.classes_to_load
@staticmethod
def base_package_file():
return infection_monkey.utils.plugins.pluginTests.__file__
@staticmethod
def base_package_name():
return infection_monkey.utils.plugins.pluginTests.__package__

View File

@ -0,0 +1,5 @@
from infection_monkey.utils.plugins.pluginTests.PluginTestClass import TestPlugin
class pluginWorking(TestPlugin):
pass

View File

@ -0,0 +1,31 @@
from unittest import TestCase
from infection_monkey.utils.plugins.pluginTests.PluginWorking import pluginWorking
from infection_monkey.utils.plugins.pluginTests.BadImport import SomeDummyPlugin
from infection_monkey.utils.plugins.pluginTests.BadInit import badPluginInit
from infection_monkey.utils.plugins.pluginTests.PluginTestClass import TestPlugin
class PluginTester(TestCase):
def setUp(self):
pass
def test_bad_init(self):
TestPlugin.classes_to_load = [badPluginInit.__name__]
to_init = TestPlugin.get_classes()
self.assertEqual(len(to_init), 1)
objects = TestPlugin.get_instances()
self.assertEqual(len(objects), 0)
def test_bad_import(self):
TestPlugin.classes_to_load = [SomeDummyPlugin.__name__]
to_init = TestPlugin.get_classes()
self.assertEqual(len(to_init), 0)
def test_flow(self):
TestPlugin.classes_to_load = [pluginWorking.__name__]
to_init = TestPlugin.get_classes()
self.assertEqual(len(to_init), 1)
objects = TestPlugin.get_instances()
self.assertEqual(len(objects), 1)