diff --git a/monkey/infection_monkey/network_scanning/tcp_scanner.py b/monkey/infection_monkey/network_scanning/tcp_scanner.py index 70ef3cc0f..af1efb182 100644 --- a/monkey/infection_monkey/network_scanning/tcp_scanner.py +++ b/monkey/infection_monkey/network_scanning/tcp_scanner.py @@ -5,6 +5,7 @@ import time from pprint import pformat from typing import Collection, Dict, Iterable, Mapping, Tuple +from common.event_queue import IAgentEventQueue from common.types import PortStatus from common.utils import Timer from infection_monkey.i_puppet import PortScanData @@ -17,7 +18,7 @@ EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)} def scan_tcp_ports( - host: str, ports_to_scan: Collection[int], timeout: float + host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue ) -> Dict[int, PortScanData]: try: return _scan_tcp_ports(host, ports_to_scan, timeout) diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 65d4cfd0c..0b4b3fdc8 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -48,7 +48,7 @@ class Puppet(IPuppet): def scan_tcp_ports( self, host: str, ports: Sequence[int], timeout: float = CONNECTION_TIMEOUT ) -> Dict[int, PortScanData]: - return network_scanning.scan_tcp_ports(host, ports, timeout) + return network_scanning.scan_tcp_ports(host, ports, timeout, self._agent_event_queue) def fingerprint( self, diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/conftest.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/conftest.py new file mode 100644 index 000000000..5b614fbda --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/conftest.py @@ -0,0 +1,10 @@ +from unittest.mock import MagicMock + +import pytest + +from common.event_queue import IAgentEventQueue + + +@pytest.fixture +def mock_agent_event_queue() -> IAgentEventQueue: + return MagicMock(spec=IAgentEventQueue) diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py index 102df5dd1..ab6c76b5a 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py @@ -7,7 +7,6 @@ import pytest import infection_monkey.network_scanning.ping_scanner # noqa: F401 from common import OperatingSystem from common.agent_events import PingScanEvent -from common.event_queue import IAgentEventQueue from common.types import PingScanData from infection_monkey.network_scanning import ping from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN @@ -99,11 +98,6 @@ def set_os_windows(monkeypatch): monkeypatch.setattr("sys.platform", "win32") -@pytest.fixture -def mock_agent_event_queue(): - return MagicMock(spec=IAgentEventQueue) - - HOST_IP = "192.168.1.1" TIMEOUT = 1.0 diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py index bd5a66cb6..3e0c38d7e 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py @@ -20,10 +20,12 @@ def patch_check_tcp_ports(monkeypatch, open_ports_data): @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) -def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): +def test_tcp_successful( + monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue +): closed_ports = [8080, 143, 445] - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data.keys(): @@ -38,8 +40,10 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): @pytest.mark.parametrize("open_ports_data", [{}]) -def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) +def test_tcp_empty_response( + monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue +): + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data: @@ -49,15 +53,17 @@ def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) -def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): - port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) +def test_tcp_no_ports_to_scan( + monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue +): + port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue) assert len(port_scan_data) == 0 -def test_exception_handling(monkeypatch): +def test_exception_handling(monkeypatch, mock_agent_event_queue): monkeypatch.setattr( "infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports", MagicMock(side_effect=Exception), ) - assert scan_tcp_ports("abc", [123], 123) == EMPTY_PORT_SCAN + assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN