Common: Refactor credentials from marshmallow to pydantic

This commit is contained in:
vakarisz 2022-09-01 12:56:05 +03:00 committed by vakaris_zilius
parent 3ac60988a8
commit f868f03ea7
16 changed files with 101 additions and 659 deletions

View File

@ -1,8 +1,3 @@
from .credential_component_type import CredentialComponentType
from .i_credential_component import ICredentialComponent
from .validators import InvalidCredentialComponentError, InvalidCredentialsError
from .lm_hash import LMHash from .lm_hash import LMHash
from .nt_hash import NTHash from .nt_hash import NTHash
from .password import Password from .password import Password

View File

@ -1,20 +0,0 @@
from marshmallow import Schema, post_load, validate
from marshmallow_enum import EnumField
from common.utils.code_utils import del_key
from . import CredentialComponentType
class CredentialTypeField(EnumField):
def __init__(self, credential_component_type: CredentialComponentType):
super().__init__(
CredentialComponentType, validate=validate.Equal(credential_component_type)
)
class CredentialComponentSchema(Schema):
@post_load
def _strip_credential_type(self, data, **kwargs):
del_key(data, "credential_type")
return data

View File

@ -1,9 +0,0 @@
from enum import Enum, auto
class CredentialComponentType(Enum):
USERNAME = auto()
PASSWORD = auto()
NT_HASH = auto()
LM_HASH = auto()
SSH_KEYPAIR = auto()

View File

