Island: Add handle_agent_registration event handler
- Add handle_agent_registration callable class - Add/Update machine to the repository
This commit is contained in:
parent
41dbb92eef
commit
c95c2ffdf9
|
@ -1,3 +1,4 @@
|
|||
from .handle_agent_registration import handle_agent_registration
|
||||
from .reset_agent_configuration import reset_agent_configuration
|
||||
from .reset_machine_repository import reset_machine_repository
|
||||
from .set_agent_configuration_per_island_mode import set_agent_configuration_per_island_mode
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
from contextlib import suppress
|
||||
from typing import Optional
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from monkey_island.cc.models import Machine
|
||||
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
|
||||
|
||||
|
||||
class handle_agent_registration:
|
||||
"""
|
||||
Update repositories when a new agent registers
|
||||
"""
|
||||
|
||||
def __init__(self, machine_repository: IMachineRepository):
|
||||
self._machine_repository = machine_repository
|
||||
|
||||
def __call__(self, agent_registration_data: AgentRegistrationData):
|
||||
self._update_machine_repository(agent_registration_data)
|
||||
|
||||
def _update_machine_repository(self, agent_registration_data: AgentRegistrationData):
|
||||
machine = self._find_existing_machine_to_update(agent_registration_data)
|
||||
|
||||
if machine is None:
|
||||
machine = Machine(id=self._machine_repository.get_new_id())
|
||||
|
||||
self._upsert_machine(machine, agent_registration_data)
|
||||
|
||||
def _find_existing_machine_to_update(
|
||||
self, agent_registration_data: AgentRegistrationData
|
||||
) -> Optional[Machine]:
|
||||
with suppress(UnknownRecordError):
|
||||
return self._machine_repository.get_machine_by_hardware_id(
|
||||
agent_registration_data.machine_hardware_id
|
||||
)
|
||||
|
||||
for network_interface in agent_registration_data.network_interfaces:
|
||||
with suppress(UnknownRecordError):
|
||||
# NOTE: For now, assume IPs are unique. In reality, two machines could share the
|
||||
# same IP if there's a router between them.
|
||||
return self._machine_repository.get_machines_by_ip(network_interface.ip)[0]
|
||||
|
||||
return None
|
||||
|
||||
def _upsert_machine(
|
||||
self, existing_machine: Machine, agent_registration_data: AgentRegistrationData
|
||||
):
|
||||
updated_machine = existing_machine.copy()
|
||||
|
||||
self._update_hardware_id(updated_machine, agent_registration_data)
|
||||
self._update_network_interfaces(updated_machine, agent_registration_data)
|
||||
|
||||
self._machine_repository.upsert_machine(updated_machine)
|
||||
|
||||
def _update_hardware_id(self, machine: Machine, agent_registration_data: AgentRegistrationData):
|
||||
if (
|
||||
machine.hardware_id is not None
|
||||
and machine.hardware_id != agent_registration_data.machine_hardware_id
|
||||
):
|
||||
raise Exception(
|
||||
f"Hardware ID mismatch:\n\tMachine: {machine}\n\t"
|
||||
f"AgentRegistrationData: {agent_registration_data}"
|
||||
)
|
||||
|
||||
machine.hardware_id = agent_registration_data.machine_hardware_id
|
||||
|
||||
def _update_network_interfaces(
|
||||
self, machine: Machine, agent_registration_data: AgentRegistrationData
|
||||
):
|
||||
updated_network_interfaces = set(machine.network_interfaces)
|
||||
updated_network_interfaces = updated_network_interfaces.union(
|
||||
agent_registration_data.network_interfaces
|
||||
)
|
||||
|
||||
machine.network_interfaces = sorted(updated_network_interfaces)
|
|
@ -0,0 +1,128 @@
|
|||
from ipaddress import IPv4Address, IPv4Interface
|
||||
from itertools import count
|
||||
from typing import Sequence
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from monkey_island.cc.island_event_handlers import handle_agent_registration
|
||||
from monkey_island.cc.models import Machine
|
||||
from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
|
||||
|
||||
AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
|
||||
|
||||
SEED_ID = 10
|
||||
|
||||
MACHINE = Machine(
|
||||
id=2,
|
||||
hardware_id=5,
|
||||
network_interfaces=[IPv4Interface("192.168.2.2/24")],
|
||||
)
|
||||
|
||||
AGENT_REGISTRATION_DATA = AgentRegistrationData(
|
||||
id=AGENT_ID,
|
||||
machine_hardware_id=MACHINE.hardware_id,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server="192.168.1.1:5000",
|
||||
network_interfaces=[IPv4Interface("192.168.1.2/24")],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def machine_repository() -> IMachineRepository:
|
||||
machine_repository = MagicMock(spec=IMachineRepository)
|
||||
machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID))
|
||||
machine_repository.upsert_machine = MagicMock()
|
||||
return machine_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler(machine_repository) -> handle_agent_registration:
|
||||
return handle_agent_registration(machine_repository)
|
||||
|
||||
|
||||
def test_new_machine_added(handler, machine_repository):
|
||||
expected_machine = Machine(
|
||||
id=SEED_ID,
|
||||
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id,
|
||||
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
|
||||
)
|
||||
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
|
||||
handler(AGENT_REGISTRATION_DATA)
|
||||
|
||||
machine_repository.upsert_machine.assert_called_once()
|
||||
new_machine = machine_repository.upsert_machine.call_args_list[0][0][0]
|
||||
|
||||
assert new_machine == expected_machine
|
||||
|
||||
|
||||
def test_existing_machine_updated__hardware_id(handler, machine_repository):
|
||||
expected_updated_machine = Machine(
|
||||
id=MACHINE.id,
|
||||
hardware_id=MACHINE.hardware_id,
|
||||
network_interfaces=[
|
||||
AGENT_REGISTRATION_DATA.network_interfaces[0],
|
||||
MACHINE.network_interfaces[0],
|
||||
],
|
||||
)
|
||||
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
|
||||
handler(AGENT_REGISTRATION_DATA)
|
||||
|
||||
machine_repository.upsert_machine.assert_called_once()
|
||||
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
|
||||
|
||||
|
||||
def test_existing_machine_updated__find_by_ip(handler, machine_repository):
|
||||
agent_registration_data = AgentRegistrationData(
|
||||
id=AGENT_ID,
|
||||
machine_hardware_id=5,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server="192.168.1.1:5000",
|
||||
network_interfaces=[
|
||||
IPv4Interface("192.168.1.2/24"),
|
||||
IPv4Interface("192.168.1.4/24"),
|
||||
IPv4Interface("192.168.1.5/24"),
|
||||
],
|
||||
)
|
||||
|
||||
existing_machine = Machine(
|
||||
id=1,
|
||||
network_interfaces=[agent_registration_data.network_interfaces[-1]],
|
||||
)
|
||||
|
||||
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
|
||||
if ip == existing_machine.network_interfaces[0].ip:
|
||||
return [existing_machine]
|
||||
|
||||
raise UnknownRecordError
|
||||
|
||||
expected_updated_machine = existing_machine.copy()
|
||||
expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id
|
||||
expected_updated_machine.network_interfaces = agent_registration_data.network_interfaces
|
||||
|
||||
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||
machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip)
|
||||
|
||||
handler(agent_registration_data)
|
||||
|
||||
machine_repository.upsert_machine.assert_called_once()
|
||||
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
|
||||
|
||||
|
||||
def test_hardware_id_mismatch(handler, machine_repository):
|
||||
existing_machine = Machine(
|
||||
id=1,
|
||||
hardware_id=AGENT_REGISTRATION_DATA.machine_hardware_id + 99,
|
||||
network_interfaces=AGENT_REGISTRATION_DATA.network_interfaces,
|
||||
)
|
||||
|
||||
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
|
||||
machine_repository.get_machines_by_ip = MagicMock(return_value=[existing_machine])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
handler(AGENT_REGISTRATION_DATA)
|
Loading…
Reference in New Issue