diff --git a/monkey/monkey_island/cc/services/reporting/report.py b/monkey/monkey_island/cc/services/reporting/report.py index 729967aba..353f19b19 100644 --- a/monkey/monkey_island/cc/services/reporting/report.py +++ b/monkey/monkey_island/cc/services/reporting/report.py @@ -8,12 +8,13 @@ from common.network.network_range import NetworkRange from common.network.network_utils import get_my_ip_addresses_legacy, get_network_interfaces from common.network.segmentation_utils import get_ip_in_src_and_not_in_dst from monkey_island.cc.database import mongo -from monkey_island.cc.models import Machine, Monkey +from monkey_island.cc.models import CommunicationType, Machine, Monkey from monkey_island.cc.models.report import get_report, save_report from monkey_island.cc.repository import ( IAgentConfigurationRepository, ICredentialsRepository, IMachineRepository, + INodeRepository, ) from monkey_island.cc.services.node import NodeService from monkey_island.cc.services.reporting.exploitations.manual_exploitation import get_manual_monkeys @@ -39,6 +40,7 @@ class ReportService: _agent_configuration_repository: Optional[IAgentConfigurationRepository] = None _credentials_repository: Optional[ICredentialsRepository] = None _machine_repository: Optional[IMachineRepository] = None + _node_repository: Optional[INodeRepository] = None class DerivedIssueEnum: ZEROLOGON_PASS_RESTORE_FAILED = "zerologon_pass_restore_failed" @@ -50,11 +52,13 @@ class ReportService: agent_configuration_repository: IAgentConfigurationRepository, credentials_repository: ICredentialsRepository, machine_repository: IMachineRepository, + node_repository: INodeRepository, ): cls._aws_service = aws_service cls._agent_configuration_repository = agent_configuration_repository cls._credentials_repository = credentials_repository cls._machine_repository = machine_repository + cls._node_repository = node_repository # This should pull from Simulation entity @staticmethod @@ -130,6 +134,21 @@ class ReportService: return formatted_nodes + @classmethod + def get_accessible_machines(cls, machine: Machine): + if cls._node_repository is None or cls._machine_repository is None: + return [] + + nodes = cls._node_repository.get_nodes() + machine_iter = (node for node in nodes if node.machine_id == machine.id) + accessible_machines = set() + for source in machine_iter: + for dest, conn in source.connections.items(): + if CommunicationType.SCANNED in conn: + accessible_machines.add(dest) + + return [cls._machine_repository.get_machine_by_id(id) for id in accessible_machines] + @classmethod def get_all_machines(cls) -> Iterable[Machine]: if cls._machine_repository is None: