Merge pull request #2098 from guardicore/2072-simplify-credentials

2072 simplify credentials
This commit is contained in:
Mike Salvatore 2022-07-18 09:35:54 -04:00 committed by GitHub
commit 19a7bfd8e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 374 additions and 324 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Mapping, MutableMapping, Sequence, Tuple
from typing import Any, Mapping, Optional, Type
from marshmallow import Schema, fields, post_load, pre_dump
from marshmallow.exceptions import MarshmallowError
@ -24,7 +24,7 @@ from .password import PasswordSchema
from .ssh_keypair import SSHKeypairSchema
from .username import UsernameSchema
CREDENTIAL_COMPONENT_TYPE_TO_CLASS = {
CREDENTIAL_COMPONENT_TYPE_TO_CLASS: Mapping[CredentialComponentType, Type[ICredentialComponent]] = {
CredentialComponentType.LM_HASH: LMHash,
CredentialComponentType.NT_HASH: NTHash,
CredentialComponentType.PASSWORD: Password,
@ -32,7 +32,7 @@ CREDENTIAL_COMPONENT_TYPE_TO_CLASS = {
CredentialComponentType.USERNAME: Username,
}
CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA = {
CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA: Mapping[CredentialComponentType, Schema] = {
CredentialComponentType.LM_HASH: LMHashSchema(),
CredentialComponentType.NT_HASH: NTHashSchema(),
CredentialComponentType.PASSWORD: PasswordSchema(),
@ -40,36 +40,39 @@ CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA = {
CredentialComponentType.USERNAME: UsernameSchema(),
}
CredentialComponentMapping = Optional[Mapping[str, Any]]
CredentialsMapping = Mapping[str, CredentialComponentMapping]
class CredentialsSchema(Schema):
# Use fields.List instead of fields.Tuple because marshmallow requires fields.Tuple to have a
# fixed length.
identities = fields.List(fields.Mapping())
secrets = fields.List(fields.Mapping())
identity = fields.Mapping(allow_none=True)
secret = fields.Mapping(allow_none=True)
@post_load
def _make_credentials(
self, data: MutableMapping, **kwargs: Mapping[str, Any]
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
data["identities"] = tuple(
[
CredentialsSchema._build_credential_component(component)
for component in data["identities"]
]
)
data["secrets"] = tuple(
[
CredentialsSchema._build_credential_component(component)
for component in data["secrets"]
]
)
self,
credentials: CredentialsMapping,
**kwargs: Mapping[str, Any],
) -> Mapping[str, Optional[ICredentialComponent]]:
if not any(credentials.values()):
raise InvalidCredentialsError("At least one credentials component must be defined")
return data
return {
key: CredentialsSchema._build_credential_component(credential_component_mapping)
for key, credential_component_mapping in credentials.items()
}
@staticmethod
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
def _build_credential_component(
credential_component: CredentialComponentMapping,
) -> Optional[ICredentialComponent]:
if credential_component is None:
return None
try:
credential_component_type = CredentialComponentType[data["credential_type"]]
credential_component_type = CredentialComponentType[
credential_component["credential_type"]
]
except KeyError as err:
raise InvalidCredentialsError(f"Unknown credential component type {err}")
@ -79,35 +82,26 @@ class CredentialsSchema(Schema):
]
try:
return credential_component_class(**credential_component_schema.load(data))
return credential_component_class(
**credential_component_schema.load(credential_component)
)
except MarshmallowError as err:
raise InvalidCredentialComponentError(credential_component_class, str(err))
@pre_dump
def _serialize_credentials(
self, credentials: Credentials, **kwargs
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
data = {}
data["identities"] = tuple(
[
CredentialsSchema._serialize_credential_component(component)
for component in credentials.identities
]
)
data["secrets"] = tuple(
[
CredentialsSchema._serialize_credential_component(component)
for component in credentials.secrets
]
)
return data
def _serialize_credentials(self, credentials: Credentials, **kwargs) -> CredentialsMapping:
return {
"identity": CredentialsSchema._serialize_credential_component(credentials.identity),
"secret": CredentialsSchema._serialize_credential_component(credentials.secret),
}
@staticmethod
def _serialize_credential_component(
credential_component: ICredentialComponent,
) -> Mapping[str, Any]:
credential_component: Optional[ICredentialComponent],
) -> CredentialComponentMapping:
if credential_component is None:
return None
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component.credential_type
]
@ -117,11 +111,22 @@ class CredentialsSchema(Schema):
@dataclass(frozen=True)
class Credentials(IJSONSerializable):
identities: Tuple[ICredentialComponent]
secrets: Tuple[ICredentialComponent]
identity: Optional[ICredentialComponent]
secret: Optional[ICredentialComponent]
def __post_init__(self):
schema = CredentialsSchema()
try:
serialized_data = schema.dump(self)
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
schema.load(serialized_data)
except Exception as err:
raise InvalidCredentialsError(err)
@staticmethod
def from_mapping(credentials: Mapping) -> Credentials:
def from_mapping(credentials: CredentialsMapping) -> Credentials:
"""
Construct a Credentials object from a Mapping
@ -163,7 +168,7 @@ class Credentials(IJSONSerializable):
raise InvalidCredentialsError(str(err))
@staticmethod
def to_mapping(credentials: Credentials) -> Mapping:
def to_mapping(credentials: Credentials) -> CredentialsMapping:
"""
Serialize a Credentials object to a Mapping

View File

@ -14,36 +14,38 @@ logger = logging.getLogger(__name__)
class MimikatzCredentialCollector(ICredentialCollector):
def collect_credentials(self, options=None) -> Sequence[Credentials]:
logger.info("Attempting to collect windows credentials with pypykatz.")
creds = pypykatz_handler.get_windows_creds()
logger.info(f"Pypykatz gathered {len(creds)} credentials.")
return MimikatzCredentialCollector._to_credentials(creds)
windows_credentials = pypykatz_handler.get_windows_creds()
logger.info(f"Pypykatz gathered {len(windows_credentials)} credentials.")
return MimikatzCredentialCollector._to_credentials(windows_credentials)
@staticmethod
def _to_credentials(win_creds: Sequence[WindowsCredentials]) -> [Credentials]:
all_creds = []
for win_cred in win_creds:
identities = []
secrets = []
def _to_credentials(windows_credentials: Sequence[WindowsCredentials]) -> Sequence[Credentials]:
credentials = []
for wc in windows_credentials:
# Mimikatz picks up users created by the Monkey even if they're successfully deleted
# since it picks up creds from the registry. The newly created users are not removed
# from the registry until a reboot of the system, hence this check.
if win_cred.username and not win_cred.username.startswith(USERNAME_PREFIX):
identity = Username(win_cred.username)
identities.append(identity)
if wc.username and wc.username.startswith(USERNAME_PREFIX):
continue
if win_cred.password:
password = Password(win_cred.password)
secrets.append(password)
identity = None
if win_cred.lm_hash:
lm_hash = LMHash(lm_hash=win_cred.lm_hash)
secrets.append(lm_hash)
if wc.username:
identity = Username(wc.username)
if win_cred.ntlm_hash:
lm_hash = NTHash(nt_hash=win_cred.ntlm_hash)
secrets.append(lm_hash)
if wc.password:
password = Password(wc.password)
credentials.append(Credentials(identity, password))
if identities != [] or secrets != []:
all_creds.append(Credentials(identities, secrets))
return all_creds
if wc.lm_hash:
lm_hash = LMHash(lm_hash=wc.lm_hash)
credentials.append(Credentials(identity, lm_hash))
if wc.ntlm_hash:
ntlm_hash = NTHash(nt_hash=wc.ntlm_hash)
credentials.append(Credentials(identity, ntlm_hash))
if len(credentials) == 0 and identity is not None:
credentials.append(Credentials(identity, None))
return credentials

View File

@ -29,11 +29,11 @@ class SSHCredentialCollector(ICredentialCollector):
ssh_credentials = []
for info in ssh_info:
identities = []
secrets = []
identity = None
secret = None
if info.get("name", ""):
identities.append(Username(info["name"]))
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
@ -41,13 +41,11 @@ class SSHCredentialCollector(ICredentialCollector):
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secrets.append(
SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if identities != [] or secrets != []:
ssh_credentials.append(Credentials(identities, secrets))
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials

View File

@ -1,7 +1,7 @@
import logging
from typing import Any, Iterable, Mapping
from common.credentials import CredentialComponentType, Credentials
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials
from infection_monkey.i_control_channel import IControlChannel
from infection_monkey.utils.decorators import request_cache
@ -26,31 +26,28 @@ class AggregatingCredentialsStore(ICredentialsStore):
def add_credentials(self, credentials_to_add: Iterable[Credentials]):
for credentials in credentials_to_add:
usernames = {
identity.username
for identity in credentials.identities
if identity.credential_type is CredentialComponentType.USERNAME
}
self._stored_credentials.setdefault("exploit_user_list", set()).update(usernames)
if credentials.identity:
self._add_identity(credentials.identity)
for secret in credentials.secrets:
if secret.credential_type is CredentialComponentType.PASSWORD:
self._stored_credentials.setdefault("exploit_password_list", set()).add(
secret.password
)
elif secret.credential_type is CredentialComponentType.LM_HASH:
self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(
secret.lm_hash
)
elif secret.credential_type is CredentialComponentType.NT_HASH:
self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(
secret.nt_hash
)
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
self._set_attribute(
"exploit_ssh_keys",
[{"public_key": secret.public_key, "private_key": secret.private_key}],
)
if credentials.secret:
self._add_secret(credentials.secret)
def _add_identity(self, identity: ICredentialComponent):
if identity.credential_type is CredentialComponentType.USERNAME:
self._stored_credentials.setdefault("exploit_user_list", set()).add(identity.username)
def _add_secret(self, secret: ICredentialComponent):
if secret.credential_type is CredentialComponentType.PASSWORD:
self._stored_credentials.setdefault("exploit_password_list", set()).add(secret.password)
elif secret.credential_type is CredentialComponentType.LM_HASH:
self._stored_credentials.setdefault("exploit_lm_hash_list", set()).add(secret.lm_hash)
elif secret.credential_type is CredentialComponentType.NT_HASH:
self._stored_credentials.setdefault("exploit_ntlm_hash_list", set()).add(secret.nt_hash)
elif secret.credential_type is CredentialComponentType.SSH_KEYPAIR:
self._set_attribute(
"exploit_ssh_keys",
[{"public_key": secret.public_key, "private_key": secret.private_key}],
)
def get_credentials(self) -> PropagationCredentials:
try:

View File

@ -71,39 +71,39 @@ class MongoCredentialsRepository(ICredentialsRepository):
except Exception as err:
raise StorageError(err)
# NOTE: The encryption/decryption is complicated and also full of mostly duplicated code. Rather
# than spend the effort to improve them now, we can revisit them when we resolve #2072.
# Resolving #2072 will make it easier to simplify these methods and remove duplication.
#
# If possible, implement the encryption/decryption as a decorator so it can be reused with
# TODO: If possible, implement the encryption/decryption as a decorator so it can be reused with
# different ICredentialsRepository implementations
def _encrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
encrypted_mapping: Dict[str, Any] = {}
for secret_or_identity, credentials_components in mapping.items():
encrypted_mapping[secret_or_identity] = []
for component in credentials_components:
encrypted_component = {}
for key, value in component.items():
encrypted_component[key] = self._repository_encryptor.encrypt(value.encode())
for secret_or_identity, credentials_component in mapping.items():
if credentials_component is None:
encrypted_component = None
else:
encrypted_component = {
key: self._repository_encryptor.encrypt(value.encode())
for key, value in credentials_component.items()
}
encrypted_mapping[secret_or_identity].append(encrypted_component)
encrypted_mapping[secret_or_identity] = encrypted_component
return encrypted_mapping
def _decrypt_credentials_mapping(self, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
encrypted_mapping: Dict[str, Any] = {}
decrypted_mapping: Dict[str, Any] = {}
for secret_or_identity, credentials_components in mapping.items():
encrypted_mapping[secret_or_identity] = []
for component in credentials_components:
encrypted_component = {}
for key, value in component.items():
encrypted_component[key] = self._repository_encryptor.decrypt(value).decode()
for secret_or_identity, credentials_component in mapping.items():
if credentials_component is None:
decrypted_component = None
else:
decrypted_component = {
key: self._repository_encryptor.decrypt(value).decode()
for key, value in credentials_component.items()
}
encrypted_mapping[secret_or_identity].append(encrypted_component)
decrypted_mapping[secret_or_identity] = decrypted_component
return encrypted_mapping
return decrypted_mapping
@staticmethod
def _remove_credentials_fom_collection(collection):

View File

View File

@ -1,26 +1,31 @@
from common.credentials import Credentials, LMHash, NTHash, Password, Username
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
username = "m0nk3y_user"
special_username = "m0nk3y.user"
nt_hash = "C1C58F96CDF212B50837BC11A00BE47C"
lm_hash = "299BD128C1101FD6299BD128C1101FD6"
password_1 = "trytostealthis"
password_2 = "password"
password_3 = "12345678"
USERNAME = "m0nk3y_user"
SPECIAL_USERNAME = "m0nk3y.user"
NT_HASH = "C1C58F96CDF212B50837BC11A00BE47C"
LM_HASH = "299BD128C1101FD6299BD128C1101FD6"
PASSWORD_1 = "trytostealthis"
PASSWORD_2 = "password!"
PASSWORD_3 = "rubberbabybuggybumpers"
PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY"
PROPAGATION_CREDENTIALS_1 = Credentials(
identities=(Username(username),),
secrets=(NTHash(nt_hash), LMHash(lm_hash), Password(password_1)),
)
PROPAGATION_CREDENTIALS_2 = Credentials(
identities=(Username(username), Username(special_username)),
secrets=(Password(password_1), Password(password_2), Password(password_3)),
)
PROPAGATION_CREDENTIALS_3 = Credentials(
identities=(Username(username),),
secrets=(Password(password_1),),
)
PROPAGATION_CREDENTIALS_4 = Credentials(
identities=(Username(username),),
secrets=(Password(password_2),),
PASSWORD_CREDENTIALS_1 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_1))
PASSWORD_CREDENTIALS_2 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_2))
LM_HASH_CREDENTIALS = Credentials(identity=Username(SPECIAL_USERNAME), secret=LMHash(LM_HASH))
NT_HASH_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=NTHash(NT_HASH))
SSH_KEY_CREDENTIALS = Credentials(
identity=Username(USERNAME), secret=SSHKeypair(PRIVATE_KEY, PUBLIC_KEY)
)
EMPTY_SECRET_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=None)
EMPTY_IDENTITY_CREDENTIALS = Credentials(identity=None, secret=Password(PASSWORD_3))
PROPAGATION_CREDENTIALS = [
PASSWORD_CREDENTIALS_1,
LM_HASH_CREDENTIALS,
NT_HASH_CREDENTIALS,
PASSWORD_CREDENTIALS_2,
SSH_KEY_CREDENTIALS,
EMPTY_SECRET_CREDENTIALS,
EMPTY_IDENTITY_CREDENTIALS,
]

