From 3d27e42ff377131a0d03329ea5b912dfe185d950 Mon Sep 17 00:00:00 2001 From: vakarisz Date: Fri, 30 Sep 2022 11:46:05 +0300 Subject: [PATCH] Island: Simplify the report of scanned machines --- .../cc/services/reporting/report.py | 71 +++++++------------ .../cc/services/reporting/test_report.py | 55 ++++++++++++-- 2 files changed, 76 insertions(+), 50 deletions(-) diff --git a/monkey/monkey_island/cc/services/reporting/report.py b/monkey/monkey_island/cc/services/reporting/report.py index aa70b3895..20991d532 100644 --- a/monkey/monkey_island/cc/services/reporting/report.py +++ b/monkey/monkey_island/cc/services/reporting/report.py @@ -1,8 +1,8 @@ import functools import ipaddress import logging -from itertools import chain, filterfalse, product, tee -from typing import Iterable, List, Optional +from itertools import chain, product +from typing import List, Optional from common.network.network_range import NetworkRange from common.network.network_utils import get_my_ip_addresses_legacy, get_network_interfaces @@ -35,7 +35,6 @@ logger = logging.getLogger(__name__) class ReportService: - _aws_service: Optional[AWSService] = None _agent_configuration_repository: Optional[IAgentConfigurationRepository] = None _credentials_repository: Optional[ICredentialsRepository] = None @@ -113,58 +112,42 @@ class ReportService: def get_scanned(): formatted_nodes = [] - machines = ReportService.get_all_machines() + machines = ReportService._machine_repository.get_machines() for machine in machines: - # This information should be evident from the map, not sure a table/list is a good way - # to display it anyways - addresses = [str(iface.ip) for iface in machine.network_interfaces] - accessible_machines = [ - m.hostname for m in ReportService.get_accessible_machines(machine) - ] - formatted_nodes.append( - { - "label": machine.hostname, - "ip_addresses": addresses, - "accessible_from_nodes": accessible_machines, - "services": [], - "domain_name": "", - "pba_results": "None", - } - ) - - logger.info("Scanned nodes generated for reporting") + accessible_from = ReportService.get_scanners_of_machine(machine) + if accessible_from: + formatted_nodes.append( + { + "hostname": machine.hostname, + "ip_addresses": machine.network_interfaces, + "accessible_from_nodes": accessible_from, + "domain_name": "", + # TODO add services + "services": [], + } + ) return formatted_nodes @classmethod - def get_accessible_machines(cls, machine: Machine): - if cls._node_repository is None: + def get_scanners_of_machine(cls, machine: Machine) -> List[Machine]: + if not cls._node_repository: raise RuntimeError("Node repository does not exist") - elif cls._machine_repository is None: + if not cls._machine_repository: raise RuntimeError("Machine repository does not exist") nodes = cls._node_repository.get_nodes() - machine_iter = (node for node in nodes if node.machine_id == machine.id) - accessible_machines = set() - for source in machine_iter: - for dest, conn in source.connections.items(): - if CommunicationType.SCANNED in conn: - accessible_machines.add(dest) + scanner_machines = set() + for node in nodes: + for dest, conn in node.connections.items(): + if CommunicationType.SCANNED in conn and dest == machine.id: + scanner_machine = ReportService._machine_repository.get_machine_by_id( + node.machine_id + ) + scanner_machines.add(scanner_machine) - return [cls._machine_repository.get_machine_by_id(id) for id in accessible_machines] - - @classmethod - def get_all_machines(cls) -> Iterable[Machine]: - if cls._machine_repository is None: - raise RuntimeError("Machine repository does not exist") - machines = cls._machine_repository.get_machines() - t1, t2 = tee(machines) - - def is_island(machine: Machine): - return machine.island - - return chain(filter(is_island, t1), filterfalse(is_island, t2)) + return list(scanner_machines) @staticmethod def process_exploit(exploit) -> ExploiterReportInfo: diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_report.py b/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_report.py index 1afad4d95..c445bbed7 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_report.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_report.py @@ -1,34 +1,77 @@ from ipaddress import IPv4Interface +from unittest.mock import MagicMock from monkey_island.cc.models import CommunicationType, Machine, Node +from monkey_island.cc.services.reporting.report import ReportService ISLAND_MACHINE = Machine( id=99, island=True, + hostname="Island", hardware_id=5, network_interfaces=[IPv4Interface("10.10.10.99/24")], ) -MACHINE_A = Machine( +MACHINE_1 = Machine( id=1, hardware_id=9, + hostname="machine_1", network_interfaces=[IPv4Interface("10.10.10.1/24")], ) -MACHINE_B = Machine( +MACHINE_2 = Machine( id=2, hardware_id=9, network_interfaces=[IPv4Interface("10.10.10.2/24")], ) -MACHINE_C = Machine( +MACHINE_3 = Machine( id=3, hardware_id=9, network_interfaces=[IPv4Interface("10.10.10.3/24")], ) NODES = [ - Node(machine_id=1, connections={"2": {CommunicationType.EXPLOITED}}), - Node(machine_id=99, connections={"1": {CommunicationType.SCANNED}}), - Node(machine_id=3, connections={"99": {CommunicationType.CC}}), + Node( + machine_id=1, + connections={"2": frozenset([CommunicationType.EXPLOITED, CommunicationType.SCANNED])}, + ), + Node(machine_id=99, connections={"1": frozenset([CommunicationType.SCANNED])}), + Node( + machine_id=3, + connections={"99": frozenset([CommunicationType.CC, CommunicationType.EXPLOITED])}, + ), ] + +MACHINES = [MACHINE_1, MACHINE_2, MACHINE_3, ISLAND_MACHINE] + +EXPECTED_SCANNED_MACHINES = [ + { + "hostname": MACHINE_1.hostname, + "ip_addresses": MACHINE_1.network_interfaces, + "accessible_from_nodes": [ISLAND_MACHINE], + "services": [], + "domain_name": "", + }, + { + "hostname": MACHINE_2.hostname, + "ip_addresses": MACHINE_2.network_interfaces, + "accessible_from_nodes": [MACHINE_1], + "services": [], + "domain_name": "", + }, +] + + +def get_machine_by_id(machine_id): + return [machine for machine in MACHINES if machine_id == machine.id][0] + + +def test_get_scanned(): + ReportService._node_repository = MagicMock() + ReportService._node_repository.get_nodes.return_value = NODES + ReportService._machine_repository = MagicMock() + ReportService._machine_repository.get_machines.return_value = MACHINES + ReportService._machine_repository.get_machine_by_id = get_machine_by_id + scanned = ReportService.get_scanned() + assert scanned == EXPECTED_SCANNED_MACHINES