diff --git a/monkey/monkey_island/cc/agent_event_handlers/__init__.py b/monkey/monkey_island/cc/agent_event_handlers/__init__.py index b04d81961..7c0a4afc1 100644 --- a/monkey/monkey_island/cc/agent_event_handlers/__init__.py +++ b/monkey/monkey_island/cc/agent_event_handlers/__init__.py @@ -1,3 +1,4 @@ from .save_event_to_event_repository import save_event_to_event_repository from .save_stolen_credentials_to_repository import save_stolen_credentials_to_repository from .scan_event_handler import ScanEventHandler +from .update_nodes_on_exploitation import update_nodes_on_exploitation diff --git a/monkey/monkey_island/cc/agent_event_handlers/utils.py b/monkey/monkey_island/cc/agent_event_handlers/utils.py new file mode 100644 index 000000000..f7a32fa8a --- /dev/null +++ b/monkey/monkey_island/cc/agent_event_handlers/utils.py @@ -0,0 +1,17 @@ +from ipaddress import IPv4Address, IPv4Interface + +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import IMachineRepository, UnknownRecordError + + +def get_or_create_target_machine(repository: IMachineRepository, target: IPv4Address): + try: + target_machines = repository.get_machines_by_ip(target) + return target_machines[0] + except UnknownRecordError: + machine = Machine( + id=repository.get_new_id(), + network_interfaces=[IPv4Interface(target)], + ) + repository.upsert_machine(machine) + return machine diff --git a/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_utils.py b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_utils.py new file mode 100644 index 000000000..fa7a128b0 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/agent_event_handlers/test_utils.py @@ -0,0 +1,46 @@ +from ipaddress import IPv4Address, IPv4Interface +from unittest.mock import MagicMock + +import pytest + +from monkey_island.cc.agent_event_handlers.utils import get_or_create_target_machine +from monkey_island.cc.models import Machine +from monkey_island.cc.repository import IMachineRepository, UnknownRecordError + +SEED_ID = 99 +IP_ADDRESS = IPv4Address("10.10.10.99") + +EXISTING_MACHINE = Machine( + id=1, + hardware_id=5, + network_interfaces=[IPv4Interface(IP_ADDRESS)], +) + +EXPECTED_CREATED_MACHINE = Machine( + id=SEED_ID, + network_interfaces=[IPv4Interface(IP_ADDRESS)], +) + + +@pytest.fixture +def machine_repository() -> IMachineRepository: + machine_repository = MagicMock(spec=IMachineRepository) + machine_repository.get_new_id = MagicMock(return_value=SEED_ID) + return machine_repository + + +def test_return_existing_machine(machine_repository): + machine_repository.get_machines_by_ip = MagicMock(return_value=[EXISTING_MACHINE]) + + target_machine = get_or_create_target_machine(machine_repository, IP_ADDRESS) + + assert target_machine == EXISTING_MACHINE + + +def test_create_new_machine(machine_repository): + machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError) + + target_machine = get_or_create_target_machine(machine_repository, IP_ADDRESS) + + assert target_machine == EXPECTED_CREATED_MACHINE + assert machine_repository.upsert_machine.called_once_with(target_machine)