@ -1,190 +1,14 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from typing import Optional, Union
from typing import Any, Mapping, Optional, Type
from marshmallow import Schema, fields, post_load, pre_dump from ..base_models import InfectionMonkeyBaseModel
from marshmallow.exceptions import MarshmallowError from . import LMHash, NTHash, Password, SSHKeypair, Username
from ..utils import IJSONSerializable Secret = Union[Password, LMHash, NTHash, SSHKeypair]
from . import ( Identity = Username
CredentialComponentType,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from .i_credential_component import ICredentialComponent
from .lm_hash import LMHashSchema
from .nt_hash import NTHashSchema
from .password import PasswordSchema
from .ssh_keypair import SSHKeypairSchema
from .username import UsernameSchema
CREDENTIAL_COMPONENT_TYPE_TO_CLASS: Mapping[CredentialComponentType, Type[ICredentialComponent]] = {
CredentialComponentType.LM_HASH: LMHash,
CredentialComponentType.NT_HASH: NTHash,
CredentialComponentType.PASSWORD: Password,
CredentialComponentType.SSH_KEYPAIR: SSHKeypair,
CredentialComponentType.USERNAME: Username,
}
CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA: Mapping[CredentialComponentType, Schema] = {
CredentialComponentType.LM_HASH: LMHashSchema(),
CredentialComponentType.NT_HASH: NTHashSchema(),
CredentialComponentType.PASSWORD: PasswordSchema(),
CredentialComponentType.SSH_KEYPAIR: SSHKeypairSchema(),
CredentialComponentType.USERNAME: UsernameSchema(),
}
CredentialComponentMapping = Optional[Mapping[str, Any]]
CredentialsMapping = Mapping[str, CredentialComponentMapping]
class CredentialsSchema(Schema): class Credentials(InfectionMonkeyBaseModel):
identity = fields.Mapping(allow_none=True) identity: Optional[Identity]
secret = fields.Mapping(allow_none=True) secret: Optional[Secret]
@post_load
def _make_credentials(
self,
credentials: CredentialsMapping,
**kwargs: Mapping[str, Any],
) -> Mapping[str, Optional[ICredentialComponent]]:
if not any(credentials.values()):
raise InvalidCredentialsError("At least one credentials component must be defined")
return {
key: CredentialsSchema._build_credential_component(credential_component_mapping)
for key, credential_component_mapping in credentials.items()
}
@staticmethod
def _build_credential_component(
credential_component: CredentialComponentMapping,
) -> Optional[ICredentialComponent]:
if credential_component is None:
return None
try:
credential_component_type = CredentialComponentType[
credential_component["credential_type"]
]
except KeyError as err:
raise InvalidCredentialsError(f"Unknown credential component type {err}")
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component_type
]
try:
return credential_component_class(
**credential_component_schema.load(credential_component)
)
except MarshmallowError as err:
raise InvalidCredentialComponentError(credential_component_class, str(err))
@pre_dump
def _serialize_credentials(self, credentials: Credentials, **kwargs) -> CredentialsMapping:
return {
"identity": CredentialsSchema._serialize_credential_component(credentials.identity),
"secret": CredentialsSchema._serialize_credential_component(credentials.secret),
}
@staticmethod
def _serialize_credential_component(
credential_component: Optional[ICredentialComponent],
) -> CredentialComponentMapping:
if credential_component is None:
return None
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component.credential_type
]
return credential_component_schema.dump(credential_component)
@dataclass(frozen=True)
class Credentials(IJSONSerializable):
identity: Optional[ICredentialComponent]
secret: Optional[ICredentialComponent]
def __post_init__(self):
schema = CredentialsSchema()
try:
serialized_data = schema.dump(self)
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
schema.load(serialized_data)
except Exception as err:
raise InvalidCredentialsError(err)
@staticmethod
def from_mapping(credentials: CredentialsMapping) -> Credentials:
"""
Construct a Credentials object from a Mapping
:param credentials: A mapping that represents a Credentials object
:return: A Credentials object
:raises InvalidCredentialsError: If the provided Mapping does not represent a valid
Credentials object
:raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets`
are not a valid ICredentialComponent
"""
try:
deserialized_data = CredentialsSchema().load(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except MarshmallowError as err:
raise InvalidCredentialsError(str(err))
@classmethod
def from_json(cls, credentials: str) -> Credentials:
"""
Construct a Credentials object from a JSON string
:param credentials: A JSON string that represents a Credentials object
:return: A Credentials object
:raises InvalidCredentialsError: If the provided JSON does not represent a valid
Credentials object
:raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets`
are not a valid ICredentialComponent
"""
try:
deserialized_data = CredentialsSchema().loads(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except MarshmallowError as err:
raise InvalidCredentialsError(str(err))
@staticmethod
def to_mapping(credentials: Credentials) -> CredentialsMapping:
"""
Serialize a Credentials object to a Mapping
:param credentials: A Credentials object
:return: A mapping representing a Credentials object
"""
return CredentialsSchema().dump(credentials)
@classmethod
def to_json(cls, credentials: Credentials) -> str:
"""
Serialize a Credentials object to JSON
:param credentials: A Credentials object
:return: A JSON string representing a Credentials object
"""
return CredentialsSchema().dumps(credentials)

View File

@ -1,13 +0,0 @@
from abc import ABC, abstractmethod
from pydantic.dataclasses import dataclass
from . import CredentialComponentType
@dataclass
class ICredentialComponent(ABC):
@property
@abstractmethod
def credential_type(self) -> CredentialComponentType:
pass

View File

@ -1,25 +1,16 @@
from dataclasses import field import re
from typing import ClassVar
from marshmallow import fields from pydantic import validator
from pydantic.dataclasses import dataclass from pydantic.main import BaseModel
from . import CredentialComponentType, ICredentialComponent from .validators import ntlm_hash_regex
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
from .validators import credential_component_validator, ntlm_hash_validator
class LMHashSchema(CredentialComponentSchema): class LMHash(BaseModel):
credential_type = CredentialTypeField(CredentialComponentType.LM_HASH)
lm_hash = fields.Str(validate=ntlm_hash_validator)
@dataclass
class LMHash(ICredentialComponent):
credential_type: ClassVar[CredentialComponentType] = field(
default=CredentialComponentType.LM_HASH, init=False
)
lm_hash: str lm_hash: str
def __post_init__(self): @validator("lm_hash")
credential_component_validator(LMHashSchema(), self) def validate_hash_format(cls, nt_hash):
if not re.match(ntlm_hash_regex, nt_hash):
raise ValueError(f"Invalid lm hash provided: {nt_hash}")
return nt_hash

View File

@ -1,25 +1,16 @@
from dataclasses import field import re
from typing import ClassVar
from marshmallow import fields from pydantic import validator
from pydantic.dataclasses import dataclass
from . import CredentialComponentType, ICredentialComponent from ..base_models import InfectionMonkeyBaseModel
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField from .validators import ntlm_hash_regex
from .validators import credential_component_validator, ntlm_hash_validator
class NTHashSchema(CredentialComponentSchema): class NTHash(InfectionMonkeyBaseModel):
credential_type = CredentialTypeField(CredentialComponentType.NT_HASH)
nt_hash = fields.Str(validate=ntlm_hash_validator)
@dataclass
class NTHash(ICredentialComponent):
credential_type: ClassVar[CredentialComponentType] = field(
default=CredentialComponentType.NT_HASH, init=False
)
nt_hash: str nt_hash: str
def __post_init__(self): @validator("nt_hash")
credential_component_validator(NTHashSchema(), self) def validate_hash_format(cls, nt_hash):
if not re.match(ntlm_hash_regex, nt_hash):
raise ValueError(f"Invalid nt hash provided: {nt_hash}")
return nt_hash

View File

@ -1,21 +1,5 @@
from dataclasses import field from ..base_models import InfectionMonkeyBaseModel
from typing import ClassVar
from marshmallow import fields
from pydantic.dataclasses import dataclass
from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class PasswordSchema(CredentialComponentSchema): class Password(InfectionMonkeyBaseModel):
credential_type = CredentialTypeField(CredentialComponentType.PASSWORD)
password = fields.Str()
@dataclass
class Password(ICredentialComponent):
credential_type: ClassVar[CredentialComponentType] = field(
default=CredentialComponentType.PASSWORD, init=False
)
password: str password: str

View File

@ -1,27 +1,6 @@
from dataclasses import field from ..base_models import InfectionMonkeyBaseModel
from typing import ClassVar
from marshmallow import fields
from pydantic.dataclasses import dataclass
from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class SSHKeypairSchema(CredentialComponentSchema): class SSHKeypair(InfectionMonkeyBaseModel):
credential_type: ClassVar[CredentialComponentType] = CredentialTypeField(
CredentialComponentType.SSH_KEYPAIR
)
# TODO: Find a list of valid formats for ssh keys and add validators.
# See https://github.com/nemchik/ssh-key-regex
private_key = fields.Str()
public_key = fields.Str()
@dataclass
class SSHKeypair(ICredentialComponent):
credential_type: ClassVar[CredentialComponentType] = field(
default=CredentialComponentType.SSH_KEYPAIR, init=False
)
private_key: str private_key: str
public_key: str public_key: str

View File

@ -1,21 +1,5 @@
from dataclasses import field from ..base_models import InfectionMonkeyBaseModel
from typing import ClassVar
from marshmallow import fields
from pydantic.dataclasses import dataclass
from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class UsernameSchema(CredentialComponentSchema): class Username(InfectionMonkeyBaseModel):
credential_type = CredentialTypeField(CredentialComponentType.USERNAME)
username = fields.Str()
@dataclass
class Username(ICredentialComponent):
credential_type: ClassVar[CredentialComponentType] = field(
default=CredentialComponentType.USERNAME, init=False
)
username: str username: str

View File

@ -1,50 +1,3 @@
import re import re
from typing import Type
from marshmallow import Schema, validate ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
from . import ICredentialComponent
_ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
class InvalidCredentialComponentError(Exception):
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
self._credential_component_name = credential_component_class.__name__
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a {self._credential_component_name} object with the supplied, "
f"invalid data: {self._message}"
)
class InvalidCredentialsError(Exception):
def __init__(self, message: str):
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a Credentials object with the supplied, "
f"invalid data: {self._message}"
)
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
"""
Validate a credential component
:param schema: A marshmallow schema used for validating the component
:param credential_component: A credential component to be validated
:raises InvalidCredentialComponent: if the credential_component contains invalid data
"""
try:
serialized_data = schema.dump(credential_component)
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
schema.load(serialized_data)
except Exception as err:
raise InvalidCredentialComponentError(credential_component.__class__, err)

