UT: Add tests for ScanEventHandler
This commit is contained in:
parent
e4aec8b9a3
commit
2686a7a4ee
|
@ -0,0 +1,336 @@
|
||||||
|
from ipaddress import IPv4Address, IPv4Interface
|
||||||
|
from itertools import count
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common import OperatingSystem
|
||||||
|
from common.agent_events import PingScanEvent, TCPScanEvent
|
||||||
|
from common.types import PortStatus, SocketAddress
|
||||||
|
from monkey_island.cc.agent_event_handlers import ScanEventHandler
|
||||||
|
from monkey_island.cc.models import Agent, CommunicationType, Machine
|
||||||
|
from monkey_island.cc.repository import (
|
||||||
|
IAgentRepository,
|
||||||
|
IMachineRepository,
|
||||||
|
INodeRepository,
|
||||||
|
RetrievalError,
|
||||||
|
StorageError,
|
||||||
|
UnknownRecordError,
|
||||||
|
)
|
||||||
|
|
||||||
|
SEED_ID = 99
|
||||||
|
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
|
||||||
|
MACHINE_ID = 11
|
||||||
|
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
|
||||||
|
AGENT = Agent(id=AGENT_ID, machine_id=MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER)
|
||||||
|
SOURCE_MACHINE = Machine(
|
||||||
|
id=MACHINE_ID,
|
||||||
|
hardware_id=5,
|
||||||
|
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
||||||
|
)
|
||||||
|
TARGET_MACHINE = Machine(
|
||||||
|
id=33,
|
||||||
|
hardware_id=9,
|
||||||
|
network_interfaces=[IPv4Interface("10.10.10.1/24")],
|
||||||
|
)
|
||||||
|
|
||||||
|
PING_SCAN_EVENT = PingScanEvent(
|
||||||
|
source=AGENT_ID,
|
||||||
|
target=IPv4Address("10.10.10.1"),
|
||||||
|
response_received=True,
|
||||||
|
os=OperatingSystem.LINUX,
|
||||||
|
)
|
||||||
|
|
||||||
|
PING_SCAN_EVENT_NO_RESPONSE = PingScanEvent(
|
||||||
|
source=AGENT_ID,
|
||||||
|
target=IPv4Address("10.10.10.1"),
|
||||||
|
response_received=False,
|
||||||
|
os=OperatingSystem.LINUX,
|
||||||
|
)
|
||||||
|
|
||||||
|
PING_SCAN_EVENT_NO_OS = PingScanEvent(
|
||||||
|
source=AGENT_ID,
|
||||||
|
target=IPv4Address("10.10.10.1"),
|
||||||
|
response_received=True,
|
||||||
|
os=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TCP_SCAN_EVENT = TCPScanEvent(
|
||||||
|
source=AGENT_ID,
|
||||||
|
target=IPv4Address("10.10.10.1"),
|
||||||
|
ports={22: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
||||||
|
)
|
||||||
|
|
||||||
|
TCP_SCAN_EVENT_CLOSED = TCPScanEvent(
|
||||||
|
source=AGENT_ID,
|
||||||
|
target=IPv4Address("10.10.10.1"),
|
||||||
|
ports={145: PortStatus.CLOSED, 8080: PortStatus.CLOSED},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_repository() -> IAgentRepository:
|
||||||
|
agent_repository = MagicMock(spec=IAgentRepository)
|
||||||
|
agent_repository.upsert_agent = MagicMock()
|
||||||
|
agent_repository.get_agent_by_id = MagicMock(return_value=AGENT)
|
||||||
|
return agent_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def machine_repository() -> IMachineRepository:
|
||||||
|
machine_repository = MagicMock(spec=IMachineRepository)
|
||||||
|
machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID))
|
||||||
|
machine_repository.upsert_machine = MagicMock()
|
||||||
|
return machine_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def node_repository() -> INodeRepository:
|
||||||
|
node_repository = MagicMock(spec=INodeRepository)
|
||||||
|
node_repository.upsert_communication = MagicMock()
|
||||||
|
return node_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scan_event_handler(agent_repository, machine_repository, node_repository):
|
||||||
|
return ScanEventHandler(agent_repository, machine_repository, node_repository)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handle_ping_scan_event(scan_event_handler):
|
||||||
|
return scan_event_handler.handle_ping_scan_event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handle_tcp_scan_event(scan_event_handler):
|
||||||
|
return scan_event_handler.handle_tcp_scan_event
|
||||||
|
|
||||||
|
|
||||||
|
machines = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
||||||
|
machines_by_id = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
||||||
|
machines_by_ip = {
|
||||||
|
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
||||||
|
IPv4Address("10.10.10.1"): [TARGET_MACHINE],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[SOURCE_MACHINE.id, TARGET_MACHINE.id])
|
||||||
|
def machine_id(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
def machine_from_id(id: int):
|
||||||
|
return machines_by_id[id]
|
||||||
|
|
||||||
|
|
||||||
|
def machines_from_ip(ip: IPv4Address):
|
||||||
|
return machines_by_ip[ip]
|
||||||
|
|
||||||
|
|
||||||
|
class error_machine_by_id:
|
||||||
|
"""Raise an error if the machine with the called ID matches the stored ID"""
|
||||||
|
|
||||||
|
def __init__(self, id: int, error):
|
||||||
|
self.id = id
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
def __call__(self, id: int):
|
||||||
|
if id == self.id:
|
||||||
|
raise self.error
|
||||||
|
else:
|
||||||
|
return machine_from_id(id)
|
||||||
|
|
||||||
|
|
||||||
|
class error_machine_by_ip:
|
||||||
|
"""Raise an error if the machine with the called IP matches the stored ID"""
|
||||||
|
|
||||||
|
def __init__(self, id: int, error):
|
||||||
|
self.id = id
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
def __call__(self, ip: IPv4Address):
|
||||||
|
print(f"IP is: {ip}")
|
||||||
|
machines = machines_from_ip(ip)
|
||||||
|
if machines[0].id == self.id:
|
||||||
|
print(f"Raise error: {self.error}")
|
||||||
|
raise self.error
|
||||||
|
else:
|
||||||
|
print(f"Return machine: {machines}")
|
||||||
|
return machines
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__target_machine_not_exists(
|
||||||
|
event, handler, machine_repository: IMachineRepository, request
|
||||||
|
):
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=machine_from_id)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
expected_machine = Machine(id=SEED_ID, network_interfaces=[IPv4Interface(event.target)])
|
||||||
|
if event == PING_SCAN_EVENT:
|
||||||
|
expected_machine.operating_system = event.os
|
||||||
|
machine_repository.upsert_machine.assert_called_with(expected_machine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__upserts_node(
|
||||||
|
event,
|
||||||
|
handler,
|
||||||
|
machine_repository: IMachineRepository,
|
||||||
|
node_repository: INodeRepository,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=machine_from_id)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(return_value=[TARGET_MACHINE])
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
node_repository.upsert_communication.assert_called_with(
|
||||||
|
SOURCE_MACHINE.id, TARGET_MACHINE.id, CommunicationType.SCANNED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__node_not_upserted_if_no_matching_agent(
|
||||||
|
event,
|
||||||
|
handler,
|
||||||
|
agent_repository: IAgentRepository,
|
||||||
|
machine_repository: IMachineRepository,
|
||||||
|
node_repository: INodeRepository,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
agent_repository.get_agent_by_id = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(return_value=TARGET_MACHINE)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not node_repository.upsert_communication.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__node_not_upserted_if_machine_retrievalerror(
|
||||||
|
event,
|
||||||
|
handler,
|
||||||
|
machine_repository: IMachineRepository,
|
||||||
|
node_repository: INodeRepository,
|
||||||
|
request,
|
||||||
|
machine_id,
|
||||||
|
):
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(
|
||||||
|
side_effect=error_machine_by_id(machine_id, RetrievalError)
|
||||||
|
)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(
|
||||||
|
side_effect=error_machine_by_ip(machine_id, RetrievalError)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not node_repository.upsert_communication.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[
|
||||||
|
(PING_SCAN_EVENT_NO_OS, "handle_ping_scan_event"),
|
||||||
|
(TCP_SCAN_EVENT_CLOSED, "handle_tcp_scan_event"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__machine_not_upserted(
|
||||||
|
event, handler, machine_repository: IMachineRepository, request
|
||||||
|
):
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=machine_from_id)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=machines_from_ip)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not machine_repository.upsert_machine.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__machine_not_upserted_if_existing_machine_has_os(
|
||||||
|
event, handler, machine_repository: IMachineRepository, request
|
||||||
|
):
|
||||||
|
machine_with_os = TARGET_MACHINE
|
||||||
|
machine_with_os.operating_system = OperatingSystem.WINDOWS
|
||||||
|
machine_repository.get_machine_by_ip = MagicMock(return_value=machine_with_os)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not machine_repository.upsert_machine.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[(PING_SCAN_EVENT, "handle_ping_scan_event"), (TCP_SCAN_EVENT, "handle_tcp_scan_event")],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__node_not_upserted_if_machine_storageerror(
|
||||||
|
event,
|
||||||
|
handler,
|
||||||
|
machine_repository: IMachineRepository,
|
||||||
|
node_repository: INodeRepository,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
if event == PING_SCAN_EVENT:
|
||||||
|
target_machine = TARGET_MACHINE
|
||||||
|
target_machine.operating_system = None
|
||||||
|
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=machine_from_id)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
if event == PING_SCAN_EVENT:
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=machines_from_ip)
|
||||||
|
machine_repository.upsert_machine = MagicMock(side_effect=StorageError)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not node_repository.upsert_communication.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"event,handler",
|
||||||
|
[
|
||||||
|
(PING_SCAN_EVENT_NO_RESPONSE, "handle_ping_scan_event"),
|
||||||
|
(TCP_SCAN_EVENT_CLOSED, "handle_tcp_scan_event"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_scan_event_handler__failed_scan(
|
||||||
|
event,
|
||||||
|
handler,
|
||||||
|
machine_repository: IMachineRepository,
|
||||||
|
node_repository: INodeRepository,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
machine_repository.upsert_machine = MagicMock(side_effect=StorageError)
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=machine_from_id)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=machines_from_ip)
|
||||||
|
|
||||||
|
handler = request.getfixturevalue(handler)
|
||||||
|
handler(event)
|
||||||
|
|
||||||
|
assert not node_repository.upsert_communication.called
|
||||||
|
assert not machine_repository.upsert_machine.called
|
Loading…
Reference in New Issue