View File

@ -1,6 +1,15 @@
import json
from itertools import product
import pytest
from tests.data_for_tests.propagation_credentials import (
LM_HASH,
NT_HASH,
PASSWORD_1,
PRIVATE_KEY,
PUBLIC_KEY,
USERNAME,
)
from common.credentials import (
Credentials,
@ -13,83 +22,101 @@ from common.credentials import (
Username,
)
USER1 = "test_user_1"
USER2 = "test_user_2"
PASSWORD = "12435"
LM_HASH = "AEBD4DE384C7EC43AAD3B435B51404EE"
NT_HASH = "7A21990FCD3D759941E45C490F143D5F"
PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY"
IDENTITIES = [Username(USERNAME), None]
IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}, None]
CREDENTIALS_DICT = {
"identities": [
{"credential_type": "USERNAME", "username": USER1},
{"credential_type": "USERNAME", "username": USER2},
],
"secrets": [
{"credential_type": "PASSWORD", "password": PASSWORD},
{"credential_type": "LM_HASH", "lm_hash": LM_HASH},
{"credential_type": "NT_HASH", "nt_hash": NT_HASH},
{
"credential_type": "SSH_KEYPAIR",
"public_key": PUBLIC_KEY,
"private_key": PRIVATE_KEY,
},
],
}
CREDENTIALS_JSON = json.dumps(CREDENTIALS_DICT)
IDENTITIES = (Username(USER1), Username(USER2))
SECRETS = (
Password(PASSWORD),
Password(PASSWORD_1),
LMHash(LM_HASH),
NTHash(NT_HASH),
SSHKeypair(PRIVATE_KEY, PUBLIC_KEY),
None,
)
CREDENTIALS_OBJECT = Credentials(IDENTITIES, SECRETS)
SECRET_DICTS = [
{"credential_type": "PASSWORD", "password": PASSWORD_1},
{"credential_type": "LM_HASH", "lm_hash": LM_HASH},
{"credential_type": "NT_HASH", "nt_hash": NT_HASH},
{
"credential_type": "SSH_KEYPAIR",
"public_key": PUBLIC_KEY,
"private_key": PRIVATE_KEY,
},
None,
]
CREDENTIALS = [
Credentials(identity, secret)
for identity, secret in product(IDENTITIES, SECRETS)
if not (identity is None and secret is None)
]
CREDENTIALS_DICTS = [
{"identity": identity, "secret": secret}
for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS)
if not (identity is None and secret is None)
]
def test_credentials_serialization_json():
serialized_credentials = Credentials.to_json(CREDENTIALS_OBJECT)
@pytest.mark.parametrize(
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_serialization_json(credentials, expected_credentials_dict):
serialized_credentials = Credentials.to_json(credentials)
assert json.loads(serialized_credentials) == CREDENTIALS_DICT
assert json.loads(serialized_credentials) == expected_credentials_dict
def test_credentials_serialization_mapping():
serialized_credentials = Credentials.to_mapping(CREDENTIALS_OBJECT)
@pytest.mark.parametrize(
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_serialization_mapping(credentials, expected_credentials_dict):
serialized_credentials = Credentials.to_mapping(credentials)
assert serialized_credentials == CREDENTIALS_DICT
assert serialized_credentials == expected_credentials_dict
def test_credentials_deserialization__from_mapping():
deserialized_credentials = Credentials.from_mapping(CREDENTIALS_DICT)
@pytest.mark.parametrize(
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_deserialization__from_mapping(expected_credentials, credentials_dict):
deserialized_credentials = Credentials.from_mapping(credentials_dict)
assert deserialized_credentials == CREDENTIALS_OBJECT
assert deserialized_credentials == expected_credentials
def test_credentials_deserialization__from_json():
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
@pytest.mark.parametrize(
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_deserialization__from_json(expected_credentials, credentials_dict):
deserialized_credentials = Credentials.from_json(json.dumps(credentials_dict))
assert deserialized_credentials == CREDENTIALS_OBJECT
assert deserialized_credentials == expected_credentials
def test_credentials_deserialization__invalid_credentials():
invalid_data = {"secrets": [], "unknown_key": []}
invalid_data = {"secret": SECRET_DICTS[0], "unknown_key": []}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component_type():
invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]}
invalid_data = {
"secret": SECRET_DICTS[0],
"identity": {"credential_type": "FAKE", "username": "user1"},
}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component():
invalid_data = {
"secrets": [],
"identities": [{"credential_type": "USERNAME", "unknown_field": "user1"}],
"secret": SECRET_DICTS[0],
"identity": {"credential_type": "USERNAME", "unknown_field": "user1"},
}
with pytest.raises(InvalidCredentialComponentError):
Credentials.from_mapping(invalid_data)
def test_create_credentials__none_none():
with pytest.raises(InvalidCredentialsError):
Credentials(None, None)

View File

@ -36,7 +36,7 @@ def test_pypykatz_result_parsing(monkeypatch):
username = Username("user")
password = Password("secret")
expected_credentials = Credentials([username], [password])
expected_credentials = Credentials(username, password)
collected_credentials = collect_credentials()
assert len(collected_credentials) == 1
@ -66,11 +66,11 @@ def test_pypykatz_result_parsing_defaults(monkeypatch):
username = Username("user2")
password = Password("secret2")
lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
expected_credentials = Credentials([username], [password, lm_hash])
expected_credentials = [Credentials(username, password), Credentials(username, lm_hash)]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 1
assert collected_credentials[0] == expected_credentials
assert len(collected_credentials) == 2
assert collected_credentials == expected_credentials
def test_pypykatz_result_parsing_no_identities(monkeypatch):
@ -86,8 +86,27 @@ def test_pypykatz_result_parsing_no_identities(monkeypatch):
lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
nt_hash = NTHash("E9F85516721DDC218359AD5280DB4450")
expected_credentials = Credentials([], [lm_hash, nt_hash])
expected_credentials = [Credentials(None, lm_hash), Credentials(None, nt_hash)]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 2
assert collected_credentials == expected_credentials
def test_pypykatz_result_parsing_no_secrets(monkeypatch):
username = "user3"
win_creds = [
WindowsCredentials(
username=username,
password="",
ntlm_hash="",
lm_hash="",
),
]
patch_pypykatz(win_creds, monkeypatch)
expected_credentials = [Credentials(Username(username), None)]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 1
assert collected_credentials[0] == expected_credentials
assert collected_credentials == expected_credentials

View File

@ -43,6 +43,12 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
"private_key": None,
},
{"name": "guest", "home_dir": "/", "public_key": None, "private_key": None},
{
"name": "",
"home_dir": "/home/mcus",
"public_key": "PubKey",
"private_key": "PrivKey",
},
]
patch_ssh_handler(ssh_creds, monkeypatch)
@ -53,11 +59,13 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
ssh_keypair1 = SSHKeypair("ExtremelyGoodPrivateKey", "SomePublicKeyUbuntu")
ssh_keypair2 = SSHKeypair("", "AnotherPublicKey")
ssh_keypair3 = SSHKeypair("PrivKey", "PubKey")
expected = [
Credentials(identities=[username], secrets=[ssh_keypair1]),
Credentials(identities=[username2], secrets=[ssh_keypair2]),
Credentials(identities=[username3], secrets=[]),
Credentials(identity=username, secret=ssh_keypair1),
Credentials(identity=username2, secret=ssh_keypair2),
Credentials(identity=username3, secret=None),
Credentials(identity=None, secret=ssh_keypair3),
]
collected = SSHCredentialCollector(patch_telemetry_messenger).collect_credentials()
assert expected == collected

