diff --git a/monkey/monkey_island/cc/services/reporting/report.py b/monkey/monkey_island/cc/services/reporting/report.py index 52ed04df9..729967aba 100644 --- a/monkey/monkey_island/cc/services/reporting/report.py +++ b/monkey/monkey_island/cc/services/reporting/report.py @@ -1,16 +1,20 @@ import functools import ipaddress import logging -from itertools import chain, product -from typing import List +from itertools import chain, filterfalse, product, tee +from typing import Iterable, List, Optional from common.network.network_range import NetworkRange from common.network.network_utils import get_my_ip_addresses_legacy, get_network_interfaces from common.network.segmentation_utils import get_ip_in_src_and_not_in_dst from monkey_island.cc.database import mongo -from monkey_island.cc.models import Monkey +from monkey_island.cc.models import Machine, Monkey from monkey_island.cc.models.report import get_report, save_report -from monkey_island.cc.repository import IAgentConfigurationRepository, ICredentialsRepository +from monkey_island.cc.repository import ( + IAgentConfigurationRepository, + ICredentialsRepository, + IMachineRepository, +) from monkey_island.cc.services.node import NodeService from monkey_island.cc.services.reporting.exploitations.manual_exploitation import get_manual_monkeys from monkey_island.cc.services.reporting.exploitations.monkey_exploitation import ( @@ -31,9 +35,10 @@ logger = logging.getLogger(__name__) class ReportService: - _aws_service = None - _agent_configuration_repository = None - _credentials_repository = None + _aws_service: Optional[AWSService] = None + _agent_configuration_repository: Optional[IAgentConfigurationRepository] = None + _credentials_repository: Optional[ICredentialsRepository] = None + _machine_repository: Optional[IMachineRepository] = None class DerivedIssueEnum: ZEROLOGON_PASS_RESTORE_FAILED = "zerologon_pass_restore_failed" @@ -44,10 +49,12 @@ class ReportService: aws_service: AWSService, agent_configuration_repository: IAgentConfigurationRepository, credentials_repository: ICredentialsRepository, + machine_repository: IMachineRepository, ): cls._aws_service = aws_service cls._agent_configuration_repository = agent_configuration_repository cls._credentials_repository = credentials_repository + cls._machine_repository = machine_repository # This should pull from Simulation entity @staticmethod @@ -123,6 +130,18 @@ class ReportService: return formatted_nodes + @classmethod + def get_all_machines(cls) -> Iterable[Machine]: + if cls._machine_repository is None: + return iter(()) + machines = cls._machine_repository.get_machines() + t1, t2 = tee(machines) + + def is_island(machine: Machine): + return machine.island + + return chain(filter(is_island, t1), *filterfalse(is_island, t2)) + @staticmethod def get_all_displayed_nodes(): nodes_without_monkeys = [