diff --git a/monkey/monkey_island/cc/agent_event_handlers/node_update_facade.py b/monkey/monkey_island/cc/agent_event_handlers/node_update_facade.py index 9659444a2..bdcfa629c 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/node_update_facade.py +++ b/monkey/monkey_island/cc/agent_event_handlers/node_update_facade.py @@ -1,11 +1,15 @@ +from functools import lru_cache from ipaddress import IPv4Address, IPv4Interface +from common.agent_events import AbstractAgentEvent +from common.types import AgentID, MachineID from monkey_island.cc.models import Machine -from monkey_island.cc.repository import IMachineRepository, UnknownRecordError +from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError class NodeUpdateFacade: - def __init__(self, machine_repository: IMachineRepository): + def __init__(self, agent_repository: IAgentRepository, machine_repository: IMachineRepository): + self._agent_repository = agent_repository self._machine_repository = machine_repository def get_or_create_target_machine(self, target: IPv4Address): @@ -19,3 +23,11 @@ class NodeUpdateFacade: ) self._machine_repository.upsert_machine(machine) return machine + + def get_event_source_machine(self, event: AbstractAgentEvent) -> Machine: + machine_id = self._get_machine_id_from_agent_id(event.source) + return self._machine_repository.get_machine_by_id(machine_id) + + @lru_cache(maxsize=None) + def _get_machine_id_from_agent_id(self, agent_id: AgentID) -> MachineID: + return self._agent_repository.get_agent_by_id(agent_id).machine_id diff --git a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py index 7b35f3c6d..45ebafaad 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py +++ b/monkey/monkey_island/cc/agent_event_handlers/scan_event_handler.py @@ -33,7 +33,7 @@ class ScanEventHandler: machine_repository: IMachineRepository, node_repository: INodeRepository, ): - self._node_update_facade = NodeUpdateFacade(machine_repository) + self._node_update_facade = NodeUpdateFacade(agent_repository, machine_repository) self._agent_repository = agent_repository self._machine_repository = machine_repository self._node_repository = node_repository diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_node_update_facade.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_node_update_facade.py index f8e6387ab..1b854e0da 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_node_update_facade.py +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_node_update_facade.py @@ -1,45 +1,87 @@ from ipaddress import IPv4Address, IPv4Interface from unittest.mock import MagicMock +from uuid import UUID import pytest +from common.agent_events import AbstractAgentEvent +from common.types import AgentID, MachineID, SocketAddress from monkey_island.cc.agent_event_handlers.node_update_facade import NodeUpdateFacade -from monkey_island.cc.models import Machine -from monkey_island.cc.repository import IMachineRepository, UnknownRecordError +from monkey_island.cc.models import Agent, Machine +from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError + + +class TestEvent(AbstractAgentEvent): + success: bool + SEED_ID = 99 IP_ADDRESS = IPv4Address("10.10.10.99") -EXISTING_MACHINE = Machine( - id=1, +SOURCE_MACHINE_ID = 1 +SOURCE_MACHINE = Machine( + id=SOURCE_MACHINE_ID, hardware_id=5, network_interfaces=[IPv4Interface(IP_ADDRESS)], ) +SOURCE_AGENT_ID = UUID("655fd01c-5eec-4e42-b6e3-1fb738c2978d") +SOURCE_AGENT = Agent( + id=SOURCE_AGENT_ID, + machine_id=SOURCE_MACHINE_ID, + start_time=0, + parent_id=None, + cc_server=(SocketAddress(ip="10.10.10.10", port=5000)), +) + EXPECTED_CREATED_MACHINE = Machine( id=SEED_ID, network_interfaces=[IPv4Interface(IP_ADDRESS)], ) +TEST_EVENT = TestEvent(source=SOURCE_AGENT_ID, success=True) + + +@pytest.fixture +def agent_repository() -> IAgentRepository: + def get_agent_by_id(agent_id: AgentID) -> Agent: + if agent_id == SOURCE_AGENT_ID: + return SOURCE_AGENT + + raise UnknownRecordError() + + agent_repository = MagicMock(spec=IAgentRepository) + agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id) + return agent_repository + @pytest.fixture def machine_repository() -> IMachineRepository: + def get_machine_by_id(machine_id: MachineID) -> Machine: + if machine_id == SOURCE_MACHINE_ID: + return SOURCE_MACHINE + + raise UnknownRecordError() + machine_repository = MagicMock(spec=IMachineRepository) machine_repository.get_new_id = MagicMock(return_value=SEED_ID) + machine_repository.get_machine_by_id = MagicMock(side_effect=get_machine_by_id) return machine_repository @pytest.fixture -def node_update_facade(machine_repository) -> NodeUpdateFacade: - return NodeUpdateFacade(machine_repository) +def node_update_facade( + agent_repository: IAgentRepository, machine_repository: IMachineRepository +) -> NodeUpdateFacade: + return NodeUpdateFacade(agent_repository, machine_repository) def test_return_existing_machine(node_update_facade, machine_repository): - machine_repository.get_machines_by_ip = MagicMock(return_value=[EXISTING_MACHINE]) + machine_repository.get_machines_by_ip = MagicMock(return_value=[SOURCE_MACHINE]) target_machine = node_update_facade.get_or_create_target_machine(IP_ADDRESS) - assert target_machine == EXISTING_MACHINE + assert target_machine == SOURCE_MACHINE def test_create_new_machine(node_update_facade, machine_repository): @@ -49,3 +91,7 @@ def test_create_new_machine(node_update_facade, machine_repository): assert target_machine == EXPECTED_CREATED_MACHINE assert machine_repository.upsert_machine.called_once_with(target_machine) + + +def test_get_event_source_machine(node_update_facade): + assert node_update_facade.get_event_source_machine(TEST_EVENT) == SOURCE_MACHINE