View File

@ -30,21 +30,25 @@ EMPTY_CHANNEL_CREDENTIALS = {
TEST_CREDENTIALS = [
Credentials(
[Username("user1"), Username("user3")],
[
Password("abcdefg"),
Password("root"),
SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"),
],
)
identity=Username("user1"),
secret=Password("root"),
),
Credentials(identity=Username("user1"), secret=Password("abcdefg")),
Credentials(
identity=Username("user3"),
secret=SSHKeypair(public_key="some_public_key_1", private_key="some_private_key_1"),
),
Credentials(
identity=None,
secret=Password("super_secret"),
),
Credentials(identity=Username("user4"), secret=None),
]
SSH_KEYS_CREDENTIALS = [
Credentials(
[Username("root")],
[
SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
],
Username("root"),
SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
)
]
@ -85,6 +89,7 @@ def test_add_credentials_to_store(aggregating_credentials_store):
"root",
"user1",
"user3",
"user4",
]
)
assert actual_stored_credentials["exploit_password_list"] == set(
@ -94,6 +99,7 @@ def test_add_credentials_to_store(aggregating_credentials_store):
"abcdefg",
"password",
"root",
"super_secret",
]
)

View File

@ -8,13 +8,10 @@ from infection_monkey.telemetry.messengers.credentials_intercepting_telemetry_me
TELEM_CREDENTIALS = [
Credentials(
[Username("user1"), Username("user3")],
[
Password("abcdefg"),
Password("root"),
SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
],
)
Username("user1"),
SSHKeypair(public_key="some_public_key", private_key="some_private_key"),
),
Credentials(Username("root"), Password("password")),
]

