Merge pull request #2196 from guardicore/2176-modify-ssh-collector-for-events

2176 modify ssh collector for events
This commit is contained in:
Mike Salvatore 2022-08-16 12:41:14 -04:00 committed by GitHub
commit 2edaf52140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 62 deletions

View File

@ -1,8 +1,9 @@
import time
from abc import ABC
from dataclasses import dataclass
from dataclasses import dataclass, field
from ipaddress import IPv4Address
from typing import FrozenSet, Union
from uuid import UUID
from uuid import UUID, getnode
@dataclass(frozen=True)
@ -21,7 +22,7 @@ class AbstractEvent(ABC):
:param tags: The set of tags associated with the event
"""
source: UUID
target: Union[UUID, IPv4Address, None]
timestamp: float
tags: FrozenSet[str]
source: UUID = field(default_factory=getnode)
target: Union[UUID, IPv4Address, None] = field(default=None)
timestamp: float = field(default_factory=time.time)
tags: FrozenSet[str] = field(default_factory=frozenset)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Sequence
from common.credentials import Credentials
@ -15,4 +15,4 @@ class CredentialsStolenEvent(AbstractEvent):
:param stolen_credentials: The credentials that were stolen by an agent
"""
stolen_credentials: Sequence[Credentials]
stolen_credentials: Sequence[Credentials] = field(default_factory=list)

View File

