forked from p15670423/monkey
Island: Add handle_agent_registration event handler
- Add handle_agent_registration callable class - Add/Update machine to the repository
This commit is contained in:
parent
41dbb92eef
commit
c95c2ffdf9
|
@ -1,3 +1,4 @@
|
||||||
|
from .handle_agent_registration import handle_agent_registration
|
||||||
from .reset_agent_configuration import reset_agent_configuration
|
from .reset_agent_configuration import reset_agent_configuration
|
||||||
from .reset_machine_repository import reset_machine_repository
|
from .reset_machine_repository import reset_machine_repository
|
||||||
from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode
|
from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from common import AgentRegistrationData
|
||||||
|
from monkey_island.cc.models import Machine
|
||||||
|
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
|
||||||
|
|
||||||
|
|
||||||
|
class handle_agent_registration:
|
||||||
|
"""
|
||||||
|
Update repositories when a new agent registers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, machine_repository: IMachineRepository):
|
||||||
|
self._machine_repository = machine_repository
|
||||||
|
|
||||||
|
def __call__(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
self._update_machine_repository(agent_registration_data)
|
||||||
|
|
||||||
|
def _update_machine_repository(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
machine = self._find_existing_machine_to_update(agent_registration_data)
|
||||||
|
|
||||||
|
if machine is None:
|
||||||
|
machine = Machine(id=self._machine_repository.get_new_id())
|
||||||
|
|
||||||
|
self._upsert_machine(machine, agent_registration_data)
|
||||||
|
|
||||||
|
def _find_existing_machine_to_update(
|
||||||
|
self, agent_registration_data: AgentRegistrationData
|
||||||
|
) -> Optional[Machine]:
|
||||||
|
with suppress(UnknownRecordError):
|
||||||
|
return self._machine_repository.get_machine_by_hardware_id(
|
||||||
|
agent_registration_data.machine_hardware_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for network_interface in agent_registration_data.network_interfaces:
|
||||||
|
with suppress(UnknownRecordError):
|
||||||
|
# NOTE: For now, assume IPs are unique. In reality, two machines could share the
|
||||||
|
# same IP if there's a router between them.
|
||||||
|
return self._machine_repository.get_machines_by_ip(network_interface.ip)[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _upsert_machine(
|
||||||
|
self, existing_machine: Machine, agent_registration_data: AgentRegistrationData
|
||||||
|
):
|
||||||
|
updated_machine = existing_machine.copy()
|
||||||
|
|
||||||
|
self._update_hardware_id(updated_machine, agent_registration_data)
|
||||||
|
self._update_network_interfaces(updated_machine, agent_registration_data)
|
||||||
|
|
||||||
|
self._machine_repository.upsert_machine(updated_machine)
|
||||||
|
|
||||||
|
def _update_hardware_id(self, machine: Machine, agent_registration_data: AgentRegistrationData):
|
||||||
|
if (
|
||||||
|
machine.hardware_id is not None
|
||||||
|
and machine.hardware_id != agent_registration_data.machine_hardware_id
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Hardware ID mismatch:\n\tMachine: {machine}\n\t"
|
||||||
|
f"AgentRegistrationData: {agent_registration_data}"
|
||||||
|
)
|
||||||
|
|
||||||
|
machine.hardware_id = agent_registration_data.machine_hardware_id
|
||||||
|
|
||||||
|
def _update_network_interfaces(
|
||||||
|
self, machine: Machine, agent_registration_data: AgentRegistrationData
|
||||||
|
):
|
||||||
|
updated_network_interfaces = set(machine.network_interfaces)
|
||||||
|
updated_network_interfaces = updated_network_interfaces.union(
|
||||||
|
agent_registration_data.network_interfaces
|
||||||
|
)
|
||||||
|
|
||||||
|
machine.network_interfaces = sorted(updated_network_interfaces)
|
|
@ -0,0 +1,128 @@
|
||||||
|
from ipaddress import IPv4Address, IPv4Interface
|
||||||
|
from itertools import count
|
||||||
|
from typing import Sequence
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common import AgentRegistrationData
|
||||||
|
from monkey_island.cc.island_event_handlers import handle_agent_registration
|
||||||
|
from monkey_island.cc.models import Machine
|
||||||
|
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
|
||||||
|
|
||||||
|
AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
|
||||||
|
|
||||||
|
SEED_ID = 10
|
||||||
|
|
||||||
|
MACHINE = Machine(
|
||||||
|
id=2,
|
||||||
|
hardware_id=5,
|
||||||
|
network_interfaces=[IPv4Interface("192.168.2.2/24")],
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_REGISTRATION_DATA = AgentRegistrationData(
|
||||||
|
id=AGENT_ID,
|
||||||
|
machine_hardware_id=MACHINE.hardware_id,
|
||||||
|
start_time=0,
|
||||||
|
parent_id=None,
|
||||||
|
cc_server="192.168.1.1:5000",
|
||||||
|
network_interfaces=[IPv4Interface("192.168.1.2/24")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def machine_repository() -> IMachineRepository:
|
||||||
|
machine_repository = MagicMock(spec=IMachineRepository)
|
||||||
|
machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID))
|
||||||
|
machine_repository.upsert_machine = MagicMock()
|
||||||
|
return machine_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler(machine_repository) -> handle_agent_registration:
|
||||||
|
return handle_agent_registration(machine_repository)
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_machine_added(handler, machine_repository):
|
||||||
|
expected_machine = Machine(
|
||||||
|
id=SEED_ID,
|
||||||
|
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id,
|
||||||
|
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
|
||||||
|
)
|
||||||
|
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
handler(AGENT_REGISTRATION_DATA)
|
||||||
|
|
||||||
|
machine_repository.upsert_machine.assert_called_once()
|
||||||
|
new_machine = machine_repository.upsert_machine.call_args_list[0][0][0]
|
||||||
|
|
||||||
|
assert new_machine == expected_machine
|
||||||
|
|
||||||
|
|
||||||
|
def test_existing_machine_updated__hardware_id(handler, machine_repository):
|
||||||
|
expected_updated_machine = Machine(
|
||||||
|
id=MACHINE.id,
|
||||||
|
hardware_id=MACHINE.hardware_id,
|
||||||
|
network_interfaces=[
|
||||||
|
AGENT_REGISTRATION_DATA.network_interfaces[0],
|
||||||
|
MACHINE.network_interfaces[0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
|
||||||
|
handler(AGENT_REGISTRATION_DATA)
|
||||||
|
|
||||||
|
machine_repository.upsert_machine.assert_called_once()
|
||||||
|
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_existing_machine_updated__find_by_ip(handler, machine_repository):
|
||||||
|
agent_registration_data = AgentRegistrationData(
|
||||||
|
id=AGENT_ID,
|
||||||
|
machine_hardware_id=5,
|
||||||
|
start_time=0,
|
||||||
|
parent_id=None,
|
||||||
|
cc_server="192.168.1.1:5000",
|
||||||
|
network_interfaces=[
|
||||||
|
IPv4Interface("192.168.1.2/24"),
|
||||||
|
IPv4Interface("192.168.1.4/24"),
|
||||||
|
IPv4Interface("192.168.1.5/24"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_machine = Machine(
|
||||||
|
id=1,
|
||||||
|
network_interfaces=[agent_registration_data.network_interfaces[-1]],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
|
||||||
|
if ip == existing_machine.network_interfaces[0].ip:
|
||||||
|
return [existing_machine]
|
||||||
|
|
||||||
|
raise UnknownRecordError
|
||||||
|
|
||||||
|
expected_updated_machine = existing_machine.copy()
|
||||||
|
expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id
|
||||||
|
expected_updated_machine.network_interfaces = agent_registration_data.network_interfaces
|
||||||
|
|
||||||
|
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip)
|
||||||
|
|
||||||
|
handler(agent_registration_data)
|
||||||
|
|
||||||
|
machine_repository.upsert_machine.assert_called_once()
|
||||||
|
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hardware_id_mismatch(handler, machine_repository):
|
||||||
|
existing_machine = Machine(
|
||||||
|
id=1,
|
||||||
|
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id + 99,
|
||||||
|
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
|
||||||
|
)
|
||||||
|
|
||||||
|
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(return_value=[existing_machine])
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
handler(AGENT_REGISTRATION_DATA)
|
Loading…
Reference in New Issue