diff --git a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py index eaea53ffb..f6add225f 100644 --- a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py +++ b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py @@ -1,9 +1,16 @@ from contextlib import suppress +from ipaddress import IPv4Address, IPv4Interface from typing import Optional from common import AgentRegistrationData -from monkey_island.cc.models import Agent, Machine -from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError +from common.network.network_utils import address_to_ip_port +from monkey_island.cc.models import Agent, CommunicationType, Machine +from monkey_island.cc.repository import ( + IAgentRepository, + IMachineRepository, + INodeRepository, + UnknownRecordError, +) class handle_agent_registration: @@ -11,13 +18,20 @@ class handle_agent_registration: Update repositories when a new agent registers """ - def __init__(self, machine_repository: IMachineRepository, agent_repository: IAgentRepository): + def __init__( + self, + machine_repository: IMachineRepository, + agent_repository: IAgentRepository, + node_repository: INodeRepository, + ): self._machine_repository = machine_repository self._agent_repository = agent_repository + self._node_repository = node_repository def __call__(self, agent_registration_data: AgentRegistrationData): machine = self._update_machine_repository(agent_registration_data) self._add_agent(agent_registration_data, machine) + self._add_node_communication(agent_registration_data, machine) def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine: machine = self._find_existing_machine_to_update(agent_registration_data) @@ -86,3 +100,25 @@ class handle_agent_registration: cc_server=agent_registration_data.cc_server, ) self._agent_repository.upsert_agent(new_agent) + + def _add_node_communication( + self, agent_registration_data: AgentRegistrationData, src_machine: Machine + ): + dst_machine = self._get_or_create_cc_machine(agent_registration_data.cc_server) + + self._node_repository.upsert_communication( + src_machine.id, dst_machine.id, CommunicationType.CC + ) + + def _get_or_create_cc_machine(self, cc_server: str) -> Machine: + dst_ip = IPv4Address(address_to_ip_port(cc_server)[0]) + + try: + return self._machine_repository.get_machines_by_ip(dst_ip)[0] + except UnknownRecordError: + new_machine = Machine( + id=self._machine_repository.get_new_id(), network_interfaces=[IPv4Interface(dst_ip)] + ) + self._machine_repository.upsert_machine(new_machine) + + return new_machine diff --git a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py index 91d392368..1ebfb59f7 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py +++ b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py @@ -8,8 +8,13 @@ import pytest from common import AgentRegistrationData from monkey_island.cc.island_event_handlers import handle_agent_registration -from monkey_island.cc.models import Agent, Machine -from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError +from monkey_island.cc.models import Agent, CommunicationType, Machine +from monkey_island.cc.repository import ( + IAgentRepository, + IMachineRepository, + INodeRepository, + UnknownRecordError, +) AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e") @@ -49,8 +54,25 @@ def agent_repository() -> IAgentRepository: @pytest.fixture -def handler(machine_repository, agent_repository) -> handle_agent_registration: - return handle_agent_registration(machine_repository, agent_repository) +def node_repository() -> INodeRepository: + node_repository = MagicMock(spec=INodeRepository) + node_repository.upsert_communication = MagicMock() + return node_repository + + +@pytest.fixture +def handler(machine_repository, agent_repository, node_repository) -> handle_agent_registration: + return handle_agent_registration(machine_repository, agent_repository, node_repository) + + +def build_get_machines_by_ip(ip_to_match: IPv4Address, machine_to_return: Machine): + def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]: + if ip == ip_to_match: + return [machine_to_return] + + raise UnknownRecordError + + return get_machines_by_ip def test_new_machine_added(handler, machine_repository): @@ -61,12 +83,10 @@ def test_new_machine_added(handler, machine_repository): ) machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) + handler(AGENT_REGISTRATION_DATA) - machine_repository.upsert_machine.assert_called_once() - new_machine = machine_repository.upsert_machine.call_args_list[0][0][0] - - assert new_machine == expected_machine + machine_repository.upsert_machine.assert_any_call(expected_machine) def test_existing_machine_updated__hardware_id(handler, machine_repository): @@ -79,10 +99,10 @@ def test_existing_machine_updated__hardware_id(handler, machine_repository): ], ) machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + handler(AGENT_REGISTRATION_DATA) - machine_repository.upsert_machine.assert_called_once() - machine_repository.upsert_machine.assert_called_with(expected_updated_machine) + machine_repository.upsert_machine.assert_any_call(expected_updated_machine) def test_existing_machine_updated__find_by_ip(handler, machine_repository): @@ -104,11 +124,9 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository): network_interfaces=[agent_registration_data.network_interfaces[-1]], ) - def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]: - if ip == existing_machine.network_interfaces[0].ip: - return [existing_machine] - - raise UnknownRecordError + get_machines_by_ip = build_get_machines_by_ip( + existing_machine.network_interfaces[0].ip, existing_machine + ) expected_updated_machine = existing_machine.copy() expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id @@ -119,8 +137,7 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository): handler(agent_registration_data) - machine_repository.upsert_machine.assert_called_once() - machine_repository.upsert_machine.assert_called_with(expected_updated_machine) + machine_repository.upsert_machine.assert_any_call(expected_updated_machine) def test_hardware_id_mismatch(handler, machine_repository): @@ -148,3 +165,39 @@ def test_add_agent(handler, agent_repository): handler(AGENT_REGISTRATION_DATA) agent_repository.upsert_agent.assert_called_with(expected_agent) + + +def test_add_node_connection(handler, machine_repository, node_repository): + island_machine = Machine( + id=1, + hardware_id=99, + island=True, + network_interfaces=[IPv4Interface("192.168.1.1/24")], + ) + get_machines_by_ip = build_get_machines_by_ip( + island_machine.network_interfaces[0].ip, island_machine + ) + machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip) + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + + handler(AGENT_REGISTRATION_DATA) + + node_repository.upsert_communication.assert_called_once() + node_repository.upsert_communication.assert_called_with( + MACHINE.id, island_machine.id, CommunicationType.CC + ) + + +def test_add_node_connection__unknown_server(handler, machine_repository, node_repository): + expected_new_server_machine = Machine( + id=SEED_ID, + network_interfaces=[IPv4Interface("192.168.1.1/32")], + ) + + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_called_with(expected_new_server_machine) + node_repository.upsert_communication.assert_called_with( + MACHINE.id, SEED_ID, CommunicationType.CC + )