From 6ab7bd2f45a27748646a93a7c79023e1967c662b Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Tue, 29 Mar 2022 19:34:19 +0200 Subject: [PATCH] Agent, UT: Remove leftover that cause overwrite in CredentialsStore * Use `add` instead of `update` - `add` doesn't let to have duplicates * Move TestTelem to conftest in UT telemetry messenger --- .../aggregating_credentials_store.py | 13 +++++----- .../test_aggregating_credentials_store.py | 24 ++++++++++--------- .../telemetry/messengers/conftest.py | 18 ++++++++++++++ ...ntials_intercepting_telemetry_messenger.py | 14 +---------- ...xploit_intercepting_telemetry_messenger.py | 14 +---------- 5 files changed, 39 insertions(+), 44 deletions(-) create mode 100644 monkey/tests/unit_tests/infection_monkey/telemetry/messengers/conftest.py diff --git a/monkey/infection_monkey/credential_store/aggregating_credentials_store.py b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py index 696b87bd5..eef6f72fc 100644 --- a/monkey/infection_monkey/credential_store/aggregating_credentials_store.py +++ b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py @@ -36,16 +36,16 @@ class AggregatingCredentialsStore(ICredentialsStore): for secret in credentials.secrets: if secret.credential_type is CredentialComponentType.PASSWORD: - self._stored_credentials.setdefault("exploit_password_list", set()).update( - [secret.password] + self._stored_credentials.setdefault("exploit_password_list", set()).add( + secret.password ) elif secret.credential_type is CredentialComponentType.LM_HASH: - self._stored_credentials.setdefault("exploit_lm_hash_list", set()).update( - [secret.lm_hash] + self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add( + secret.lm_hash ) elif secret.credential_type is CredentialComponentType.NT_HASH: - self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).update( - [secret.nt_hash] + self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add( + secret.nt_hash ) elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: self._set_attribute( @@ -78,7 +78,6 @@ class AggregatingCredentialsStore(ICredentialsStore): return if isinstance(credentials_values[0], dict): - self._stored_credentials[attribute_to_be_set] = [] self._stored_credentials.setdefault(attribute_to_be_set, []).extend(credentials_values) self._stored_credentials[attribute_to_be_set] = [ dict(s_c) diff --git a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py index ff25dc02f..0b6fd8545 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py @@ -21,19 +21,21 @@ CONTROL_CHANNEL_CREDENTIALS = { ], } - -PROPAGATION_CREDENTIALS = { - "exploit_user_list": ["user1", "user3"], - "exploit_password_list": ["abcdefg", "root"], - "exploit_ssh_keys": [{"public_key": "some_public_key", "private_key": "some_private_key"}], -} - -CREDENTIALS_COLLECTION = [ +TEST_CREDENTIALS = [ Credentials( [Username("user1"), Username("user3")], [ Password("abcdefg"), Password("root"), + SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"), + ], + ) +] + +SSH_KEYS_CREDENTIALS = [ + Credentials( + [Username("root")], + [ SSHKeypair(public_key="some_public_key", private_key="some_private_key"), ], ) @@ -67,7 +69,8 @@ def test_get_credentials_from_store(aggregating_credentials_store): def test_add_credentials_to_store(aggregating_credentials_store): - aggregating_credentials_store.add_credentials(CREDENTIALS_COLLECTION) + aggregating_credentials_store.add_credentials(TEST_CREDENTIALS) + aggregating_credentials_store.add_credentials(SSH_KEYS_CREDENTIALS) actual_stored_credentials = aggregating_credentials_store.get_credentials() @@ -89,5 +92,4 @@ def test_add_credentials_to_store(aggregating_credentials_store): ] ) - for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]: - assert ssh_keypair in CONTROL_CHANNEL_CREDENTIALS["exploit_ssh_keys"] + assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3 diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/conftest.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/conftest.py new file mode 100644 index 000000000..c29555262 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/conftest.py @@ -0,0 +1,18 @@ +import pytest + +from infection_monkey.telemetry.base_telem import BaseTelem + + +@pytest.fixture(scope="package") +def TestTelem(): + class InnerTestTelem(BaseTelem): + telem_category = None + __test__ = False + + def __init__(self): + pass + + def get_data(self): + return {} + + return InnerTestTelem diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_credentials_intercepting_telemetry_messenger.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_credentials_intercepting_telemetry_messenger.py index 214f001b6..7481bff34 100644 --- a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_credentials_intercepting_telemetry_messenger.py +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_credentials_intercepting_telemetry_messenger.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock from infection_monkey.credential_collectors import Password, SSHKeypair, Username from infection_monkey.i_puppet import Credentials -from infection_monkey.telemetry.base_telem import BaseTelem from infection_monkey.telemetry.credentials_telem import CredentialsTelem from infection_monkey.telemetry.messengers.credentials_intercepting_telemetry_messenger import ( CredentialsInterceptingTelemetryMessenger, @@ -20,17 +19,6 @@ TELEM_CREDENTIALS = [ ] -class TestTelem(BaseTelem): - telem_category = None - __test__ = False - - def __init__(self): - pass - - def get_data(self): - return {} - - class MockCredentialsTelem(CredentialsTelem): def __init(self, credentials): super().__init__(credentials) @@ -39,7 +27,7 @@ class MockCredentialsTelem(CredentialsTelem): return {} -def test_credentials_generic_telemetry(): +def test_credentials_generic_telemetry(TestTelem): mock_telemetry_messenger = MagicMock() mock_credentials_store = MagicMock() diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py index b07ea4a1d..969489107 100644 --- a/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/messengers/test_exploit_intercepting_telemetry_messenger.py @@ -2,24 +2,12 @@ from unittest.mock import MagicMock from infection_monkey.i_puppet.i_puppet import ExploiterResultData from infection_monkey.model.host import VictimHost -from infection_monkey.telemetry.base_telem import BaseTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.messengers.exploit_intercepting_telemetry_messenger import ( ExploitInterceptingTelemetryMessenger, ) -class TestTelem(BaseTelem): - telem_category = None - __test__ = False - - def __init__(self): - pass - - def get_data(self): - return {} - - class MockExpliotTelem(ExploitTelem): def __init__(self, propagation_success): erd = ExploiterResultData() @@ -30,7 +18,7 @@ class MockExpliotTelem(ExploitTelem): return {} -def test_generic_telemetry(): +def test_generic_telemetry(TestTelem): mock_telemetry_messenger = MagicMock() mock_tunnel = MagicMock()