diff --git a/monkey/infection_monkey/credential_store/__init__.py b/monkey/infection_monkey/credential_store/__init__.py index 3b3f4475f..e05ce3160 100644 --- a/monkey/infection_monkey/credential_store/__init__.py +++ b/monkey/infection_monkey/credential_store/__init__.py @@ -1,2 +1,2 @@ from .i_credentials_store import ICredentialsStore -from .credentials_store import CredentialsStore +from .aggregating_credentials_store import AggregatingCredentialsStore diff --git a/monkey/infection_monkey/credential_store/aggregating_credentials_store.py b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py new file mode 100644 index 000000000..d855b98dd --- /dev/null +++ b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py @@ -0,0 +1,68 @@ +import logging +from typing import Iterable, Mapping + +from common.common_consts.credential_component_type import CredentialComponentType +from infection_monkey.i_control_channel import IControlChannel +from infection_monkey.i_puppet import Credentials + +from .i_credentials_store import ICredentialsStore + +logger = logging.getLogger(__name__) + + +class AggregatingCredentialsStore(ICredentialsStore): + def __init__(self, control_channel: IControlChannel): + self.stored_credentials = {} + self._control_channel = control_channel + + def add_credentials(self, credentials_to_add: Iterable[Credentials]) -> None: + for credentials in credentials_to_add: + usernames = [ + identity.username + for identity in credentials.identities + if identity.credential_type is CredentialComponentType.USERNAME + ] + self._set_attribute("exploit_user_list", usernames) + + for secret in credentials.secrets: + if secret.credential_type is CredentialComponentType.PASSWORD: + self._set_attribute("exploit_password_list", [secret.password]) + elif secret.credential_type is CredentialComponentType.LM_HASH: + self._set_attribute("exploit_lm_hash_list", [secret.lm_hash]) + elif secret.credential_type is CredentialComponentType.NT_HASH: + self._set_attribute("exploit_ntlm_hash_list", [secret.nt_hash]) + elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: + self._set_attribute( + "exploit_ssh_keys", + [{"public_key": secret.public_key, "private_key": secret.private_key}], + ) + + def get_credentials(self): + try: + propagation_credentials = self._control_channel.get_credentials_for_propagation() + self._aggregate_credentials(propagation_credentials) + except Exception as ex: + self.stored_credentials = {} + logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}") + + def _aggregate_credentials(self, credentials_to_aggr: Mapping): + for cred_attr, credentials_values in credentials_to_aggr.items(): + if credentials_values: + self._set_attribute(cred_attr, credentials_values) + + def _set_attribute(self, attribute_to_be_set, credentials_values): + if attribute_to_be_set not in self.stored_credentials: + self.stored_credentials[attribute_to_be_set] = [] + + if isinstance(credentials_values[0], dict): + self.stored_credentials.setdefault(attribute_to_be_set, []).extend(credentials_values) + self.stored_credentials[attribute_to_be_set] = [ + dict(s_c) + for s_c in set( + frozenset(d_c.items()) for d_c in self.stored_credentials[attribute_to_be_set] + ) + ] + else: + self.stored_credentials[attribute_to_be_set] = sorted( + list(set(self.stored_credentials[attribute_to_be_set]).union(credentials_values)) + ) diff --git a/monkey/infection_monkey/credential_store/credentials_store.py b/monkey/infection_monkey/credential_store/credentials_store.py deleted file mode 100644 index a0500804d..000000000 --- a/monkey/infection_monkey/credential_store/credentials_store.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Mapping - -from .i_credentials_store import ICredentialsStore - - -class CredentialsStore(ICredentialsStore): - def __init__(self, credentials: Mapping = None): - self.stored_credentials = credentials - - def add_credentials(self, credentials_to_add: Mapping) -> None: - if self.stored_credentials is None: - self.stored_credentials = {} - - for key, value in credentials_to_add.items(): - if key not in self.stored_credentials: - self.stored_credentials[key] = [] - - if key != "exploit_ssh_keys": - self.stored_credentials[key] = list( - sorted(set(self.stored_credentials[key]).union(credentials_to_add[key])) - ) - else: - self.stored_credentials[key] += credentials_to_add[key] - self.stored_credentials[key] = [ - dict(s) for s in set(frozenset(d.items()) for d in self.stored_credentials[key]) - ] - - def get_credentials(self) -> Mapping: - return self.stored_credentials diff --git a/monkey/infection_monkey/credential_store/i_credentials_store.py b/monkey/infection_monkey/credential_store/i_credentials_store.py index 7730c99d2..2ac10192b 100644 --- a/monkey/infection_monkey/credential_store/i_credentials_store.py +++ b/monkey/infection_monkey/credential_store/i_credentials_store.py @@ -1,19 +1,20 @@ import abc -from typing import Mapping +from typing import Iterable + +from infection_monkey.i_puppet import Credentials class ICredentialsStore(metaclass=abc.ABCMeta): @abc.abstractmethod - def add_credentials(self, credentials_to_add: Mapping = {}) -> None: - """ + def add_credentials(self, credentials_to_add: Iterable[Credentials]) -> None: + """a Method that adds credentials to the CredentialStore - :param Credentials credentials: The credentials which will be added + :param Credentials credentials: The credentials that will be added """ @abc.abstractmethod - def get_credentials(self) -> Mapping: + def get_credentials(self) -> None: """ - Method that gets credentials from the ControlChannel - :return: A squence of Credentials that have been added for propagation - :rtype: Mapping + Method that retrieves credentials from the store + :return: Credentials that can be used for propagation """ diff --git a/monkey/tests/unit_tests/infection_monkey/credential_store/test_credential_store.py b/monkey/tests/unit_tests/infection_monkey/credential_store/test_credential_store.py index 83382dc3e..1035de4d0 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_store/test_credential_store.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_store/test_credential_store.py @@ -2,66 +2,83 @@ from unittest.mock import MagicMock import pytest +from infection_monkey.credential_collectors import Password, SSHKeypair, Username from infection_monkey.credential_store import AggregatingCredentialsStore +from infection_monkey.i_puppet import Credentials DEFAULT_CREDENTIALS = { "exploit_user_list": ["Administrator", "root", "user1"], - "exploit_password_list": [ - "root", - "123456", - "password", - "123456789", - ], + "exploit_password_list": ["123456", "123456789", "password", "root"], "exploit_lm_hash_list": ["aasdf23asd1fdaasadasdfas"], - "exploit_ntlm_hash_list": ["qw4trklxklvznksbhasd1231", "asdfadvxvsdftw3e3421234123412"], + "exploit_ntlm_hash_list": ["asdfadvxvsdftw3e3421234123412", "qw4trklxklvznksbhasd1231"], "exploit_ssh_keys": [ + {"public_key": "some_public_key", "private_key": "some_private_key"}, { "public_key": "ssh-ed25519 AAAAC3NzEIFaJ7xH+Yoxd\n", "private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BdHIAAAAGYXjl0j66VAKruPEKjS3A=\n" "-----END OPENSSH PRIVATE KEY-----\n", - "user": "ubuntu", - "ip": "10.0.3.15", }, - {"public_key": "some_public_key", "private_key": "some_private_key"}, ], } -SAMPLE_CREDENTIALS = { +PROPAGATION_CREDENTIALS = { "exploit_user_list": ["user1", "user3"], "exploit_password_list": ["abcdefg", "root"], "exploit_ssh_keys": [{"public_key": "some_public_key", "private_key": "some_private_key"}], - "exploit_ntlm_hash_list": [], } +TELEM_CREDENTIALS = [ + Credentials( + [Username("user1"), Username("user3")], + [ + Password("abcdefg"), + Password("root"), + SSHKeypair(public_key="some_public_key", private_key="some_private_key"), + ], + ) +] + @pytest.fixture def aggregating_credentials_store() -> AggregatingCredentialsStore: - return AggregatingCredentialsStore() + control_channel = MagicMock() + control_channel.get_credentials_for_propagation.return_value = DEFAULT_CREDENTIALS + return AggregatingCredentialsStore(control_channel) -@pytest.mark.parametrize("credentials_to_store", [DEFAULT_CREDENTIALS, SAMPLE_CREDENTIALS]) -def test_get_credentials_from_store(aggregating_credentials_store, credentials_to_store): - get_updated_credentials_for_propagation = MagicMock(return_value=credentials_to_store) +def test_get_credentials_from_store(aggregating_credentials_store): + aggregating_credentials_store.get_credentials() - aggregating_credentials_store.get_credentials(get_updated_credentials_for_propagation) + actual_stored_credentials = aggregating_credentials_store.stored_credentials - assert aggregating_credentials_store.stored_credentials == credentials_to_store + assert ( + actual_stored_credentials["exploit_user_list"] == DEFAULT_CREDENTIALS["exploit_user_list"] + ) + assert ( + actual_stored_credentials["exploit_password_list"] + == DEFAULT_CREDENTIALS["exploit_password_list"] + ) + assert ( + actual_stored_credentials["exploit_ntlm_hash_list"] + == DEFAULT_CREDENTIALS["exploit_ntlm_hash_list"] + ) + + for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]: + assert ssh_keypair in DEFAULT_CREDENTIALS["exploit_ssh_keys"] def test_add_credentials_to_empty_store(aggregating_credentials_store): + aggregating_credentials_store.add_credentials(TELEM_CREDENTIALS) - aggregating_credentials_store.add_credentials(SAMPLE_CREDENTIALS) - - assert aggregating_credentials_store.stored_credentials == SAMPLE_CREDENTIALS + assert aggregating_credentials_store.stored_credentials == PROPAGATION_CREDENTIALS def test_add_credentials_to_full_store(aggregating_credentials_store): - get_updated_credentials_for_propagation = MagicMock(return_value=DEFAULT_CREDENTIALS) - aggregating_credentials_store.get_credentials(get_updated_credentials_for_propagation) + aggregating_credentials_store.get_credentials() - aggregating_credentials_store.add_credentials(SAMPLE_CREDENTIALS) + aggregating_credentials_store.add_credentials(TELEM_CREDENTIALS) actual_stored_credentials = aggregating_credentials_store.stored_credentials @@ -78,4 +95,6 @@ def test_add_credentials_to_full_store(aggregating_credentials_store): "password", "root", ] - assert actual_stored_credentials["exploit_ssh_keys"] == DEFAULT_CREDENTIALS["exploit_ssh_keys"] + + for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]: + assert ssh_keypair in DEFAULT_CREDENTIALS["exploit_ssh_keys"]