Agent: Use new Credentials objects in AggregatingCredentialsStore

This commit is contained in:
Mike Salvatore 2022-07-18 10:22:14 -04:00
parent ef4fbb30cc
commit ebc854735e
2 changed files with 72 additions and 61 deletions

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Iterable, Mapping from typing import Any, Iterable, Sequence
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
@ -52,9 +52,7 @@ class AggregatingCredentialsStore(ICredentialsStore):
def get_credentials(self) -> PropagationCredentials: def get_credentials(self) -> PropagationCredentials:
try: try:
propagation_credentials = self._get_credentials_from_control_channel() propagation_credentials = self._get_credentials_from_control_channel()
self.add_credentials(propagation_credentials)
# Needs to be reworked when exploiters accepts sequence of Credentials
self._aggregate_credentials(propagation_credentials)
return self._stored_credentials return self._stored_credentials
except Exception as ex: except Exception as ex:
@ -62,13 +60,9 @@ class AggregatingCredentialsStore(ICredentialsStore):
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}") logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
@request_cache(CREDENTIALS_POLL_PERIOD_SEC) @request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def _get_credentials_from_control_channel(self) -> PropagationCredentials: def _get_credentials_from_control_channel(self) -> Sequence[Credentials]:
return self._control_channel.get_credentials_for_propagation() return self._control_channel.get_credentials_for_propagation()
def _aggregate_credentials(self, credentials_to_aggr: Mapping):
for cred_attr, credentials_values in credentials_to_aggr.items():
self._set_attribute(cred_attr, credentials_values)
def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]): def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]):
if not credentials_values: if not credentials_values:
return return

View File

