diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py index b54f8d464..b696adf40 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py @@ -1,7 +1,7 @@ import logging -from typing import Dict, Iterable, Sequence +from typing import Sequence -from common.credentials import Credentials, SSHKeypair, Username +from common.credentials import Credentials from common.event_queue import IEventQueue from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.i_puppet import ICredentialCollector @@ -24,30 +24,4 @@ class SSHCredentialCollector(ICredentialCollector): ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue) logger.info("Finished scanning for SSH credentials") - return SSHCredentialCollector._to_credentials(ssh_info) - - @staticmethod - def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]: - ssh_credentials = [] - - for info in ssh_info: - identity = None - secret = None - - if info.get("name", ""): - identity = Username(info["name"]) - - ssh_keypair = {} - for key in ["public_key", "private_key"]: - if info.get(key) is not None: - ssh_keypair[key] = info[key] - - if len(ssh_keypair): - secret = SSHKeypair( - ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "") - ) - - if any([identity, secret]): - ssh_credentials.append(Credentials(identity, secret)) - - return ssh_credentials + return ssh_handler.to_credentials(ssh_info) diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py index ed16bbfa2..cc4245842 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py @@ -1,7 +1,7 @@ import glob import logging import os -from typing import Dict, Iterable +from typing import Dict, Iterable, Sequence from common.credentials import Credentials, SSHKeypair, Username from common.event_queue import IEventQueue @@ -118,12 +118,7 @@ def _get_ssh_files( ) ) - collected_credentials = Credentials( - identity=Username(info["name"]), - secret=SSHKeypair( - info["private_key"], info["public_key"] - ), - ) + collected_credentials = to_credentials([info]) _publish_credentials_stolen_event( collected_credentials, event_queue ) @@ -142,6 +137,32 @@ def _get_ssh_files( return user_info +def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]: + ssh_credentials = [] + + for info in ssh_info: + identity = None + secret = None + + if info.get("name", ""): + identity = Username(info["name"]) + + ssh_keypair = {} + for key in ["public_key", "private_key"]: + if info.get(key) is not None: + ssh_keypair[key] = info[key] + + if len(ssh_keypair): + secret = SSHKeypair( + ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "") + ) + + if any([identity, secret]): + ssh_credentials.append(Credentials(identity, secret)) + + return ssh_credentials + + def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue): credentials_stolen_event = CredentialsStolenEvent( target=None, diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py index d82236014..14a2e320a 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_ssh_credentials_collector.py @@ -1,9 +1,10 @@ from unittest.mock import MagicMock import pytest +from pubsub.core import Publisher from common.credentials import Credentials, SSHKeypair, Username -from common.event_queue import IEventQueue +from common.event_queue import IEventQueue, PyPubSubEventQueue from infection_monkey.credential_collectors import SSHCredentialCollector @@ -13,8 +14,8 @@ def patch_telemetry_messenger(): @pytest.fixture -def mock_event_queue(): - return MagicMock(spec=IEventQueue) +def event_queue() -> IEventQueue: + return PyPubSubEventQueue(Publisher()) def patch_ssh_handler(ssh_creds, monkeypatch): @@ -27,17 +28,15 @@ def patch_ssh_handler(ssh_creds, monkeypatch): @pytest.mark.parametrize( "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])] ) -def test_ssh_credentials_empty_results( - monkeypatch, ssh_creds, patch_telemetry_messenger, mock_event_queue -): +def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger): patch_ssh_handler(ssh_creds, monkeypatch) collected = SSHCredentialCollector( - patch_telemetry_messenger, mock_event_queue + patch_telemetry_messenger, MagicMock(spec=IEventQueue) ).collect_credentials() assert not collected -def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue): +def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger): ssh_creds = [ { @@ -78,6 +77,6 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_ev Credentials(identity=None, secret=ssh_keypair3), ] collected = SSHCredentialCollector( - patch_telemetry_messenger, mock_event_queue + patch_telemetry_messenger, MagicMock(spec=IEventQueue) ).collect_credentials() assert expected == collected