Agent: Publish CredentialsStolenEvent each time we find a SSHKeypair

This commit is contained in:
Ilija Lazoroski 2022-08-15 17:09:32 +02:00
parent e439a53bde
commit b22ccdb942
3 changed files with 35 additions and 35 deletions

View File

@ -1,19 +1,14 @@
import logging import logging
import time
from typing import Dict, Iterable, Sequence from typing import Dict, Iterable, Sequence
from common.credentials import Credentials, SSHKeypair, Username from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue from common.event_queue import IEventQueue
from common.events import CredentialsStolenEvent
from infection_monkey.config import GUID
from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet import ICredentialCollector from infection_monkey.i_puppet import ICredentialCollector
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SSH_CREDENTIAL_COLLECTOR_TAG = "SSHCredentialsStolen"
class SSHCredentialCollector(ICredentialCollector): class SSHCredentialCollector(ICredentialCollector):
""" """
@ -26,31 +21,12 @@ class SSHCredentialCollector(ICredentialCollector):
def collect_credentials(self, _options=None) -> Sequence[Credentials]: def collect_credentials(self, _options=None) -> Sequence[Credentials]:
logger.info("Started scanning for SSH credentials") logger.info("Started scanning for SSH credentials")
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger) ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
logger.info("Finished scanning for SSH credentials") logger.info("Finished scanning for SSH credentials")
ssh_collector_credentials = SSHCredentialCollector._to_credentials(ssh_info) ssh_collector_credentials = SSHCredentialCollector._to_credentials(ssh_info)
credentials_stolen_event = SSHCredentialCollector._generate_credentials_stolen_event(
ssh_collector_credentials
)
self._event_queue.publish(credentials_stolen_event)
return ssh_collector_credentials return ssh_collector_credentials
@staticmethod
def _generate_credentials_stolen_event(
collected_credentials: Sequence[Credentials],
) -> CredentialsStolenEvent:
credentials_stolen_event = CredentialsStolenEvent(
source=GUID,
target=None,
timestamp=time.time(),
tags=frozenset({SSH_CREDENTIAL_COLLECTOR_TAG, "T1005", "T1145"}),
stolen_credentials=collected_credentials,
)
return credentials_stolen_event
@staticmethod @staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]: def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = [] ssh_credentials = []

View File

@ -1,8 +1,13 @@
import glob import glob
import logging import logging
import os import os
import time
import uuid
from typing import Dict, Iterable from typing import Dict, Iterable
from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue
from common.events import CredentialsStolenEvent
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
from infection_monkey.telemetry.attack.t1145_telem import T1145Telem from infection_monkey.telemetry.attack.t1145_telem import T1145Telem
@ -12,9 +17,12 @@ from infection_monkey.utils.environment import is_windows_os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_DIRS = ["/.ssh/", "/"] DEFAULT_DIRS = ["/.ssh/", "/"]
SSH_CREDENTIAL_COLLECTOR_TAG = "SSHCredentialsStolen"
def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]: def get_ssh_info(
telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue
) -> Iterable[Dict]:
# TODO: Remove this check when this is turned into a plugin. # TODO: Remove this check when this is turned into a plugin.
if is_windows_os(): if is_windows_os():
logger.debug( logger.debug(
@ -23,7 +31,7 @@ def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]:
return [] return []
home_dirs = _get_home_dirs() home_dirs = _get_home_dirs()
ssh_info = _get_ssh_files(home_dirs, telemetry_messenger) ssh_info = _get_ssh_files(home_dirs, telemetry_messenger, event_queue)
return ssh_info return ssh_info
@ -62,7 +70,7 @@ def _get_ssh_struct(name: str, home_dir: str) -> Dict:
def _get_ssh_files( def _get_ssh_files(
usr_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger usr_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger, event_queue: IEventQueue
) -> Iterable[Dict]: ) -> Iterable[Dict]:
for info in usr_info: for info in usr_info:
path = info["home_dir"] path = info["home_dir"]
@ -101,6 +109,16 @@ def _get_ssh_files(
ScanStatus.USED, info["name"], info["home_dir"] ScanStatus.USED, info["name"], info["home_dir"]
) )
) )
collected_credentials = Credentials(
identity=Username(info["name"]),
secrets=SSHKeypair(
info["private_key"], info["public_key"]
),
)
_publish_credentials_stolen_event(
collected_credentials, event_queue
)
else: else:
continue continue
except (IOError, OSError): except (IOError, OSError):
@ -114,3 +132,15 @@ def _get_ssh_files(
pass pass
usr_info = [info for info in usr_info if info["private_key"] or info["public_key"]] usr_info = [info for info in usr_info if info["private_key"] or info["public_key"]]
return usr_info return usr_info
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
credentials_stolen_event = CredentialsStolenEvent(
source=uuid.getnode(),
target=None,
timestamp=time.time(),
tags=frozenset({SSH_CREDENTIAL_COLLECTOR_TAG, "T1005", "T1145"}),
stolen_credentials=[collected_credentials],
)
event_queue.publish(credentials_stolen_event)

View File

@ -20,14 +20,10 @@ def mock_event_queue():
def patch_ssh_handler(ssh_creds, monkeypatch): def patch_ssh_handler(ssh_creds, monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info", "infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info",
lambda _: ssh_creds, lambda _, __: ssh_creds,
) )
def patch_guid(monkeypatch):
monkeypatch.setattr("infection_monkey.config.GUID", "1-2-3-4-5-6")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])] "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
) )
@ -39,7 +35,6 @@ def test_ssh_credentials_empty_results(
patch_telemetry_messenger, mock_event_queue patch_telemetry_messenger, mock_event_queue
).collect_credentials() ).collect_credentials()
assert not collected assert not collected
mock_event_queue.publish.assert_called_once()
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue): def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue):
@ -86,4 +81,3 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_ev
patch_telemetry_messenger, mock_event_queue patch_telemetry_messenger, mock_event_queue
).collect_credentials() ).collect_credentials()
assert expected == collected assert expected == collected
mock_event_queue.publish.assert_called_once()