diff --git a/monkey/monkey_island/cc/island_event_handlers/__init__.py b/monkey/monkey_island/cc/island_event_handlers/__init__.py index 11343c2fa..bd390f447 100644 --- a/monkey/monkey_island/cc/island_event_handlers/__init__.py +++ b/monkey/monkey_island/cc/island_event_handlers/__init__.py @@ -1,3 +1,4 @@ +from .handle_agent_registration import handle_agent_registration from .reset_agent_configuration import reset_agent_configuration from .reset_machine_repository import reset_machine_repository from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode diff --git a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py new file mode 100644 index 000000000..9e3ead2b6 --- /dev/null +++ b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py @@ -0,0 +1,74 @@ +from contextlib import suppress +from typing import Optional + +from common import AgentRegistrationData +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import IMachineRepository, UnknownRecordError + + +class handle_agent_registration: + """ + Update repositories when a new agent registers + """ + + def __init__(self, machine_repository: IMachineRepository): + self._machine_repository = machine_repository + + def __call__(self, agent_registration_data: AgentRegistrationData): + self._update_machine_repository(agent_registration_data) + + def _update_machine_repository(self, agent_registration_data: AgentRegistrationData): + machine = self._find_existing_machine_to_update(agent_registration_data) + + if machine is None: + machine = Machine(id=self._machine_repository.get_new_id()) + + self._upsert_machine(machine, agent_registration_data) + + def _find_existing_machine_to_update( + self, agent_registration_data: AgentRegistrationData + ) -> Optional[Machine]: + with suppress(UnknownRecordError): + return self._machine_repository.get_machine_by_hardware_id( + agent_registration_data.machine_hardware_id + ) + + for network_interface in agent_registration_data.network_interfaces: + with suppress(UnknownRecordError): + # NOTE: For now, assume IPs are unique. In reality, two machines could share the + # same IP if there's a router between them. + return self._machine_repository.get_machines_by_ip(network_interface.ip)[0] + + return None + + def _upsert_machine( + self, existing_machine: Machine, agent_registration_data: AgentRegistrationData + ): + updated_machine = existing_machine.copy() + + self._update_hardware_id(updated_machine, agent_registration_data) + self._update_network_interfaces(updated_machine, agent_registration_data) + + self._machine_repository.upsert_machine(updated_machine) + + def _update_hardware_id(self, machine: Machine, agent_registration_data: AgentRegistrationData): + if ( + machine.hardware_id is not None + and machine.hardware_id != agent_registration_data.machine_hardware_id + ): + raise Exception( + f"Hardware ID mismatch:\n\tMachine: {machine}\n\t" + f"AgentRegistrationData: {agent_registration_data}" + ) + + machine.hardware_id = agent_registration_data.machine_hardware_id + + def _update_network_interfaces( + self, machine: Machine, agent_registration_data: AgentRegistrationData + ): + updated_network_interfaces = set(machine.network_interfaces) + updated_network_interfaces = updated_network_interfaces.union( + agent_registration_data.network_interfaces + ) + + machine.network_interfaces = sorted(updated_network_interfaces) diff --git a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py new file mode 100644 index 000000000..37e95b14f --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py @@ -0,0 +1,128 @@ +from ipaddress import IPv4Address, IPv4Interface +from itertools import count +from typing import Sequence +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +from common import AgentRegistrationData +from monkey_island.cc.island_event_handlers import handle_agent_registration +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import IMachineRepository, UnknownRecordError + +AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e") + +SEED_ID = 10 + +MACHINE = Machine( + id=2, + hardware_id=5, + network_interfaces=[IPv4Interface("192.168.2.2/24")], +) + +AGENT_REGISTRATION_DATA = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=MACHINE.hardware_id, + start_time=0, + parent_id=None, + cc_server="192.168.1.1:5000", + network_interfaces=[IPv4Interface("192.168.1.2/24")], +) + + +@pytest.fixture +def machine_repository() -> IMachineRepository: + machine_repository = MagicMock(spec=IMachineRepository) + machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID)) + machine_repository.upsert_machine = MagicMock() + return machine_repository + + +@pytest.fixture +def handler(machine_repository) -> handle_agent_registration: + return handle_agent_registration(machine_repository) + + +def test_new_machine_added(handler, machine_repository): + expected_machine = Machine( + id=SEED_ID, + hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id, + network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces, + ) + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_called_once() + new_machine = machine_repository.upsert_machine.call_args_list[0][0][0] + + assert new_machine == expected_machine + + +def test_existing_machine_updated__hardware_id(handler, machine_repository): + expected_updated_machine = Machine( + id=MACHINE.id, + hardware_id=MACHINE.hardware_id, + network_interfaces=[ + AGENT_REGISTRATION_DATA.network_interfaces[0], + MACHINE.network_interfaces[0], + ], + ) + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_called_once() + machine_repository.upsert_machine.assert_called_with(expected_updated_machine) + + +def test_existing_machine_updated__find_by_ip(handler, machine_repository): + agent_registration_data = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=5, + start_time=0, + parent_id=None, + cc_server="192.168.1.1:5000", + network_interfaces=[ + IPv4Interface("192.168.1.2/24"), + IPv4Interface("192.168.1.4/24"), + IPv4Interface("192.168.1.5/24"), + ], + ) + + existing_machine = Machine( + id=1, + network_interfaces=[agent_registration_data.network_interfaces[-1]], + ) + + def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]: + if ip == existing_machine.network_interfaces[0].ip: + return [existing_machine] + + raise UnknownRecordError + + expected_updated_machine = existing_machine.copy() + expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id + expected_updated_machine.network_interfaces = agent_registration_data.network_interfaces + + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip) + + handler(agent_registration_data) + + machine_repository.upsert_machine.assert_called_once() + machine_repository.upsert_machine.assert_called_with(expected_updated_machine) + + +def test_hardware_id_mismatch(handler, machine_repository): + existing_machine = Machine( + id=1, + hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id + 99, + network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces, + ) + + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(return_value=[existing_machine]) + + with pytest.raises(Exception): + handler(AGENT_REGISTRATION_DATA)