Agent: Rework ssh credential collector to match credential architecture
* Parametrize empty result unit test * Apply small changes to ssh credential collector
This commit is contained in:
parent
a97b8706ec
commit
63d632d142
|
@ -2,4 +2,5 @@ from .credential_components.nt_hash import NTHash
|
|||
from .credential_components.lm_hash import LMHash
|
||||
from .credential_components.password import Password
|
||||
from .credential_components.username import Username
|
||||
from .credential_components.ssh_keypair import SSHKeypair
|
||||
from .mimikatz_collector import MimikatzCredentialCollector
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .SSH_credentials_collector import SSHCollector
|
||||
from .ssh_credential_collector import SSHCredentialCollector
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import logging
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
from infection_monkey.credential_collectors import (
|
||||
SSHKeypair,
|
||||
Username,
|
||||
)
|
||||
from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector
|
||||
from infection_monkey.credential_collectors import SSHKeypair, Username
|
||||
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
|
||||
from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -26,17 +23,17 @@ class SSHCredentialCollector(ICredentialCollector):
|
|||
@staticmethod
|
||||
def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]:
|
||||
ssh_credentials = []
|
||||
identities = []
|
||||
secrets = []
|
||||
|
||||
for info in ssh_info:
|
||||
identities = []
|
||||
secrets = []
|
||||
|
||||
if "name" in info and info["name"] != "":
|
||||
if info.get("name", ""):
|
||||
identities.append(Username(info["name"]))
|
||||
|
||||
ssh_keypair = {}
|
||||
for key in ["public_key", "private_key"]:
|
||||
if key in info and info.get(key) is not None:
|
||||
if info.get(key) is not None:
|
||||
ssh_keypair[key] = info[key]
|
||||
|
||||
if len(ssh_keypair):
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from infection_monkey.credential_collectors import Credentials, SSHKeypair, Username
|
||||
from infection_monkey.credential_collectors.ssh_collector import SSHCollector
|
||||
import pytest
|
||||
|
||||
from infection_monkey.credential_collectors import SSHKeypair, Username
|
||||
from infection_monkey.credential_collectors.ssh_collector import SSHCredentialCollector
|
||||
from infection_monkey.i_puppet.credential_collection import Credentials
|
||||
|
||||
|
||||
def patch_ssh_handler(ssh_creds, monkeypatch):
|
||||
|
@ -9,16 +12,13 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
def test_ssh_credentials_empty_results(monkeypatch):
|
||||
patch_ssh_handler([], monkeypatch)
|
||||
collected = SSHCollector().collect_credentials()
|
||||
assert [] == collected
|
||||
|
||||
ssh_creds = [{"name": "", "home_dir": "", "public_key": None, "private_key": None}]
|
||||
@pytest.mark.parametrize(
|
||||
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
|
||||
)
|
||||
def test_ssh_credentials_empty_results(monkeypatch, ssh_creds):
|
||||
patch_ssh_handler(ssh_creds, monkeypatch)
|
||||
expected = []
|
||||
collected = SSHCollector().collect_credentials()
|
||||
assert expected == collected
|
||||
collected = SSHCredentialCollector().collect_credentials()
|
||||
assert not collected
|
||||
|
||||
|
||||
def test_ssh_info_result_parsing(monkeypatch):
|
||||
|
@ -48,14 +48,12 @@ def test_ssh_info_result_parsing(monkeypatch):
|
|||
ssh_keypair1 = SSHKeypair(
|
||||
{"public_key": "SomePublicKeyUbuntu", "private_key": "ExtremelyGoodPrivateKey"}
|
||||
)
|
||||
ssh_keypair2 = SSHKeypair(
|
||||
{"public_key": "AnotherPublicKey", "private_key": "NotSoGoodPrivateKey"}
|
||||
)
|
||||
ssh_keypair2 = SSHKeypair({"public_key": "AnotherPublicKey"})
|
||||
|
||||
expected = [
|
||||
Credentials(identities=[username], secrets=[ssh_keypair1]),
|
||||
Credentials(identities=[username2], secrets=[ssh_keypair2]),
|
||||
Credentials(identities=[username3], secrets=[]),
|
||||
]
|
||||
collected = SSHCollector().collect_credentials()
|
||||
collected = SSHCredentialCollector().collect_credentials()
|
||||
assert expected == collected
|
||||
|
|
Loading…
Reference in New Issue