Island: Add handle_agent_registration event handler

- Add handle_agent_registration callable class
- Add/Update machine to the repository
This commit is contained in:
Mike Salvatore 2022-09-21 09:58:12 -04:00
parent 41dbb92eef
commit c95c2ffdf9
3 changed files with 203 additions and 0 deletions

View File

@ -1,3 +1,4 @@
from .handle_agent_registration import handle_agent_registration
from .reset_agent_configuration import reset_agent_configuration
from .reset_machine_repository import reset_machine_repository
from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode

View File

@ -0,0 +1,74 @@
from contextlib import suppress
from typing import Optional
from common import AgentRegistrationData
from monkey_island.cc.models import Machine
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
class handle_agent_registration:
"""
Update repositories when a new agent registers
"""
def __init__(self, machine_repository: IMachineRepository):
self._machine_repository = machine_repository
def __call__(self, agent_registration_data: AgentRegistrationData):
self._update_machine_repository(agent_registration_data)
def _update_machine_repository(self, agent_registration_data: AgentRegistrationData):
machine = self._find_existing_machine_to_update(agent_registration_data)
if machine is None:
machine = Machine(id=self._machine_repository.get_new_id())
self._upsert_machine(machine, agent_registration_data)
def _find_existing_machine_to_update(
self, agent_registration_data: AgentRegistrationData
) -> Optional[Machine]:
with suppress(UnknownRecordError):
return self._machine_repository.get_machine_by_hardware_id(
agent_registration_data.machine_hardware_id
)
for network_interface in agent_registration_data.network_interfaces:
with suppress(UnknownRecordError):
# NOTE: For now, assume IPs are unique. In reality, two machines could share the
# same IP if there's a router between them.
return self._machine_repository.get_machines_by_ip(network_interface.ip)[0]
return None
def _upsert_machine(
self, existing_machine: Machine, agent_registration_data: AgentRegistrationData
):
updated_machine = existing_machine.copy()
self._update_hardware_id(updated_machine, agent_registration_data)
self._update_network_interfaces(updated_machine, agent_registration_data)
self._machine_repository.upsert_machine(updated_machine)
def _update_hardware_id(self, machine: Machine, agent_registration_data: AgentRegistrationData):
if (
machine.hardware_id is not None
and machine.hardware_id != agent_registration_data.machine_hardware_id
):
raise Exception(
f"Hardware ID mismatch:\n\tMachine: {machine}\n\t"
f"AgentRegistrationData: {agent_registration_data}"
)
machine.hardware_id = agent_registration_data.machine_hardware_id
def _update_network_interfaces(
self, machine: Machine, agent_registration_data: AgentRegistrationData
):
updated_network_interfaces = set(machine.network_interfaces)
updated_network_interfaces = updated_network_interfaces.union(
agent_registration_data.network_interfaces
)
machine.network_interfaces = sorted(updated_network_interfaces)

View File

@ -0,0 +1,128 @@
from ipaddress import IPv4Address, IPv4Interface
from itertools import count
from typing import Sequence
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from common import AgentRegistrationData
from monkey_island.cc.island_event_handlers import handle_agent_registration
from monkey_island.cc.models import Machine
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
SEED_ID = 10
MACHINE = Machine(
id=2,
hardware_id=5,
network_interfaces=[IPv4Interface("192.168.2.2/24")],
)
AGENT_REGISTRATION_DATA = AgentRegistrationData(
id=AGENT_ID,
machine_hardware_id=MACHINE.hardware_id,
start_time=0,
parent_id=None,
cc_server="192.168.1.1:5000",
network_interfaces=[IPv4Interface("192.168.1.2/24")],
)
@pytest.fixture
def machine_repository() -> IMachineRepository:
machine_repository = MagicMock(spec=IMachineRepository)
machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID))
machine_repository.upsert_machine = MagicMock()
return machine_repository
@pytest.fixture
def handler(machine_repository) -> handle_agent_registration:
return handle_agent_registration(machine_repository)
def test_new_machine_added(handler, machine_repository):
expected_machine = Machine(
id=SEED_ID,
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id,
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
)
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once()
new_machine = machine_repository.upsert_machine.call_args_list[0][0][0]
assert new_machine == expected_machine
def test_existing_machine_updated__hardware_id(handler, machine_repository):
expected_updated_machine = Machine(
id=MACHINE.id,
hardware_id=MACHINE.hardware_id,
network_interfaces=[
AGENT_REGISTRATION_DATA.network_interfaces[0],
MACHINE.network_interfaces[0],
],
)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once()
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
def test_existing_machine_updated__find_by_ip(handler, machine_repository):
agent_registration_data = AgentRegistrationData(
id=AGENT_ID,
machine_hardware_id=5,
start_time=0,
parent_id=None,
cc_server="192.168.1.1:5000",
network_interfaces=[
IPv4Interface("192.168.1.2/24"),
IPv4Interface("192.168.1.4/24"),
IPv4Interface("192.168.1.5/24"),
],
)
existing_machine = Machine(
id=1,
network_interfaces=[agent_registration_data.network_interfaces[-1]],
)
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
if ip == existing_machine.network_interfaces[0].ip:
return [existing_machine]
raise UnknownRecordError
expected_updated_machine = existing_machine.copy()
expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id
expected_updated_machine.network_interfaces = agent_registration_data.network_interfaces
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip)
handler(agent_registration_data)
machine_repository.upsert_machine.assert_called_once()
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
def test_hardware_id_mismatch(handler, machine_repository):
existing_machine = Machine(
id=1,
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id + 99,
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
)
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
machine_repository.get_machines_by_ip = MagicMock(return_value=[existing_machine])
with pytest.raises(Exception):
handler(AGENT_REGISTRATION_DATA)