@ -1,54 +1,67 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from tests.data_for_tests.propagation_credentials import (
LM_HASH,
NT_HASH,
PASSWORD_1,
PASSWORD_2,
PASSWORD_3,
PRIVATE_KEY,
PROPAGATION_CREDENTIALS,
PUBLIC_KEY,
SPECIAL_USERNAME,
USERNAME,
)
from common.credentials import Credentials, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
from infection_monkey.credential_store import AggregatingCredentialsStore from infection_monkey.credential_store import AggregatingCredentialsStore
CONTROL_CHANNEL_CREDENTIALS = { CONTROL_CHANNEL_CREDENTIALS = PROPAGATION_CREDENTIALS
"exploit_user_list": ["Administrator", "root", "user1"], TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS = {
"exploit_password_list": ["123456", "123456789", "password", "root"], "exploit_user_list": {USERNAME, SPECIAL_USERNAME},
"exploit_lm_hash_list": ["aasdf23asd1fdaasadasdfas"], "exploit_password_list": {PASSWORD_1, PASSWORD_2, PASSWORD_3},
"exploit_ntlm_hash_list": ["asdfadvxvsdftw3e3421234123412", "qw4trklxklvznksbhasd1231"], "exploit_lm_hash_list": {LM_HASH},
"exploit_ssh_keys": [ "exploit_ntlm_hash_list": {NT_HASH},
{"public_key": "some_public_key", "private_key": "some_private_key"}, "exploit_ssh_keys": [{"public_key": PUBLIC_KEY, "private_key": PRIVATE_KEY}],
{
"public_key": "ssh-ed25519 AAAAC3NzEIFaJ7xH+Yoxd\n",
"private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BdHIAAAAGYXjl0j66VAKruPEKjS3A=\n"
"-----END OPENSSH PRIVATE KEY-----\n",
},
],
} }
EMPTY_CHANNEL_CREDENTIALS = { EMPTY_CHANNEL_CREDENTIALS = []
"exploit_user_list": [],
"exploit_password_list": [],
"exploit_lm_hash_list": [],
"exploit_ntlm_hash_list": [],
"exploit_ssh_keys": [],
}
TEST_CREDENTIALS = [ STOLEN_USERNAME_1 = "user1"
STOLEN_USERNAME_2 = "user2"
STOLEN_USERNAME_3 = "user3"
STOLEN_PASSWORD_1 = "abcdefg"
STOLEN_PASSWORD_2 = "super_secret"
STOLEN_PUBLIC_KEY_1 = "some_public_key_1"
STOLEN_PUBLIC_KEY_2 = "some_public_key_2"
STOLEN_LM_HASH = "AAD3B435B51404EEAAD3B435B51404EE"
STOLEN_NT_HASH = "C0172DFF622FE29B5327CB79DC12D24C"
STOLEN_PRIVATE_KEY_1 = "some_private_key_1"
STOLEN_PRIVATE_KEY_2 = "some_private_key_2"
STOLEN_CREDENTIALS = [
Credentials( Credentials(
identity=Username("user1"), identity=Username(STOLEN_USERNAME_1),
secret=Password("root"), secret=Password(PASSWORD_1),
), ),
Credentials(identity=Username("user1"), secret=Password("abcdefg")), Credentials(identity=Username(STOLEN_USERNAME_1), secret=Password(STOLEN_PASSWORD_1)),
Credentials( Credentials(
identity=Username("user3"), identity=Username(STOLEN_USERNAME_2),
secret=SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"), secret=SSHKeypair(public_key=STOLEN_PUBLIC_KEY_1, private_key=STOLEN_PRIVATE_KEY_1),
), ),
Credentials( Credentials(
identity=None, identity=None,
secret=Password("super_secret"), secret=Password(STOLEN_PASSWORD_2),
), ),
Credentials(identity=Username("user4"), secret=None), Credentials(identity=Username(STOLEN_USERNAME_2), secret=LMHash(STOLEN_LM_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_2), secret=NTHash(STOLEN_NT_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_3), secret=None),
] ]
SSH_KEYS_CREDENTIALS = [ STOLEN_SSH_KEYS_CREDENTIALS = [
Credentials( Credentials(
Username("root"), Username(USERNAME),
SSHKeypair(public_key="some_public_key", private_key="some_private_key"), SSHKeypair(public_key=STOLEN_PUBLIC_KEY_2, private_key=STOLEN_PRIVATE_KEY_2),
) )
] ]
@ -63,45 +76,49 @@ def aggregating_credentials_store() -> AggregatingCredentialsStore:
def test_get_credentials_from_store(aggregating_credentials_store): def test_get_credentials_from_store(aggregating_credentials_store):
actual_stored_credentials = aggregating_credentials_store.get_credentials() actual_stored_credentials = aggregating_credentials_store.get_credentials()
assert actual_stored_credentials["exploit_user_list"] == set( assert (
CONTROL_CHANNEL_CREDENTIALS["exploit_user_list"] actual_stored_credentials["exploit_user_list"]
== TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS["exploit_user_list"]
) )
assert actual_stored_credentials["exploit_password_list"] == set( assert (
CONTROL_CHANNEL_CREDENTIALS["exploit_password_list"] actual_stored_credentials["exploit_password_list"]
== TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS["exploit_password_list"]
) )
assert actual_stored_credentials["exploit_ntlm_hash_list"] == set( assert (
CONTROL_CHANNEL_CREDENTIALS["exploit_ntlm_hash_list"] actual_stored_credentials["exploit_ntlm_hash_list"]
== TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS["exploit_ntlm_hash_list"]
) )
for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]: for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]:
assert ssh_keypair in CONTROL_CHANNEL_CREDENTIALS["exploit_ssh_keys"] assert ssh_keypair in TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS["exploit_ssh_keys"]
def test_add_credentials_to_store(aggregating_credentials_store): def test_add_credentials_to_store(aggregating_credentials_store):
aggregating_credentials_store.add_credentials(TEST_CREDENTIALS) aggregating_credentials_store.add_credentials(STOLEN_CREDENTIALS)
aggregating_credentials_store.add_credentials(SSH_KEYS_CREDENTIALS) aggregating_credentials_store.add_credentials(STOLEN_SSH_KEYS_CREDENTIALS)
actual_stored_credentials = aggregating_credentials_store.get_credentials() actual_stored_credentials = aggregating_credentials_store.get_credentials()
assert actual_stored_credentials["exploit_user_list"] == set( assert actual_stored_credentials["exploit_user_list"] == set(
[ [
"Administrator", USERNAME,
"root", SPECIAL_USERNAME,
"user1", STOLEN_USERNAME_1,
"user3", STOLEN_USERNAME_2,
"user4", STOLEN_USERNAME_3,
] ]
) )
assert actual_stored_credentials["exploit_password_list"] == set( assert actual_stored_credentials["exploit_password_list"] == set(
[ [
"123456", PASSWORD_1,
"123456789", PASSWORD_2,
"abcdefg", PASSWORD_3,
"password", STOLEN_PASSWORD_1,
"root", STOLEN_PASSWORD_2,
"super_secret",
] ]
) )
assert actual_stored_credentials["exploit_lm_hash_list"] == set([LM_HASH, STOLEN_LM_HASH])
assert actual_stored_credentials["exploit_ntlm_hash_list"] == set([NT_HASH, STOLEN_NT_HASH])
assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3 assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3