Agent: Extract methods to clean up AggregatingCredentialsStore

This commit is contained in:
Mike Salvatore 2022-07-18 09:14:51 -04:00
parent 7c920cced3
commit 068dbbe963
2 changed files with 23 additions and 24 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Iterable, Mapping from typing import Any, Iterable, Mapping
from common.credentials import CredentialComponentType, Credentials from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.utils.decorators import request_cache from infection_monkey.utils.decorators import request_cache
@ -26,31 +26,28 @@ class AggregatingCredentialsStore(ICredentialsStore):
def add_credentials(self, credentials_to_add: Iterable[Credentials]): def add_credentials(self, credentials_to_add: Iterable[Credentials]):
for credentials in credentials_to_add: for credentials in credentials_to_add:
identity = credentials.identity if credentials.identity:
if identity and identity.credential_type is CredentialComponentType.USERNAME: self._add_identity(credentials.identity)
self._stored_credentials.setdefault("exploit_user_list", set()).add(
identity.username
)
secret = credentials.secret if credentials.secret:
self._add_secret(credentials.secret)
if secret.credential_type is CredentialComponentType.PASSWORD: def _add_identity(self, identity: ICredentialComponent):
self._stored_credentials.setdefault("exploit_password_list", set()).add( if identity.credential_type is CredentialComponentType.USERNAME:
secret.password self._stored_credentials.setdefault("exploit_user_list", set()).add(identity.username)
)
elif secret.credential_type is CredentialComponentType.LM_HASH: def _add_secret(self, secret: ICredentialComponent):
self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add( if secret.credential_type is CredentialComponentType.PASSWORD:
secret.lm_hash self._stored_credentials.setdefault("exploit_password_list", set()).add(secret.password)
) elif secret.credential_type is CredentialComponentType.LM_HASH:
elif secret.credential_type is CredentialComponentType.NT_HASH: self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(secret.lm_hash)
self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add( elif secret.credential_type is CredentialComponentType.NT_HASH:
secret.nt_hash self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(secret.nt_hash)
) elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR: self._set_attribute(
self._set_attribute( "exploit_ssh_keys",
"exploit_ssh_keys", [{"public_key": secret.public_key, "private_key": secret.private_key}],
[{"public_key": secret.public_key, "private_key": secret.private_key}], )
)
def get_credentials(self) -> PropagationCredentials: def get_credentials(self) -> PropagationCredentials:
try: try:

View File

@ -42,6 +42,7 @@ TEST_CREDENTIALS = [
identity=None, identity=None,
secret=Password("super_secret"), secret=Password("super_secret"),
), ),
Credentials(identity=Username("user4"), secret=None),
] ]
SSH_KEYS_CREDENTIALS = [ SSH_KEYS_CREDENTIALS = [
@ -88,6 +89,7 @@ def test_add_credentials_to_store(aggregating_credentials_store):
"root", "root",
"user1", "user1",
"user3", "user3",
"user4",
] ]
) )
assert actual_stored_credentials["exploit_password_list"] == set( assert actual_stored_credentials["exploit_password_list"] == set(