diff --git a/monkey/common/events/abstract_event.py b/monkey/common/events/abstract_event.py index 33bb25506..4b4edbf99 100644 --- a/monkey/common/events/abstract_event.py +++ b/monkey/common/events/abstract_event.py @@ -1,8 +1,9 @@ +import time from abc import ABC -from dataclasses import dataclass +from dataclasses import dataclass, field from ipaddress import IPv4Address from typing import FrozenSet, Union -from uuid import UUID +from uuid import UUID, getnode @dataclass(frozen=True) @@ -21,7 +22,7 @@ class AbstractEvent(ABC): :param tags: The set of tags associated with the event """ - source: UUID - target: Union[UUID, IPv4Address, None] - timestamp: float - tags: FrozenSet[str] + source: UUID = field(default_factory=getnode) + target: Union[UUID, IPv4Address, None] = field(default=None) + timestamp: float = field(default_factory=time.time) + tags: FrozenSet[str] = field(default_factory=frozenset) diff --git a/monkey/common/events/credentials_stolen_events.py b/monkey/common/events/credentials_stolen_events.py index f1db1c142..ffd7707e7 100644 --- a/monkey/common/events/credentials_stolen_events.py +++ b/monkey/common/events/credentials_stolen_events.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Sequence from common.credentials import Credentials @@ -15,4 +15,4 @@ class CredentialsStolenEvent(AbstractEvent): :param stolen_credentials: The credentials that were stolen by an agent """ - stolen_credentials: Sequence[Credentials] + stolen_credentials: Sequence[Credentials] = field(default_factory=list) diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py index cf18a8efc..b696adf40 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py @@ -1,7 +1,8 @@ import logging -from typing import Dict, Iterable, Sequence +from typing import Sequence -from common.credentials import Credentials, SSHKeypair, Username +from common.credentials import Credentials +from common.event_queue import IEventQueue from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.i_puppet import ICredentialCollector from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger @@ -14,38 +15,13 @@ class SSHCredentialCollector(ICredentialCollector): SSH keys credential collector """ - def __init__(self, telemetry_messenger: ITelemetryMessenger): + def __init__(self, telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue): self._telemetry_messenger = telemetry_messenger + self._event_queue = event_queue def collect_credentials(self, _options=None) -> Sequence[Credentials]: logger.info("Started scanning for SSH credentials") - ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger) + ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue) logger.info("Finished scanning for SSH credentials") - return SSHCredentialCollector._to_credentials(ssh_info) - - @staticmethod - def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]: - ssh_credentials = [] - - for info in ssh_info: - identity = None - secret = None - - if info.get("name", ""): - identity = Username(info["name"]) - - ssh_keypair = {} - for key in ["public_key", "private_key"]: - if info.get(key) is not None: - ssh_keypair[key] = info[key] - - if len(ssh_keypair): - secret = SSHKeypair( - ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "") - ) - - if any([identity, secret]): - ssh_credentials.append(Credentials(identity, secret)) - - return ssh_credentials + return ssh_handler.to_credentials(ssh_info) diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py index 98ca0df4a..097d59f40 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py @@ -1,8 +1,11 @@ import glob import logging import os -from typing import Dict, Iterable +from typing import Dict, Iterable, Sequence +from common.credentials import Credentials, SSHKeypair, Username +from common.event_queue import IEventQueue +from common.events import CredentialsStolenEvent from common.utils.attack_utils import ScanStatus from infection_monkey.telemetry.attack.t1005_telem import T1005Telem from infection_monkey.telemetry.attack.t1145_telem import T1145Telem @@ -12,9 +15,22 @@ from infection_monkey.utils.environment import is_windows_os logger = logging.getLogger(__name__) DEFAULT_DIRS = ["/.ssh/", "/"] +SSH_CREDENTIAL_COLLECTOR_TAG = "ssh-credentials-collector" +T1003_ATTACK_TECHNIQUE_TAG = "attack-t1003" +T1005_ATTACK_TECHNIQUE_TAG = "attack-t1005" +T1145_ATTACK_TECHNIQUE_TAG = "attack-t1145" + +SSH_COLLECTOR_EVENT_TAG = { + SSH_CREDENTIAL_COLLECTOR_TAG, + T1003_ATTACK_TECHNIQUE_TAG, + T1005_ATTACK_TECHNIQUE_TAG, + T1145_ATTACK_TECHNIQUE_TAG, +} -def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]: +def get_ssh_info( + telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue +) -> Iterable[Dict]: # TODO: Remove this check when this is turned into a plugin. if is_windows_os(): logger.debug( @@ -23,7 +39,7 @@ def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]: return [] home_dirs = _get_home_dirs() - ssh_info = _get_ssh_files(home_dirs, telemetry_messenger) + ssh_info = _get_ssh_files(home_dirs, telemetry_messenger, event_queue) return ssh_info @@ -62,9 +78,9 @@ def _get_ssh_struct(name: str, home_dir: str) -> Dict: def _get_ssh_files( - usr_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger + user_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue ) -> Iterable[Dict]: - for info in usr_info: + for info in user_info: path = info["home_dir"] for directory in DEFAULT_DIRS: # TODO: Use PATH @@ -101,6 +117,11 @@ def _get_ssh_files( ScanStatus.USED, info["name"], info["home_dir"] ) ) + + collected_credentials = to_credentials([info]) + _publish_credentials_stolen_event( + collected_credentials, event_queue + ) else: continue except (IOError, OSError): @@ -112,5 +133,40 @@ def _get_ssh_files( pass except OSError: pass - usr_info = [info for info in usr_info if info["private_key"] or info["public_key"]] - return usr_info + user_info = [info for info in user_info if info["private_key"] or info["public_key"]] + return user_info + + +def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]: + ssh_credentials = [] + + for info in ssh_info: + identity = None + secret = None + + if info.get("name", ""): + identity = Username(info["name"]) + + ssh_keypair = {} + for key in ["public_key", "private_key"]: + if info.get(key) is not None: + ssh_keypair[key] = info[key] + + if len(ssh_keypair): + secret = SSHKeypair( + ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "") + ) + + if any([identity, secret]): + ssh_credentials.append(Credentials(identity, secret)) + + return ssh_credentials + + +def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue): + credentials_stolen_event = CredentialsStolenEvent( + tags=frozenset(SSH_COLLECTOR_EVENT_TAG), + stolen_credentials=[collected_credentials], + ) + + event_queue.publish(credentials_stolen_event) diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 258d05a36..48cc70c67 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -9,7 +9,7 @@ from typing import List from pubsub.core import Publisher import infection_monkey.tunnel as tunnel -from common.event_queue import PyPubSubEventQueue +from common.event_queue import IEventQueue, PyPubSubEventQueue from common.events import CredentialsStolenEvent from common.network.network_utils import address_to_ip_port from common.utils.argparse_types import positive_int @@ -199,18 +199,18 @@ class InfectionMonkey: def _build_master(self): local_network_interfaces = InfectionMonkey._get_local_network_interfaces() - _event_queue = PyPubSubEventQueue(Publisher()) - _event_queue.subscribe_type( - CredentialsStolenEvent, add_credentials_from_event_to_propagation_credentials_repository - ) - # TODO control_channel and control_client have same responsibilities, merge them control_channel = ControlChannel( self._control_client.server_address, GUID, self._control_client.proxies ) - credentials_store = AggregatingPropagationCredentialsRepository(control_channel) + propagation_credentials_repository = AggregatingPropagationCredentialsRepository( + control_channel + ) - puppet = self._build_puppet(credentials_store) + event_queue = PyPubSubEventQueue(Publisher()) + InfectionMonkey._subscribe_events(event_queue, propagation_credentials_repository) + + puppet = self._build_puppet(propagation_credentials_repository, event_queue) victim_host_factory = self._build_victim_host_factory(local_network_interfaces) @@ -218,7 +218,7 @@ class InfectionMonkey: ExploitInterceptingTelemetryMessenger( self._telemetry_messenger, self._monkey_inbound_tunnel ), - credentials_store, + propagation_credentials_repository, ) self._master = AutomatedMaster( @@ -228,7 +228,19 @@ class InfectionMonkey: victim_host_factory, control_channel, local_network_interfaces, - credentials_store, + propagation_credentials_repository, + ) + + @staticmethod + def _subscribe_events( + event_queue: IEventQueue, + propagation_credentials_repository: IPropagationCredentialsRepository, + ): + event_queue.subscribe_type( + CredentialsStolenEvent, + add_credentials_from_event_to_propagation_credentials_repository( + propagation_credentials_repository + ), ) @staticmethod @@ -239,7 +251,11 @@ class InfectionMonkey: return local_network_interfaces - def _build_puppet(self, credentials_store: IPropagationCredentialsRepository) -> IPuppet: + def _build_puppet( + self, + propagation_credentials_repository: IPropagationCredentialsRepository, + event_queue: IEventQueue, + ) -> IPuppet: puppet = Puppet() puppet.load_plugin( @@ -249,7 +265,7 @@ class InfectionMonkey: ) puppet.load_plugin( "SSHCollector", - SSHCredentialCollector(self._telemetry_messenger), + SSHCredentialCollector(self._telemetry_messenger, event_queue), PluginType.CREDENTIAL_COLLECTOR, ) @@ -281,7 +297,7 @@ class InfectionMonkey: ) zerologon_telemetry_messenger = CredentialsInterceptingTelemetryMessenger( - self._telemetry_messenger, credentials_store + self._telemetry_messenger, propagation_credentials_repository ) zerologon_wrapper = ExploiterWrapper(zerologon_telemetry_messenger, agent_repository) puppet.load_plugin( diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py index c6d2a869d..ba12e416f 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock import pytest from common.credentials import Credentials, SSHKeypair, Username +from common.event_queue import IEventQueue from infection_monkey.credential_collectors import SSHCredentialCollector @@ -14,7 +15,7 @@ def patch_telemetry_messenger(): def patch_ssh_handler(ssh_creds, monkeypatch): monkeypatch.setattr( "infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info", - lambda _: ssh_creds, + lambda _, __: ssh_creds, ) @@ -23,7 +24,9 @@ def patch_ssh_handler(ssh_creds, monkeypatch): ) def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger): patch_ssh_handler(ssh_creds, monkeypatch) - collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials() + collected = SSHCredentialCollector( + patch_telemetry_messenger, MagicMock(spec=IEventQueue) + ).collect_credentials() assert not collected @@ -67,5 +70,7 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger): Credentials(identity=username3, secret=None), Credentials(identity=None, secret=ssh_keypair3), ] - collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials() + collected = SSHCredentialCollector( + patch_telemetry_messenger, MagicMock(spec=IEventQueue) + ).collect_credentials() assert expected == collected