diff --git a/monkey/infection_monkey/payload/ransomware/ransomware.py b/monkey/infection_monkey/payload/ransomware/ransomware.py index 2f09e386f..1050bab75 100644 --- a/monkey/infection_monkey/payload/ransomware/ransomware.py +++ b/monkey/infection_monkey/payload/ransomware/ransomware.py @@ -1,4 +1,5 @@ import logging +import threading from pathlib import Path from typing import Callable, List @@ -32,7 +33,7 @@ class Ransomware: self._target_directory / README_FILE_NAME if self._target_directory else None ) - def run_payload(self): + def run(self, _: threading.Event): if not self._target_directory: return diff --git a/monkey/infection_monkey/payload/ransomware/ransomware_payload.py b/monkey/infection_monkey/payload/ransomware/ransomware_payload.py new file mode 100644 index 000000000..d785859a2 --- /dev/null +++ b/monkey/infection_monkey/payload/ransomware/ransomware_payload.py @@ -0,0 +1,12 @@ +import threading +from typing import Dict + +from infection_monkey.payload.i_payload import IPayload + +from . import ransomware_builder + + +class RansomwarePayload(IPayload): + def run(self, options: Dict, interrupt: threading.Event): + ransomware = ransomware_builder.build_ransomware(options) + ransomware.run(interrupt) diff --git a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py index a7e9f8a90..6024f2afd 100644 --- a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py +++ b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py @@ -1,3 +1,4 @@ +import threading from pathlib import PurePosixPath from unittest.mock import MagicMock @@ -73,12 +74,12 @@ def test_files_selected_from_target_dir( ransomware_config, mock_file_selector, ): - ransomware.run_payload() + ransomware.run(threading.Event()) mock_file_selector.assert_called_with(ransomware_config.target_directory) def test_all_selected_files_encrypted(ransomware_test_data, ransomware, mock_file_encryptor): - ransomware.run_payload() + ransomware.run(threading.Event()) assert mock_file_encryptor.call_count == 2 mock_file_encryptor.assert_any_call(ransomware_test_data / ALL_ZEROS_PDF) @@ -91,7 +92,7 @@ def test_encryption_skipped_if_configured_false( ransomware_config.encryption_enabled = False ransomware = build_ransomware(ransomware_config) - ransomware.run_payload() + ransomware.run(threading.Event()) assert mock_file_encryptor.call_count == 0 @@ -103,13 +104,13 @@ def test_encryption_skipped_if_no_directory( ransomware_config.target_directory = None ransomware = build_ransomware(ransomware_config) - ransomware.run_payload() + ransomware.run(threading.Event()) assert mock_file_encryptor.call_count == 0 def test_telemetry_success(ransomware, telemetry_messenger_spy): - ransomware.run_payload() + ransomware.run(threading.Event()) assert len(telemetry_messenger_spy.telemetries) == 2 telem_1 = telemetry_messenger_spy.telemetries[0] @@ -131,7 +132,7 @@ def test_telemetry_failure(build_ransomware, ransomware_config, telemetry_messen mfs = MagicMock(return_value=[PurePosixPath(file_not_exists)]) ransomware = build_ransomware(config=ransomware_config, file_encryptor=mfe, file_selector=mfs) - ransomware.run_payload() + ransomware.run(threading.Event()) telem = telemetry_messenger_spy.telemetries[0] assert file_not_exists in telem.get_data()["files"][0]["path"] @@ -143,7 +144,7 @@ def test_readme_false(build_ransomware, ransomware_config, mock_leave_readme): ransomware_config.readme_enabled = False ransomware = build_ransomware(ransomware_config) - ransomware.run_payload() + ransomware.run(threading.Event()) mock_leave_readme.assert_not_called() @@ -151,7 +152,7 @@ def test_readme_true(build_ransomware, ransomware_config, mock_leave_readme, ran ransomware_config.readme_enabled = True ransomware = build_ransomware(ransomware_config) - ransomware.run_payload() + ransomware.run(threading.Event()) mock_leave_readme.assert_called_with(README_SRC, ransomware_test_data / README_FILE_NAME) @@ -161,7 +162,7 @@ def test_no_readme_if_no_directory(build_ransomware, ransomware_config, mock_lea ransomware = build_ransomware(ransomware_config) - ransomware.run_payload() + ransomware.run(threading.Event()) mock_leave_readme.assert_not_called() @@ -171,4 +172,4 @@ def test_leave_readme_exceptions_handled(build_ransomware, ransomware_config): ransomware = build_ransomware(config=ransomware_config, leave_readme=leave_readme) # Test will fail if exception is raised and not handled - ransomware.run_payload() + ransomware.run(threading.Event())