Agent: Extract methods to clean up AggregatingCredentialsStore

This commit is contained in:
Mike Salvatore 2022-07-18 09:14:51 -04:00
parent 7c920cced3
commit 068dbbe963
2 changed files with 23 additions and 24 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Iterable, Mapping from typing import Any, Iterable, Mapping
from common.credentials import CredentialComponentType, Credentials from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.utils.decorators import request_cache from infection_monkey.utils.decorators import request_cache
@ -26,26 +26,23 @@ class AggregatingCredentialsStore(ICredentialsStore):
def add_credentials(self, credentials_to_add: Iterable[Credentials]): def add_credentials(self, credentials_to_add: Iterable[Credentials]):
for credentials in credentials_to_add: for credentials in credentials_to_add:
identity = credentials.identity if credentials.identity:
if identity and identity.credential_type is CredentialComponentType.USERNAME: self._add_identity(credentials.identity)
self._stored_credentials.setdefault("exploit_user_list", set()).add(
identity.username
)
secret = credentials.secret if credentials.secret:
self._add_secret(credentials.secret)
def _add_identity(self, identity: ICredentialComponent):
if identity.credential_type is CredentialComponentType.USERNAME:
self._stored_credentials.setdefault("exploit_user_list", set()).add(identity.username)
def _add_secret(self, secret: ICredentialComponent):
if secret.credential_type is CredentialComponentType.PASSWORD: if secret.credential_type is CredentialComponentType.PASSWORD:
self._stored_credentials.setdefault("exploit_password_list", set()).add( self._stored_credentials.setdefault("exploit_password_list", set()).add(secret.password)
secret.password
)
elif secret.credential_type is CredentialComponentType.LM_HASH: elif secret.credential_type is CredentialComponentType.LM_HASH:
self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add( self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(secret.lm_hash)
secret.lm_hash
)
elif secret.credential_type is CredentialComponentType.NT_HASH: elif secret.credential_type is CredentialComponentType.NT_HASH:
self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add( self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(secret.nt_hash)
secret.nt_hash
)
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
self._set_attribute( self._set_attribute(
"exploit_ssh_keys", "exploit_ssh_keys",

View File

@ -42,6 +42,7 @@ TEST_CREDENTIALS = [
identity=None, identity=None,
secret=Password("super_secret"), secret=Password("super_secret"),
), ),
Credentials(identity=Username("user4"), secret=None),
] ]
SSH_KEYS_CREDENTIALS = [ SSH_KEYS_CREDENTIALS = [
@ -88,6 +89,7 @@ def test_add_credentials_to_store(aggregating_credentials_store):
"root", "root",
"user1", "user1",
"user3", "user3",
"user4",
] ]
) )
assert actual_stored_credentials["exploit_password_list"] == set( assert actual_stored_credentials["exploit_password_list"] == set(