diff --git a/monkey/common/configuration/agent_configuration.py b/monkey/common/configuration/agent_configuration.py index 097b86382..d6e61ddd1 100644 --- a/monkey/common/configuration/agent_configuration.py +++ b/monkey/common/configuration/agent_configuration.py @@ -19,12 +19,14 @@ from .agent_sub_configurations import ( class InvalidConfigurationError(Exception): - pass + def __init__(self, message: str): + self._message = message - -INVALID_CONFIGURATION_ERROR_MESSAGE = ( - "Cannot construct an AgentConfiguration object with the supplied, invalid data:" -) + def __str__(self) -> str: + return ( + f"Cannot construct an AgentConfiguration object with the supplied, invalid data: " + f"{self._message}" + ) @dataclass(frozen=True) @@ -42,7 +44,7 @@ class AgentConfiguration: try: AgentConfigurationSchema().dump(self) except Exception as err: - raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + raise InvalidConfigurationError(str(err)) @staticmethod def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration: @@ -59,7 +61,7 @@ class AgentConfiguration: config_dict = AgentConfigurationSchema().load(config_mapping) return AgentConfiguration(**config_dict) except MarshmallowError as err: - raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + raise InvalidConfigurationError(str(err)) @staticmethod def from_json(config_json: str) -> AgentConfiguration: @@ -75,7 +77,7 @@ class AgentConfiguration: config_dict = AgentConfigurationSchema().loads(config_json) return AgentConfiguration(**config_dict) except MarshmallowError as err: - raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + raise InvalidConfigurationError(str(err)) @staticmethod def to_json(config: AgentConfiguration) -> str: diff --git a/monkey/common/credentials/__init__.py b/monkey/common/credentials/__init__.py index 92a778886..6275e0985 100644 --- a/monkey/common/credentials/__init__.py +++ b/monkey/common/credentials/__init__.py @@ -1,8 +1,12 @@ from .credential_component_type import CredentialComponentType from .i_credential_component import ICredentialComponent -from .credentials import Credentials + +from .validators import InvalidCredentialComponentError, InvalidCredentialsError + from .lm_hash import LMHash from .nt_hash import NTHash from .password import Password from .ssh_keypair import SSHKeypair from .username import Username + +from .credentials import Credentials diff --git a/monkey/common/credentials/credential_component_schema.py b/monkey/common/credentials/credential_component_schema.py new file mode 100644 index 000000000..ff3e657be --- /dev/null +++ b/monkey/common/credentials/credential_component_schema.py @@ -0,0 +1,20 @@ +from marshmallow import Schema, post_load, validate +from marshmallow_enum import EnumField + +from common.utils.code_utils import del_key + +from . import CredentialComponentType + + +class CredentialTypeField(EnumField): + def __init__(self, credential_component_type: CredentialComponentType): + super().__init__( + CredentialComponentType, validate=validate.Equal(credential_component_type) + ) + + +class CredentialComponentSchema(Schema): + @post_load + def _strip_credential_type(self, data, **kwargs): + del_key(data, "credential_type") + return data diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index d5591f6d7..e04be27a5 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -1,10 +1,144 @@ -from dataclasses import dataclass -from typing import Tuple +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Mapping, MutableMapping, Sequence, Tuple + +from marshmallow import Schema, fields, post_load, pre_dump +from marshmallow.exceptions import MarshmallowError + +from . import ( + CredentialComponentType, + InvalidCredentialComponentError, + InvalidCredentialsError, + LMHash, + NTHash, + Password, + SSHKeypair, + Username, +) from .i_credential_component import ICredentialComponent +from .lm_hash import LMHashSchema +from .nt_hash import NTHashSchema +from .password import PasswordSchema +from .ssh_keypair import SSHKeypairSchema +from .username import UsernameSchema + +CREDENTIAL_COMPONENT_TYPE_TO_CLASS = { + CredentialComponentType.LM_HASH: LMHash, + CredentialComponentType.NT_HASH: NTHash, + CredentialComponentType.PASSWORD: Password, + CredentialComponentType.SSH_KEYPAIR: SSHKeypair, + CredentialComponentType.USERNAME: Username, +} + +CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA = { + CredentialComponentType.LM_HASH: LMHashSchema(), + CredentialComponentType.NT_HASH: NTHashSchema(), + CredentialComponentType.PASSWORD: PasswordSchema(), + CredentialComponentType.SSH_KEYPAIR: SSHKeypairSchema(), + CredentialComponentType.USERNAME: UsernameSchema(), +} + + +class CredentialsSchema(Schema): + # Use fields.List instead of fields.Tuple because marshmallow requires fields.Tuple to have a + # fixed length. + identities = fields.List(fields.Mapping()) + secrets = fields.List(fields.Mapping()) + + @post_load + def _make_credentials( + self, data: MutableMapping, **kwargs: Mapping[str, Any] + ) -> Mapping[str, Sequence[Mapping[str, Any]]]: + data["identities"] = tuple( + [ + CredentialsSchema._build_credential_component(component) + for component in data["identities"] + ] + ) + data["secrets"] = tuple( + [ + CredentialsSchema._build_credential_component(component) + for component in data["secrets"] + ] + ) + + return data + + @staticmethod + def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent: + try: + credential_component_type = CredentialComponentType[data["credential_type"]] + except KeyError as err: + raise InvalidCredentialsError(f"Unknown credential component type {err}") + + credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type] + credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ + credential_component_type + ] + + try: + return credential_component_class(**credential_component_schema.load(data)) + except MarshmallowError as err: + raise InvalidCredentialComponentError(credential_component_class, str(err)) + + @pre_dump + def _serialize_credentials( + self, credentials: Credentials, **kwargs + ) -> Mapping[str, Sequence[Mapping[str, Any]]]: + data = {} + + data["identities"] = tuple( + [ + CredentialsSchema._serialize_credential_component(component) + for component in credentials.identities + ] + ) + data["secrets"] = tuple( + [ + CredentialsSchema._serialize_credential_component(component) + for component in credentials.secrets + ] + ) + + return data + + @staticmethod + def _serialize_credential_component( + credential_component: ICredentialComponent, + ) -> Mapping[str, Any]: + credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ + credential_component.credential_type + ] + + return credential_component_schema.dump(credential_component) @dataclass(frozen=True) class Credentials: identities: Tuple[ICredentialComponent] secrets: Tuple[ICredentialComponent] + + @staticmethod + def from_mapping(credentials: Mapping) -> Credentials: + try: + deserialized_data = CredentialsSchema().load(credentials) + return Credentials(**deserialized_data) + except (InvalidCredentialsError, InvalidCredentialComponentError) as err: + raise err + except MarshmallowError as err: + raise InvalidCredentialsError(str(err)) + + @staticmethod + def from_json(credentials: str) -> Credentials: + try: + deserialized_data = CredentialsSchema().loads(credentials) + return Credentials(**deserialized_data) + except (InvalidCredentialsError, InvalidCredentialComponentError) as err: + raise err + except MarshmallowError as err: + raise InvalidCredentialsError(str(err)) + + @staticmethod + def to_json(credentials: Credentials) -> str: + return CredentialsSchema().dumps(credentials) diff --git a/monkey/common/credentials/lm_hash.py b/monkey/common/credentials/lm_hash.py index 818999244..5a04e8bae 100644 --- a/monkey/common/credentials/lm_hash.py +++ b/monkey/common/credentials/lm_hash.py @@ -1,6 +1,15 @@ from dataclasses import dataclass, field +from marshmallow import fields + from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField +from .validators import credential_component_validator, ntlm_hash_validator + + +class LMHashSchema(CredentialComponentSchema): + credential_type = CredentialTypeField(CredentialComponentType.LM_HASH) + lm_hash = fields.Str(validate=ntlm_hash_validator) @dataclass(frozen=True) @@ -9,3 +18,6 @@ class LMHash(ICredentialComponent): default=CredentialComponentType.LM_HASH, init=False ) lm_hash: str + + def __post_init__(self): + credential_component_validator(LMHashSchema(), self) diff --git a/monkey/common/credentials/nt_hash.py b/monkey/common/credentials/nt_hash.py index 67246f695..a7145a5a0 100644 --- a/monkey/common/credentials/nt_hash.py +++ b/monkey/common/credentials/nt_hash.py @@ -1,6 +1,15 @@ from dataclasses import dataclass, field +from marshmallow import fields + from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField +from .validators import credential_component_validator, ntlm_hash_validator + + +class NTHashSchema(CredentialComponentSchema): + credential_type = CredentialTypeField(CredentialComponentType.NT_HASH) + nt_hash = fields.Str(validate=ntlm_hash_validator) @dataclass(frozen=True) @@ -9,3 +18,6 @@ class NTHash(ICredentialComponent): default=CredentialComponentType.NT_HASH, init=False ) nt_hash: str + + def __post_init__(self): + credential_component_validator(NTHashSchema(), self) diff --git a/monkey/common/credentials/password.py b/monkey/common/credentials/password.py index ff6da0eae..b7bd1b84c 100644 --- a/monkey/common/credentials/password.py +++ b/monkey/common/credentials/password.py @@ -1,6 +1,14 @@ from dataclasses import dataclass, field +from marshmallow import fields + from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField + + +class PasswordSchema(CredentialComponentSchema): + credential_type = CredentialTypeField(CredentialComponentType.PASSWORD) + password = fields.Str() @dataclass(frozen=True) diff --git a/monkey/common/credentials/ssh_keypair.py b/monkey/common/credentials/ssh_keypair.py index 897acba8a..6b8dcded2 100644 --- a/monkey/common/credentials/ssh_keypair.py +++ b/monkey/common/credentials/ssh_keypair.py @@ -1,6 +1,17 @@ from dataclasses import dataclass, field +from marshmallow import fields + from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField + + +class SSHKeypairSchema(CredentialComponentSchema): + credential_type = CredentialTypeField(CredentialComponentType.SSH_KEYPAIR) + # TODO: Find a list of valid formats for ssh keys and add validators. + # See https://github.com/nemchik/ssh-key-regex + private_key = fields.Str() + public_key = fields.Str() @dataclass(frozen=True) diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index c3249058e..86fde05ff 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,6 +1,14 @@ from dataclasses import dataclass, field +from marshmallow import fields + from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField + + +class UsernameSchema(CredentialComponentSchema): + credential_type = CredentialTypeField(CredentialComponentType.USERNAME) + username = fields.Str() @dataclass(frozen=True) diff --git a/monkey/common/credentials/validators.py b/monkey/common/credentials/validators.py new file mode 100644 index 000000000..2e0e2e93c --- /dev/null +++ b/monkey/common/credentials/validators.py @@ -0,0 +1,50 @@ +import re +from typing import Type + +from marshmallow import Schema, validate + +from . import ICredentialComponent + +_ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$") +ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex) + + +class InvalidCredentialComponentError(Exception): + def __init__(self, credential_component_class: Type[ICredentialComponent], message: str): + self._credential_component_name = credential_component_class.__name__ + self._message = message + + def __str__(self) -> str: + return ( + f"Cannot construct a {self._credential_component_name} object with the supplied, " + f"invalid data: {self._message}" + ) + + +class InvalidCredentialsError(Exception): + def __init__(self, message: str): + self._message = message + + def __str__(self) -> str: + return ( + f"Cannot construct a Credentials object with the supplied, " + f"invalid data: {self._message}" + ) + + +def credential_component_validator(schema: Schema, credential_component: ICredentialComponent): + """ + Validate a credential component + + :param schema: A marshmallow schema used for validating the component + :param credential_component: A credential component to be validated + :raises InvalidCredentialComponent: if the credential_component contains invalid data + """ + try: + serialized_data = schema.dump(credential_component) + + # This will raise an exception if the object is invalid. Calling this in __post__init() + # makes it impossible to construct an invalid object + schema.load(serialized_data) + except Exception as err: + raise InvalidCredentialComponentError(credential_component.__class__, err) diff --git a/monkey/common/di_container.py b/monkey/common/di_container.py index 0a9748f4b..c1c740e39 100644 --- a/monkey/common/di_container.py +++ b/monkey/common/di_container.py @@ -1,5 +1,7 @@ import inspect -from typing import Any, MutableMapping, Sequence, Type, TypeVar +from typing import Any, Sequence, Type, TypeVar + +from common.utils.code_utils import del_key T = TypeVar("T") @@ -43,7 +45,7 @@ class DIContainer: ) self._type_registry[interface] = concrete_type - DIContainer._del_key(self._instance_registry, interface) + del_key(self._instance_registry, interface) def register_instance(self, interface: Type[T], instance: T): """ @@ -59,7 +61,7 @@ class DIContainer: ) self._instance_registry[interface] = instance - DIContainer._del_key(self._type_registry, interface) + del_key(self._type_registry, interface) def register_convention(self, type_: Type[T], name: str, instance: T): """ @@ -168,8 +170,8 @@ class DIContainer: :param interface: The interface to release """ - DIContainer._del_key(self._type_registry, interface) - DIContainer._del_key(self._instance_registry, interface) + del_key(self._type_registry, interface) + del_key(self._instance_registry, interface) def release_convention(self, type_: Type[T], name: str): """ @@ -179,18 +181,4 @@ class DIContainer: :param name: The name of the dependency parameter """ convention_identifier = (type_, name) - DIContainer._del_key(self._convention_registry, convention_identifier) - - @staticmethod - def _del_key(mapping: MutableMapping[T, Any], key: T): - """ - Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError - if the key does not exist. - - :param mapping: A mapping from which a key will be deleted - :param key: A key to delete from `mapping` - """ - try: - del mapping[key] - except KeyError: - pass + del_key(self._convention_registry, convention_identifier) diff --git a/monkey/common/utils/code_utils.py b/monkey/common/utils/code_utils.py index 251ce9375..21c0ce175 100644 --- a/monkey/common/utils/code_utils.py +++ b/monkey/common/utils/code_utils.py @@ -1,5 +1,7 @@ import queue -from typing import Any, List +from typing import Any, List, MutableMapping, TypeVar + +T = TypeVar("T") class abstractstatic(staticmethod): @@ -30,3 +32,19 @@ def queue_to_list(q: queue.Queue) -> List[Any]: pass return list_ + + +def del_key(mapping: MutableMapping[T, Any], key: T): + """ + Delete a key from a mapping. + + Unlike the `del` keyword, this function does not raise a KeyError + if the key does not exist. + + :param mapping: A mapping from which a key will be deleted + :param key: A key to delete from `mapping` + """ + try: + del mapping[key] + except KeyError: + pass diff --git a/monkey/infection_monkey/telemetry/credentials_telem.py b/monkey/infection_monkey/telemetry/credentials_telem.py index 11e2bbb6d..7504b9c65 100644 --- a/monkey/infection_monkey/telemetry/credentials_telem.py +++ b/monkey/infection_monkey/telemetry/credentials_telem.py @@ -1,9 +1,8 @@ -import enum import json -from typing import Dict, Iterable +from typing import Iterable from common.common_consts.telem_categories import TelemCategoryEnum -from common.credentials import Credentials, ICredentialComponent +from common.credentials import Credentials from infection_monkey.telemetry.base_telem import BaseTelem @@ -24,24 +23,5 @@ class CredentialsTelem(BaseTelem): def send(self, log_data=True): super().send(log_data=False) - def get_data(self) -> Dict: - # TODO: At a later time we can consider factoring this into a Serializer class or similar. - return json.loads(json.dumps(self._credentials, default=_serialize)) - - -def _serialize(obj): - if isinstance(obj, enum.Enum): - return obj.name - - if isinstance(obj, ICredentialComponent): - # This is a workaround for ICredentialComponents that are implemented as dataclasses. If the - # credential_type attribute is populated with `field(init=False, ...)`, then credential_type - # is not added to the object's __dict__ attribute. The biggest risk of this workaround is - # that we might change the name of the credential_type field in ICredentialComponents, but - # automated refactoring tools would not detect that this string needs to change. This is - # mittigated by the call to getattr() below, which will raise an AttributeException if the - # attribute name changes and a unit test will fail under these conditions. - credential_type = getattr(obj, "credential_type") - return dict(obj.__dict__, **{"credential_type": credential_type}) - - return getattr(obj, "__dict__", str(obj)) + def get_data(self): + return [json.loads(Credentials.to_json(c)) for c in self._credentials] diff --git a/monkey/tests/unit_tests/common/credentials/__init__.py b/monkey/tests/unit_tests/common/credentials/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/monkey/tests/unit_tests/common/credentials/test_credential_components.py b/monkey/tests/unit_tests/common/credentials/test_credential_components.py new file mode 100644 index 000000000..55a2b9279 --- /dev/null +++ b/monkey/tests/unit_tests/common/credentials/test_credential_components.py @@ -0,0 +1,148 @@ +from typing import Any, Mapping + +import pytest +from marshmallow.exceptions import ValidationError + +from common.credentials import ( + CredentialComponentType, + LMHash, + NTHash, + Password, + SSHKeypair, + Username, +) +from common.credentials.lm_hash import LMHashSchema +from common.credentials.nt_hash import NTHashSchema +from common.credentials.password import PasswordSchema +from common.credentials.ssh_keypair import SSHKeypairSchema +from common.credentials.username import UsernameSchema + +PARAMETRIZED_PARAMETER_NAMES = ( + "credential_component_class, schema_class, credential_component_type, credential_component_data" +) + +PARAMETRIZED_PARAMETER_VALUES = [ + (Username, UsernameSchema, CredentialComponentType.USERNAME, {"username": "test_user"}), + (Password, PasswordSchema, CredentialComponentType.PASSWORD, {"password": "123456"}), + ( + LMHash, + LMHashSchema, + CredentialComponentType.LM_HASH, + {"lm_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"}, + ), + ( + NTHash, + NTHashSchema, + CredentialComponentType.NT_HASH, + {"nt_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"}, + ), + ( + SSHKeypair, + SSHKeypairSchema, + CredentialComponentType.SSH_KEYPAIR, + {"public_key": "TEST_PUBLIC_KEY", "private_key": "TEST_PRIVATE_KEY"}, + ), +] + + +INVALID_COMPONENT_DATA = { + CredentialComponentType.USERNAME: ({"username": None}, {"username": 1}, {"username": 2.0}), + CredentialComponentType.PASSWORD: ({"password": None}, {"password": 1}, {"password": 2.0}), + CredentialComponentType.LM_HASH: ( + {"lm_hash": None}, + {"lm_hash": 1}, + {"lm_hash": 2.0}, + {"lm_hash": "0123456789012345678901234568901"}, + {"lm_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"}, + ), + CredentialComponentType.NT_HASH: ( + {"nt_hash": None}, + {"nt_hash": 1}, + {"nt_hash": 2.0}, + {"nt_hash": "0123456789012345678901234568901"}, + {"nt_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"}, + ), + CredentialComponentType.SSH_KEYPAIR: ( + {"public_key": None, "private_key": "TEST_PRIVATE_KEY"}, + {"public_key": "TEST_PUBLIC_KEY", "private_key": None}, + {"public_key": 1, "private_key": "TEST_PRIVATE_KEY"}, + {"public_key": "TEST_PUBLIC_KEY", "private_key": 999}, + ), +} + + +def build_component_dict( + credential_component_type: CredentialComponentType, credential_component_data: Mapping[str, Any] +): + return {"credential_type": credential_component_type.name, **credential_component_data} + + +@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) +def test_credential_component_serialize( + credential_component_class, schema_class, credential_component_type, credential_component_data +): + schema = schema_class() + constructed_object = credential_component_class(**credential_component_data) + + serialized_object = schema.dump(constructed_object) + + assert serialized_object == build_component_dict( + credential_component_type, credential_component_data + ) + + +@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) +def test_credential_component_deserialize( + credential_component_class, schema_class, credential_component_type, credential_component_data +): + schema = schema_class() + credential_dict = build_component_dict(credential_component_type, credential_component_data) + expected_deserialized_object = credential_component_class(**credential_component_data) + + deserialized_object = credential_component_class(**schema.load(credential_dict)) + + assert deserialized_object == expected_deserialized_object + + +@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) +def test_invalid_credential_type( + credential_component_class, schema_class, credential_component_type, credential_component_data +): + invalid_component_dict = build_component_dict( + credential_component_type, credential_component_data + ) + invalid_component_dict["credential_type"] = "INVALID" + schema = schema_class() + + with pytest.raises(ValidationError): + credential_component_class(**schema.load(invalid_component_dict)) + + +@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) +def test_encorrect_credential_type( + credential_component_class, schema_class, credential_component_type, credential_component_data +): + incorrect_component_dict = build_component_dict( + credential_component_type, credential_component_data + ) + incorrect_component_dict["credential_type"] = ( + CredentialComponentType.USERNAME.name + if credential_component_type != CredentialComponentType.USERNAME + else CredentialComponentType.PASSWORD + ) + schema = schema_class() + + with pytest.raises(ValidationError): + credential_component_class(**schema.load(incorrect_component_dict)) + + +@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) +def test_invalid_values( + credential_component_class, schema_class, credential_component_type, credential_component_data +): + schema = schema_class() + + for invalid_component_data in INVALID_COMPONENT_DATA[credential_component_type]: + component_dict = build_component_dict(credential_component_type, invalid_component_data) + with pytest.raises(ValidationError): + credential_component_class(**schema.load(component_dict)) diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py new file mode 100644 index 000000000..c68e1c813 --- /dev/null +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -0,0 +1,89 @@ +import json + +import pytest + +from common.credentials import ( + Credentials, + InvalidCredentialComponentError, + InvalidCredentialsError, + LMHash, + NTHash, + Password, + SSHKeypair, + Username, +) + +USER1 = "test_user_1" +USER2 = "test_user_2" +PASSWORD = "12435" +LM_HASH = "AEBD4DE384C7EC43AAD3B435B51404EE" +NT_HASH = "7A21990FCD3D759941E45C490F143D5F" +PUBLIC_KEY = "MY_PUBLIC_KEY" +PRIVATE_KEY = "MY_PRIVATE_KEY" + +CREDENTIALS_DICT = { + "identities": [ + {"credential_type": "USERNAME", "username": USER1}, + {"credential_type": "USERNAME", "username": USER2}, + ], + "secrets": [ + {"credential_type": "PASSWORD", "password": PASSWORD}, + {"credential_type": "LM_HASH", "lm_hash": LM_HASH}, + {"credential_type": "NT_HASH", "nt_hash": NT_HASH}, + { + "credential_type": "SSH_KEYPAIR", + "public_key": PUBLIC_KEY, + "private_key": PRIVATE_KEY, + }, + ], +} + +CREDENTIALS_JSON = json.dumps(CREDENTIALS_DICT) + +IDENTITIES = (Username(USER1), Username(USER2)) +SECRETS = ( + Password(PASSWORD), + LMHash(LM_HASH), + NTHash(NT_HASH), + SSHKeypair(PRIVATE_KEY, PUBLIC_KEY), +) +CREDENTIALS_OBJECT = Credentials(IDENTITIES, SECRETS) + + +def test_credentials_serialization_json(): + serialized_credentials = Credentials.to_json(CREDENTIALS_OBJECT) + + assert json.loads(serialized_credentials) == CREDENTIALS_DICT + + +def test_credentials_deserialization__from_mapping(): + deserialized_credentials = Credentials.from_mapping(CREDENTIALS_DICT) + + assert deserialized_credentials == CREDENTIALS_OBJECT + + +def test_credentials_deserialization__from_json(): + deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON) + + assert deserialized_credentials == CREDENTIALS_OBJECT + + +def test_credentials_deserialization__invalid_credentials(): + invalid_data = {"secrets": [], "unknown_key": []} + with pytest.raises(InvalidCredentialsError): + Credentials.from_mapping(invalid_data) + + +def test_credentials_deserialization__invalid_component_type(): + invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]} + with pytest.raises(InvalidCredentialsError): + Credentials.from_mapping(invalid_data) + + +def test_credentials_deserialization__invalid_component(): + invalid_data = { + "secrets": [], + "identities": [{"credential_type": "USERNAME", "unknown_field": "user1"}], + } + with pytest.raises(InvalidCredentialComponentError): + Credentials.from_mapping(invalid_data) diff --git a/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py b/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py new file mode 100644 index 000000000..5f50110e8 --- /dev/null +++ b/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py @@ -0,0 +1,26 @@ +import pytest + +from common.credentials import InvalidCredentialComponentError, LMHash, NTHash + +VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D" +INVALID_HASHES = ( + 0, + 1, + 2.0, + "invalid", + "0123456789012345678901234568901", + "E52GAC67419A9A224A3B108F3FA6CB6D", +) + + +@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash)) +def test_construct_valid_ntlm_hash(ntlm_hash_class): + # This test will fail if an exception is raised + ntlm_hash_class(VALID_HASH) + + +@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash)) +def test_construct_invalid_ntlm_hash(ntlm_hash_class): + for invalid_hash in INVALID_HASHES: + with pytest.raises(InvalidCredentialComponentError): + ntlm_hash_class(invalid_hash) diff --git a/monkey/tests/unit_tests/common/utils/test_code_utils.py b/monkey/tests/unit_tests/common/utils/test_code_utils.py index 411b07a63..e5980723d 100644 --- a/monkey/tests/unit_tests/common/utils/test_code_utils.py +++ b/monkey/tests/unit_tests/common/utils/test_code_utils.py @@ -1,6 +1,6 @@ from queue import Queue -from common.utils.code_utils import queue_to_list +from common.utils.code_utils import del_key, queue_to_list def test_empty_queue_to_empty_list(): @@ -20,3 +20,23 @@ def test_queue_to_list(): list_ = queue_to_list(q) assert list_ == expected_list + + +def test_del_key__deletes_key(): + key_to_delete = "a" + my_dict = {"a": 1, "b": 2} + expected_dict = {k: v for k, v in my_dict.items() if k != key_to_delete} + + del_key(my_dict, key_to_delete) + + assert my_dict == expected_dict + + +def test_del_key__nonexistant_key(): + key_to_delete = "a" + my_dict = {"a": 1, "b": 2} + + del_key(my_dict, key_to_delete) + + # This test passes if the following call does not raise an error + del_key(my_dict, key_to_delete) diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_mimikatz_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_mimikatz_collector.py index f87dea93c..62142f6e9 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_mimikatz_collector.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/test_mimikatz_collector.py @@ -56,14 +56,16 @@ def test_pypykatz_result_parsing_duplicates(monkeypatch): def test_pypykatz_result_parsing_defaults(monkeypatch): win_creds = [ - WindowsCredentials(username="user2", password="secret2", lm_hash="lm_hash"), + WindowsCredentials( + username="user2", password="secret2", lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E" + ), ] patch_pypykatz(win_creds, monkeypatch) # Expected credentials username = Username("user2") password = Password("secret2") - lm_hash = LMHash("lm_hash") + lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E") expected_credentials = Credentials([username], [password, lm_hash]) collected_credentials = collect_credentials() @@ -73,12 +75,17 @@ def test_pypykatz_result_parsing_defaults(monkeypatch): def test_pypykatz_result_parsing_no_identities(monkeypatch): win_creds = [ - WindowsCredentials(username="", password="", ntlm_hash="ntlm_hash", lm_hash="lm_hash"), + WindowsCredentials( + username="", + password="", + ntlm_hash="E9F85516721DDC218359AD5280DB4450", + lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E", + ), ] patch_pypykatz(win_creds, monkeypatch) - lm_hash = LMHash("lm_hash") - nt_hash = NTHash("ntlm_hash") + lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E") + nt_hash = NTHash("E9F85516721DDC218359AD5280DB4450") expected_credentials = Credentials([], [lm_hash, nt_hash]) collected_credentials = collect_credentials() diff --git a/monkey/tests/unit_tests/infection_monkey/telemetry/test_credentials_telem.py b/monkey/tests/unit_tests/infection_monkey/telemetry/test_credentials_telem.py index 5fca80eff..701071fbb 100644 --- a/monkey/tests/unit_tests/infection_monkey/telemetry/test_credentials_telem.py +++ b/monkey/tests/unit_tests/infection_monkey/telemetry/test_credentials_telem.py @@ -38,8 +38,7 @@ def test_credential_telem_send(spy_send_telemetry, credentials_for_test): telem = CredentialsTelem([credentials_for_test]) telem.send() - expected_data = json.dumps(expected_data, cls=telem.json_encoder) - assert spy_send_telemetry.data == expected_data + assert json.loads(spy_send_telemetry.data) == expected_data assert spy_send_telemetry.telem_category == "credentials" diff --git a/vulture_allowlist.py b/vulture_allowlist.py index a3393325e..049e487fc 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -196,6 +196,12 @@ _make_tcp_scan_configuration # unused method (monkey/common/configuration/agent _make_network_scan_configuration # unused method (monkey/common/configuration/agent_configuration.py:110) _make_propagation_configuration # unused method (monkey/common/configuration/agent_configuration.py:167) +# Credentials +_strip_credential_type # unused method (monkey/common/credentials/password.py:18) +_make_credentials # unused method (monkey/common/credentials/credentials:39) +_serialize_credentials # unused method (monkey/common/credentials/credentials:67) + + # Models _make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py:19