View File

@ -1,3 +1,5 @@
from itertools import product
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
USERNAME = "m0nk3y_user" USERNAME = "m0nk3y_user"
@ -10,22 +12,47 @@ PASSWORD_3 = "rubberbabybuggybumpers"
PUBLIC_KEY = "MY_PUBLIC_KEY" PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY" PRIVATE_KEY = "MY_PRIVATE_KEY"
PASSWORD_CREDENTIALS_1 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_1)) IDENTITIES = [Username(username=USERNAME), None, Username(username=SPECIAL_USERNAME)]
PASSWORD_CREDENTIALS_2 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_2)) IDENTITY_DICTS = [{"username": USERNAME}, None]
LM_HASH_CREDENTIALS = Credentials(identity=Username(SPECIAL_USERNAME), secret=LMHash(LM_HASH))
NT_HASH_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=NTHash(NT_HASH))
SSH_KEY_CREDENTIALS = Credentials(
identity=Username(USERNAME), secret=SSHKeypair(PRIVATE_KEY, PUBLIC_KEY)
)
EMPTY_SECRET_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=None)
EMPTY_IDENTITY_CREDENTIALS = Credentials(identity=None, secret=Password(PASSWORD_3))
PROPAGATION_CREDENTIALS = [ SECRETS = (
PASSWORD_CREDENTIALS_1, Password(password=PASSWORD_1),
LM_HASH_CREDENTIALS, Password(password=PASSWORD_2),
NT_HASH_CREDENTIALS, Password(password=PASSWORD_3),
PASSWORD_CREDENTIALS_2, LMHash(lm_hash=LM_HASH),
SSH_KEY_CREDENTIALS, NTHash(nt_hash=NT_HASH),
EMPTY_SECRET_CREDENTIALS, SSHKeypair(private_key=PRIVATE_KEY, public_key=PUBLIC_KEY),
EMPTY_IDENTITY_CREDENTIALS, None,
)
SECRET_DICTS = [
{"password": PASSWORD_1},
{"lm_hash": LM_HASH},
{"nt_hash": NT_HASH},
{
"public_key": PUBLIC_KEY,
"private_key": PRIVATE_KEY,
},
None,
]
CREDENTIALS = [
Credentials(identity=identity, secret=secret)
for identity, secret in product(IDENTITIES, SECRETS)
]
FULL_CREDENTIALS = [
credentials
for credentials in CREDENTIALS
if not (credentials.identity is None and credentials.secret is None)
]
CREDENTIALS_DICTS = [
{"identity": identity, "secret": secret}
for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS)
]
FULL_CREDENTIALS_DICTS = [
credentials
for credentials in CREDENTIALS_DICTS
if not (credentials["identity"] is None and credentials["secret"] is None)
] ]

