diff --git a/monkey/monkey_island/cc/event_queue/i_island_event_queue.py b/monkey/monkey_island/cc/event_queue/i_island_event_queue.py index bf1fcf2cc..cf123c3e2 100644 --- a/monkey/monkey_island/cc/event_queue/i_island_event_queue.py +++ b/monkey/monkey_island/cc/event_queue/i_island_event_queue.py @@ -5,7 +5,7 @@ from . import IslandEventSubscriber class IslandEventTopic(Enum): - AGENT_CONNECTED = auto() + AGENT_REGISTERED = auto() CLEAR_SIMULATION_DATA = auto() RESET_AGENT_CONFIGURATION = auto() SET_ISLAND_MODE = auto() diff --git a/monkey/monkey_island/cc/island_event_handlers/__init__.py b/monkey/monkey_island/cc/island_event_handlers/__init__.py index 11343c2fa..bd390f447 100644 --- a/monkey/monkey_island/cc/island_event_handlers/__init__.py +++ b/monkey/monkey_island/cc/island_event_handlers/__init__.py @@ -1,3 +1,4 @@ +from .handle_agent_registration import handle_agent_registration from .reset_agent_configuration import reset_agent_configuration from .reset_machine_repository import reset_machine_repository from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode diff --git a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py new file mode 100644 index 000000000..363335c5f --- /dev/null +++ b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py @@ -0,0 +1,130 @@ +from contextlib import suppress +from ipaddress import IPv4Address, IPv4Interface +from typing import List, Optional + +from common import AgentRegistrationData +from common.network.network_utils import address_to_ip_port +from monkey_island.cc.models import Agent, CommunicationType, Machine +from monkey_island.cc.repository import ( + IAgentRepository, + IMachineRepository, + INodeRepository, + UnknownRecordError, +) + + +class handle_agent_registration: + """ + Update repositories when a new agent registers + """ + + def __init__( + self, + machine_repository: IMachineRepository, + agent_repository: IAgentRepository, + node_repository: INodeRepository, + ): + self._machine_repository = machine_repository + self._agent_repository = agent_repository + self._node_repository = node_repository + + def __call__(self, agent_registration_data: AgentRegistrationData): + machine = self._update_machine_repository(agent_registration_data) + self._add_agent(agent_registration_data, machine) + self._add_node_communication(agent_registration_data, machine) + + def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine: + machine = self._find_existing_machine_to_update(agent_registration_data) + + if machine is None: + machine = Machine(id=self._machine_repository.get_new_id()) + + self._upsert_machine(machine, agent_registration_data) + + return machine + + def _find_existing_machine_to_update( + self, agent_registration_data: AgentRegistrationData + ) -> Optional[Machine]: + with suppress(UnknownRecordError): + return self._machine_repository.get_machine_by_hardware_id( + agent_registration_data.machine_hardware_id + ) + + for network_interface in agent_registration_data.network_interfaces: + with suppress(UnknownRecordError): + # NOTE: For now, assume IPs are unique. In reality, two machines could share the + # same IP if there's a router between them. + return self._machine_repository.get_machines_by_ip(network_interface.ip)[0] + + return None + + def _upsert_machine(self, machine: Machine, agent_registration_data: AgentRegistrationData): + self._update_hardware_id(machine, agent_registration_data) + self._update_network_interfaces(machine, agent_registration_data) + + self._machine_repository.upsert_machine(machine) + + def _update_hardware_id(self, machine: Machine, agent_registration_data: AgentRegistrationData): + if ( + machine.hardware_id is not None + and machine.hardware_id != agent_registration_data.machine_hardware_id + ): + raise Exception( + f"Hardware ID mismatch:\n\tMachine: {machine}\n\t" + f"AgentRegistrationData: {agent_registration_data}" + ) + + machine.hardware_id = agent_registration_data.machine_hardware_id + + def _update_network_interfaces( + self, machine: Machine, agent_registration_data: AgentRegistrationData + ): + updated_network_interfaces: List[IPv4Interface] = [] + agent_registration_data_ips = set( + map(lambda iface: iface.ip, agent_registration_data.network_interfaces) + ) + + # Prefer interfaces provided by the AgentRegistrationData to those in the Machine record. + # The AgentRegistrationData was collected while running on the machine, whereas the Machine + # data may have only been collected from a scan. For example, the Machine and + # AgentRedistrationData may have the same IP with a different subnet mask. + for interface in machine.network_interfaces: + if interface.ip not in agent_registration_data_ips: + updated_network_interfaces.append(interface) + + updated_network_interfaces.extend(agent_registration_data.network_interfaces) + + machine.network_interfaces = sorted(updated_network_interfaces) + + def _add_agent(self, agent_registration_data: AgentRegistrationData, machine: Machine): + new_agent = Agent( + id=agent_registration_data.id, + machine_id=machine.id, + start_time=agent_registration_data.start_time, + parent_id=agent_registration_data.parent_id, + cc_server=agent_registration_data.cc_server, + ) + self._agent_repository.upsert_agent(new_agent) + + def _add_node_communication( + self, agent_registration_data: AgentRegistrationData, src_machine: Machine + ): + dst_machine = self._get_or_create_cc_machine(agent_registration_data.cc_server) + + self._node_repository.upsert_communication( + src_machine.id, dst_machine.id, CommunicationType.CC + ) + + def _get_or_create_cc_machine(self, cc_server: str) -> Machine: + dst_ip = IPv4Address(address_to_ip_port(cc_server)[0]) + + try: + return self._machine_repository.get_machines_by_ip(dst_ip)[0] + except UnknownRecordError: + new_machine = Machine( + id=self._machine_repository.get_new_id(), network_interfaces=[IPv4Interface(dst_ip)] + ) + self._machine_repository.upsert_machine(new_machine) + + return new_machine diff --git a/monkey/monkey_island/cc/resources/agents.py b/monkey/monkey_island/cc/resources/agents.py index b727bd9ae..9d6f701ec 100644 --- a/monkey/monkey_island/cc/resources/agents.py +++ b/monkey/monkey_island/cc/resources/agents.py @@ -5,6 +5,7 @@ from http import HTTPStatus from flask import make_response, request from common import AgentRegistrationData +from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic from monkey_island.cc.resources.AbstractResource import AbstractResource logger = logging.getLogger(__name__) @@ -13,12 +14,18 @@ logger = logging.getLogger(__name__) class Agents(AbstractResource): urls = ["/api/agents"] + def __init__(self, island_event_queue: IIslandEventQueue): + self._island_event_queue = island_event_queue + def post(self): try: # Just parse for now agent_registration_data = AgentRegistrationData(**request.json) logger.debug(f"Agent registered: {agent_registration_data}") + self._island_event_queue.publish( + IslandEventTopic.AGENT_REGISTERED, agent_registration_data=agent_registration_data + ) return make_response({}, HTTPStatus.NO_CONTENT) except (TypeError, ValueError, json.JSONDecodeError) as err: diff --git a/monkey/monkey_island/cc/setup/island_event_handlers.py b/monkey/monkey_island/cc/setup/island_event_handlers.py index ee37568f1..6b4cb8c53 100644 --- a/monkey/monkey_island/cc/setup/island_event_handlers.py +++ b/monkey/monkey_island/cc/setup/island_event_handlers.py @@ -3,6 +3,7 @@ from functools import partial from common import DIContainer from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic from monkey_island.cc.island_event_handlers import ( + handle_agent_registration, reset_agent_configuration, reset_machine_repository, set_agent_configuration_per_island_mode, @@ -20,11 +21,20 @@ from monkey_island.cc.services.database import Database def setup_island_event_handlers(container: DIContainer): island_event_queue = container.resolve(IIslandEventQueue) + _subscribe_agent_registration_events(island_event_queue, container) _subscribe_reset_agent_configuration_events(island_event_queue, container) _subscribe_clear_simulation_data_events(island_event_queue, container) _subscribe_set_island_mode_events(island_event_queue, container) +def _subscribe_agent_registration_events( + island_event_queue: IIslandEventQueue, container: DIContainer +): + topic = IslandEventTopic.AGENT_REGISTERED + + island_event_queue.subscribe(topic, container.resolve(handle_agent_registration)) + + def _subscribe_reset_agent_configuration_events( island_event_queue: IIslandEventQueue, container: DIContainer ): diff --git a/monkey/tests/unit_tests/monkey_island/cc/event_queue/test_pypubsub_island_event_queue.py b/monkey/tests/unit_tests/monkey_island/cc/event_queue/test_pypubsub_island_event_queue.py index 206009727..91af47fa0 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/event_queue/test_pypubsub_island_event_queue.py +++ b/monkey/tests/unit_tests/monkey_island/cc/event_queue/test_pypubsub_island_event_queue.py @@ -40,7 +40,7 @@ def test_subscribe_publish__no_event_body( topic=IslandEventTopic.CLEAR_SIMULATION_DATA, subscriber=event_queue_subscriber ) - event_queue.publish(topic=IslandEventTopic.AGENT_CONNECTED) + event_queue.publish(topic=IslandEventTopic.AGENT_REGISTERED) event_queue.publish(topic=IslandEventTopic.CLEAR_SIMULATION_DATA) event_queue.publish(topic=IslandEventTopic.RESET_AGENT_CONFIGURATION) @@ -64,9 +64,9 @@ def test_subscribe_publish__with_event_body( event = "my event!" my_callable = MyCallable() - event_queue.subscribe(topic=IslandEventTopic.AGENT_CONNECTED, subscriber=my_callable) + event_queue.subscribe(topic=IslandEventTopic.AGENT_REGISTERED, subscriber=my_callable) - event_queue.publish(topic=IslandEventTopic.AGENT_CONNECTED, event=event) + event_queue.publish(topic=IslandEventTopic.AGENT_REGISTERED, event=event) event_queue.publish(topic=IslandEventTopic.CLEAR_SIMULATION_DATA) event_queue.publish(topic=IslandEventTopic.RESET_AGENT_CONFIGURATION) @@ -84,10 +84,10 @@ def test_keep_subscriber_in_scope(event_queue: IIslandEventQueue): def subscribe(): # fn will go out of scope after subscribe() returns. fn = MyCallable() - event_queue.subscribe(topic=IslandEventTopic.AGENT_CONNECTED, subscriber=fn) + event_queue.subscribe(topic=IslandEventTopic.AGENT_REGISTERED, subscriber=fn) subscribe() - event_queue.publish(topic=IslandEventTopic.AGENT_CONNECTED) + event_queue.publish(topic=IslandEventTopic.AGENT_REGISTERED) assert MyCallable.called diff --git a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py new file mode 100644 index 000000000..c6070f1ed --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py @@ -0,0 +1,233 @@ +from ipaddress import IPv4Address, IPv4Interface +from itertools import count +from typing import Sequence +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +from common import AgentRegistrationData +from monkey_island.cc.island_event_handlers import handle_agent_registration +from monkey_island.cc.models import Agent, CommunicationType, Machine +from monkey_island.cc.repository import ( + IAgentRepository, + IMachineRepository, + INodeRepository, + UnknownRecordError, +) + +AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e") + +SEED_ID = 10 + +MACHINE = Machine( + id=2, + hardware_id=5, + network_interfaces=[IPv4Interface("192.168.2.2/24")], +) + +AGENT_REGISTRATION_DATA = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=MACHINE.hardware_id, + start_time=0, + parent_id=None, + cc_server="192.168.1.1:5000", + network_interfaces=[IPv4Interface("192.168.1.2/24")], +) + + +@pytest.fixture +def machine_repository() -> IMachineRepository: + machine_repository = MagicMock(spec=IMachineRepository) + machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID)) + machine_repository.upsert_machine = MagicMock() + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) + return machine_repository + + +@pytest.fixture +def agent_repository() -> IAgentRepository: + agent_repository = MagicMock(spec=IAgentRepository) + agent_repository.upsert_agent = MagicMock() + return agent_repository + + +@pytest.fixture +def node_repository() -> INodeRepository: + node_repository = MagicMock(spec=INodeRepository) + node_repository.upsert_communication = MagicMock() + return node_repository + + +@pytest.fixture +def handler(machine_repository, agent_repository, node_repository) -> handle_agent_registration: + return handle_agent_registration(machine_repository, agent_repository, node_repository) + + +def build_get_machines_by_ip(ip_to_match: IPv4Address, machine_to_return: Machine): + def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]: + if ip == ip_to_match: + return [machine_to_return] + + raise UnknownRecordError + + return get_machines_by_ip + + +def test_new_machine_added(handler, machine_repository): + expected_machine = Machine( + id=SEED_ID, + hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id, + network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces, + ) + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) + + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_any_call(expected_machine) + + +def test_existing_machine_updated__hardware_id(handler, machine_repository): + expected_updated_machine = Machine( + id=MACHINE.id, + hardware_id=MACHINE.hardware_id, + network_interfaces=[ + AGENT_REGISTRATION_DATA.network_interfaces[0], + MACHINE.network_interfaces[0], + ], + ) + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_any_call(expected_updated_machine) + + +def test_existing_machine_updated__find_by_ip(handler, machine_repository): + agent_registration_data = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=5, + start_time=0, + parent_id=None, + cc_server="192.168.1.1:5000", + network_interfaces=[ + IPv4Interface("192.168.1.2/24"), + IPv4Interface("192.168.1.4/24"), + IPv4Interface("192.168.1.5/24"), + ], + ) + + existing_machine = Machine( + id=1, + network_interfaces=[agent_registration_data.network_interfaces[-1]], + ) + + get_machines_by_ip = build_get_machines_by_ip( + existing_machine.network_interfaces[0].ip, existing_machine + ) + + expected_updated_machine = existing_machine.copy() + expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id + expected_updated_machine.network_interfaces = agent_registration_data.network_interfaces + + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip) + + handler(agent_registration_data) + + machine_repository.upsert_machine.assert_any_call(expected_updated_machine) + + +def test_hardware_id_mismatch(handler, machine_repository): + existing_machine = Machine( + id=1, + hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id + 99, + network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces, + ) + + machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) + machine_repository.get_machines_by_ip = MagicMock(return_value=[existing_machine]) + + with pytest.raises(Exception): + handler(AGENT_REGISTRATION_DATA) + + +def test_add_agent(handler, agent_repository): + expected_agent = Agent( + id=AGENT_REGISTRATION_DATA.id, + machine_id=SEED_ID, + start_time=AGENT_REGISTRATION_DATA.start_time, + parent_id=AGENT_REGISTRATION_DATA.parent_id, + cc_server=AGENT_REGISTRATION_DATA.cc_server, + ) + handler(AGENT_REGISTRATION_DATA) + + agent_repository.upsert_agent.assert_called_with(expected_agent) + + +def test_add_node_connection(handler, machine_repository, node_repository): + island_machine = Machine( + id=1, + hardware_id=99, + island=True, + network_interfaces=[IPv4Interface("192.168.1.1/24")], + ) + get_machines_by_ip = build_get_machines_by_ip( + island_machine.network_interfaces[0].ip, island_machine + ) + machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip) + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + + handler(AGENT_REGISTRATION_DATA) + + node_repository.upsert_communication.assert_called_once() + node_repository.upsert_communication.assert_called_with( + MACHINE.id, island_machine.id, CommunicationType.CC + ) + + +def test_add_node_connection__unknown_server(handler, machine_repository, node_repository): + expected_new_server_machine = Machine( + id=SEED_ID, + network_interfaces=[IPv4Interface("192.168.1.1/32")], + ) + + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) + handler(AGENT_REGISTRATION_DATA) + + machine_repository.upsert_machine.assert_called_with(expected_new_server_machine) + node_repository.upsert_communication.assert_called_with( + MACHINE.id, SEED_ID, CommunicationType.CC + ) + + +def test_machine_interfaces_updated(handler, machine_repository): + existing_machine = Machine( + id=SEED_ID, + hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id, + network_interfaces=[IPv4Interface("192.168.1.2/32"), IPv4Interface("192.168.1.5/32")], + ) + machine_repository.get_machine_by_hardware_id = MagicMock(return_value=existing_machine) + agent_registration_data = AgentRegistrationData( + id=AGENT_ID, + machine_hardware_id=MACHINE.hardware_id, + start_time=0, + parent_id=None, + cc_server="192.168.1.1:5000", + network_interfaces=[ + IPv4Interface("192.168.1.2/24"), + IPv4Interface("192.168.1.3/16"), + IPv4Interface("192.168.1.4/24"), + ], + ) + expected_network_interfaces = sorted( + (*agent_registration_data.network_interfaces, existing_machine.network_interfaces[-1]) + ) + + handler(agent_registration_data) + updated_machine = machine_repository.upsert_machine.call_args_list[0][0][0] + actual_network_interfaces = sorted(updated_machine.network_interfaces) + + assert actual_network_interfaces == expected_network_interfaces diff --git a/monkey/tests/unit_tests/monkey_island/cc/resources/test_agents.py b/monkey/tests/unit_tests/monkey_island/cc/resources/test_agents.py index 3ab90feb6..e44c36c09 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/resources/test_agents.py +++ b/monkey/tests/unit_tests/monkey_island/cc/resources/test_agents.py @@ -1,8 +1,12 @@ from http import HTTPStatus +from unittest.mock import MagicMock from uuid import UUID +import pytest +from tests.common import StubDIContainer from tests.unit_tests.monkey_island.conftest import get_url_for_resource +from monkey_island.cc.event_queue import IIslandEventQueue from monkey_island.cc.resources import Agents AGENTS_URL = get_url_for_resource(Agents) @@ -17,8 +21,16 @@ AGENT_REGISTRATION_DICT = { } +@pytest.fixture +def flask_client(build_flask_client): + container = StubDIContainer() + container.register_instance(IIslandEventQueue, MagicMock(spec=IIslandEventQueue)) + + with build_flask_client(container) as flask_client: + yield flask_client + + def test_agent_registration(flask_client): - print(AGENTS_URL) resp = flask_client.post( AGENTS_URL, json=AGENT_REGISTRATION_DICT, diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 40a2f5cdb..156db5df5 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -318,7 +318,3 @@ SCANNED EXPLOITED CC CC_TUNNEL - -IslandEventTopic.AGENT_CONNECTED -IslandEventTopic.CLEAR_SIMULATION_DATA -IslandEventTopic.RESET_AGENT_CONFIGURATION