diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py index aa9a52b72..bdcc56098 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py @@ -4,6 +4,7 @@ from typing import Dict, Iterable, List from infection_monkey.credential_collectors import SSHKeypair, Username from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector +from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger logger = logging.getLogger(__name__) @@ -13,9 +14,12 @@ class SSHCredentialCollector(ICredentialCollector): SSH keys credential collector """ + def __init__(self, telemetry_messenger: ITelemetryMessenger): + self._telemetry_messenger = telemetry_messenger + def collect_credentials(self, _options=None) -> List[Credentials]: logger.info("Started scanning for SSH credentials") - ssh_info = ssh_handler.get_ssh_info() + ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger) logger.info("Finished scanning for SSH credentials") return SSHCredentialCollector._to_credentials(ssh_info) diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py index a204550f5..8c635d92b 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py @@ -7,15 +7,16 @@ from typing import Dict, Iterable from common.utils.attack_utils import ScanStatus from infection_monkey.telemetry.attack.t1005_telem import T1005Telem from infection_monkey.telemetry.attack.t1145_telem import T1145Telem +from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger logger = logging.getLogger(__name__) DEFAULT_DIRS = ["/.ssh/", "/"] -def get_ssh_info() -> Iterable[Dict]: +def get_ssh_info(telemetry_messenger: ITelemetryMessenger) -> Iterable[Dict]: home_dirs = _get_home_dirs() - ssh_info = _get_ssh_files(home_dirs) + ssh_info = _get_ssh_files(home_dirs, telemetry_messenger) return ssh_info @@ -51,7 +52,9 @@ def _get_ssh_struct(name: str, home_dir: str) -> Dict: } -def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]: +def _get_ssh_files( + usr_info: Iterable[Dict], telemetry_messenger: ITelemetryMessenger +) -> Iterable[Dict]: for info in usr_info: path = info["home_dir"] for directory in DEFAULT_DIRS: @@ -79,12 +82,16 @@ def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]: if private_key.find("ENCRYPTED") == -1: info["private_key"] = private_key logger.info("Found private key in %s" % private) - T1005Telem( - ScanStatus.USED, "SSH key", "Path: %s" % private - ).send() - T1145Telem( - ScanStatus.USED, info["name"], info["home_dir"] - ).send() + telemetry_messenger.send_telemetry( + T1005Telem( + ScanStatus.USED, "SSH key", "Path: %s" % private + ) + ) + telemetry_messenger.send_telemetry( + T1145Telem( + ScanStatus.USED, info["name"], info["home_dir"] + ) + ) else: continue except (IOError, OSError): diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py index a19434282..2762892bf 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest from infection_monkey.credential_collectors import SSHKeypair, Username @@ -5,23 +7,28 @@ from infection_monkey.credential_collectors.ssh_collector import SSHCredentialCo from infection_monkey.i_puppet.credential_collection import Credentials +@pytest.fixture +def patch_telemetry_messenger(): + return MagicMock() + + def patch_ssh_handler(ssh_creds, monkeypatch): monkeypatch.setattr( "infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info", - lambda: ssh_creds, + lambda _: ssh_creds, ) @pytest.mark.parametrize( "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])] ) -def test_ssh_credentials_empty_results(monkeypatch, ssh_creds): +def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger): patch_ssh_handler(ssh_creds, monkeypatch) - collected = SSHCredentialCollector().collect_credentials() + collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials() assert not collected -def test_ssh_info_result_parsing(monkeypatch): +def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger): ssh_creds = [ { @@ -55,5 +62,5 @@ def test_ssh_info_result_parsing(monkeypatch): Credentials(identities=[username2], secrets=[ssh_keypair2]), Credentials(identities=[username3], secrets=[]), ] - collected = SSHCredentialCollector().collect_credentials() + collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials() assert expected == collected