Island: Add new node communication on agent registration

This commit is contained in:
Mike Salvatore 2022-09-21 12:04:37 -04:00
parent 1e8a60c890
commit c0870e6696
2 changed files with 109 additions and 20 deletions

View File

@ -1,9 +1,16 @@
from contextlib import suppress
from ipaddress import IPv4Address, IPv4Interface
from typing import Optional
from common import AgentRegistrationData
from monkey_island.cc.models import Agent, Machine
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
from common.network.network_utils import address_to_ip_port
from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import (
IAgentRepository,
IMachineRepository,
INodeRepository,
UnknownRecordError,
)
class handle_agent_registration:
@ -11,13 +18,20 @@ class handle_agent_registration:
Update repositories when a new agent registers
"""
def __init__(self, machine_repository: IMachineRepository, agent_repository: IAgentRepository):
def __init__(
self,
machine_repository: IMachineRepository,
agent_repository: IAgentRepository,
node_repository: INodeRepository,
):
self._machine_repository = machine_repository
self._agent_repository = agent_repository
self._node_repository = node_repository
def __call__(self, agent_registration_data: AgentRegistrationData):
machine = self._update_machine_repository(agent_registration_data)
self._add_agent(agent_registration_data, machine)
self._add_node_communication(agent_registration_data, machine)
def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine:
machine = self._find_existing_machine_to_update(agent_registration_data)
@ -86,3 +100,25 @@ class handle_agent_registration:
cc_server=agent_registration_data.cc_server,
)
self._agent_repository.upsert_agent(new_agent)
def _add_node_communication(
self, agent_registration_data: AgentRegistrationData, src_machine: Machine
):
dst_machine = self._get_or_create_cc_machine(agent_registration_data.cc_server)
self._node_repository.upsert_communication(
src_machine.id, dst_machine.id, CommunicationType.CC
)
def _get_or_create_cc_machine(self, cc_server: str) -> Machine:
dst_ip = IPv4Address(address_to_ip_port(cc_server)[0])
try:
return self._machine_repository.get_machines_by_ip(dst_ip)[0]
except UnknownRecordError:
new_machine = Machine(
id=self._machine_repository.get_new_id(), network_interfaces=[IPv4Interface(dst_ip)]
)
self._machine_repository.upsert_machine(new_machine)
return new_machine

View File

@ -8,8 +8,13 @@ import pytest
from common import AgentRegistrationData
from monkey_island.cc.island_event_handlers import handle_agent_registration
from monkey_island.cc.models import Agent, Machine
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import (
IAgentRepository,
IMachineRepository,
INodeRepository,
UnknownRecordError,
)
AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
@ -49,8 +54,25 @@ def agent_repository() -> IAgentRepository:
@pytest.fixture
def handler(machine_repository, agent_repository) -> handle_agent_registration:
return handle_agent_registration(machine_repository, agent_repository)
def node_repository() -> INodeRepository:
node_repository = MagicMock(spec=INodeRepository)
node_repository.upsert_communication = MagicMock()
return node_repository
@pytest.fixture
def handler(machine_repository, agent_repository, node_repository) -> handle_agent_registration:
return handle_agent_registration(machine_repository, agent_repository, node_repository)
def build_get_machines_by_ip(ip_to_match: IPv4Address, machine_to_return: Machine):
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
if ip == ip_to_match:
return [machine_to_return]
raise UnknownRecordError
return get_machines_by_ip
def test_new_machine_added(handler, machine_repository):
@ -61,12 +83,10 @@ def test_new_machine_added(handler, machine_repository):
)
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once()
new_machine = machine_repository.upsert_machine.call_args_list[0][0][0]
assert new_machine == expected_machine
machine_repository.upsert_machine.assert_any_call(expected_machine)
def test_existing_machine_updated__hardware_id(handler, machine_repository):
@ -79,10 +99,10 @@ def test_existing_machine_updated__hardware_id(handler, machine_repository):
],
)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once()
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
machine_repository.upsert_machine.assert_any_call(expected_updated_machine)
def test_existing_machine_updated__find_by_ip(handler, machine_repository):
@ -104,11 +124,9 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
network_interfaces=[agent_registration_data.network_interfaces[-1]],
)
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
if ip == existing_machine.network_interfaces[0].ip:
return [existing_machine]
raise UnknownRecordError
get_machines_by_ip = build_get_machines_by_ip(
existing_machine.network_interfaces[0].ip, existing_machine
)
expected_updated_machine = existing_machine.copy()
expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id
@ -119,8 +137,7 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
handler(agent_registration_data)
machine_repository.upsert_machine.assert_called_once()
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
machine_repository.upsert_machine.assert_any_call(expected_updated_machine)
def test_hardware_id_mismatch(handler, machine_repository):
@ -148,3 +165,39 @@ def test_add_agent(handler, agent_repository):
handler(AGENT_REGISTRATION_DATA)
agent_repository.upsert_agent.assert_called_with(expected_agent)
def test_add_node_connection(handler, machine_repository, node_repository):
island_machine = Machine(
id=1,
hardware_id=99,
island=True,
network_interfaces=[IPv4Interface("192.168.1.1/24")],
)
get_machines_by_ip = build_get_machines_by_ip(
island_machine.network_interfaces[0].ip, island_machine
)
machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
node_repository.upsert_communication.assert_called_once()
node_repository.upsert_communication.assert_called_with(
MACHINE.id, island_machine.id, CommunicationType.CC
)
def test_add_node_connection__unknown_server(handler, machine_repository, node_repository):
expected_new_server_machine = Machine(
id=SEED_ID,
network_interfaces=[IPv4Interface("192.168.1.1/32")],
)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_with(expected_new_server_machine)
node_repository.upsert_communication.assert_called_with(
MACHINE.id, SEED_ID, CommunicationType.CC
)