UT: Assert mock_agent_event_queue.publish's call args in test_tcp_scanner.py

This commit is contained in:
Shreya Malviya 2022-09-30 14:20:22 +05:30
parent 0bf9309e07
commit 0a11d34fb7
1 changed files with 31 additions and 3 deletions

View File

@ -2,9 +2,12 @@ from unittest.mock import MagicMock
import pytest import pytest
from common.agent_events import TCPScanEvent
from common.types import PortStatus from common.types import PortStatus
from infection_monkey.i_puppet import PortScanData
from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning import scan_tcp_ports
from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN
from infection_monkey.utils.ids import get_agent_id
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
@ -19,13 +22,28 @@ def patch_check_tcp_ports(monkeypatch, open_ports_data):
) )
HOST_IP = "127.0.0.1"
def _get_tcp_scan_event(port_scan_data: PortScanData):
port_statuses = {}
for port, data in port_scan_data.items():
port_statuses[port] = data.status
return TCPScanEvent(
source=get_agent_id(),
target=HOST_IP,
ports=port_statuses,
)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_successful( def test_tcp_successful(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
): ):
closed_ports = [8080, 143, 445] closed_ports = [8080, 143, 445]
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data.keys(): for port in open_ports_data.keys():
@ -38,14 +56,17 @@ def test_tcp_successful(
assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1 assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response( def test_tcp_empty_response(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
): ):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data: for port in open_ports_data:
@ -53,17 +74,24 @@ def test_tcp_empty_response(
assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1 assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan( def test_tcp_no_ports_to_scan(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
): ):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, [], 0, mock_agent_event_queue)
assert len(port_scan_data) == 0 assert len(port_scan_data) == 0
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1 assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
def test_exception_handling(monkeypatch, mock_agent_event_queue): def test_exception_handling(monkeypatch, mock_agent_event_queue):