From 0328d2860eb8af12ff40157853d68637bb7cffdd Mon Sep 17 00:00:00 2001
From: Mike Salvatore <mike.s.salvatore@gmail.com>
Date: Fri, 17 Dec 2021 09:17:19 -0500
Subject: [PATCH] Agent: Add a RansomwarePayload that implements to the
 IPayload interface

---
 .../payload/ransomware/ransomware.py          |  3 ++-
 .../payload/ransomware/ransomware_payload.py  | 12 +++++++++++
 .../payload/ransomware/test_ransomware.py     | 21 ++++++++++---------
 3 files changed, 25 insertions(+), 11 deletions(-)
 create mode 100644 monkey/infection_monkey/payload/ransomware/ransomware_payload.py

diff --git a/monkey/infection_monkey/payload/ransomware/ransomware.py b/monkey/infection_monkey/payload/ransomware/ransomware.py
index 2f09e386f..1050bab75 100644
--- a/monkey/infection_monkey/payload/ransomware/ransomware.py
+++ b/monkey/infection_monkey/payload/ransomware/ransomware.py
@@ -1,4 +1,5 @@
 import logging
+import threading
 from pathlib import Path
 from typing import Callable, List
 
@@ -32,7 +33,7 @@ class Ransomware:
             self._target_directory / README_FILE_NAME if self._target_directory else None
         )
 
-    def run_payload(self):
+    def run(self, _: threading.Event):
         if not self._target_directory:
             return
 
diff --git a/monkey/infection_monkey/payload/ransomware/ransomware_payload.py b/monkey/infection_monkey/payload/ransomware/ransomware_payload.py
new file mode 100644
index 000000000..d785859a2
--- /dev/null
+++ b/monkey/infection_monkey/payload/ransomware/ransomware_payload.py
@@ -0,0 +1,12 @@
+import threading
+from typing import Dict
+
+from infection_monkey.payload.i_payload import IPayload
+
+from . import ransomware_builder
+
+
+class RansomwarePayload(IPayload):
+    def run(self, options: Dict, interrupt: threading.Event):
+        ransomware = ransomware_builder.build_ransomware(options)
+        ransomware.run(interrupt)
diff --git a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py
index a7e9f8a90..6024f2afd 100644
--- a/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py
+++ b/monkey/tests/unit_tests/infection_monkey/payload/ransomware/test_ransomware.py
@@ -1,3 +1,4 @@
+import threading
 from pathlib import PurePosixPath
 from unittest.mock import MagicMock
 
@@ -73,12 +74,12 @@ def test_files_selected_from_target_dir(
     ransomware_config,
     mock_file_selector,
 ):
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
     mock_file_selector.assert_called_with(ransomware_config.target_directory)
 
 
 def test_all_selected_files_encrypted(ransomware_test_data, ransomware, mock_file_encryptor):
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
 
     assert mock_file_encryptor.call_count == 2
     mock_file_encryptor.assert_any_call(ransomware_test_data / ALL_ZEROS_PDF)
@@ -91,7 +92,7 @@ def test_encryption_skipped_if_configured_false(
     ransomware_config.encryption_enabled = False
 
     ransomware = build_ransomware(ransomware_config)
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
 
     assert mock_file_encryptor.call_count == 0
 
@@ -103,13 +104,13 @@ def test_encryption_skipped_if_no_directory(
     ransomware_config.target_directory = None
 
     ransomware = build_ransomware(ransomware_config)
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
 
     assert mock_file_encryptor.call_count == 0
 
 
 def test_telemetry_success(ransomware, telemetry_messenger_spy):
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
 
     assert len(telemetry_messenger_spy.telemetries) == 2
     telem_1 = telemetry_messenger_spy.telemetries[0]
@@ -131,7 +132,7 @@ def test_telemetry_failure(build_ransomware, ransomware_config, telemetry_messen
     mfs = MagicMock(return_value=[PurePosixPath(file_not_exists)])
     ransomware = build_ransomware(config=ransomware_config, file_encryptor=mfe, file_selector=mfs)
 
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
     telem = telemetry_messenger_spy.telemetries[0]
 
     assert file_not_exists in telem.get_data()["files"][0]["path"]
@@ -143,7 +144,7 @@ def test_readme_false(build_ransomware, ransomware_config, mock_leave_readme):
     ransomware_config.readme_enabled = False
     ransomware = build_ransomware(ransomware_config)
 
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
     mock_leave_readme.assert_not_called()
 
 
@@ -151,7 +152,7 @@ def test_readme_true(build_ransomware, ransomware_config, mock_leave_readme, ran
     ransomware_config.readme_enabled = True
     ransomware = build_ransomware(ransomware_config)
 
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
     mock_leave_readme.assert_called_with(README_SRC, ransomware_test_data / README_FILE_NAME)
 
 
@@ -161,7 +162,7 @@ def test_no_readme_if_no_directory(build_ransomware, ransomware_config, mock_lea
 
     ransomware = build_ransomware(ransomware_config)
 
-    ransomware.run_payload()
+    ransomware.run(threading.Event())
     mock_leave_readme.assert_not_called()
 
 
@@ -171,4 +172,4 @@ def test_leave_readme_exceptions_handled(build_ransomware, ransomware_config):
     ransomware = build_ransomware(config=ransomware_config, leave_readme=leave_readme)
 
     # Test will fail if exception is raised and not handled
-    ransomware.run_payload()
+    ransomware.run(threading.Event())