Island: Add new node communication on agent registration

This commit is contained in:
Mike Salvatore 2022-09-21 12:04:37 -04:00
parent 1e8a60c890
commit c0870e6696
2 changed files with 109 additions and 20 deletions

View File

@ -1,9 +1,16 @@
from contextlib import suppress from contextlib import suppress
from ipaddress import IPv4Address, IPv4Interface
from typing import Optional from typing import Optional
from common import AgentRegistrationData from common import AgentRegistrationData
from monkey_island.cc.models import Agent, Machine from common.network.network_utils import address_to_ip_port
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import (
IAgentRepository,
IMachineRepository,
INodeRepository,
UnknownRecordError,
)
class handle_agent_registration: class handle_agent_registration:
@ -11,13 +18,20 @@ class handle_agent_registration:
Update repositories when a new agent registers Update repositories when a new agent registers
""" """
def __init__(self, machine_repository: IMachineRepository, agent_repository: IAgentRepository): def __init__(
self,
machine_repository: IMachineRepository,
agent_repository: IAgentRepository,
node_repository: INodeRepository,
):
self._machine_repository = machine_repository self._machine_repository = machine_repository
self._agent_repository = agent_repository self._agent_repository = agent_repository
self._node_repository = node_repository
def __call__(self, agent_registration_data: AgentRegistrationData): def __call__(self, agent_registration_data: AgentRegistrationData):
machine = self._update_machine_repository(agent_registration_data) machine = self._update_machine_repository(agent_registration_data)
self._add_agent(agent_registration_data, machine) self._add_agent(agent_registration_data, machine)
self._add_node_communication(agent_registration_data, machine)
def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine: def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine:
machine = self._find_existing_machine_to_update(agent_registration_data) machine = self._find_existing_machine_to_update(agent_registration_data)
@ -86,3 +100,25 @@ class handle_agent_registration:
cc_server=agent_registration_data.cc_server, cc_server=agent_registration_data.cc_server,
) )
self._agent_repository.upsert_agent(new_agent) self._agent_repository.upsert_agent(new_agent)
def _add_node_communication(
self, agent_registration_data: AgentRegistrationData, src_machine: Machine
):
dst_machine = self._get_or_create_cc_machine(agent_registration_data.cc_server)
self._node_repository.upsert_communication(
src_machine.id, dst_machine.id, CommunicationType.CC
)
def _get_or_create_cc_machine(self, cc_server: str) -> Machine:
dst_ip = IPv4Address(address_to_ip_port(cc_server)[0])
try:
return self._machine_repository.get_machines_by_ip(dst_ip)[0]
except UnknownRecordError:
new_machine = Machine(
id=self._machine_repository.get_new_id(), network_interfaces=[IPv4Interface(dst_ip)]
)
self._machine_repository.upsert_machine(new_machine)
return new_machine

View File

@ -8,8 +8,13 @@ import pytest
from common import AgentRegistrationData from common import AgentRegistrationData
from monkey_island.cc.island_event_handlers import handle_agent_registration from monkey_island.cc.island_event_handlers import handle_agent_registration
from monkey_island.cc.models import Agent, Machine from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError from monkey_island.cc.repository import (
IAgentRepository,
IMachineRepository,
INodeRepository,
UnknownRecordError,
)
AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e") AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
@ -49,8 +54,25 @@ def agent_repository() -> IAgentRepository:
@pytest.fixture @pytest.fixture
def handler(machine_repository, agent_repository) -> handle_agent_registration: def node_repository() -> INodeRepository:
return handle_agent_registration(machine_repository, agent_repository) node_repository = MagicMock(spec=INodeRepository)
node_repository.upsert_communication = MagicMock()
return node_repository
@pytest.fixture
def handler(machine_repository, agent_repository, node_repository) -> handle_agent_registration:
return handle_agent_registration(machine_repository, agent_repository, node_repository)
def build_get_machines_by_ip(ip_to_match: IPv4Address, machine_to_return: Machine):
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]:
if ip == ip_to_match:
return [machine_to_return]
raise UnknownRecordError
return get_machines_by_ip
def test_new_machine_added(handler, machine_repository): def test_new_machine_added(handler, machine_repository):
@ -61,12 +83,10 @@ def test_new_machine_added(handler, machine_repository):
) )
machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError) machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
handler(AGENT_REGISTRATION_DATA) handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once() machine_repository.upsert_machine.assert_any_call(expected_machine)
new_machine = machine_repository.upsert_machine.call_args_list[0][0][0]
assert new_machine == expected_machine
def test_existing_machine_updated__hardware_id(handler, machine_repository): def test_existing_machine_updated__hardware_id(handler, machine_repository):
@ -79,10 +99,10 @@ def test_existing_machine_updated__hardware_id(handler, machine_repository):
], ],
) )
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE) machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA) handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_once() machine_repository.upsert_machine.assert_any_call(expected_updated_machine)
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
def test_existing_machine_updated__find_by_ip(handler, machine_repository): def test_existing_machine_updated__find_by_ip(handler, machine_repository):
@ -104,11 +124,9 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
network_interfaces=[agent_registration_data.network_interfaces[-1]], network_interfaces=[agent_registration_data.network_interfaces[-1]],
) )
def get_machines_by_ip(ip: IPv4Address) -> Sequence[Machine]: get_machines_by_ip = build_get_machines_by_ip(
if ip == existing_machine.network_interfaces[0].ip: existing_machine.network_interfaces[0].ip, existing_machine
return [existing_machine] )
raise UnknownRecordError
expected_updated_machine = existing_machine.copy() expected_updated_machine = existing_machine.copy()
expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id expected_updated_machine.hardware_id = agent_registration_data.machine_hardware_id
@ -119,8 +137,7 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
handler(agent_registration_data) handler(agent_registration_data)
machine_repository.upsert_machine.assert_called_once() machine_repository.upsert_machine.assert_any_call(expected_updated_machine)
machine_repository.upsert_machine.assert_called_with(expected_updated_machine)
def test_hardware_id_mismatch(handler, machine_repository): def test_hardware_id_mismatch(handler, machine_repository):
@ -148,3 +165,39 @@ def test_add_agent(handler, agent_repository):
handler(AGENT_REGISTRATION_DATA) handler(AGENT_REGISTRATION_DATA)
agent_repository.upsert_agent.assert_called_with(expected_agent) agent_repository.upsert_agent.assert_called_with(expected_agent)
def test_add_node_connection(handler, machine_repository, node_repository):
island_machine = Machine(
id=1,
hardware_id=99,
island=True,
network_interfaces=[IPv4Interface("192.168.1.1/24")],
)
get_machines_by_ip = build_get_machines_by_ip(
island_machine.network_interfaces[0].ip, island_machine
)
machine_repository.get_machines_by_ip = MagicMock(side_effect=get_machines_by_ip)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
node_repository.upsert_communication.assert_called_once()
node_repository.upsert_communication.assert_called_with(
MACHINE.id, island_machine.id, CommunicationType.CC
)
def test_add_node_connection__unknown_server(handler, machine_repository, node_repository):
expected_new_server_machine = Machine(
id=SEED_ID,
network_interfaces=[IPv4Interface("192.168.1.1/32")],
)
machine_repository.get_machine_by_hardware_id = MagicMock(return_value=MACHINE)
handler(AGENT_REGISTRATION_DATA)
machine_repository.upsert_machine.assert_called_with(expected_new_server_machine)
node_repository.upsert_communication.assert_called_with(
MACHINE.id, SEED_ID, CommunicationType.CC
)