Agent: Fix logic in concrete Credentials Store
This commit is contained in:
parent
162dd0a920
commit
5060ddb5d1
|
@ -1,2 +1,2 @@
|
|||
from .i_credentials_store import ICredentialsStore
|
||||
from .credentials_store import CredentialsStore
|
||||
from .aggregating_credentials_store import AggregatingCredentialsStore
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import logging
|
||||
from typing import Iterable, Mapping
|
||||
|
||||
from common.common_consts.credential_component_type import CredentialComponentType
|
||||
from infection_monkey.i_control_channel import IControlChannel
|
||||
from infection_monkey.i_puppet import Credentials
|
||||
|
||||
from .i_credentials_store import ICredentialsStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AggregatingCredentialsStore(ICredentialsStore):
|
||||
def __init__(self, control_channel: IControlChannel):
|
||||
self.stored_credentials = {}
|
||||
self._control_channel = control_channel
|
||||
|
||||
def add_credentials(self, credentials_to_add: Iterable[Credentials]) -> None:
|
||||
for credentials in credentials_to_add:
|
||||
usernames = [
|
||||
identity.username
|
||||
for identity in credentials.identities
|
||||
if identity.credential_type is CredentialComponentType.USERNAME
|
||||
]
|
||||
self._set_attribute("exploit_user_list", usernames)
|
||||
|
||||
for secret in credentials.secrets:
|
||||
if secret.credential_type is CredentialComponentType.PASSWORD:
|
||||
self._set_attribute("exploit_password_list", [secret.password])
|
||||
elif secret.credential_type is CredentialComponentType.LM_HASH:
|
||||
self._set_attribute("exploit_lm_hash_list", [secret.lm_hash])
|
||||
elif secret.credential_type is CredentialComponentType.NT_HASH:
|
||||
self._set_attribute("exploit_ntlm_hash_list", [secret.nt_hash])
|
||||
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
|
||||
self._set_attribute(
|
||||
"exploit_ssh_keys",
|
||||
[{"public_key": secret.public_key, "private_key": secret.private_key}],
|
||||
)
|
||||
|
||||
def get_credentials(self):
|
||||
try:
|
||||
propagation_credentials = self._control_channel.get_credentials_for_propagation()
|
||||
self._aggregate_credentials(propagation_credentials)
|
||||
except Exception as ex:
|
||||
self.stored_credentials = {}
|
||||
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
|
||||
|
||||
def _aggregate_credentials(self, credentials_to_aggr: Mapping):
|
||||
for cred_attr, credentials_values in credentials_to_aggr.items():
|
||||
if credentials_values:
|
||||
self._set_attribute(cred_attr, credentials_values)
|
||||
|
||||
def _set_attribute(self, attribute_to_be_set, credentials_values):
|
||||
if attribute_to_be_set not in self.stored_credentials:
|
||||
self.stored_credentials[attribute_to_be_set] = []
|
||||
|
||||
if isinstance(credentials_values[0], dict):
|
||||
self.stored_credentials.setdefault(attribute_to_be_set, []).extend(credentials_values)
|
||||
self.stored_credentials[attribute_to_be_set] = [
|
||||
dict(s_c)
|
||||
for s_c in set(
|
||||
frozenset(d_c.items()) for d_c in self.stored_credentials[attribute_to_be_set]
|
||||
)
|
||||
]
|
||||
else:
|
||||
self.stored_credentials[attribute_to_be_set] = sorted(
|
||||
list(set(self.stored_credentials[attribute_to_be_set]).union(credentials_values))
|
||||
)
|
|
@ -1,29 +0,0 @@
|
|||
from typing import Mapping
|
||||
|
||||
from .i_credentials_store import ICredentialsStore
|
||||
|
||||
|
||||
class CredentialsStore(ICredentialsStore):
|
||||
def __init__(self, credentials: Mapping = None):
|
||||
self.stored_credentials = credentials
|
||||
|
||||
def add_credentials(self, credentials_to_add: Mapping) -> None:
|
||||
if self.stored_credentials is None:
|
||||
self.stored_credentials = {}
|
||||
|
||||
for key, value in credentials_to_add.items():
|
||||
if key not in self.stored_credentials:
|
||||
self.stored_credentials[key] = []
|
||||
|
||||
if key != "exploit_ssh_keys":
|
||||
self.stored_credentials[key] = list(
|
||||
sorted(set(self.stored_credentials[key]).union(credentials_to_add[key]))
|
||||
)
|
||||
else:
|
||||
self.stored_credentials[key] += credentials_to_add[key]
|
||||
self.stored_credentials[key] = [
|
||||
dict(s) for s in set(frozenset(d.items()) for d in self.stored_credentials[key])
|
||||
]
|
||||
|
||||
def get_credentials(self) -> Mapping:
|
||||
return self.stored_credentials
|
|
@ -1,19 +1,20 @@
|
|||
import abc
|
||||
from typing import Mapping
|
||||
from typing import Iterable
|
||||
|
||||
from infection_monkey.i_puppet import Credentials
|
||||
|
||||
|
||||
class ICredentialsStore(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def add_credentials(self, credentials_to_add: Mapping = {}) -> None:
|
||||
"""
|
||||
def add_credentials(self, credentials_to_add: Iterable[Credentials]) -> None:
|
||||
"""a
|
||||
Method that adds credentials to the CredentialStore
|
||||
:param Credentials credentials: The credentials which will be added
|
||||
:param Credentials credentials: The credentials that will be added
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> Mapping:
|
||||
def get_credentials(self) -> None:
|
||||
"""
|
||||
Method that gets credentials from the ControlChannel
|
||||
:return: A squence of Credentials that have been added for propagation
|
||||
:rtype: Mapping
|
||||
Method that retrieves credentials from the store
|
||||
:return: Credentials that can be used for propagation
|
||||
"""
|
||||
|
|
|
@ -2,66 +2,83 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.credential_collectors import Password, SSHKeypair, Username
|
||||
from infection_monkey.credential_store import AggregatingCredentialsStore
|
||||
from infection_monkey.i_puppet import Credentials
|
||||
|
||||
DEFAULT_CREDENTIALS = {
|
||||
"exploit_user_list": ["Administrator", "root", "user1"],
|
||||
"exploit_password_list": [
|
||||
"root",
|
||||
"123456",
|
||||
"password",
|
||||
"123456789",
|
||||
],
|
||||
"exploit_password_list": ["123456", "123456789", "password", "root"],
|
||||
"exploit_lm_hash_list": ["aasdf23asd1fdaasadasdfas"],
|
||||
"exploit_ntlm_hash_list": ["qw4trklxklvznksbhasd1231", "asdfadvxvsdftw3e3421234123412"],
|
||||
"exploit_ntlm_hash_list": ["asdfadvxvsdftw3e3421234123412", "qw4trklxklvznksbhasd1231"],
|
||||
"exploit_ssh_keys": [
|
||||
{"public_key": "some_public_key", "private_key": "some_private_key"},
|
||||
{
|
||||
"public_key": "ssh-ed25519 AAAAC3NzEIFaJ7xH+Yoxd\n",
|
||||
"private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BdHIAAAAGYXjl0j66VAKruPEKjS3A=\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----\n",
|
||||
"user": "ubuntu",
|
||||
"ip": "10.0.3.15",
|
||||
},
|
||||
{"public_key": "some_public_key", "private_key": "some_private_key"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
SAMPLE_CREDENTIALS = {
|
||||
PROPAGATION_CREDENTIALS = {
|
||||
"exploit_user_list": ["user1", "user3"],
|
||||
"exploit_password_list": ["abcdefg", "root"],
|
||||
"exploit_ssh_keys": [{"public_key": "some_public_key", "private_key": "some_private_key"}],
|
||||
"exploit_ntlm_hash_list": [],
|
||||
}
|
||||
|
||||
TELEM_CREDENTIALS = [
|
||||
Credentials(
|
||||
[Username("user1"), Username("user3")],
|
||||
[
|
||||
Password("abcdefg"),
|
||||
Password("root"),
|
||||
SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aggregating_credentials_store() -> AggregatingCredentialsStore:
|
||||
return AggregatingCredentialsStore()
|
||||
control_channel = MagicMock()
|
||||
control_channel.get_credentials_for_propagation.return_value = DEFAULT_CREDENTIALS
|
||||
return AggregatingCredentialsStore(control_channel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("credentials_to_store", [DEFAULT_CREDENTIALS, SAMPLE_CREDENTIALS])
|
||||
def test_get_credentials_from_store(aggregating_credentials_store, credentials_to_store):
|
||||
get_updated_credentials_for_propagation = MagicMock(return_value=credentials_to_store)
|
||||
def test_get_credentials_from_store(aggregating_credentials_store):
|
||||
aggregating_credentials_store.get_credentials()
|
||||
|
||||
aggregating_credentials_store.get_credentials(get_updated_credentials_for_propagation)
|
||||
actual_stored_credentials = aggregating_credentials_store.stored_credentials
|
||||
|
||||
assert aggregating_credentials_store.stored_credentials == credentials_to_store
|
||||
assert (
|
||||
actual_stored_credentials["exploit_user_list"] == DEFAULT_CREDENTIALS["exploit_user_list"]
|
||||
)
|
||||
assert (
|
||||
actual_stored_credentials["exploit_password_list"]
|
||||
== DEFAULT_CREDENTIALS["exploit_password_list"]
|
||||
)
|
||||
assert (
|
||||
actual_stored_credentials["exploit_ntlm_hash_list"]
|
||||
== DEFAULT_CREDENTIALS["exploit_ntlm_hash_list"]
|
||||
)
|
||||
|
||||
for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]:
|
||||
assert ssh_keypair in DEFAULT_CREDENTIALS["exploit_ssh_keys"]
|
||||
|
||||
|
||||
def test_add_credentials_to_empty_store(aggregating_credentials_store):
|
||||
aggregating_credentials_store.add_credentials(TELEM_CREDENTIALS)
|
||||
|
||||
aggregating_credentials_store.add_credentials(SAMPLE_CREDENTIALS)
|
||||
|
||||
assert aggregating_credentials_store.stored_credentials == SAMPLE_CREDENTIALS
|
||||
assert aggregating_credentials_store.stored_credentials == PROPAGATION_CREDENTIALS
|
||||
|
||||
|
||||
def test_add_credentials_to_full_store(aggregating_credentials_store):
|
||||
get_updated_credentials_for_propagation = MagicMock(return_value=DEFAULT_CREDENTIALS)
|
||||
|
||||
aggregating_credentials_store.get_credentials(get_updated_credentials_for_propagation)
|
||||
aggregating_credentials_store.get_credentials()
|
||||
|
||||
aggregating_credentials_store.add_credentials(SAMPLE_CREDENTIALS)
|
||||
aggregating_credentials_store.add_credentials(TELEM_CREDENTIALS)
|
||||
|
||||
actual_stored_credentials = aggregating_credentials_store.stored_credentials
|
||||
|
||||
|
@ -78,4 +95,6 @@ def test_add_credentials_to_full_store(aggregating_credentials_store):
|
|||
"password",
|
||||
"root",
|
||||
]
|
||||
assert actual_stored_credentials["exploit_ssh_keys"] == DEFAULT_CREDENTIALS["exploit_ssh_keys"]
|
||||
|
||||
for ssh_keypair in actual_stored_credentials["exploit_ssh_keys"]:
|
||||
assert ssh_keypair in DEFAULT_CREDENTIALS["exploit_ssh_keys"]
|
||||
|
|
Loading…
Reference in New Issue