UT: Extract function collect_credentials() to reduce code duplication

This commit is contained in:
Mike Salvatore 2022-02-15 13:30:13 -05:00
parent 86f2c7b08c
commit 236b545816
1 changed files with 11 additions and 5 deletions

View File

@ -1,3 +1,5 @@
from typing import List
import pytest import pytest
from infection_monkey.credential_collectors import Credentials, LMHash, NTHash, Password, Username from infection_monkey.credential_collectors import Credentials, LMHash, NTHash, Password, Username
@ -17,12 +19,16 @@ def patch_pypykatz(win_creds: [WindowsCredentials], monkeypatch):
) )
def collect_credentials() -> List[Credentials]:
return list(MimikatzCredentialCollector().collect_credentials())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"win_creds", [([WindowsCredentials(username="", password="", ntlm_hash="", lm_hash="")]), ([])] "win_creds", [([WindowsCredentials(username="", password="", ntlm_hash="", lm_hash="")]), ([])]
) )
def test_empty_results(monkeypatch, win_creds): def test_empty_results(monkeypatch, win_creds):
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
collected_credentials = list(MimikatzCredentialCollector().collect_credentials()) collected_credentials = collect_credentials()
assert not collected_credentials assert not collected_credentials
@ -34,7 +40,7 @@ def test_pypykatz_result_parsing(monkeypatch):
password = Password("secret") password = Password("secret")
expected_credentials = Credentials([username], [password]) expected_credentials = Credentials([username], [password])
collected_credentials = list(MimikatzCredentialCollector().collect_credentials()) collected_credentials = collect_credentials()
assert len(collected_credentials) == 1 assert len(collected_credentials) == 1
assert collected_credentials[0] == expected_credentials assert collected_credentials[0] == expected_credentials
@ -46,7 +52,7 @@ def test_pypykatz_result_parsing_duplicates(monkeypatch):
] ]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
collected_credentials = list(MimikatzCredentialCollector().collect_credentials()) collected_credentials = collect_credentials()
assert len(collected_credentials) == 2 assert len(collected_credentials) == 2
@ -62,7 +68,7 @@ def test_pypykatz_result_parsing_defaults(monkeypatch):
lm_hash = LMHash("lm_hash") lm_hash = LMHash("lm_hash")
expected_credentials = Credentials([username], [password, lm_hash]) expected_credentials = Credentials([username], [password, lm_hash])
collected_credentials = list(MimikatzCredentialCollector().collect_credentials()) collected_credentials = collect_credentials()
assert len(collected_credentials) == 1 assert len(collected_credentials) == 1
assert collected_credentials[0] == expected_credentials assert collected_credentials[0] == expected_credentials
@ -77,6 +83,6 @@ def test_pypykatz_result_parsing_no_identities(monkeypatch):
nt_hash = NTHash("ntlm_hash") nt_hash = NTHash("ntlm_hash")
expected_credentials = Credentials([], [lm_hash, nt_hash]) expected_credentials = Credentials([], [lm_hash, nt_hash])
collected_credentials = list(MimikatzCredentialCollector().collect_credentials()) collected_credentials = collect_credentials()
assert len(collected_credentials) == 1 assert len(collected_credentials) == 1
assert collected_credentials[0] == expected_credentials assert collected_credentials[0] == expected_credentials