View File

@ -2,7 +2,7 @@ import json
import pytest
from common.credentials import Credentials, Password, SSHKeypair, Username
from common.credentials import Credentials, Password, Username
from infection_monkey.telemetry.credentials_telem import CredentialsTelem
USERNAME = "m0nkey"
@ -14,26 +14,12 @@ PRIVATE_KEY = "priv_key"
@pytest.fixture
def credentials_for_test():
return Credentials(
[Username(USERNAME)], [Password(PASSWORD), SSHKeypair(PRIVATE_KEY, PUBLIC_KEY)]
)
return Credentials(Username(USERNAME), Password(PASSWORD))
def test_credential_telem_send(spy_send_telemetry, credentials_for_test):
expected_data = [
{
"identities": [{"username": USERNAME, "credential_type": "USERNAME"}],
"secrets": [
{"password": PASSWORD, "credential_type": "PASSWORD"},
{
"private_key": PRIVATE_KEY,
"public_key": PUBLIC_KEY,
"credential_type": "SSH_KEYPAIR",
},
],
}
]
expected_data = [Credentials.to_mapping(credentials_for_test)]
telem = CredentialsTelem([credentials_for_test])
telem.send()

View File

@ -1,42 +1,19 @@
from typing import Any, Iterable, Mapping, Sequence
from unittest.mock import MagicMock
import mongomock
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
from common.credentials import Credentials
from monkey_island.cc.repository import MongoCredentialsRepository
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
USER1 = "test_user_1"
USER2 = "test_user_2"
USER3 = "test_user_3"
PASSWORD = "12435"
PASSWORD2 = "password"
PASSWORD3 = "lozinka"
LM_HASH = "AEBD4DE384C7EC43AAD3B435B51404EE"
NT_HASH = "7A21990FCD3D759941E45C490F143D5F"
PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY"
IDENTITIES_1 = (Username(USER1), Username(USER2))
SECRETS_1 = (
Password(PASSWORD),
LMHash(LM_HASH),
NTHash(NT_HASH),
SSHKeypair(PRIVATE_KEY, PUBLIC_KEY),
)
CREDENTIALS_OBJECT_1 = Credentials(IDENTITIES_1, SECRETS_1)
IDENTITIES_2 = (Username(USER3),)
SECRETS_2 = (Password(PASSWORD2), Password(PASSWORD3))
CREDENTIALS_OBJECT_2 = Credentials(IDENTITIES_2, SECRETS_2)
CONFIGURED_CREDENTIALS = [CREDENTIALS_OBJECT_1]
STOLEN_CREDENTIALS = [CREDENTIALS_OBJECT_2]
CREDENTIALS_LIST = [CREDENTIALS_OBJECT_1, CREDENTIALS_OBJECT_2]
CONFIGURED_CREDENTIALS = PROPAGATION_CREDENTIALS[0:3]
STOLEN_CREDENTIALS = PROPAGATION_CREDENTIALS[3:]
def reverse(data: bytes) -> bytes:
@ -45,6 +22,7 @@ def reverse(data: bytes) -> bytes:
@pytest.fixture
def repository_encryptor():
# NOTE: Tests will fail if any inputs to this mock encryptor are palindromes.
repository_encryptor = MagicMock(spec=ILockableEncryptor)
repository_encryptor.encrypt = MagicMock(side_effect=reverse)
repository_encryptor.decrypt = MagicMock(side_effect=reverse)
@ -81,9 +59,9 @@ def test_mongo_repository_get_all(mongo_repository):
def test_mongo_repository_configured(mongo_repository):
mongo_repository.save_configured_credentials(CREDENTIALS_LIST)
mongo_repository.save_configured_credentials(PROPAGATION_CREDENTIALS)
actual_configured_credentials = mongo_repository.get_configured_credentials()
assert actual_configured_credentials == CREDENTIALS_LIST
assert actual_configured_credentials == PROPAGATION_CREDENTIALS
mongo_repository.remove_configured_credentials()
actual_configured_credentials = mongo_repository.get_configured_credentials()
@ -93,7 +71,7 @@ def test_mongo_repository_configured(mongo_repository):
def test_mongo_repository_stolen(mongo_repository):
mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
actual_stolen_credentials = mongo_repository.get_stolen_credentials()
assert sorted(actual_stolen_credentials) == sorted(STOLEN_CREDENTIALS)
assert actual_stolen_credentials == STOLEN_CREDENTIALS
mongo_repository.remove_stolen_credentials()
actual_stolen_credentials = mongo_repository.get_stolen_credentials()
@ -104,7 +82,7 @@ def test_mongo_repository_all(mongo_repository):
mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS)
mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
actual_credentials = mongo_repository.get_all_credentials()
assert actual_credentials == CREDENTIALS_LIST
assert actual_credentials == PROPAGATION_CREDENTIALS
mongo_repository.remove_all_credentials()
@ -113,41 +91,65 @@ def test_mongo_repository_all(mongo_repository):
assert mongo_repository.get_configured_credentials() == []
# NOTE: The following tests are complicated, but they work. Rather than spend the effort to improve
# them now, we can revisit them when we resolve #2072. Resolving #2072 will make it easier to
# simplify these tests.
def test_configured_secrets_encrypted(mongo_repository, mongo_client):
mongo_repository.save_configured_credentials([CREDENTIALS_OBJECT_2])
check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2)
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_configured_secrets_encrypted(
mongo_repository: MongoCredentialsRepository,
mongo_client: MongoClient,
credentials: Sequence[Credentials],
):
mongo_repository.save_configured_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client):
mongo_repository.save_stolen_credentials([CREDENTIALS_OBJECT_2])
check_if_stored_credentials_encrypted(mongo_client, CREDENTIALS_OBJECT_2)
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
mongo_repository.save_stolen_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials)
def check_if_stored_credentials_encrypted(mongo_client, original_credentials):
raw_credentials = get_all_credentials_in_mongo(mongo_client)
def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
original_credentials_mapping = Credentials.to_mapping(original_credentials)
raw_credentials = get_all_credentials_in_mongo(mongo_client)
for rc in raw_credentials:
for identity_or_secret, credentials_components in rc.items():
for component in credentials_components:
for key, value in component.items():
assert (
original_credentials_mapping[identity_or_secret][0].get(key, None) != value
)
for identity_or_secret, credentials_component in rc.items():
if original_credentials_mapping[identity_or_secret] is None:
assert credentials_component is None
else:
for key, value in credentials_component.items():
assert original_credentials_mapping[identity_or_secret][key] != value.decode()
def get_all_credentials_in_mongo(mongo_client):
def get_all_credentials_in_mongo(
mongo_client: MongoClient,
) -> Iterable[Mapping[str, Mapping[str, Any]]]:
encrypted_credentials = []
# Loop through all databases and collections and search for credentials. We don't want the tests
# to assume anything about the internal workings of the repository.
for db in mongo_client.list_database_names():
for collection in mongo_client[db].list_collection_names():
mongo_credentials = mongo_client[db][collection].find({})
for mc in mongo_credentials:
del mc["_id"]
encrypted_credentials.append(mc)
for collection in get_all_collections_in_mongo(mongo_client):
mongo_credentials = collection.find({})
for mc in mongo_credentials:
del mc["_id"]
encrypted_credentials.append(mc)
return encrypted_credentials
def get_all_collections_in_mongo(mongo_client: MongoClient) -> Iterable[Collection]:
collections = [
collection
for db in get_all_databases_in_mongo(mongo_client)
for collection in get_all_collections_in_database(db)
]
assert len(collections) > 0
return collections
def get_all_databases_in_mongo(mongo_client) -> Iterable[Database]:
return (mongo_client[db_name] for db_name in mongo_client.list_database_names())
def get_all_collections_in_database(db: Database) -> Iterable[Collection]:
return (db[collection_name] for collection_name in db.list_collection_names())

View File

@ -6,10 +6,10 @@ from urllib.parse import urljoin
import pytest
from tests.common import StubDIContainer
from tests.data_for_tests.propagation_credentials import (
PROPAGATION_CREDENTIALS_1,
PROPAGATION_CREDENTIALS_2,
PROPAGATION_CREDENTIALS_3,
PROPAGATION_CREDENTIALS_4,
LM_HASH_CREDENTIALS,
NT_HASH_CREDENTIALS,
PASSWORD_CREDENTIALS_1,
PASSWORD_CREDENTIALS_2,
)
from tests.monkey_island import InMemoryCredentialsRepository
@ -43,21 +43,19 @@ def flask_client(build_flask_client, credentials_repository):
def test_propagation_credentials_endpoint_get(flask_client, credentials_repository):
credentials_repository.save_configured_credentials(
[PROPAGATION_CREDENTIALS_1, PROPAGATION_CREDENTIALS_3]
)
credentials_repository.save_stolen_credentials(
[PROPAGATION_CREDENTIALS_2, PROPAGATION_CREDENTIALS_4]
[PASSWORD_CREDENTIALS_1, NT_HASH_CREDENTIALS]
)
credentials_repository.save_stolen_credentials([LM_HASH_CREDENTIALS, PASSWORD_CREDENTIALS_2])
resp = flask_client.get(ALL_CREDENTIALS_URL)
actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 4
assert PROPAGATION_CREDENTIALS_1 in actual_propagation_credentials
assert PROPAGATION_CREDENTIALS_2 in actual_propagation_credentials
assert PROPAGATION_CREDENTIALS_3 in actual_propagation_credentials
assert PROPAGATION_CREDENTIALS_4 in actual_propagation_credentials
assert PASSWORD_CREDENTIALS_1 in actual_propagation_credentials
assert LM_HASH_CREDENTIALS in actual_propagation_credentials
assert NT_HASH_CREDENTIALS in actual_propagation_credentials
assert PASSWORD_CREDENTIALS_2 in actual_propagation_credentials
def pre_populate_repository(
@ -72,7 +70,7 @@ def pre_populate_repository(
@pytest.mark.parametrize("url", [CONFIGURED_CREDENTIALS_URL, STOLEN_CREDENTIALS_URL])
def test_propagation_credentials_endpoint__get_stolen(flask_client, credentials_repository, url):
pre_populate_repository(
url, credentials_repository, [PROPAGATION_CREDENTIALS_1, PROPAGATION_CREDENTIALS_2]
url, credentials_repository, [PASSWORD_CREDENTIALS_1, LM_HASH_CREDENTIALS]
)
resp = flask_client.get(url)
@ -80,19 +78,19 @@ def test_propagation_credentials_endpoint__get_stolen(flask_client, credentials_
assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 2
assert actual_propagation_credentials[0] == PROPAGATION_CREDENTIALS_1
assert actual_propagation_credentials[1] == PROPAGATION_CREDENTIALS_2
assert actual_propagation_credentials[0] == PASSWORD_CREDENTIALS_1
assert actual_propagation_credentials[1] == LM_HASH_CREDENTIALS
@pytest.mark.parametrize("url", [CONFIGURED_CREDENTIALS_URL, STOLEN_CREDENTIALS_URL])
def test_propagation_credentials_endpoint__post_stolen(flask_client, credentials_repository, url):
pre_populate_repository(url, credentials_repository, [PROPAGATION_CREDENTIALS_1])
pre_populate_repository(url, credentials_repository, [PASSWORD_CREDENTIALS_1])
resp = flask_client.post(
url,
json=[
Credentials.to_json(PROPAGATION_CREDENTIALS_2),
Credentials.to_json(PROPAGATION_CREDENTIALS_3),
Credentials.to_json(LM_HASH_CREDENTIALS),
Credentials.to_json(NT_HASH_CREDENTIALS),
],
)
assert resp.status_code == HTTPStatus.NO_CONTENT
@ -102,15 +100,15 @@ def test_propagation_credentials_endpoint__post_stolen(flask_client, credentials
assert resp.status_code == HTTPStatus.OK
assert len(retrieved_propagation_credentials) == 3
assert PROPAGATION_CREDENTIALS_1 in retrieved_propagation_credentials
assert PROPAGATION_CREDENTIALS_2 in retrieved_propagation_credentials
assert PROPAGATION_CREDENTIALS_3 in retrieved_propagation_credentials
assert PASSWORD_CREDENTIALS_1 in retrieved_propagation_credentials
assert LM_HASH_CREDENTIALS in retrieved_propagation_credentials
assert NT_HASH_CREDENTIALS in retrieved_propagation_credentials
@pytest.mark.parametrize("url", [CONFIGURED_CREDENTIALS_URL, STOLEN_CREDENTIALS_URL])
def test_stolen_propagation_credentials_endpoint_delete(flask_client, credentials_repository, url):
pre_populate_repository(
url, credentials_repository, [PROPAGATION_CREDENTIALS_1, PROPAGATION_CREDENTIALS_2]
url, credentials_repository, [PASSWORD_CREDENTIALS_1, LM_HASH_CREDENTIALS]
)
resp = flask_client.delete(url)
assert resp.status_code == HTTPStatus.NO_CONTENT
@ -136,8 +134,8 @@ def test_propagation_credentials_endpoint__post_not_found(flask_client):
resp = flask_client.post(
NON_EXISTENT_COLLECTION_URL,
json=[
Credentials.to_json(PROPAGATION_CREDENTIALS_2),
Credentials.to_json(PROPAGATION_CREDENTIALS_3),
Credentials.to_json(LM_HASH_CREDENTIALS),
Credentials.to_json(NT_HASH_CREDENTIALS),
],
)
assert resp.status_code == HTTPStatus.NOT_FOUND