diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 6f3a7030f..73e3023db 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -1,4 +1,3 @@ -from copy import deepcopy from ipaddress import IPv4Interface from logging import getLogger from typing import Union @@ -98,16 +97,16 @@ class ScanEventHandler: ) def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent): - node_connections = dict(deepcopy(src_node.tcp_connections)) - machine_connections = set(node_connections.get(target_machine.id, set())) - open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN] + tcp_connections = set() + open_ports = (port for port, status in event.ports.items() if status == PortStatus.OPEN) for open_port in open_ports: socket_address = SocketAddress(ip=event.target, port=open_port) - machine_connections.add(socket_address) + tcp_connections.add(socket_address) - node_connections[target_machine.id] = tuple(machine_connections) - src_node.tcp_connections = node_connections - self._node_repository.upsert_node(src_node) + if tcp_connections: + self._node_repository.add_tcp_connections( + src_node.machine_id, {target_machine.id: tcp_connections} + ) def _get_source_machine(self, event: ScanEvent) -> Machine: agent = self._agent_repository.get_agent_by_id(event.source) diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py index ada8aac19..95404f7d1 100644 --- a/monkey/monkey_island/cc/models/node.py +++ b/monkey/monkey_island/cc/models/node.py @@ -1,4 +1,4 @@ -from typing import FrozenSet, Mapping, Tuple +from typing import Dict, FrozenSet, Mapping, Tuple from pydantic import Field from typing_extensions import TypeAlias @@ -9,6 +9,7 @@ from common.types import SocketAddress from . import CommunicationType, MachineID NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]] +TCPConnections: TypeAlias = Dict[MachineID, Tuple[SocketAddress, ...]] class Node(MutableInfectionMonkeyBaseModel): @@ -26,5 +27,5 @@ class Node(MutableInfectionMonkeyBaseModel): connections: NodeConnections """All outbound connections from this node to other machines""" - tcp_connections: Mapping[MachineID, Tuple[SocketAddress, ...]] = {} + tcp_connections: TCPConnections = {} """All successfull outbound TCP connections""" diff --git a/monkey/monkey_island/cc/repository/i_node_repository.py b/monkey/monkey_island/cc/repository/i_node_repository.py index 3738a1eb6..11983206c 100644 --- a/monkey/monkey_island/cc/repository/i_node_repository.py +++ b/monkey/monkey_island/cc/repository/i_node_repository.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Sequence from monkey_island.cc.models import CommunicationType, MachineID, Node +from monkey_island.cc.models.node import TCPConnections class INodeRepository(ABC): @@ -26,11 +27,12 @@ class INodeRepository(ABC): """ @abstractmethod - def upsert_node(self, node: Node): + def add_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections): """ - Store the Node object in the repository by creating a new one or updating an existing one. - :param node: Node that will be saved - :raises StorageError: If an error occurs while attempting to upsert the Node + Add TCP connections to Node + :param machine_id: Machine ID of the Node that made the connections + :param tcp_connections: TCP connections made by node + :raises StorageError: If an error occurs while attempting to add connections """ @abstractmethod diff --git a/monkey/monkey_island/cc/repository/mongo_node_repository.py b/monkey/monkey_island/cc/repository/mongo_node_repository.py index 248b0f973..599bddf61 100644 --- a/monkey/monkey_island/cc/repository/mongo_node_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_node_repository.py @@ -5,6 +5,7 @@ from pymongo import MongoClient from monkey_island.cc.models import CommunicationType, MachineID, Node +from ..models.node import TCPConnections from . import INodeRepository, RemovalError, RetrievalError, StorageError from .consts import MONGO_OBJECT_ID_KEY @@ -47,7 +48,17 @@ class MongoNodeRepository(INodeRepository): return new_node - def upsert_node(self, node: Node): + def add_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections): + node = self._get_node_by_id(machine_id) + + for target, connections in tcp_connections.items(): + if target in node.tcp_connections: + node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections}) + else: + node.tcp_connections[target] = connections + self._upsert_node(node) + + def _upsert_node(self, node: Node): try: result = self._nodes_collection.replace_one( {SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py index 160d00ae1..1d8e71869 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_scan_event_handler.py @@ -95,6 +95,13 @@ TCP_SCAN_EVENT = TCPScanEvent( ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED}, ) +TCP_CONNECTIONS = { + TARGET_MACHINE_ID: ( + SocketAddress(ip=TARGET_MACHINE_IP, port=22), + SocketAddress(ip=TARGET_MACHINE_IP, port=80), + ) +} + TCP_SCAN_EVENT_CLOSED = TCPScanEvent( source=AGENT_ID, target=IPv4Address(TARGET_MACHINE_IP), @@ -220,31 +227,29 @@ def test_tcp_scan_event_target_machine_not_exists( machine_repository.upsert_machine.assert_called_with(expected_machine) -def test_handle_tcp_scan_event__tcp_connections( +def test_handle_tcp_scan_event__no_open_ports( + scan_event_handler, machine_repository, node_repository +): + event = TCP_SCAN_EVENT_CLOSED + scan_event_handler._update_nodes = MagicMock() + scan_event_handler.handle_tcp_scan_event(event) + + assert not node_repository.add_tcp_connections.called + + +def test_handle_tcp_scan_event__ports_found( scan_event_handler, machine_repository, node_repository ): event = TCP_SCAN_EVENT scan_event_handler._update_nodes = MagicMock() scan_event_handler.handle_tcp_scan_event(event) - node_passed = node_repository.upsert_node.call_args[0][0] - assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set( - EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID] - ) - - -def test_handle_tcp_scan_event__tcp_connections_upsert( - scan_event_handler, machine_repository, node_repository -): - event = TCP_SCAN_EVENT - node_repository.get_nodes.return_value = [deepcopy(SOURCE_NODE_2)] - scan_event_handler._update_nodes = MagicMock() - scan_event_handler.handle_tcp_scan_event(event) - - node_passed = node_repository.upsert_node.call_args[0][0] - assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set( - EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID] - ) + call_args = node_repository.add_tcp_connections.call_args[0] + assert call_args[0] == MACHINE_ID + assert TARGET_MACHINE_ID in call_args[1] + open_socket_addresses = call_args[1][TARGET_MACHINE_ID] + assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID]) + assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID]) def test_handle_tcp_scan_event__no_source( diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py index 338526d76..d4b16cd46 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_node_repository.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock import mongomock import pytest +from common.types import SocketAddress from monkey_island.cc.models import CommunicationType, Node from monkey_island.cc.repository import ( INodeRepository, @@ -12,6 +13,14 @@ from monkey_island.cc.repository import ( StorageError, ) +TARGET_MACHINE_IP = "2.2.2.2" + +TCP_CONNECTION_PORT_22 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22),)} +TCP_CONNECTION_PORT_80 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=80),)} +ALL_TCP_CONNECTIONS = { + 3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22), SocketAddress(ip=TARGET_MACHINE_IP, port=80)) +} + NODES = ( Node( machine_id=1, @@ -23,6 +32,7 @@ NODES = ( Node( machine_id=2, connections={1: frozenset((CommunicationType.CC,))}, + tcp_connections=TCP_CONNECTION_PORT_22, ), Node( machine_id=3, @@ -32,10 +42,7 @@ NODES = ( 5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), }, ), - Node( - machine_id=4, - connections={}, - ), + Node(machine_id=4, connections={}, tcp_connections=ALL_TCP_CONNECTIONS), Node( machine_id=5, connections={ @@ -201,3 +208,27 @@ def test_reset(node_repository): def test_reset__removal_error(error_raising_node_repository): with pytest.raises(RemovalError): error_raising_node_repository.reset() + + +def test_upsert_tcp_connections__empty_connections(node_repository): + node_repository.add_tcp_connections(1, TCP_CONNECTION_PORT_22) + nodes = node_repository.get_nodes() + for node in nodes: + if node.machine_id == 1: + assert node.tcp_connections == TCP_CONNECTION_PORT_22 + + +def test_upsert_tcp_connections__upsert_new_port(node_repository): + node_repository.add_tcp_connections(2, TCP_CONNECTION_PORT_80) + nodes = node_repository.get_nodes() + modified_node = [node for node in nodes if node.machine_id == 2][0] + assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS) + assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS) + + +def test_upsert_tcp_connections__port_already_present(node_repository): + node_repository.add_tcp_connections(4, TCP_CONNECTION_PORT_80) + nodes = node_repository.get_nodes() + modified_node = [node for node in nodes if node.machine_id == 4][0] + assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS) + assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS)