Merge pull request #2099 from guardicore/2072-credentials-store-fix

2072 credentials store fix
This commit is contained in:
Mike Salvatore 2022-07-18 11:07:38 -04:00 committed by GitHub
commit 4e11ed2816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 75 deletions

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Iterable, Mapping from typing import Any, Iterable, Sequence
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials from infection_monkey.custom_types import PropagationCredentials
@ -52,9 +52,7 @@ class AggregatingCredentialsStore(ICredentialsStore):
def get_credentials(self) -> PropagationCredentials: def get_credentials(self) -> PropagationCredentials:
try: try:
propagation_credentials = self._get_credentials_from_control_channel() propagation_credentials = self._get_credentials_from_control_channel()
self.add_credentials(propagation_credentials)
# Needs to be reworked when exploiters accepts sequence of Credentials
self._aggregate_credentials(propagation_credentials)
return self._stored_credentials return self._stored_credentials
except Exception as ex: except Exception as ex:
@ -62,13 +60,9 @@ class AggregatingCredentialsStore(ICredentialsStore):
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}") logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
@request_cache(CREDENTIALS_POLL_PERIOD_SEC) @request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def _get_credentials_from_control_channel(self) -> PropagationCredentials: def _get_credentials_from_control_channel(self) -> Sequence[Credentials]:
return self._control_channel.get_credentials_for_propagation() return self._control_channel.get_credentials_for_propagation()
def _aggregate_credentials(self, credentials_to_aggr: Mapping):
for cred_attr, credentials_values in credentials_to_aggr.items():
self._set_attribute(cred_attr, credentials_values)
def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]): def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]):
if not credentials_values: if not credentials_values:
return return

View File

