diff --git a/monkey/infection_monkey/network_scanning/tcp_scanner.py b/monkey/infection_monkey/network_scanning/tcp_scanner.py index af1efb182..9b6d581e2 100644 --- a/monkey/infection_monkey/network_scanning/tcp_scanner.py +++ b/monkey/infection_monkey/network_scanning/tcp_scanner.py @@ -1,15 +1,18 @@ import logging import select import socket -import time +from ipaddress import IPv4Address from pprint import pformat +from time import sleep, time from typing import Collection, Dict, Iterable, Mapping, Tuple +from common.agent_events import TCPScanEvent from common.event_queue import IAgentEventQueue from common.types import PortStatus from common.utils import Timer from infection_monkey.i_puppet import PortScanData from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service +from infection_monkey.utils.ids import get_agent_id logger = logging.getLogger(__name__) @@ -21,21 +24,41 @@ def scan_tcp_ports( host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue ) -> Dict[int, PortScanData]: try: - return _scan_tcp_ports(host, ports_to_scan, timeout) + return _scan_tcp_ports(host, ports_to_scan, timeout, agent_event_queue) except Exception: logger.exception("Unhandled exception occurred while trying to scan tcp ports") return EMPTY_PORT_SCAN -def _scan_tcp_ports(host: str, ports_to_scan: Collection[int], timeout: float): - open_ports = _check_tcp_ports(host, ports_to_scan, timeout) +def _scan_tcp_ports( + host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue +) -> Dict[int, PortScanData]: + event_timestamp, open_ports = _check_tcp_ports(host, ports_to_scan, timeout) - return _build_port_scan_data(ports_to_scan, open_ports) + port_scan_data = _build_port_scan_data(ports_to_scan, open_ports) + + tcp_scan_event = _generate_tcp_scan_event(host, port_scan_data, event_timestamp) + agent_event_queue.publish(tcp_scan_event) + + return port_scan_data + + +def _generate_tcp_scan_event( + host: str, port_scan_data: Dict[int, PortScanData], event_timestamp: float +): + port_statuses = {port: psd.status for port, psd in port_scan_data.items()} + + return TCPScanEvent( + source=get_agent_id(), + target=IPv4Address(host), + timestamp=event_timestamp, + ports=port_statuses, + ) def _build_port_scan_data( ports_to_scan: Iterable[int], open_ports: Mapping[int, str] -) -> Mapping[int, PortScanData]: +) -> Dict[int, PortScanData]: port_scan_data = {} for port in ports_to_scan: if port in open_ports: @@ -55,7 +78,7 @@ def _get_closed_port_data(port: int) -> PortScanData: def _check_tcp_ports( ip: str, ports_to_scan: Collection[int], timeout: float = DEFAULT_TIMEOUT -) -> Mapping[int, str]: +) -> Tuple[float, Dict[int, str]]: """ Checks whether any of the given ports are open on a target IP. :param ip: IP of host to attack @@ -72,6 +95,7 @@ def _check_tcp_ports( connected_ports = set() open_ports = {} + event_timestamp = time() try: logger.debug( "Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan)) @@ -100,7 +124,7 @@ def _check_tcp_ports( while (not timer.is_expired()) and sockets_to_try: # The call to select() may return sockets that are writeable but not actually # connected. Adding this sleep prevents excessive looping. - time.sleep(min(POLL_INTERVAL, timer.time_remaining)) + sleep(min(POLL_INTERVAL, timer.time_remaining)) sock_objects = [s[1] for s in sockets_to_try] @@ -136,7 +160,7 @@ def _check_tcp_ports( _clean_up_sockets(possible_ports, connected_ports) - return open_ports + return event_timestamp, open_ports def _clean_up_sockets( diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py index 3e0c38d7e..2fd67590d 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py @@ -2,20 +2,46 @@ from unittest.mock import MagicMock import pytest +from common.agent_events import TCPScanEvent from common.types import PortStatus +from infection_monkey.i_puppet import PortScanData from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN +from infection_monkey.utils.ids import get_agent_id PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"} +TIMESTAMP = 123.321 + +HOST_IP = "127.0.0.1" + + +@pytest.fixture(autouse=True) +def patch_timestamp(monkeypatch): + monkeypatch.setattr( + "infection_monkey.network_scanning.tcp_scanner.time", + lambda: TIMESTAMP, + ) + @pytest.fixture def patch_check_tcp_ports(monkeypatch, open_ports_data): monkeypatch.setattr( "infection_monkey.network_scanning.tcp_scanner._check_tcp_ports", - lambda *_: open_ports_data, + lambda *_: (TIMESTAMP, open_ports_data), + ) + + +def _get_tcp_scan_event(port_scan_data: PortScanData): + port_statuses = {port: psd.status for port, psd in port_scan_data.items()} + + return TCPScanEvent( + source=get_agent_id(), + target=HOST_IP, + timestamp=TIMESTAMP, + ports=port_statuses, ) @@ -25,7 +51,7 @@ def test_tcp_successful( ): closed_ports = [8080, 143, 445] - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data.keys(): @@ -38,12 +64,17 @@ def test_tcp_successful( assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].banner is None + event = _get_tcp_scan_event(port_scan_data) + + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) + @pytest.mark.parametrize("open_ports_data", [{}]) def test_tcp_empty_response( monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue ): - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data: @@ -51,15 +82,25 @@ def test_tcp_empty_response( assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].banner is None + event = _get_tcp_scan_event(port_scan_data) + + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) + @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) def test_tcp_no_ports_to_scan( monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue ): - port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, [], 0, mock_agent_event_queue) assert len(port_scan_data) == 0 + event = _get_tcp_scan_event(port_scan_data) + + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) + def test_exception_handling(monkeypatch, mock_agent_event_queue): monkeypatch.setattr( @@ -67,3 +108,4 @@ def test_exception_handling(monkeypatch, mock_agent_event_queue): MagicMock(side_effect=Exception), ) assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN + assert mock_agent_event_queue.publish.call_count == 0