diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 80293d44d..65d4cfd0c 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -4,6 +4,7 @@ from typing import Dict, Iterable, Sequence from common.common_consts.timeouts import CONNECTION_TIMEOUT from common.credentials import Credentials +from common.event_queue import IAgentEventQueue from common.types import PingScanData from infection_monkey import network_scanning from infection_monkey.i_puppet import ( @@ -24,8 +25,9 @@ logger = logging.getLogger() class Puppet(IPuppet): - def __init__(self) -> None: + def __init__(self, agent_event_queue: IAgentEventQueue) -> None: self._plugin_registry = PluginRegistry() + self._agent_event_queue = agent_event_queue def load_plugin(self, plugin_name: str, plugin: object, plugin_type: PluginType) -> None: self._plugin_registry.load_plugin(plugin_name, plugin, plugin_type) @@ -41,7 +43,7 @@ class Puppet(IPuppet): return pba.run(options) def ping(self, host: str, timeout: float = CONNECTION_TIMEOUT) -> PingScanData: - return network_scanning.ping(host, timeout) + return network_scanning.ping(host, timeout, self._agent_event_queue) def scan_tcp_ports( self, host: str, ports: Sequence[int], timeout: float = CONNECTION_TIMEOUT diff --git a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py index a0df06fd6..1b0f6b9ee 100644 --- a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py +++ b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py @@ -1,13 +1,14 @@ import threading from unittest.mock import MagicMock +from common.event_queue import IAgentEventQueue from common.types import PingScanData from infection_monkey.i_puppet import PluginType from infection_monkey.puppet.puppet import EMPTY_FINGERPRINT, Puppet def test_puppet_run_payload_success(): - p = Puppet() + p = Puppet(agent_event_queue=MagicMock(spec=IAgentEventQueue)) payload = MagicMock() payload_name = "PayloadOne" @@ -19,7 +20,7 @@ def test_puppet_run_payload_success(): def test_puppet_run_multiple_payloads(): - p = Puppet() + p = Puppet(agent_event_queue=MagicMock(spec=IAgentEventQueue)) payload_1 = MagicMock() payload1_name = "PayloadOne" @@ -45,6 +46,6 @@ def test_puppet_run_multiple_payloads(): def test_fingerprint_exception_handling(monkeypatch): - p = Puppet() + p = Puppet(agent_event_queue=MagicMock(spec=IAgentEventQueue)) p._plugin_registry.get_plugin = MagicMock(side_effect=Exception) assert p.fingerprint("", "", PingScanData("windows", False), {}, {}) == EMPTY_FINGERPRINT