From f868f03ea7eb363507add24241f6489b4c3dc4f0 Mon Sep 17 00:00:00 2001 From: vakarisz Date: Thu, 1 Sep 2022 12:56:05 +0300 Subject: [PATCH] Common: Refactor credentials from marshmallow to pydantic --- monkey/common/credentials/__init__.py | 5 - .../credential_component_schema.py | 20 -- .../credentials/credential_component_type.py | 9 - monkey/common/credentials/credentials.py | 192 +----------------- .../credentials/i_credential_component.py | 13 -- monkey/common/credentials/lm_hash.py | 29 +-- monkey/common/credentials/nt_hash.py | 29 +-- monkey/common/credentials/password.py | 20 +- monkey/common/credentials/ssh_keypair.py | 25 +-- monkey/common/credentials/username.py | 20 +- monkey/common/credentials/validators.py | 49 +---- .../data_for_tests/propagation_credentials.py | 61 ++++-- .../credentials/test_credential_components.py | 148 -------------- .../common/credentials/test_credentials.py | 118 +---------- .../common/credentials/test_lm_hash.py | 14 ++ vulture_allowlist.py | 8 +- 16 files changed, 101 insertions(+), 659 deletions(-) delete mode 100644 monkey/common/credentials/credential_component_schema.py delete mode 100644 monkey/common/credentials/credential_component_type.py delete mode 100644 monkey/common/credentials/i_credential_component.py delete mode 100644 monkey/tests/unit_tests/common/credentials/test_credential_components.py create mode 100644 monkey/tests/unit_tests/common/credentials/test_lm_hash.py diff --git a/monkey/common/credentials/__init__.py b/monkey/common/credentials/__init__.py index 6275e0985..66b91971d 100644 --- a/monkey/common/credentials/__init__.py +++ b/monkey/common/credentials/__init__.py @@ -1,8 +1,3 @@ -from .credential_component_type import CredentialComponentType -from .i_credential_component import ICredentialComponent - -from .validators import InvalidCredentialComponentError, InvalidCredentialsError - from .lm_hash import LMHash from .nt_hash import NTHash from .password import Password diff --git a/monkey/common/credentials/credential_component_schema.py b/monkey/common/credentials/credential_component_schema.py deleted file mode 100644 index ff3e657be..000000000 --- a/monkey/common/credentials/credential_component_schema.py +++ /dev/null @@ -1,20 +0,0 @@ -from marshmallow import Schema, post_load, validate -from marshmallow_enum import EnumField - -from common.utils.code_utils import del_key - -from . import CredentialComponentType - - -class CredentialTypeField(EnumField): - def __init__(self, credential_component_type: CredentialComponentType): - super().__init__( - CredentialComponentType, validate=validate.Equal(credential_component_type) - ) - - -class CredentialComponentSchema(Schema): - @post_load - def _strip_credential_type(self, data, **kwargs): - del_key(data, "credential_type") - return data diff --git a/monkey/common/credentials/credential_component_type.py b/monkey/common/credentials/credential_component_type.py deleted file mode 100644 index 25bd3a168..000000000 --- a/monkey/common/credentials/credential_component_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum, auto - - -class CredentialComponentType(Enum): - USERNAME = auto() - PASSWORD = auto() - NT_HASH = auto() - LM_HASH = auto() - SSH_KEYPAIR = auto() diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index 93390f212..053d7f33d 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -1,190 +1,14 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Mapping, Optional, Type +from typing import Optional, Union -from marshmallow import Schema, fields, post_load, pre_dump -from marshmallow.exceptions import MarshmallowError +from ..base_models import InfectionMonkeyBaseModel +from . import LMHash, NTHash, Password, SSHKeypair, Username -from ..utils import IJSONSerializable -from . import ( - CredentialComponentType, - InvalidCredentialComponentError, - InvalidCredentialsError, - LMHash, - NTHash, - Password, - SSHKeypair, - Username, -) -from .i_credential_component import ICredentialComponent -from .lm_hash import LMHashSchema -from .nt_hash import NTHashSchema -from .password import PasswordSchema -from .ssh_keypair import SSHKeypairSchema -from .username import UsernameSchema - -CREDENTIAL_COMPONENT_TYPE_TO_CLASS: Mapping[CredentialComponentType, Type[ICredentialComponent]] = { - CredentialComponentType.LM_HASH: LMHash, - CredentialComponentType.NT_HASH: NTHash, - CredentialComponentType.PASSWORD: Password, - CredentialComponentType.SSH_KEYPAIR: SSHKeypair, - CredentialComponentType.USERNAME: Username, -} - -CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA: Mapping[CredentialComponentType, Schema] = { - CredentialComponentType.LM_HASH: LMHashSchema(), - CredentialComponentType.NT_HASH: NTHashSchema(), - CredentialComponentType.PASSWORD: PasswordSchema(), - CredentialComponentType.SSH_KEYPAIR: SSHKeypairSchema(), - CredentialComponentType.USERNAME: UsernameSchema(), -} - -CredentialComponentMapping = Optional[Mapping[str, Any]] -CredentialsMapping = Mapping[str, CredentialComponentMapping] +Secret = Union[Password, LMHash, NTHash, SSHKeypair] +Identity = Username -class CredentialsSchema(Schema): - identity = fields.Mapping(allow_none=True) - secret = fields.Mapping(allow_none=True) - - @post_load - def _make_credentials( - self, - credentials: CredentialsMapping, - **kwargs: Mapping[str, Any], - ) -> Mapping[str, Optional[ICredentialComponent]]: - if not any(credentials.values()): - raise InvalidCredentialsError("At least one credentials component must be defined") - - return { - key: CredentialsSchema._build_credential_component(credential_component_mapping) - for key, credential_component_mapping in credentials.items() - } - - @staticmethod - def _build_credential_component( - credential_component: CredentialComponentMapping, - ) -> Optional[ICredentialComponent]: - if credential_component is None: - return None - - try: - credential_component_type = CredentialComponentType[ - credential_component["credential_type"] - ] - except KeyError as err: - raise InvalidCredentialsError(f"Unknown credential component type {err}") - - credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type] - credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ - credential_component_type - ] - - try: - return credential_component_class( - **credential_component_schema.load(credential_component) - ) - except MarshmallowError as err: - raise InvalidCredentialComponentError(credential_component_class, str(err)) - - @pre_dump - def _serialize_credentials(self, credentials: Credentials, **kwargs) -> CredentialsMapping: - return { - "identity": CredentialsSchema._serialize_credential_component(credentials.identity), - "secret": CredentialsSchema._serialize_credential_component(credentials.secret), - } - - @staticmethod - def _serialize_credential_component( - credential_component: Optional[ICredentialComponent], - ) -> CredentialComponentMapping: - if credential_component is None: - return None - - credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ - credential_component.credential_type - ] - - return credential_component_schema.dump(credential_component) - - -@dataclass(frozen=True) -class Credentials(IJSONSerializable): - identity: Optional[ICredentialComponent] - secret: Optional[ICredentialComponent] - - def __post_init__(self): - schema = CredentialsSchema() - try: - serialized_data = schema.dump(self) - - # This will raise an exception if the object is invalid. Calling this in __post__init() - # makes it impossible to construct an invalid object - schema.load(serialized_data) - except Exception as err: - raise InvalidCredentialsError(err) - - @staticmethod - def from_mapping(credentials: CredentialsMapping) -> Credentials: - """ - Construct a Credentials object from a Mapping - - :param credentials: A mapping that represents a Credentials object - :return: A Credentials object - :raises InvalidCredentialsError: If the provided Mapping does not represent a valid - Credentials object - :raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets` - are not a valid ICredentialComponent - """ - - try: - deserialized_data = CredentialsSchema().load(credentials) - return Credentials(**deserialized_data) - except (InvalidCredentialsError, InvalidCredentialComponentError) as err: - raise err - except MarshmallowError as err: - raise InvalidCredentialsError(str(err)) - - @classmethod - def from_json(cls, credentials: str) -> Credentials: - """ - Construct a Credentials object from a JSON string - - :param credentials: A JSON string that represents a Credentials object - :return: A Credentials object - :raises InvalidCredentialsError: If the provided JSON does not represent a valid - Credentials object - :raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets` - are not a valid ICredentialComponent - """ - - try: - deserialized_data = CredentialsSchema().loads(credentials) - return Credentials(**deserialized_data) - except (InvalidCredentialsError, InvalidCredentialComponentError) as err: - raise err - except MarshmallowError as err: - raise InvalidCredentialsError(str(err)) - - @staticmethod - def to_mapping(credentials: Credentials) -> CredentialsMapping: - """ - Serialize a Credentials object to a Mapping - - :param credentials: A Credentials object - :return: A mapping representing a Credentials object - """ - - return CredentialsSchema().dump(credentials) - - @classmethod - def to_json(cls, credentials: Credentials) -> str: - """ - Serialize a Credentials object to JSON - - :param credentials: A Credentials object - :return: A JSON string representing a Credentials object - """ - - return CredentialsSchema().dumps(credentials) +class Credentials(InfectionMonkeyBaseModel): + identity: Optional[Identity] + secret: Optional[Secret] diff --git a/monkey/common/credentials/i_credential_component.py b/monkey/common/credentials/i_credential_component.py deleted file mode 100644 index ba55b3cce..000000000 --- a/monkey/common/credentials/i_credential_component.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod - -from pydantic.dataclasses import dataclass - -from . import CredentialComponentType - - -@dataclass -class ICredentialComponent(ABC): - @property - @abstractmethod - def credential_type(self) -> CredentialComponentType: - pass diff --git a/monkey/common/credentials/lm_hash.py b/monkey/common/credentials/lm_hash.py index df59beeb3..02525d386 100644 --- a/monkey/common/credentials/lm_hash.py +++ b/monkey/common/credentials/lm_hash.py @@ -1,25 +1,16 @@ -from dataclasses import field -from typing import ClassVar +import re -from marshmallow import fields -from pydantic.dataclasses import dataclass +from pydantic import validator +from pydantic.main import BaseModel -from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema, CredentialTypeField -from .validators import credential_component_validator, ntlm_hash_validator +from .validators import ntlm_hash_regex -class LMHashSchema(CredentialComponentSchema): - credential_type = CredentialTypeField(CredentialComponentType.LM_HASH) - lm_hash = fields.Str(validate=ntlm_hash_validator) - - -@dataclass -class LMHash(ICredentialComponent): - credential_type: ClassVar[CredentialComponentType] = field( - default=CredentialComponentType.LM_HASH, init=False - ) +class LMHash(BaseModel): lm_hash: str - def __post_init__(self): - credential_component_validator(LMHashSchema(), self) + @validator("lm_hash") + def validate_hash_format(cls, nt_hash): + if not re.match(ntlm_hash_regex, nt_hash): + raise ValueError(f"Invalid lm hash provided: {nt_hash}") + return nt_hash diff --git a/monkey/common/credentials/nt_hash.py b/monkey/common/credentials/nt_hash.py index e7b84449a..dd288afba 100644 --- a/monkey/common/credentials/nt_hash.py +++ b/monkey/common/credentials/nt_hash.py @@ -1,25 +1,16 @@ -from dataclasses import field -from typing import ClassVar +import re -from marshmallow import fields -from pydantic.dataclasses import dataclass +from pydantic import validator -from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema, CredentialTypeField -from .validators import credential_component_validator, ntlm_hash_validator +from ..base_models import InfectionMonkeyBaseModel +from .validators import ntlm_hash_regex -class NTHashSchema(CredentialComponentSchema): - credential_type = CredentialTypeField(CredentialComponentType.NT_HASH) - nt_hash = fields.Str(validate=ntlm_hash_validator) - - -@dataclass -class NTHash(ICredentialComponent): - credential_type: ClassVar[CredentialComponentType] = field( - default=CredentialComponentType.NT_HASH, init=False - ) +class NTHash(InfectionMonkeyBaseModel): nt_hash: str - def __post_init__(self): - credential_component_validator(NTHashSchema(), self) + @validator("nt_hash") + def validate_hash_format(cls, nt_hash): + if not re.match(ntlm_hash_regex, nt_hash): + raise ValueError(f"Invalid nt hash provided: {nt_hash}") + return nt_hash diff --git a/monkey/common/credentials/password.py b/monkey/common/credentials/password.py index 8fac2a37f..a9c1f2016 100644 --- a/monkey/common/credentials/password.py +++ b/monkey/common/credentials/password.py @@ -1,21 +1,5 @@ -from dataclasses import field -from typing import ClassVar - -from marshmallow import fields -from pydantic.dataclasses import dataclass - -from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema, CredentialTypeField +from ..base_models import InfectionMonkeyBaseModel -class PasswordSchema(CredentialComponentSchema): - credential_type = CredentialTypeField(CredentialComponentType.PASSWORD) - password = fields.Str() - - -@dataclass -class Password(ICredentialComponent): - credential_type: ClassVar[CredentialComponentType] = field( - default=CredentialComponentType.PASSWORD, init=False - ) +class Password(InfectionMonkeyBaseModel): password: str diff --git a/monkey/common/credentials/ssh_keypair.py b/monkey/common/credentials/ssh_keypair.py index 3183adbd4..49d886b8b 100644 --- a/monkey/common/credentials/ssh_keypair.py +++ b/monkey/common/credentials/ssh_keypair.py @@ -1,27 +1,6 @@ -from dataclasses import field -from typing import ClassVar - -from marshmallow import fields -from pydantic.dataclasses import dataclass - -from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema, CredentialTypeField +from ..base_models import InfectionMonkeyBaseModel -class SSHKeypairSchema(CredentialComponentSchema): - credential_type: ClassVar[CredentialComponentType] = CredentialTypeField( - CredentialComponentType.SSH_KEYPAIR - ) - # TODO: Find a list of valid formats for ssh keys and add validators. - # See https://github.com/nemchik/ssh-key-regex - private_key = fields.Str() - public_key = fields.Str() - - -@dataclass -class SSHKeypair(ICredentialComponent): - credential_type: ClassVar[CredentialComponentType] = field( - default=CredentialComponentType.SSH_KEYPAIR, init=False - ) +class SSHKeypair(InfectionMonkeyBaseModel): private_key: str public_key: str diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index 115b205a6..2bfc0b25d 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,21 +1,5 @@ -from dataclasses import field -from typing import ClassVar - -from marshmallow import fields -from pydantic.dataclasses import dataclass - -from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema, CredentialTypeField +from ..base_models import InfectionMonkeyBaseModel -class UsernameSchema(CredentialComponentSchema): - credential_type = CredentialTypeField(CredentialComponentType.USERNAME) - username = fields.Str() - - -@dataclass -class Username(ICredentialComponent): - credential_type: ClassVar[CredentialComponentType] = field( - default=CredentialComponentType.USERNAME, init=False - ) +class Username(InfectionMonkeyBaseModel): username: str diff --git a/monkey/common/credentials/validators.py b/monkey/common/credentials/validators.py index 2e0e2e93c..1d94575d4 100644 --- a/monkey/common/credentials/validators.py +++ b/monkey/common/credentials/validators.py @@ -1,50 +1,3 @@ import re -from typing import Type -from marshmallow import Schema, validate - -from . import ICredentialComponent - -_ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$") -ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex) - - -class InvalidCredentialComponentError(Exception): - def __init__(self, credential_component_class: Type[ICredentialComponent], message: str): - self._credential_component_name = credential_component_class.__name__ - self._message = message - - def __str__(self) -> str: - return ( - f"Cannot construct a {self._credential_component_name} object with the supplied, " - f"invalid data: {self._message}" - ) - - -class InvalidCredentialsError(Exception): - def __init__(self, message: str): - self._message = message - - def __str__(self) -> str: - return ( - f"Cannot construct a Credentials object with the supplied, " - f"invalid data: {self._message}" - ) - - -def credential_component_validator(schema: Schema, credential_component: ICredentialComponent): - """ - Validate a credential component - - :param schema: A marshmallow schema used for validating the component - :param credential_component: A credential component to be validated - :raises InvalidCredentialComponent: if the credential_component contains invalid data - """ - try: - serialized_data = schema.dump(credential_component) - - # This will raise an exception if the object is invalid. Calling this in __post__init() - # makes it impossible to construct an invalid object - schema.load(serialized_data) - except Exception as err: - raise InvalidCredentialComponentError(credential_component.__class__, err) +ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$") diff --git a/monkey/tests/data_for_tests/propagation_credentials.py b/monkey/tests/data_for_tests/propagation_credentials.py index 6efe9b7af..a1d783f78 100644 --- a/monkey/tests/data_for_tests/propagation_credentials.py +++ b/monkey/tests/data_for_tests/propagation_credentials.py @@ -1,3 +1,5 @@ +from itertools import product + from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username USERNAME = "m0nk3y_user" @@ -10,22 +12,47 @@ PASSWORD_3 = "rubberbabybuggybumpers" PUBLIC_KEY = "MY_PUBLIC_KEY" PRIVATE_KEY = "MY_PRIVATE_KEY" -PASSWORD_CREDENTIALS_1 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_1)) -PASSWORD_CREDENTIALS_2 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_2)) -LM_HASH_CREDENTIALS = Credentials(identity=Username(SPECIAL_USERNAME), secret=LMHash(LM_HASH)) -NT_HASH_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=NTHash(NT_HASH)) -SSH_KEY_CREDENTIALS = Credentials( - identity=Username(USERNAME), secret=SSHKeypair(PRIVATE_KEY, PUBLIC_KEY) -) -EMPTY_SECRET_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=None) -EMPTY_IDENTITY_CREDENTIALS = Credentials(identity=None, secret=Password(PASSWORD_3)) +IDENTITIES = [Username(username=USERNAME), None, Username(username=SPECIAL_USERNAME)] +IDENTITY_DICTS = [{"username": USERNAME}, None] -PROPAGATION_CREDENTIALS = [ - PASSWORD_CREDENTIALS_1, - LM_HASH_CREDENTIALS, - NT_HASH_CREDENTIALS, - PASSWORD_CREDENTIALS_2, - SSH_KEY_CREDENTIALS, - EMPTY_SECRET_CREDENTIALS, - EMPTY_IDENTITY_CREDENTIALS, +SECRETS = ( + Password(password=PASSWORD_1), + Password(password=PASSWORD_2), + Password(password=PASSWORD_3), + LMHash(lm_hash=LM_HASH), + NTHash(nt_hash=NT_HASH), + SSHKeypair(private_key=PRIVATE_KEY, public_key=PUBLIC_KEY), + None, +) +SECRET_DICTS = [ + {"password": PASSWORD_1}, + {"lm_hash": LM_HASH}, + {"nt_hash": NT_HASH}, + { + "public_key": PUBLIC_KEY, + "private_key": PRIVATE_KEY, + }, + None, +] + +CREDENTIALS = [ + Credentials(identity=identity, secret=secret) + for identity, secret in product(IDENTITIES, SECRETS) +] + +FULL_CREDENTIALS = [ + credentials + for credentials in CREDENTIALS + if not (credentials.identity is None and credentials.secret is None) +] + +CREDENTIALS_DICTS = [ + {"identity": identity, "secret": secret} + for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS) +] + +FULL_CREDENTIALS_DICTS = [ + credentials + for credentials in CREDENTIALS_DICTS + if not (credentials["identity"] is None and credentials["secret"] is None) ] diff --git a/monkey/tests/unit_tests/common/credentials/test_credential_components.py b/monkey/tests/unit_tests/common/credentials/test_credential_components.py deleted file mode 100644 index 55a2b9279..000000000 --- a/monkey/tests/unit_tests/common/credentials/test_credential_components.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Any, Mapping - -import pytest -from marshmallow.exceptions import ValidationError - -from common.credentials import ( - CredentialComponentType, - LMHash, - NTHash, - Password, - SSHKeypair, - Username, -) -from common.credentials.lm_hash import LMHashSchema -from common.credentials.nt_hash import NTHashSchema -from common.credentials.password import PasswordSchema -from common.credentials.ssh_keypair import SSHKeypairSchema -from common.credentials.username import UsernameSchema - -PARAMETRIZED_PARAMETER_NAMES = ( - "credential_component_class, schema_class, credential_component_type, credential_component_data" -) - -PARAMETRIZED_PARAMETER_VALUES = [ - (Username, UsernameSchema, CredentialComponentType.USERNAME, {"username": "test_user"}), - (Password, PasswordSchema, CredentialComponentType.PASSWORD, {"password": "123456"}), - ( - LMHash, - LMHashSchema, - CredentialComponentType.LM_HASH, - {"lm_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"}, - ), - ( - NTHash, - NTHashSchema, - CredentialComponentType.NT_HASH, - {"nt_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"}, - ), - ( - SSHKeypair, - SSHKeypairSchema, - CredentialComponentType.SSH_KEYPAIR, - {"public_key": "TEST_PUBLIC_KEY", "private_key": "TEST_PRIVATE_KEY"}, - ), -] - - -INVALID_COMPONENT_DATA = { - CredentialComponentType.USERNAME: ({"username": None}, {"username": 1}, {"username": 2.0}), - CredentialComponentType.PASSWORD: ({"password": None}, {"password": 1}, {"password": 2.0}), - CredentialComponentType.LM_HASH: ( - {"lm_hash": None}, - {"lm_hash": 1}, - {"lm_hash": 2.0}, - {"lm_hash": "0123456789012345678901234568901"}, - {"lm_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"}, - ), - CredentialComponentType.NT_HASH: ( - {"nt_hash": None}, - {"nt_hash": 1}, - {"nt_hash": 2.0}, - {"nt_hash": "0123456789012345678901234568901"}, - {"nt_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"}, - ), - CredentialComponentType.SSH_KEYPAIR: ( - {"public_key": None, "private_key": "TEST_PRIVATE_KEY"}, - {"public_key": "TEST_PUBLIC_KEY", "private_key": None}, - {"public_key": 1, "private_key": "TEST_PRIVATE_KEY"}, - {"public_key": "TEST_PUBLIC_KEY", "private_key": 999}, - ), -} - - -def build_component_dict( - credential_component_type: CredentialComponentType, credential_component_data: Mapping[str, Any] -): - return {"credential_type": credential_component_type.name, **credential_component_data} - - -@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) -def test_credential_component_serialize( - credential_component_class, schema_class, credential_component_type, credential_component_data -): - schema = schema_class() - constructed_object = credential_component_class(**credential_component_data) - - serialized_object = schema.dump(constructed_object) - - assert serialized_object == build_component_dict( - credential_component_type, credential_component_data - ) - - -@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) -def test_credential_component_deserialize( - credential_component_class, schema_class, credential_component_type, credential_component_data -): - schema = schema_class() - credential_dict = build_component_dict(credential_component_type, credential_component_data) - expected_deserialized_object = credential_component_class(**credential_component_data) - - deserialized_object = credential_component_class(**schema.load(credential_dict)) - - assert deserialized_object == expected_deserialized_object - - -@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) -def test_invalid_credential_type( - credential_component_class, schema_class, credential_component_type, credential_component_data -): - invalid_component_dict = build_component_dict( - credential_component_type, credential_component_data - ) - invalid_component_dict["credential_type"] = "INVALID" - schema = schema_class() - - with pytest.raises(ValidationError): - credential_component_class(**schema.load(invalid_component_dict)) - - -@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) -def test_encorrect_credential_type( - credential_component_class, schema_class, credential_component_type, credential_component_data -): - incorrect_component_dict = build_component_dict( - credential_component_type, credential_component_data - ) - incorrect_component_dict["credential_type"] = ( - CredentialComponentType.USERNAME.name - if credential_component_type != CredentialComponentType.USERNAME - else CredentialComponentType.PASSWORD - ) - schema = schema_class() - - with pytest.raises(ValidationError): - credential_component_class(**schema.load(incorrect_component_dict)) - - -@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES) -def test_invalid_values( - credential_component_class, schema_class, credential_component_type, credential_component_data -): - schema = schema_class() - - for invalid_component_data in INVALID_COMPONENT_DATA[credential_component_type]: - component_dict = build_component_dict(credential_component_type, invalid_component_data) - with pytest.raises(ValidationError): - credential_component_class(**schema.load(component_dict)) diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py index d4e14fbad..89238b839 100644 --- a/monkey/tests/unit_tests/common/credentials/test_credentials.py +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -1,122 +1,14 @@ -import json -from itertools import product - import pytest -from tests.data_for_tests.propagation_credentials import ( - LM_HASH, - NT_HASH, - PASSWORD_1, - PRIVATE_KEY, - PUBLIC_KEY, - USERNAME, -) +from tests.data_for_tests.propagation_credentials import CREDENTIALS, CREDENTIALS_DICTS -from common.credentials import ( - Credentials, - InvalidCredentialComponentError, - InvalidCredentialsError, - LMHash, - NTHash, - Password, - SSHKeypair, - Username, -) - -IDENTITIES = [Username(USERNAME), None] -IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}, None] - -SECRETS = ( - Password(PASSWORD_1), - LMHash(LM_HASH), - NTHash(NT_HASH), - SSHKeypair(PRIVATE_KEY, PUBLIC_KEY), - None, -) -SECRET_DICTS = [ - {"credential_type": "PASSWORD", "password": PASSWORD_1}, - {"credential_type": "LM_HASH", "lm_hash": LM_HASH}, - {"credential_type": "NT_HASH", "nt_hash": NT_HASH}, - { - "credential_type": "SSH_KEYPAIR", - "public_key": PUBLIC_KEY, - "private_key": PRIVATE_KEY, - }, - None, -] - -CREDENTIALS = [ - Credentials(identity=identity, secret=secret) - for identity, secret in product(IDENTITIES, SECRETS) - if not (identity is None and secret is None) -] - -CREDENTIALS_DICTS = [ - {"identity": identity, "secret": secret} - for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS) - if not (identity is None and secret is None) -] +from common.credentials import Credentials @pytest.mark.parametrize( "credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS) ) def test_credentials_serialization_json(credentials, expected_credentials_dict): - serialized_credentials = Credentials.to_json(credentials) + serialized_credentials = credentials.json() + deserialized_credentials = Credentials.parse_raw(serialized_credentials) - assert json.loads(serialized_credentials) == expected_credentials_dict - - -@pytest.mark.parametrize( - "credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS) -) -def test_credentials_serialization_mapping(credentials, expected_credentials_dict): - serialized_credentials = Credentials.to_mapping(credentials) - - assert serialized_credentials == expected_credentials_dict - - -@pytest.mark.parametrize( - "expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS) -) -def test_credentials_deserialization__from_mapping(expected_credentials, credentials_dict): - deserialized_credentials = Credentials.from_mapping(credentials_dict) - - assert deserialized_credentials == expected_credentials - - -@pytest.mark.parametrize( - "expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS) -) -def test_credentials_deserialization__from_json(expected_credentials, credentials_dict): - deserialized_credentials = Credentials.from_json(json.dumps(credentials_dict)) - - assert deserialized_credentials == expected_credentials - - -def test_credentials_deserialization__invalid_credentials(): - invalid_data = {"secret": SECRET_DICTS[0], "unknown_key": []} - with pytest.raises(InvalidCredentialsError): - Credentials.from_mapping(invalid_data) - - -def test_credentials_deserialization__invalid_component_type(): - invalid_data = { - "secret": SECRET_DICTS[0], - "identity": {"credential_type": "FAKE", "username": "user1"}, - } - with pytest.raises(InvalidCredentialsError): - Credentials.from_mapping(invalid_data) - - -def test_credentials_deserialization__invalid_component(): - invalid_data = { - "secret": SECRET_DICTS[0], - "identity": {"credential_type": "USERNAME", "unknown_field": "user1"}, - } - with pytest.raises(InvalidCredentialComponentError): - Credentials.from_mapping(invalid_data) - - -def test_create_credentials__none_none(): - with pytest.raises(InvalidCredentialsError): - Credentials(None, None) + assert credentials == deserialized_credentials diff --git a/monkey/tests/unit_tests/common/credentials/test_lm_hash.py b/monkey/tests/unit_tests/common/credentials/test_lm_hash.py new file mode 100644 index 000000000..4bcfb3270 --- /dev/null +++ b/monkey/tests/unit_tests/common/credentials/test_lm_hash.py @@ -0,0 +1,14 @@ +import pytest + +from common.credentials import LMHash + + +def test_construct_valid_nt_hash(valid_ntlm_hash): + # This test will fail if an exception is raised + LMHash(lm_hash=valid_ntlm_hash) + + +def test_construct_invalid_nt_hash(invalid_ntlm_hashes): + for invalid_hash in invalid_ntlm_hashes: + with pytest.raises(ValueError): + LMHash(lm_hash=invalid_hash) diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 42d46716f..127a319d2 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -7,8 +7,7 @@ from common.agent_configuration.agent_sub_configurations import ( CustomPBAConfiguration, ScanTargetConfiguration, ) -from common.credentials import Credentials -from common.utils import IJSONSerializable +from common.credentials import LMHash, NTHash from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.models import Report @@ -166,6 +165,8 @@ LDAPServerFactory.buildProtocol get_file_sha256_hash strict_slashes # unused attribute (monkey/monkey_island/cc/app.py:96) post_breach_actions # unused variable (monkey\infection_monkey\config.py:95) +LMHash.validate_hash_format +NTHash.validate_hash_format # Deployments DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5) @@ -314,9 +315,6 @@ EXPLOITED CC CC_TUNNEL -Credentials.from_json -IJSONSerializable.from_json - IslandEventTopic.AGENT_CONNECTED IslandEventTopic.CLEAR_SIMULATION_DATA IslandEventTopic.RESET_AGENT_CONFIGURATION