From 068dbbe963d1cf8d3b22f68aba997f74b4bbb960 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Mon, 18 Jul 2022 09:14:51 -0400 Subject: [PATCH] Agent: Extract methods to clean up AggregatingCredentialsStore --- .../aggregating_credentials_store.py | 45 +++++++++---------- .../test_aggregating_credentials_store.py | 2 + 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/monkey/infection_monkey/credential_store/aggregating_credentials_store.py b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py index 23ad7ee19..bfa69117f 100644 --- a/monkey/infection_monkey/credential_store/aggregating_credentials_store.py +++ b/monkey/infection_monkey/credential_store/aggregating_credentials_store.py @@ -1,7 +1,7 @@ import logging from typing import Any, Iterable, Mapping -from common.credentials import CredentialComponentType, Credentials +from common.credentials import CredentialComponentType, Credentials, ICredentialComponent from infection_monkey.custom_types import PropagationCredentials from infection_monkey.i_control_channel import IControlChannel from infection_monkey.utils.decorators import request_cache @@ -26,31 +26,28 @@ class AggregatingCredentialsStore(ICredentialsStore): def add_credentials(self, credentials_to_add: Iterable[Credentials]): for credentials in credentials_to_add: - identity = credentials.identity - if identity and identity.credential_type is CredentialComponentType.USERNAME: - self._stored_credentials.setdefault("exploit_user_list", set()).add( - identity.username - ) + if credentials.identity: + self._add_identity(credentials.identity) - secret = credentials.secret + if credentials.secret: + self._add_secret(credentials.secret) - if secret.credential_type is CredentialComponentType.PASSWORD: - self._stored_credentials.setdefault("exploit_password_list", set()).add( - secret.password - ) - elif secret.credential_type is CredentialComponentType.LM_HASH: - self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add( - secret.lm_hash - ) - elif secret.credential_type is CredentialComponentType.NT_HASH: - self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add( - secret.nt_hash - ) - elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: - self._set_attribute( - "exploit_ssh_keys", - [{"public_key": secret.public_key, "private_key": secret.private_key}], - ) + def _add_identity(self, identity: ICredentialComponent): + if identity.credential_type is CredentialComponentType.USERNAME: + self._stored_credentials.setdefault("exploit_user_list", set()).add(identity.username) + + def _add_secret(self, secret: ICredentialComponent): + if secret.credential_type is CredentialComponentType.PASSWORD: + self._stored_credentials.setdefault("exploit_password_list", set()).add(secret.password) + elif secret.credential_type is CredentialComponentType.LM_HASH: + self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(secret.lm_hash) + elif secret.credential_type is CredentialComponentType.NT_HASH: + self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(secret.nt_hash) + elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: + self._set_attribute( + "exploit_ssh_keys", + [{"public_key": secret.public_key, "private_key": secret.private_key}], + ) def get_credentials(self) -> PropagationCredentials: try: diff --git a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py index 871f0db44..e5ddccfe2 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_credentials_store.py @@ -42,6 +42,7 @@ TEST_CREDENTIALS = [ identity=None, secret=Password("super_secret"), ), + Credentials(identity=Username("user4"), secret=None), ] SSH_KEYS_CREDENTIALS = [ @@ -88,6 +89,7 @@ def test_add_credentials_to_store(aggregating_credentials_store): "root", "user1", "user3", + "user4", ] ) assert actual_stored_credentials["exploit_password_list"] == set(