diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 272749807..3c43ddd92 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -90,15 +90,15 @@ class ScanEventHandler: self._machine_repository.upsert_machine(machine) def _update_network_services(self, target: Machine, event: TCPScanEvent): - for port in self._get_open_ports(event): - socket_addr = SocketAddress(ip=event.target, port=port) - target.network_services[socket_addr] = NetworkService.UNKNOWN - - self._machine_repository.upsert_machine(target) + network_services = { + SocketAddress(ip=event.target, port=port): NetworkService.UNKNOWN + for port in self._get_open_ports(event) + } + self._machine_repository.upsert_network_services(target.id, network_services) @staticmethod def _get_open_ports(event: TCPScanEvent) -> List[int]: - return (port for port, status in event.ports.items() if status == PortStatus.OPEN) + return [port for port, status in event.ports.items() if status == PortStatus.OPEN] def _update_nodes(self, target_machine: Machine, event: ScanEvent): src_machine = self._get_source_machine(event) diff --git a/monkey/monkey_island/cc/models/machine.py b/monkey/monkey_island/cc/models/machine.py index a4dfbc982..c5144fd97 100644 --- a/monkey/monkey_island/cc/models/machine.py +++ b/monkey/monkey_island/cc/models/machine.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Interface from typing import Any, Dict, Mapping, Optional, Sequence from pydantic import Field, validator +from typing_extensions import TypeAlias from common import OperatingSystem from common.base_models import MutableInfectionMonkeyBaseModel, MutableInfectionMonkeyModelConfig @@ -11,6 +12,8 @@ from common.types import HardwareID, NetworkService, SocketAddress from . import MachineID +NetworkServices: TypeAlias = Dict[SocketAddress, NetworkService] + def _serialize_network_services(machine_dict: Dict, *, default): machine_dict["network_services"] = { @@ -61,7 +64,7 @@ class Machine(MutableInfectionMonkeyBaseModel): hostname: str = "" """The hostname of the machine""" - network_services: Mapping[SocketAddress, NetworkService] = Field(default_factory=dict) + network_services: NetworkServices = Field(default_factory=dict) """All network services found running on the machine""" _make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)( diff --git a/monkey/monkey_island/cc/repository/i_machine_repository.py b/monkey/monkey_island/cc/repository/i_machine_repository.py index 7cea0bb02..fb69cddad 100644 --- a/monkey/monkey_island/cc/repository/i_machine_repository.py +++ b/monkey/monkey_island/cc/repository/i_machine_repository.py @@ -4,6 +4,7 @@ from typing import Sequence from common.types import HardwareID from monkey_island.cc.models import Machine, MachineID +from monkey_island.cc.models.machine import NetworkServices class IMachineRepository(ABC): @@ -29,6 +30,16 @@ class IMachineRepository(ABC): :raises StorageError: If an error occurs while attempting to store the `Machine` """ + @abstractmethod + def upsert_network_services(self, machine_id: MachineID, services: NetworkServices): + """ + Add/update network services on the machine + :param machine_id: ID of machine with services to be updated + :param services: Network services to be added to machine model + :raises UnknownRecordError: If the Machine is not found + :raises StorageError: If an error occurs while attempting to add/store the services + """ + @abstractmethod def get_machine_by_id(self, machine_id: MachineID) -> Machine: """ diff --git a/monkey/monkey_island/cc/repository/mongo_machine_repository.py b/monkey/monkey_island/cc/repository/mongo_machine_repository.py index a3297449e..00856139e 100644 --- a/monkey/monkey_island/cc/repository/mongo_machine_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_machine_repository.py @@ -7,8 +7,10 @@ from pymongo import MongoClient from common.types import HardwareID from monkey_island.cc.models import Machine, MachineID +from ..models.machine import NetworkServices from . import IMachineRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError from .consts import MONGO_OBJECT_ID_KEY +from .utils import DOT_REPLACEMENT, mongo_dot_decoder, mongo_dot_encoder class MongoMachineRepository(IMachineRepository): @@ -32,8 +34,9 @@ class MongoMachineRepository(IMachineRepository): def upsert_machine(self, machine: Machine): try: + machine_dict = mongo_dot_encoder(machine.dict(simplify=True)) result = self._machines_collection.replace_one( - {"id": machine.id}, machine.dict(simplify=True), upsert=True + {"id": machine.id}, machine_dict, upsert=True ) except Exception as err: raise StorageError(f'Error updating machine with ID "{machine.id}": {err}') @@ -44,8 +47,19 @@ class MongoMachineRepository(IMachineRepository): f"but no machines were inserted" ) + def upsert_network_services(self, machine_id: MachineID, services: NetworkServices): + machine = self.get_machine_by_id(machine_id) + try: + machine.network_services.update(services) + self.upsert_machine(machine) + except Exception as err: + raise StorageError(f"Failed upserting the machine or adding services") from err + def get_machine_by_id(self, machine_id: MachineID) -> Machine: - return self._find_one("id", machine_id) + machine = self._find_one("id", machine_id) + if not machine: + raise UnknownRecordError(f"Machine with id {machine_id} not found") + return machine def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine: return self._find_one("hardware_id", hardware_id) @@ -61,6 +75,7 @@ class MongoMachineRepository(IMachineRepository): if machine_dict is None: raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"') + machine_dict = mongo_dot_decoder(machine_dict) return Machine(**machine_dict) def get_machines(self) -> Sequence[Machine]: @@ -69,10 +84,10 @@ class MongoMachineRepository(IMachineRepository): except Exception as err: raise RetrievalError(f"Error retrieving machines: {err}") - return [Machine(**m) for m in cursor] + return [Machine(**mongo_dot_decoder(m)) for m in cursor] def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: - ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" + ip_regex = "^" + str(ip).replace(".", DOT_REPLACEMENT) + "\\/.*$" query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} try: @@ -80,7 +95,7 @@ class MongoMachineRepository(IMachineRepository): except Exception as err: raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}') - machines = [Machine(**m) for m in cursor] + machines = [Machine(**mongo_dot_decoder(m)) for m in cursor] if len(machines) == 0: raise UnknownRecordError(f'No machines found with IP "{ip}"') diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index 673f8293c..55b8f1bce 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -8,7 +8,7 @@ import pytest from common import OperatingSystem from common.agent_events import PingScanEvent, TCPScanEvent -from common.types import PortStatus, SocketAddress +from common.types import NetworkService, PortStatus, SocketAddress from monkey_island.cc.agent_event_handlers import ScanEventHandler from monkey_island.cc.models import Agent, CommunicationType, Machine, Node from monkey_island.cc.repository import ( @@ -74,6 +74,11 @@ TCP_SCAN_EVENT = TCPScanEvent( ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED}, ) +EXPECTED_NETWORK_SERVICES = { + SocketAddress(ip=TARGET_MACHINE_IP, port=22): NetworkService.UNKNOWN, + SocketAddress(ip=TARGET_MACHINE_IP, port=80): NetworkService.UNKNOWN, +} + TCP_CONNECTIONS = { TARGET_MACHINE_ID: ( SocketAddress(ip=TARGET_MACHINE_IP, port=22), @@ -382,3 +387,11 @@ def test_failed_scan( assert not node_repository.upsert_communication.called assert not machine_repository.upsert_machine.called + + +def test_network_services_handling(scan_event_handler, machine_repository): + scan_event_handler.handle_tcp_scan_event(TCP_SCAN_EVENT) + + machine_repository.upsert_network_services.assert_called_with( + TARGET_MACHINE_ID, EXPECTED_NETWORK_SERVICES + ) diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py index 7e4ef93a8..d03cdb0ce 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_machine_repository.py @@ -6,6 +6,7 @@ import mongomock import pytest from common import OperatingSystem +from common.types import NetworkService, SocketAddress from monkey_island.cc.models import Machine from monkey_island.cc.repository import ( IMachineRepository, @@ -15,6 +16,7 @@ from monkey_island.cc.repository import ( StorageError, UnknownRecordError, ) +from monkey_island.cc.repository.utils import mongo_dot_encoder MACHINES = ( Machine( @@ -32,6 +34,10 @@ MACHINES = ( operating_system=OperatingSystem.WINDOWS, operating_system_version="eXtra Problems", hostname="hal", + network_services={ + SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN, + SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN, + }, ), Machine( id=3, @@ -40,6 +46,10 @@ MACHINES = ( operating_system=OperatingSystem.WINDOWS, operating_system_version="Vista", hostname="smith", + network_services={ + SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN, + SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN, + }, ), Machine( id=4, @@ -51,11 +61,24 @@ MACHINES = ( ), ) +SERVICES_TO_ADD = { + SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN, + SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN, +} + +EXPECTED_SERVICES_1 = EXPECTED_SERVICES_3 = SERVICES_TO_ADD +EXPECTED_SERVICES_2 = { + **SERVICES_TO_ADD, + SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN, +} + @pytest.fixture def mongo_client() -> mongomock.MongoClient: client = mongomock.MongoClient() - client.monkey_island.machines.insert_many((m.dict(simplify=True) for m in MACHINES)) + client.monkey_island.machines.insert_many( + (mongo_dot_encoder(m.dict(simplify=True)) for m in MACHINES) + ) return client @@ -264,3 +287,27 @@ def test_usable_after_reset(machine_repository): def test_reset__removal_error(error_raising_machine_repository): with pytest.raises(RemovalError): error_raising_machine_repository.reset() + + +@pytest.mark.parametrize( + "machine_id, expected_services", + [ + (MACHINES[0].id, EXPECTED_SERVICES_1), + (MACHINES[1].id, EXPECTED_SERVICES_2), + (MACHINES[2].id, EXPECTED_SERVICES_3), + ], +) +def test_service_upsert(machine_id, expected_services, machine_repository): + machine_repository.upsert_network_services(machine_id, SERVICES_TO_ADD) + assert machine_repository.get_machine_by_id(machine_id).network_services == expected_services + + +def test_service_upsert__machine_not_found(machine_repository): + with pytest.raises(UnknownRecordError): + machine_repository.upsert_network_services(machine_id=999, services=SERVICES_TO_ADD) + + +def test_service_upsert__error_on_storage(machine_repository): + malformed_services = 3 + with pytest.raises(StorageError): + machine_repository.upsert_network_services(MACHINES[0].id, malformed_services)