@ -1,7 +1,8 @@
import logging
from typing import Dict, Iterable, Sequence
from typing import Sequence
from common.credentials import Credentials, SSHKeypair, Username
from common.credentials import Credentials
from common.event_queue import IEventQueue
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet import ICredentialCollector
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
@ -14,38 +15,13 @@ class SSHCredentialCollector(ICredentialCollector):
SSH keys credential collector
"""
def __init__(self, telemetry_messenger: ITelemetryMessenger):
def __init__(self, telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue):
self._telemetry_messenger = telemetry_messenger
self._event_queue = event_queue
def collect_credentials(self, _options=None) -> Sequence[Credentials]:
logger.info("Started scanning for SSH credentials")
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger)
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
logger.info("Finished scanning for SSH credentials")
return SSHCredentialCollector._to_credentials(ssh_info)
@staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials
return ssh_handler.to_credentials(ssh_info)

View File

@ -1,8 +1,11 @@
import glob
import logging
import os
from typing import Dict, Iterable
from typing import Dict, Iterable, Sequence
from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue
from common.events import CredentialsStolenEvent
from common.utils.attack_utils import ScanStatus
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
from infection_monkey.telemetry.attack.t1145_telem import T1145Telem
@ -12,9 +15,22 @@ from infection_monkey.utils.environment import is_windows_os
logger = logging.getLogger(__name__)
DEFAULT_DIRS = ["/.ssh/", "/"]
SSH_CREDENTIAL_COLLECTOR_TAG = "ssh-credentials-collector"
T1003_ATTACK_TECHNIQUE_TAG = "attack-t1003"
T1005_ATTACK_TECHNIQUE_TAG = "attack-t1005"
T1145_ATTACK_TECHNIQUE_TAG = "attack-t1145"
SSH_COLLECTOR_EVENT_TAG = {
SSH_CREDENTIAL_COLLECTOR_TAG,
T1003_ATTACK_TECHNIQUE_TAG,
T1005_ATTACK_TECHNIQUE_TAG,
T1145_ATTACK_TECHNIQUE_TAG,
}
def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]:
def get_ssh_info(
telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue
) -> Iterable[Dict]:
# TODO: Remove this check when this is turned into a plugin.
if is_windows_os():
logger.debug(
@ -23,7 +39,7 @@ def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]:
return []
home_dirs = _get_home_dirs()
ssh_info = _get_ssh_files(home_dirs, telemetry_messenger)
ssh_info = _get_ssh_files(home_dirs, telemetry_messenger, event_queue)
return ssh_info
@ -62,9 +78,9 @@ def _get_ssh_struct(name: str, home_dir: str) -> Dict:
def _get_ssh_files(
usr_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger
user_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue
) -> Iterable[Dict]:
for info in usr_info:
for info in user_info:
path = info["home_dir"]
for directory in DEFAULT_DIRS:
# TODO: Use PATH
@ -101,6 +117,11 @@ def _get_ssh_files(
ScanStatus.USED, info["name"], info["home_dir"]
)
)
collected_credentials = to_credentials([info])
_publish_credentials_stolen_event(
collected_credentials, event_queue
)
else:
continue
except (IOError, OSError):
@ -112,5 +133,40 @@ def _get_ssh_files(
pass
except OSError:
pass
usr_info = [info for info in usr_info if info["private_key"] or info["public_key"]]
return usr_info
user_info = [info for info in user_info if info["private_key"] or info["public_key"]]
return user_info
def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
credentials_stolen_event = CredentialsStolenEvent(
tags=frozenset(SSH_COLLECTOR_EVENT_TAG),
stolen_credentials=[collected_credentials],
)
event_queue.publish(credentials_stolen_event)

View File

@ -9,7 +9,7 @@ from typing import List
from pubsub.core import Publisher
import infection_monkey.tunnel as tunnel
from common.event_queue import PyPubSubEventQueue
from common.event_queue import IEventQueue, PyPubSubEventQueue
from common.events import CredentialsStolenEvent
from common.network.network_utils import address_to_ip_port
from common.utils.argparse_types import positive_int
@ -199,18 +199,18 @@ class InfectionMonkey:
def _build_master(self):
local_network_interfaces = InfectionMonkey._get_local_network_interfaces()
_event_queue = PyPubSubEventQueue(Publisher())
_event_queue.subscribe_type(
CredentialsStolenEvent, add_credentials_from_event_to_propagation_credentials_repository
)
# TODO control_channel and control_client have same responsibilities, merge them
control_channel = ControlChannel(
self._control_client.server_address, GUID, self._control_client.proxies
)
credentials_store = AggregatingPropagationCredentialsRepository(control_channel)
propagation_credentials_repository = AggregatingPropagationCredentialsRepository(
control_channel
)
puppet = self._build_puppet(credentials_store)
event_queue = PyPubSubEventQueue(Publisher())
InfectionMonkey._subscribe_events(event_queue, propagation_credentials_repository)
puppet = self._build_puppet(propagation_credentials_repository, event_queue)
victim_host_factory = self._build_victim_host_factory(local_network_interfaces)
@ -218,7 +218,7 @@ class InfectionMonkey:
ExploitInterceptingTelemetryMessenger(
self._telemetry_messenger, self._monkey_inbound_tunnel
),
credentials_store,
propagation_credentials_repository,
)
self._master = AutomatedMaster(
@ -228,7 +228,19 @@ class InfectionMonkey:
victim_host_factory,
control_channel,
local_network_interfaces,
credentials_store,
propagation_credentials_repository,
)
@staticmethod
def _subscribe_events(
event_queue: IEventQueue,
propagation_credentials_repository: IPropagationCredentialsRepository,
):
event_queue.subscribe_type(
CredentialsStolenEvent,
add_credentials_from_event_to_propagation_credentials_repository(
propagation_credentials_repository
),
)
@staticmethod
@ -239,7 +251,11 @@ class InfectionMonkey:
return local_network_interfaces
def _build_puppet(self, credentials_store: IPropagationCredentialsRepository) -> IPuppet:
def _build_puppet(
self,
propagation_credentials_repository: IPropagationCredentialsRepository,
event_queue: IEventQueue,
) -> IPuppet:
puppet = Puppet()
puppet.load_plugin(
@ -249,7 +265,7 @@ class InfectionMonkey:
)
puppet.load_plugin(
"SSHCollector",
SSHCredentialCollector(self._telemetry_messenger),
SSHCredentialCollector(self._telemetry_messenger, event_queue),
PluginType.CREDENTIAL_COLLECTOR,
)
@ -281,7 +297,7 @@ class InfectionMonkey:
)
zerologon_telemetry_messenger = CredentialsInterceptingTelemetryMessenger(
self._telemetry_messenger, credentials_store
self._telemetry_messenger, propagation_credentials_repository
)
zerologon_wrapper = ExploiterWrapper(zerologon_telemetry_messenger, agent_repository)
puppet.load_plugin(

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock
import pytest
from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue
from infection_monkey.credential_collectors import SSHCredentialCollector
@ -14,7 +15,7 @@ def patch_telemetry_messenger():
def patch_ssh_handler(ssh_creds, monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info",
lambda _: ssh_creds,
lambda _, __: ssh_creds,
)
@ -23,7 +24,9 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
)
def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger):
patch_ssh_handler(ssh_creds, monkeypatch)
collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials()
collected = SSHCredentialCollector(
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials()
assert not collected
@ -67,5 +70,7 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
Credentials(identity=username3, secret=None),
Credentials(identity=None, secret=ssh_keypair3),
]
collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials()
collected = SSHCredentialCollector(
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials()
assert expected == collected