Agent: Refactor credentials and credential_components as dataclasses

Using frozen dataclasses for Credentials and ICredentialComponents
automatically creates a useful __eq__() function that allows us to
easily compare credentials-related objects.
This commit is contained in:
Mike Salvatore 2022-02-15 12:23:57 -05:00
parent 811434ff22
commit ebd5642b52
6 changed files with 46 additions and 44 deletions

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from ..credential_type import CredentialType from ..credential_type import CredentialType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
@dataclass(frozen=True)
class LMHash(ICredentialComponent): class LMHash(ICredentialComponent):
type = CredentialType.LM_HASH type: CredentialType = field(default=CredentialType.LM_HASH, init=False)
lm_hash: str
def __init__(self, lm_hash: str):
self.lm_hash = lm_hash

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from ..credential_type import CredentialType from ..credential_type import CredentialType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
@dataclass(frozen=True)
class NTHash(ICredentialComponent): class NTHash(ICredentialComponent):
type = CredentialType.NT_HASH type: CredentialType = field(default=CredentialType.NT_HASH, init=False)
nt_hash: str
def __init__(self, nt_hash: str):
self.nt_hash = nt_hash

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from ..credential_type import CredentialType from ..credential_type import CredentialType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
@dataclass(frozen=True)
class Password(ICredentialComponent): class Password(ICredentialComponent):
type = CredentialType.PASSWORD type: CredentialType = field(default=CredentialType.PASSWORD, init=False)
password: str
def __init__(self, password: str):
self.password = password

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from ..credential_type import CredentialType from ..credential_type import CredentialType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
@dataclass(frozen=True)
class Username(ICredentialComponent): class Username(ICredentialComponent):
type = CredentialType.USERNAME type: CredentialType = field(default=CredentialType.USERNAME, init=False)
username: str
def __init__(self, username: str):
self.username = username

View File

@ -1,11 +1,10 @@
from typing import Iterable from dataclasses import dataclass
from typing import Tuple
from .credential_components.i_credential_component import ICredentialComponent from .credential_components.i_credential_component import ICredentialComponent
@dataclass(frozen=True)
class Credentials: class Credentials:
def __init__( identities: Tuple[ICredentialComponent]
self, identities: Iterable[ICredentialComponent], secrets: Iterable[ICredentialComponent] secrets: Tuple[ICredentialComponent]
):
self.identities = tuple(identities)
self.secrets = tuple(secrets)

View File

@ -1,4 +1,4 @@
from infection_monkey.credential_collectors import LMHash, NTHash, Password, Username from infection_monkey.credential_collectors import Credentials, LMHash, NTHash, Password, Username
from infection_monkey.credential_collectors.mimikatz_collector.mimikatz_cred_collector import ( from infection_monkey.credential_collectors.mimikatz_collector.mimikatz_cred_collector import (
MimikatzCredentialCollector, MimikatzCredentialCollector,
) )
@ -18,27 +18,26 @@ def patch_pypykatz(win_creds: [WindowsCredentials], monkeypatch):
def test_empty_results(monkeypatch): def test_empty_results(monkeypatch):
win_creds = [WindowsCredentials(username="", password="", ntlm_hash="", lm_hash="")] win_creds = [WindowsCredentials(username="", password="", ntlm_hash="", lm_hash="")]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
expected = [] expected_credentials = []
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = MimikatzCredentialCollector().collect_credentials()
assert expected == collected assert expected_credentials == collected_credentials
patch_pypykatz([], monkeypatch) patch_pypykatz([], monkeypatch)
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = MimikatzCredentialCollector().collect_credentials()
assert [] == collected assert not collected_credentials
def test_pypykatz_result_parsing(monkeypatch): def test_pypykatz_result_parsing(monkeypatch):
win_creds = [WindowsCredentials(username="user", password="secret", ntlm_hash="", lm_hash="")] win_creds = [WindowsCredentials(username="user", password="secret", ntlm_hash="", lm_hash="")]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
# Expected credentials
username = Username("user") username = Username("user")
password = Password("secret") password = Password("secret")
expected_credentials = Credentials([username], [password])
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = list(MimikatzCredentialCollector().collect_credentials())
assert len(list(collected)) == 1 assert len(collected_credentials) == 1
assert list(collected)[0].identities[0].__dict__ == username.__dict__ assert collected_credentials[0] == expected_credentials
assert list(collected)[0].secrets[0].__dict__ == password.__dict__
def test_pypykatz_result_parsing_duplicates(monkeypatch): def test_pypykatz_result_parsing_duplicates(monkeypatch):
@ -48,8 +47,8 @@ def test_pypykatz_result_parsing_duplicates(monkeypatch):
] ]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = list(MimikatzCredentialCollector().collect_credentials())
assert len(list(collected)) == 2 assert len(collected_credentials) == 2
def test_pypykatz_result_parsing_defaults(monkeypatch): def test_pypykatz_result_parsing_defaults(monkeypatch):
@ -62,11 +61,11 @@ def test_pypykatz_result_parsing_defaults(monkeypatch):
username = Username("user2") username = Username("user2")
password = Password("secret2") password = Password("secret2")
lm_hash = LMHash("lm_hash") lm_hash = LMHash("lm_hash")
expected_credentials = Credentials([username], [password, lm_hash])
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = list(MimikatzCredentialCollector().collect_credentials())
assert list(collected)[0].identities[0].__dict__ == username.__dict__ assert len(collected_credentials) == 1
assert list(collected)[0].secrets[0].__dict__ == password.__dict__ assert collected_credentials[0] == expected_credentials
assert list(collected)[0].secrets[1].__dict__ == lm_hash.__dict__
def test_pypykatz_result_parsing_no_identities(monkeypatch): def test_pypykatz_result_parsing_no_identities(monkeypatch):
@ -75,10 +74,10 @@ def test_pypykatz_result_parsing_no_identities(monkeypatch):
] ]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
# Expected credentials
nt_hash = NTHash("ntlm_hash")
lm_hash = LMHash("lm_hash") lm_hash = LMHash("lm_hash")
nt_hash = NTHash("ntlm_hash")
expected_credentials = Credentials([], [lm_hash, nt_hash])
collected = MimikatzCredentialCollector().collect_credentials() collected_credentials = list(MimikatzCredentialCollector().collect_credentials())
assert list(collected)[0].secrets[0].__dict__ == lm_hash.__dict__ assert len(collected_credentials) == 1
assert list(collected)[0].secrets[1].__dict__ == nt_hash.__dict__ assert collected_credentials[0] == expected_credentials