diff --git a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py index 9e3ead2b6..eaea53ffb 100644 --- a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py +++ b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py @@ -2,8 +2,8 @@ from contextlib import suppress from typing import Optional from common import AgentRegistrationData -from monkey_island.cc.models import Machine -from monkey_island.cc.repository import IMachineRepository, UnknownRecordError +from monkey_island.cc.models import Agent, Machine +from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError class handle_agent_registration: @@ -11,13 +11,15 @@ class handle_agent_registration: Update repositories when a new agent registers """ - def __init__(self, machine_repository: IMachineRepository): + def __init__(self, machine_repository: IMachineRepository, agent_repository: IAgentRepository): self._machine_repository = machine_repository + self._agent_repository = agent_repository def __call__(self, agent_registration_data: AgentRegistrationData): - self._update_machine_repository(agent_registration_data) + machine = self._update_machine_repository(agent_registration_data) + self._add_agent(agent_registration_data, machine) - def _update_machine_repository(self, agent_registration_data: AgentRegistrationData): + def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine: machine = self._find_existing_machine_to_update(agent_registration_data) if machine is None: @@ -25,6 +27,8 @@ class handle_agent_registration: self._upsert_machine(machine, agent_registration_data) + return machine + def _find_existing_machine_to_update( self, agent_registration_data: AgentRegistrationData ) -> Optional[Machine]: @@ -72,3 +76,13 @@ class handle_agent_registration: ) machine.network_interfaces = sorted(updated_network_interfaces) + + def _add_agent(self, agent_registration_data: AgentRegistrationData, machine: Machine): + new_agent = Agent( + id=agent_registration_data.id, + machine_id=machine.id, + start_time=agent_registration_data.start_time, + parent_id=agent_registration_data.parent_id, + cc_server=agent_registration_data.cc_server, + ) + self._agent_repository.upsert_agent(new_agent) diff --git a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py index 37e95b14f..91d392368 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py +++ b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py @@ -8,8 +8,8 @@ import pytest from common import AgentRegistrationData from monkey_island.cc.island_event_handlers import handle_agent_registration -from monkey_island.cc.models import Machine -from monkey_island.cc.repository import IMachineRepository, UnknownRecordError +from monkey_island.cc.models import Agent, Machine +from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e") @@ -36,12 +36,21 @@ def machine_repository() -> IMachineRepository: machine_repository = MagicMock(spec=IMachineRepository) machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID)) machine_repository.upsert_machine = MagicMock() + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) return machine_repository @pytest.fixture -def handler(machine_repository) -> handle_agent_registration: - return handle_agent_registration(machine_repository) +def agent_repository() -> IAgentRepository: + agent_repository = MagicMock(spec=IAgentRepository) + agent_repository.upsert_agent = MagicMock() + return agent_repository + + +@pytest.fixture +def handler(machine_repository, agent_repository) -> handle_agent_registration: + return handle_agent_registration(machine_repository, agent_repository) def test_new_machine_added(handler, machine_repository): @@ -126,3 +135,16 @@ def test_hardware_id_mismatch(handler, machine_repository): with pytest.raises(Exception): handler(AGENT_REGISTRATION_DATA) + + +def test_add_agent(handler, agent_repository): + expected_agent = Agent( + id=AGENT_REGISTRATION_DATA.id, + machine_id=SEED_ID, + start_time=AGENT_REGISTRATION_DATA.start_time, + parent_id=AGENT_REGISTRATION_DATA.parent_id, + cc_server=AGENT_REGISTRATION_DATA.cc_server, + ) + handler(AGENT_REGISTRATION_DATA) + + agent_repository.upsert_agent.assert_called_with(expected_agent)