From 0a11d34fb7c3597c3678e497aaca02f4487fce6e Mon Sep 17 00:00:00 2001 From: Shreya Malviya Date: Fri, 30 Sep 2022 14:20:22 +0530 Subject: [PATCH] UT: Assert mock_agent_event_queue.publish's call args in test_tcp_scanner.py --- .../network_scanning/test_tcp_scanner.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py index 578f57508..c96728658 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py @@ -2,9 +2,12 @@ from unittest.mock import MagicMock import pytest +from common.agent_events import TCPScanEvent from common.types import PortStatus +from infection_monkey.i_puppet import PortScanData from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN +from infection_monkey.utils.ids import get_agent_id PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] @@ -19,13 +22,28 @@ def patch_check_tcp_ports(monkeypatch, open_ports_data): ) +HOST_IP = "127.0.0.1" + + +def _get_tcp_scan_event(port_scan_data: PortScanData): + port_statuses = {} + for port, data in port_scan_data.items(): + port_statuses[port] = data.status + + return TCPScanEvent( + source=get_agent_id(), + target=HOST_IP, + ports=port_statuses, + ) + + @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) def test_tcp_successful( monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue ): closed_ports = [8080, 143, 445] - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data.keys(): @@ -38,14 +56,17 @@ def test_tcp_successful( assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].banner is None + event = _get_tcp_scan_event(port_scan_data) + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) @pytest.mark.parametrize("open_ports_data", [{}]) def test_tcp_empty_response( monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue ): - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue) assert len(port_scan_data) == 6 for port in open_ports_data: @@ -53,17 +74,24 @@ def test_tcp_empty_response( assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].banner is None + event = _get_tcp_scan_event(port_scan_data) + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) def test_tcp_no_ports_to_scan( monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue ): - port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue) + port_scan_data = scan_tcp_ports(HOST_IP, [], 0, mock_agent_event_queue) assert len(port_scan_data) == 0 + + event = _get_tcp_scan_event(port_scan_data) + assert mock_agent_event_queue.publish.call_count == 1 + mock_agent_event_queue.publish.assert_called_with(event) def test_exception_handling(monkeypatch, mock_agent_event_queue):