View File

@ -1,148 +0,0 @@
from typing import Any, Mapping
import pytest
from marshmallow.exceptions import ValidationError
from common.credentials import (
CredentialComponentType,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from common.credentials.lm_hash import LMHashSchema
from common.credentials.nt_hash import NTHashSchema
from common.credentials.password import PasswordSchema
from common.credentials.ssh_keypair import SSHKeypairSchema
from common.credentials.username import UsernameSchema
PARAMETRIZED_PARAMETER_NAMES = (
"credential_component_class, schema_class, credential_component_type, credential_component_data"
)
PARAMETRIZED_PARAMETER_VALUES = [
(Username, UsernameSchema, CredentialComponentType.USERNAME, {"username": "test_user"}),
(Password, PasswordSchema, CredentialComponentType.PASSWORD, {"password": "123456"}),
(
LMHash,
LMHashSchema,
CredentialComponentType.LM_HASH,
{"lm_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
),
(
NTHash,
NTHashSchema,
CredentialComponentType.NT_HASH,
{"nt_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
),
(
SSHKeypair,
SSHKeypairSchema,
CredentialComponentType.SSH_KEYPAIR,
{"public_key": "TEST_PUBLIC_KEY", "private_key": "TEST_PRIVATE_KEY"},
),
]
INVALID_COMPONENT_DATA = {
CredentialComponentType.USERNAME: ({"username": None}, {"username": 1}, {"username": 2.0}),
CredentialComponentType.PASSWORD: ({"password": None}, {"password": 1}, {"password": 2.0}),
CredentialComponentType.LM_HASH: (
{"lm_hash": None},
{"lm_hash": 1},
{"lm_hash": 2.0},
{"lm_hash": "0123456789012345678901234568901"},
{"lm_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
),
CredentialComponentType.NT_HASH: (
{"nt_hash": None},
{"nt_hash": 1},
{"nt_hash": 2.0},
{"nt_hash": "0123456789012345678901234568901"},
{"nt_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
),
CredentialComponentType.SSH_KEYPAIR: (
{"public_key": None, "private_key": "TEST_PRIVATE_KEY"},
{"public_key": "TEST_PUBLIC_KEY", "private_key": None},
{"public_key": 1, "private_key": "TEST_PRIVATE_KEY"},
{"public_key": "TEST_PUBLIC_KEY", "private_key": 999},
),
}
def build_component_dict(
credential_component_type: CredentialComponentType, credential_component_data: Mapping[str, Any]
):
return {"credential_type": credential_component_type.name, **credential_component_data}
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_credential_component_serialize(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
constructed_object = credential_component_class(**credential_component_data)
serialized_object = schema.dump(constructed_object)
assert serialized_object == build_component_dict(
credential_component_type, credential_component_data
)
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_credential_component_deserialize(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
credential_dict = build_component_dict(credential_component_type, credential_component_data)
expected_deserialized_object = credential_component_class(**credential_component_data)
deserialized_object = credential_component_class(**schema.load(credential_dict))
assert deserialized_object == expected_deserialized_object
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_invalid_credential_type(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
invalid_component_dict = build_component_dict(
credential_component_type, credential_component_data
)
invalid_component_dict["credential_type"] = "INVALID"
schema = schema_class()
with pytest.raises(ValidationError):
credential_component_class(**schema.load(invalid_component_dict))
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_encorrect_credential_type(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
incorrect_component_dict = build_component_dict(
credential_component_type, credential_component_data
)
incorrect_component_dict["credential_type"] = (
CredentialComponentType.USERNAME.name
if credential_component_type != CredentialComponentType.USERNAME
else CredentialComponentType.PASSWORD
)
schema = schema_class()
with pytest.raises(ValidationError):
credential_component_class(**schema.load(incorrect_component_dict))
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_invalid_values(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
for invalid_component_data in INVALID_COMPONENT_DATA[credential_component_type]:
component_dict = build_component_dict(credential_component_type, invalid_component_data)
with pytest.raises(ValidationError):
credential_component_class(**schema.load(component_dict))

View File

@ -1,122 +1,14 @@
import json
from itertools import product
import pytest import pytest
from tests.data_for_tests.propagation_credentials import ( from tests.data_for_tests.propagation_credentials import CREDENTIALS, CREDENTIALS_DICTS
LM_HASH,
NT_HASH,
PASSWORD_1,
PRIVATE_KEY,
PUBLIC_KEY,
USERNAME,
)
from common.credentials import ( from common.credentials import Credentials
Credentials,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
IDENTITIES = [Username(USERNAME), None]
IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}, None]
SECRETS = (
Password(PASSWORD_1),
LMHash(LM_HASH),
NTHash(NT_HASH),
SSHKeypair(PRIVATE_KEY, PUBLIC_KEY),
None,
)
SECRET_DICTS = [
{"credential_type": "PASSWORD", "password": PASSWORD_1},
{"credential_type": "LM_HASH", "lm_hash": LM_HASH},
{"credential_type": "NT_HASH", "nt_hash": NT_HASH},
{
"credential_type": "SSH_KEYPAIR",
"public_key": PUBLIC_KEY,
"private_key": PRIVATE_KEY,
},
None,
]
CREDENTIALS = [
Credentials(identity=identity, secret=secret)
for identity, secret in product(IDENTITIES, SECRETS)
if not (identity is None and secret is None)
]
CREDENTIALS_DICTS = [
{"identity": identity, "secret": secret}
for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS)
if not (identity is None and secret is None)
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS) "credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
) )
def test_credentials_serialization_json(credentials, expected_credentials_dict): def test_credentials_serialization_json(credentials, expected_credentials_dict):
serialized_credentials = Credentials.to_json(credentials) serialized_credentials = credentials.json()
deserialized_credentials = Credentials.parse_raw(serialized_credentials)
assert json.loads(serialized_credentials) == expected_credentials_dict assert credentials == deserialized_credentials
@pytest.mark.parametrize(
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_serialization_mapping(credentials, expected_credentials_dict):
serialized_credentials = Credentials.to_mapping(credentials)
assert serialized_credentials == expected_credentials_dict
@pytest.mark.parametrize(
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_deserialization__from_mapping(expected_credentials, credentials_dict):
deserialized_credentials = Credentials.from_mapping(credentials_dict)
assert deserialized_credentials == expected_credentials
@pytest.mark.parametrize(
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
)
def test_credentials_deserialization__from_json(expected_credentials, credentials_dict):
deserialized_credentials = Credentials.from_json(json.dumps(credentials_dict))
assert deserialized_credentials == expected_credentials
def test_credentials_deserialization__invalid_credentials():
invalid_data = {"secret": SECRET_DICTS[0], "unknown_key": []}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component_type():
invalid_data = {
"secret": SECRET_DICTS[0],
"identity": {"credential_type": "FAKE", "username": "user1"},
}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component():
invalid_data = {
"secret": SECRET_DICTS[0],
"identity": {"credential_type": "USERNAME", "unknown_field": "user1"},
}
with pytest.raises(InvalidCredentialComponentError):
Credentials.from_mapping(invalid_data)
def test_create_credentials__none_none():
with pytest.raises(InvalidCredentialsError):
Credentials(None, None)

View File

@ -0,0 +1,14 @@
import pytest
from common.credentials import LMHash
def test_construct_valid_nt_hash(valid_ntlm_hash):
# This test will fail if an exception is raised
LMHash(lm_hash=valid_ntlm_hash)
def test_construct_invalid_nt_hash(invalid_ntlm_hashes):
for invalid_hash in invalid_ntlm_hashes:
with pytest.raises(ValueError):
LMHash(lm_hash=invalid_hash)

View File

@ -7,8 +7,7 @@ from common.agent_configuration.agent_sub_configurations import (
CustomPBAConfiguration, CustomPBAConfiguration,
ScanTargetConfiguration, ScanTargetConfiguration,
) )
from common.credentials import Credentials from common.credentials import LMHash, NTHash
from common.utils import IJSONSerializable
from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory
from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue
from monkey_island.cc.models import Report from monkey_island.cc.models import Report
@ -166,6 +165,8 @@ LDAPServerFactory.buildProtocol
get_file_sha256_hash get_file_sha256_hash
strict_slashes # unused attribute (monkey/monkey_island/cc/app.py:96) strict_slashes # unused attribute (monkey/monkey_island/cc/app.py:96)
post_breach_actions # unused variable (monkey\infection_monkey\config.py:95) post_breach_actions # unused variable (monkey\infection_monkey\config.py:95)
LMHash.validate_hash_format
NTHash.validate_hash_format
# Deployments # Deployments
DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5) DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5)
@ -314,9 +315,6 @@ EXPLOITED
CC CC
CC_TUNNEL CC_TUNNEL
Credentials.from_json
IJSONSerializable.from_json
IslandEventTopic.AGENT_CONNECTED IslandEventTopic.AGENT_CONNECTED
IslandEventTopic.CLEAR_SIMULATION_DATA IslandEventTopic.CLEAR_SIMULATION_DATA
IslandEventTopic.RESET_AGENT_CONFIGURATION IslandEventTopic.RESET_AGENT_CONFIGURATION