Agent: Rework ssh credential collector to match credential architecture

* Parametrize empty result unit test
* Apply small changes to ssh credential collector
This commit is contained in:
Ilija Lazoroski 2022-02-16 17:37:12 +01:00
parent a97b8706ec
commit 63d632d142
4 changed files with 21 additions and 25 deletions

View File

@ -2,4 +2,5 @@ from .credential_components.nt_hash import NTHash
from .credential_components.lm_hash import LMHash from .credential_components.lm_hash import LMHash
from .credential_components.password import Password from .credential_components.password import Password
from .credential_components.username import Username from .credential_components.username import Username
from .credential_components.ssh_keypair import SSHKeypair
from .mimikatz_collector import MimikatzCredentialCollector from .mimikatz_collector import MimikatzCredentialCollector

View File

@ -1 +1 @@
from .SSH_credentials_collector import SSHCollector from .ssh_credential_collector import SSHCredentialCollector

View File

@ -1,12 +1,9 @@
import logging import logging
from typing import Dict, Iterable, List from typing import Dict, Iterable, List
from infection_monkey.credential_collectors import ( from infection_monkey.credential_collectors import SSHKeypair, Username
SSHKeypair,
Username,
)
from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector
from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,17 +23,17 @@ class SSHCredentialCollector(ICredentialCollector):
@staticmethod @staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]: def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]:
ssh_credentials = [] ssh_credentials = []
for info in ssh_info:
identities = [] identities = []
secrets = [] secrets = []
for info in ssh_info: if info.get("name", ""):
if "name" in info and info["name"] != "":
identities.append(Username(info["name"])) identities.append(Username(info["name"]))
ssh_keypair = {} ssh_keypair = {}
for key in ["public_key", "private_key"]: for key in ["public_key", "private_key"]:
if key in info and info.get(key) is not None: if info.get(key) is not None:
ssh_keypair[key] = info[key] ssh_keypair[key] = info[key]
if len(ssh_keypair): if len(ssh_keypair):

View File

@ -1,5 +1,8 @@
from infection_monkey.credential_collectors import Credentials, SSHKeypair, Username import pytest
from infection_monkey.credential_collectors.ssh_collector import SSHCollector
from infection_monkey.credential_collectors import SSHKeypair, Username
from infection_monkey.credential_collectors.ssh_collector import SSHCredentialCollector
from infection_monkey.i_puppet.credential_collection import Credentials
def patch_ssh_handler(ssh_creds, monkeypatch): def patch_ssh_handler(ssh_creds, monkeypatch):
@ -9,16 +12,13 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
) )
def test_ssh_credentials_empty_results(monkeypatch): @pytest.mark.parametrize(
patch_ssh_handler([], monkeypatch) "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
collected = SSHCollector().collect_credentials() )
assert [] == collected def test_ssh_credentials_empty_results(monkeypatch, ssh_creds):
ssh_creds = [{"name": "", "home_dir": "", "public_key": None, "private_key": None}]
patch_ssh_handler(ssh_creds, monkeypatch) patch_ssh_handler(ssh_creds, monkeypatch)
expected = [] collected = SSHCredentialCollector().collect_credentials()
collected = SSHCollector().collect_credentials() assert not collected
assert expected == collected
def test_ssh_info_result_parsing(monkeypatch): def test_ssh_info_result_parsing(monkeypatch):
@ -48,14 +48,12 @@ def test_ssh_info_result_parsing(monkeypatch):
ssh_keypair1 = SSHKeypair( ssh_keypair1 = SSHKeypair(
{"public_key": "SomePublicKeyUbuntu", "private_key": "ExtremelyGoodPrivateKey"} {"public_key": "SomePublicKeyUbuntu", "private_key": "ExtremelyGoodPrivateKey"}
) )
ssh_keypair2 = SSHKeypair( ssh_keypair2 = SSHKeypair({"public_key": "AnotherPublicKey"})
{"public_key": "AnotherPublicKey", "private_key": "NotSoGoodPrivateKey"}
)
expected = [ expected = [
Credentials(identities=[username], secrets=[ssh_keypair1]), Credentials(identities=[username], secrets=[ssh_keypair1]),
Credentials(identities=[username2], secrets=[ssh_keypair2]), Credentials(identities=[username2], secrets=[ssh_keypair2]),
Credentials(identities=[username3], secrets=[]), Credentials(identities=[username3], secrets=[]),
] ]
collected = SSHCollector().collect_credentials() collected = SSHCredentialCollector().collect_credentials()
assert expected == collected assert expected == collected