Agent, UT: Remove leftover that cause overwrite in CredentialsStore

* Use `add` instead of `update` - `add` doesn't let to have duplicates
* Move TestTelem to conftest in UT telemetry messenger
This commit is contained in:
Ilija Lazoroski 2022-03-29 19:34:19 +02:00
parent 763cf578c7
commit 6ab7bd2f45
5 changed files with 39 additions and 44 deletions

View File

@ -36,16 +36,16 @@ class AggregatingCredentialsStore(ICredentialsStore):
for secret in credentials.secrets: for secret in credentials.secrets:
if secret.credential_type is CredentialComponentType.PASSWORD: if secret.credential_type is CredentialComponentType.PASSWORD:
self._stored_credentials.setdefault("exploit_password_list", set()).update( self._stored_credentials.setdefault("exploit_password_list", set()).add(
[secret.password] secret.password
) )
elif secret.credential_type is CredentialComponentType.LM_HASH: elif secret.credential_type is CredentialComponentType.LM_HASH:
self._stored_credentials.setdefault("exploit_lm_hash_list", set()).update( self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(
[secret.lm_hash] secret.lm_hash
) )
elif secret.credential_type is CredentialComponentType.NT_HASH: elif secret.credential_type is CredentialComponentType.NT_HASH:
self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).update( self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(
[secret.nt_hash] secret.nt_hash
) )
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
self._set_attribute( self._set_attribute(
@ -78,7 +78,6 @@ class AggregatingCredentialsStore(ICredentialsStore):
return return
if isinstance(credentials_values[0], dict): if isinstance(credentials_values[0], dict):
self._stored_credentials[attribute_to_be_set] = []
self._stored_credentials.setdefault(attribute_to_be_set, []).extend(credentials_values) self._stored_credentials.setdefault(attribute_to_be_set, []).extend(credentials_values)
self._stored_credentials[attribute_to_be_set] = [ self._stored_credentials[attribute_to_be_set] = [
dict(s_c) dict(s_c)

View File

@ -21,19 +21,21 @@ CONTROL_CHANNEL_CREDENTIALS = {
], ],
} }
TEST_CREDENTIALS = [
PROPAGATION_CREDENTIALS = {
"exploit_user_list": ["user1", "user3"],
"exploit_password_list": ["abcdefg", "root"],
"exploit_ssh_keys": [{"public_key": "some_public_key", "private_key": "some_private_key"}],
}
CREDENTIALS_COLLECTION = [
Credentials( Credentials(
[Username("user1"), Username("user3")], [Username("user1"), Username("user3")],
[ [
Password("abcdefg"), Password("abcdefg"),
Password("root"), Password("root"),
SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"),
],
)
]
SSH_KEYS_CREDENTIALS = [
Credentials(
[Username("root")],
[
SSHKeypair(public_key="some_public_key", private_key="some_private_key"), SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
], ],
) )
@ -67,7 +69,8 @@ def test_get_credentials_from_store(aggregating_credentials_store):
def test_add_credentials_to_store(aggregating_credentials_store): def test_add_credentials_to_store(aggregating_credentials_store):
aggregating_credentials_store.add_credentials(CREDENTIALS_COLLECTION) aggregating_credentials_store.add_credentials(TEST_CREDENTIALS)
aggregating_credentials_store.add_credentials(SSH_KEYS_CREDENTIALS)
actual_stored_credentials = aggregating_credentials_store.get_credentials() actual_stored_credentials = aggregating_credentials_store.get_credentials()
@ -89,5 +92,4 @@ def test_add_credentials_to_store(aggregating_credentials_store):
] ]
) )
for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]: assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3
assert ssh_keypair in CONTROL_CHANNEL_CREDENTIALS["exploit_ssh_keys"]

View File

@ -0,0 +1,18 @@
import pytest
from infection_monkey.telemetry.base_telem import BaseTelem
@pytest.fixture(scope="package")
def TestTelem():
class InnerTestTelem(BaseTelem):
telem_category = None
__test__ = False
def __init__(self):
pass
def get_data(self):
return {}
return InnerTestTelem

View File

@ -2,7 +2,6 @@ from unittest.mock import MagicMock
from infection_monkey.credential_collectors import Password, SSHKeypair, Username from infection_monkey.credential_collectors import Password, SSHKeypair, Username
from infection_monkey.i_puppet import Credentials from infection_monkey.i_puppet import Credentials
from infection_monkey.telemetry.base_telem import BaseTelem
from infection_monkey.telemetry.credentials_telem import CredentialsTelem from infection_monkey.telemetry.credentials_telem import CredentialsTelem
from infection_monkey.telemetry.messengers.credentials_intercepting_telemetry_messenger import ( from infection_monkey.telemetry.messengers.credentials_intercepting_telemetry_messenger import (
CredentialsInterceptingTelemetryMessenger, CredentialsInterceptingTelemetryMessenger,
@ -20,17 +19,6 @@ TELEM_CREDENTIALS = [
] ]
class TestTelem(BaseTelem):
telem_category = None
__test__ = False
def __init__(self):
pass
def get_data(self):
return {}
class MockCredentialsTelem(CredentialsTelem): class MockCredentialsTelem(CredentialsTelem):
def __init(self, credentials): def __init(self, credentials):
super().__init__(credentials) super().__init__(credentials)
@ -39,7 +27,7 @@ class MockCredentialsTelem(CredentialsTelem):
return {} return {}
def test_credentials_generic_telemetry(): def test_credentials_generic_telemetry(TestTelem):
mock_telemetry_messenger = MagicMock() mock_telemetry_messenger = MagicMock()
mock_credentials_store = MagicMock() mock_credentials_store = MagicMock()

View File

@ -2,24 +2,12 @@ from unittest.mock import MagicMock
from infection_monkey.i_puppet.i_puppet import ExploiterResultData from infection_monkey.i_puppet.i_puppet import ExploiterResultData
from infection_monkey.model.host import VictimHost from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.base_telem import BaseTelem
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.messengers.exploit_intercepting_telemetry_messenger import ( from infection_monkey.telemetry.messengers.exploit_intercepting_telemetry_messenger import (
ExploitInterceptingTelemetryMessenger, ExploitInterceptingTelemetryMessenger,
) )
class TestTelem(BaseTelem):
telem_category = None
__test__ = False
def __init__(self):
pass
def get_data(self):
return {}
class MockExpliotTelem(ExploitTelem): class MockExpliotTelem(ExploitTelem):
def __init__(self, propagation_success): def __init__(self, propagation_success):
erd = ExploiterResultData() erd = ExploiterResultData()
@ -30,7 +18,7 @@ class MockExpliotTelem(ExploitTelem):
return {} return {}
def test_generic_telemetry(): def test_generic_telemetry(TestTelem):
mock_telemetry_messenger = MagicMock() mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock() mock_tunnel = MagicMock()