@ -1,6 +1,8 @@
import abc import abc
from typing import Sequence
from common.configuration import AgentConfiguration from common.configuration import AgentConfiguration
from common.credentials import Credentials
class IControlChannel(metaclass=abc.ABCMeta): class IControlChannel(metaclass=abc.ABCMeta):
@ -21,10 +23,11 @@ class IControlChannel(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_credentials_for_propagation(self) -> dict: def get_credentials_for_propagation(self) -> Sequence[Credentials]:
""" """
:return: A dictionary containing propagation credentials data Get credentials to use during propagation
:rtype: dict
:return: A Sequence containing propagation credentials data
""" """
pass pass

View File

@ -1,13 +1,13 @@
import json import json
import logging import logging
from pprint import pformat from pprint import pformat
from typing import Mapping from typing import Mapping, Sequence
import requests import requests
from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from common.configuration import AgentConfiguration from common.configuration import AgentConfiguration
from infection_monkey.custom_types import PropagationCredentials from common.credentials import Credentials
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
@ -71,7 +71,7 @@ class ControlChannel(IControlChannel):
) as e: ) as e:
raise IslandCommunicationError(e) raise IslandCommunicationError(e)
def get_credentials_for_propagation(self) -> PropagationCredentials: def get_credentials_for_propagation(self) -> Sequence[Credentials]:
propagation_credentials_url = ( propagation_credentials_url = (
f"https://{self._control_channel_server}/api/propagation-credentials" f"https://{self._control_channel_server}/api/propagation-credentials"
) )
@ -84,9 +84,9 @@ class ControlChannel(IControlChannel):
) )
response.raise_for_status() response.raise_for_status()
return json.loads(response.content.decode())["propagation_credentials"] return [Credentials.from_mapping(credentials) for credentials in response.json]
except ( except (
json.JSONDecodeError, requests.exceptions.JSONDecodeError,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,
requests.exceptions.Timeout, requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects, requests.exceptions.TooManyRedirects,

View File

@ -1,54 +1,67 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from tests.data_for_tests.propagation_credentials import (
LM_HASH,
NT_HASH,
PASSWORD_1,
PASSWORD_2,
PASSWORD_3,
PRIVATE_KEY,
PROPAGATION_CREDENTIALS,
PUBLIC_KEY,
SPECIAL_USERNAME,
USERNAME,
)
from common.credentials import Credentials, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
from infection_monkey.credential_store import AggregatingCredentialsStore from infection_monkey.credential_store import AggregatingCredentialsStore
CONTROL_CHANNEL_CREDENTIALS = { CONTROL_CHANNEL_CREDENTIALS = PROPAGATION_CREDENTIALS
"exploit_user_list": ["Administrator", "root", "user1"], TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS = {
"exploit_password_list": ["123456", "123456789", "password", "root"], "exploit_user_list": {USERNAME, SPECIAL_USERNAME},
"exploit_lm_hash_list": ["aasdf23asd1fdaasadasdfas"], "exploit_password_list": {PASSWORD_1, PASSWORD_2, PASSWORD_3},
"exploit_ntlm_hash_list": ["asdfadvxvsdftw3e3421234123412", "qw4trklxklvznksbhasd1231"], "exploit_lm_hash_list": {LM_HASH},
"exploit_ssh_keys": [ "exploit_ntlm_hash_list": {NT_HASH},
{"public_key": "some_public_key", "private_key": "some_private_key"}, "exploit_ssh_keys": [{"public_key": PUBLIC_KEY, "private_key": PRIVATE_KEY}],
{
"public_key": "ssh-ed25519 AAAAC3NzEIFaJ7xH+Yoxd\n",
"private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BdHIAAAAGYXjl0j66VAKruPEKjS3A=\n"
"-----END OPENSSH PRIVATE KEY-----\n",
},
],
} }
EMPTY_CHANNEL_CREDENTIALS = { EMPTY_CHANNEL_CREDENTIALS = []
"exploit_user_list": [],
"exploit_password_list": [],
"exploit_lm_hash_list": [],
"exploit_ntlm_hash_list": [],
"exploit_ssh_keys": [],
}
TEST_CREDENTIALS = [ STOLEN_USERNAME_1 = "user1"
STOLEN_USERNAME_2 = "user2"
STOLEN_USERNAME_3 = "user3"
STOLEN_PASSWORD_1 = "abcdefg"
STOLEN_PASSWORD_2 = "super_secret"
STOLEN_PUBLIC_KEY_1 = "some_public_key_1"
STOLEN_PUBLIC_KEY_2 = "some_public_key_2"
STOLEN_LM_HASH = "AAD3B435B51404EEAAD3B435B51404EE"
STOLEN_NT_HASH = "C0172DFF622FE29B5327CB79DC12D24C"
STOLEN_PRIVATE_KEY_1 = "some_private_key_1"
STOLEN_PRIVATE_KEY_2 = "some_private_key_2"
STOLEN_CREDENTIALS = [
Credentials( Credentials(
identity=Username("user1"), identity=Username(STOLEN_USERNAME_1),
secret=Password("root"), secret=Password(PASSWORD_1),
), ),
Credentials(identity=Username("user1"), secret=Password("abcdefg")), Credentials(identity=Username(STOLEN_USERNAME_1), secret=Password(STOLEN_PASSWORD_1)),
Credentials( Credentials(
identity=Username("user3"), identity=Username(STOLEN_USERNAME_2),
secret=SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"), secret=SSHKeypair(public_key=STOLEN_PUBLIC_KEY_1, private_key=STOLEN_PRIVATE_KEY_1),
), ),
Credentials( Credentials(
identity=None, identity=None,
secret=Password("super_secret"), secret=Password(STOLEN_PASSWORD_2),
), ),
Credentials(identity=Username("user4"), secret=None), Credentials(identity=Username(STOLEN_USERNAME_2), secret=LMHash(STOLEN_LM_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_2), secret=NTHash(STOLEN_NT_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_3), secret=None),
] ]
SSH_KEYS_CREDENTIALS = [ STOLEN_SSH_KEYS_CREDENTIALS = [
Credentials( Credentials(
Username("root"), Username(USERNAME),
SSHKeypair(public_key="some_public_key", private_key="some_private_key"), SSHKeypair(public_key=STOLEN_PUBLIC_KEY_2, private_key=STOLEN_PRIVATE_KEY_2),
) )
] ]
@ -60,48 +73,39 @@ def aggregating_credentials_store() -> AggregatingCredentialsStore:
return AggregatingCredentialsStore(control_channel) return AggregatingCredentialsStore(control_channel)
def test_get_credentials_from_store(aggregating_credentials_store): @pytest.mark.parametrize("key", TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS.keys())
def test_get_credentials_from_store(aggregating_credentials_store, key):
actual_stored_credentials = aggregating_credentials_store.get_credentials() actual_stored_credentials = aggregating_credentials_store.get_credentials()
assert actual_stored_credentials["exploit_user_list"] == set( assert actual_stored_credentials[key] == TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS[key]
CONTROL_CHANNEL_CREDENTIALS["exploit_user_list"]
)
assert actual_stored_credentials["exploit_password_list"] == set(
CONTROL_CHANNEL_CREDENTIALS["exploit_password_list"]
)
assert actual_stored_credentials["exploit_ntlm_hash_list"] == set(
CONTROL_CHANNEL_CREDENTIALS["exploit_ntlm_hash_list"]
)
for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]:
assert ssh_keypair in CONTROL_CHANNEL_CREDENTIALS["exploit_ssh_keys"]
def test_add_credentials_to_store(aggregating_credentials_store): def test_add_credentials_to_store(aggregating_credentials_store):
aggregating_credentials_store.add_credentials(TEST_CREDENTIALS) aggregating_credentials_store.add_credentials(STOLEN_CREDENTIALS)
aggregating_credentials_store.add_credentials(SSH_KEYS_CREDENTIALS) aggregating_credentials_store.add_credentials(STOLEN_SSH_KEYS_CREDENTIALS)
actual_stored_credentials = aggregating_credentials_store.get_credentials() actual_stored_credentials = aggregating_credentials_store.get_credentials()
assert actual_stored_credentials["exploit_user_list"] == set( assert actual_stored_credentials["exploit_user_list"] == set(
[ [
"Administrator", USERNAME,
"root", SPECIAL_USERNAME,
"user1", STOLEN_USERNAME_1,
"user3", STOLEN_USERNAME_2,
"user4", STOLEN_USERNAME_3,
] ]
) )
assert actual_stored_credentials["exploit_password_list"] == set( assert actual_stored_credentials["exploit_password_list"] == set(
[ [
"123456", PASSWORD_1,
"123456789", PASSWORD_2,
"abcdefg", PASSWORD_3,
"password", STOLEN_PASSWORD_1,
"root", STOLEN_PASSWORD_2,
"super_secret",
] ]
) )
assert actual_stored_credentials["exploit_lm_hash_list"] == set([LM_HASH, STOLEN_LM_HASH])
assert actual_stored_credentials["exploit_ntlm_hash_list"] == set([NT_HASH, STOLEN_NT_HASH])
assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3 assert len(actual_stored_credentials["exploit_ssh_keys"]) == 3