diff --git a/monkey/infection_monkey/credential_collectors/__init__.py b/monkey/infection_monkey/credential_collectors/__init__.py index a9d22a4c4..a5d48e466 100644 --- a/monkey/infection_monkey/credential_collectors/__init__.py +++ b/monkey/infection_monkey/credential_collectors/__init__.py @@ -2,4 +2,5 @@ from .credential_components.nt_hash import NTHash from .credential_components.lm_hash import LMHash from .credential_components.password import Password from .credential_components.username import Username +from .credential_components.ssh_keypair import SSHKeypair from .mimikatz_collector import MimikatzCredentialCollector diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/__init__.py b/monkey/infection_monkey/credential_collectors/ssh_collector/__init__.py index adc6a2dc5..d89d836f8 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/__init__.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/__init__.py @@ -1 +1 @@ -from .SSH_credentials_collector import SSHCollector +from .ssh_credential_collector import SSHCredentialCollector diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py index 85a9c505a..aa9a52b72 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_credential_collector.py @@ -1,12 +1,9 @@ import logging from typing import Dict, Iterable, List -from infection_monkey.credential_collectors import ( - SSHKeypair, - Username, -) -from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector +from infection_monkey.credential_collectors import SSHKeypair, Username from infection_monkey.credential_collectors.ssh_collector import ssh_handler +from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector logger = logging.getLogger(__name__) @@ -26,17 +23,17 @@ class SSHCredentialCollector(ICredentialCollector): @staticmethod def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]: ssh_credentials = [] - identities = [] - secrets = [] for info in ssh_info: + identities = [] + secrets = [] - if "name" in info and info["name"] != "": + if info.get("name", ""): identities.append(Username(info["name"])) ssh_keypair = {} for key in ["public_key", "private_key"]: - if key in info and info.get(key) is not None: + if info.get(key) is not None: ssh_keypair[key] = info[key] if len(ssh_keypair): diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py index 45aff0878..a19434282 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py @@ -1,5 +1,8 @@ -from infection_monkey.credential_collectors import Credentials, SSHKeypair, Username -from infection_monkey.credential_collectors.ssh_collector import SSHCollector +import pytest + +from infection_monkey.credential_collectors import SSHKeypair, Username +from infection_monkey.credential_collectors.ssh_collector import SSHCredentialCollector +from infection_monkey.i_puppet.credential_collection import Credentials def patch_ssh_handler(ssh_creds, monkeypatch): @@ -9,16 +12,13 @@ def patch_ssh_handler(ssh_creds, monkeypatch): ) -def test_ssh_credentials_empty_results(monkeypatch): - patch_ssh_handler([], monkeypatch) - collected = SSHCollector().collect_credentials() - assert [] == collected - - ssh_creds = [{"name": "", "home_dir": "", "public_key": None, "private_key": None}] +@pytest.mark.parametrize( + "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])] +) +def test_ssh_credentials_empty_results(monkeypatch, ssh_creds): patch_ssh_handler(ssh_creds, monkeypatch) - expected = [] - collected = SSHCollector().collect_credentials() - assert expected == collected + collected = SSHCredentialCollector().collect_credentials() + assert not collected def test_ssh_info_result_parsing(monkeypatch): @@ -48,14 +48,12 @@ def test_ssh_info_result_parsing(monkeypatch): ssh_keypair1 = SSHKeypair( {"public_key": "SomePublicKeyUbuntu", "private_key": "ExtremelyGoodPrivateKey"} ) - ssh_keypair2 = SSHKeypair( - {"public_key": "AnotherPublicKey", "private_key": "NotSoGoodPrivateKey"} - ) + ssh_keypair2 = SSHKeypair({"public_key": "AnotherPublicKey"}) expected = [ Credentials(identities=[username], secrets=[ssh_keypair1]), Credentials(identities=[username2], secrets=[ssh_keypair2]), Credentials(identities=[username3], secrets=[]), ] - collected = SSHCollector().collect_credentials() + collected = SSHCredentialCollector().collect_credentials() assert expected == collected