Common: Refactor credentials from marshmallow to pydantic
This commit is contained in:
parent
3ac60988a8
commit
f868f03ea7
|
@ -1,8 +1,3 @@
|
|||
from .credential_component_type import CredentialComponentType
|
||||
from .i_credential_component import ICredentialComponent
|
||||
|
||||
from .validators import InvalidCredentialComponentError, InvalidCredentialsError
|
||||
|
||||
from .lm_hash import LMHash
|
||||
from .nt_hash import NTHash
|
||||
from .password import Password
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
from marshmallow import Schema, post_load, validate
|
||||
from marshmallow_enum import EnumField
|
||||
|
||||
from common.utils.code_utils import del_key
|
||||
|
||||
from . import CredentialComponentType
|
||||
|
||||
|
||||
class CredentialTypeField(EnumField):
|
||||
def __init__(self, credential_component_type: CredentialComponentType):
|
||||
super().__init__(
|
||||
CredentialComponentType, validate=validate.Equal(credential_component_type)
|
||||
)
|
||||
|
||||
|
||||
class CredentialComponentSchema(Schema):
|
||||
@post_load
|
||||
def _strip_credential_type(self, data, **kwargs):
|
||||
del_key(data, "credential_type")
|
||||
return data
|
|
@ -1,9 +0,0 @@
|
|||
from enum import Enum, auto
|
||||
|
||||
|
||||
class CredentialComponentType(Enum):
|
||||
USERNAME = auto()
|
||||
PASSWORD = auto()
|
||||
NT_HASH = auto()
|
||||
LM_HASH = auto()
|
||||
SSH_KEYPAIR = auto()
|
|
@ -1,190 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping, Optional, Type
|
||||
from typing import Optional, Union
|
||||
|
||||
from marshmallow import Schema, fields, post_load, pre_dump
|
||||
from marshmallow.exceptions import MarshmallowError
|
||||
from ..base_models import InfectionMonkeyBaseModel
|
||||
from . import LMHash, NTHash, Password, SSHKeypair, Username
|
||||
|
||||
from ..utils import IJSONSerializable
|
||||
from . import (
|
||||
CredentialComponentType,
|
||||
InvalidCredentialComponentError,
|
||||
InvalidCredentialsError,
|
||||
LMHash,
|
||||
NTHash,
|
||||
Password,
|
||||
SSHKeypair,
|
||||
Username,
|
||||
)
|
||||
from .i_credential_component import ICredentialComponent
|
||||
from .lm_hash import LMHashSchema
|
||||
from .nt_hash import NTHashSchema
|
||||
from .password import PasswordSchema
|
||||
from .ssh_keypair import SSHKeypairSchema
|
||||
from .username import UsernameSchema
|
||||
|
||||
CREDENTIAL_COMPONENT_TYPE_TO_CLASS: Mapping[CredentialComponentType, Type[ICredentialComponent]] = {
|
||||
CredentialComponentType.LM_HASH: LMHash,
|
||||
CredentialComponentType.NT_HASH: NTHash,
|
||||
CredentialComponentType.PASSWORD: Password,
|
||||
CredentialComponentType.SSH_KEYPAIR: SSHKeypair,
|
||||
CredentialComponentType.USERNAME: Username,
|
||||
}
|
||||
|
||||
CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA: Mapping[CredentialComponentType, Schema] = {
|
||||
CredentialComponentType.LM_HASH: LMHashSchema(),
|
||||
CredentialComponentType.NT_HASH: NTHashSchema(),
|
||||
CredentialComponentType.PASSWORD: PasswordSchema(),
|
||||
CredentialComponentType.SSH_KEYPAIR: SSHKeypairSchema(),
|
||||
CredentialComponentType.USERNAME: UsernameSchema(),
|
||||
}
|
||||
|
||||
CredentialComponentMapping = Optional[Mapping[str, Any]]
|
||||
CredentialsMapping = Mapping[str, CredentialComponentMapping]
|
||||
Secret = Union[Password, LMHash, NTHash, SSHKeypair]
|
||||
Identity = Username
|
||||
|
||||
|
||||
class CredentialsSchema(Schema):
|
||||
identity = fields.Mapping(allow_none=True)
|
||||
secret = fields.Mapping(allow_none=True)
|
||||
|
||||
@post_load
|
||||
def _make_credentials(
|
||||
self,
|
||||
credentials: CredentialsMapping,
|
||||
**kwargs: Mapping[str, Any],
|
||||
) -> Mapping[str, Optional[ICredentialComponent]]:
|
||||
if not any(credentials.values()):
|
||||
raise InvalidCredentialsError("At least one credentials component must be defined")
|
||||
|
||||
return {
|
||||
key: CredentialsSchema._build_credential_component(credential_component_mapping)
|
||||
for key, credential_component_mapping in credentials.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_credential_component(
|
||||
credential_component: CredentialComponentMapping,
|
||||
) -> Optional[ICredentialComponent]:
|
||||
if credential_component is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
credential_component_type = CredentialComponentType[
|
||||
credential_component["credential_type"]
|
||||
]
|
||||
except KeyError as err:
|
||||
raise InvalidCredentialsError(f"Unknown credential component type {err}")
|
||||
|
||||
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
|
||||
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
|
||||
credential_component_type
|
||||
]
|
||||
|
||||
try:
|
||||
return credential_component_class(
|
||||
**credential_component_schema.load(credential_component)
|
||||
)
|
||||
except MarshmallowError as err:
|
||||
raise InvalidCredentialComponentError(credential_component_class, str(err))
|
||||
|
||||
@pre_dump
|
||||
def _serialize_credentials(self, credentials: Credentials, **kwargs) -> CredentialsMapping:
|
||||
return {
|
||||
"identity": CredentialsSchema._serialize_credential_component(credentials.identity),
|
||||
"secret": CredentialsSchema._serialize_credential_component(credentials.secret),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_credential_component(
|
||||
credential_component: Optional[ICredentialComponent],
|
||||
) -> CredentialComponentMapping:
|
||||
if credential_component is None:
|
||||
return None
|
||||
|
||||
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
|
||||
credential_component.credential_type
|
||||
]
|
||||
|
||||
return credential_component_schema.dump(credential_component)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Credentials(IJSONSerializable):
|
||||
identity: Optional[ICredentialComponent]
|
||||
secret: Optional[ICredentialComponent]
|
||||
|
||||
def __post_init__(self):
|
||||
schema = CredentialsSchema()
|
||||
try:
|
||||
serialized_data = schema.dump(self)
|
||||
|
||||
# This will raise an exception if the object is invalid. Calling this in __post__init()
|
||||
# makes it impossible to construct an invalid object
|
||||
schema.load(serialized_data)
|
||||
except Exception as err:
|
||||
raise InvalidCredentialsError(err)
|
||||
|
||||
@staticmethod
|
||||
def from_mapping(credentials: CredentialsMapping) -> Credentials:
|
||||
"""
|
||||
Construct a Credentials object from a Mapping
|
||||
|
||||
:param credentials: A mapping that represents a Credentials object
|
||||
:return: A Credentials object
|
||||
:raises InvalidCredentialsError: If the provided Mapping does not represent a valid
|
||||
Credentials object
|
||||
:raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets`
|
||||
are not a valid ICredentialComponent
|
||||
"""
|
||||
|
||||
try:
|
||||
deserialized_data = CredentialsSchema().load(credentials)
|
||||
return Credentials(**deserialized_data)
|
||||
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
|
||||
raise err
|
||||
except MarshmallowError as err:
|
||||
raise InvalidCredentialsError(str(err))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, credentials: str) -> Credentials:
|
||||
"""
|
||||
Construct a Credentials object from a JSON string
|
||||
|
||||
:param credentials: A JSON string that represents a Credentials object
|
||||
:return: A Credentials object
|
||||
:raises InvalidCredentialsError: If the provided JSON does not represent a valid
|
||||
Credentials object
|
||||
:raises InvalidCredentialComponentError: If any of the contents of `identities` or `secrets`
|
||||
are not a valid ICredentialComponent
|
||||
"""
|
||||
|
||||
try:
|
||||
deserialized_data = CredentialsSchema().loads(credentials)
|
||||
return Credentials(**deserialized_data)
|
||||
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
|
||||
raise err
|
||||
except MarshmallowError as err:
|
||||
raise InvalidCredentialsError(str(err))
|
||||
|
||||
@staticmethod
|
||||
def to_mapping(credentials: Credentials) -> CredentialsMapping:
|
||||
"""
|
||||
Serialize a Credentials object to a Mapping
|
||||
|
||||
:param credentials: A Credentials object
|
||||
:return: A mapping representing a Credentials object
|
||||
"""
|
||||
|
||||
return CredentialsSchema().dump(credentials)
|
||||
|
||||
@classmethod
|
||||
def to_json(cls, credentials: Credentials) -> str:
|
||||
"""
|
||||
Serialize a Credentials object to JSON
|
||||
|
||||
:param credentials: A Credentials object
|
||||
:return: A JSON string representing a Credentials object
|
||||
"""
|
||||
|
||||
return CredentialsSchema().dumps(credentials)
|
||||
class Credentials(InfectionMonkeyBaseModel):
|
||||
identity: Optional[Identity]
|
||||
secret: Optional[Secret]
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from . import CredentialComponentType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ICredentialComponent(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def credential_type(self) -> CredentialComponentType:
|
||||
pass
|
|
@ -1,25 +1,16 @@
|
|||
from dataclasses import field
|
||||
from typing import ClassVar
|
||||
import re
|
||||
|
||||
from marshmallow import fields
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic import validator
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from . import CredentialComponentType, ICredentialComponent
|
||||
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
|
||||
from .validators import credential_component_validator, ntlm_hash_validator
|
||||
from .validators import ntlm_hash_regex
|
||||
|
||||
|
||||
class LMHashSchema(CredentialComponentSchema):
|
||||
credential_type = CredentialTypeField(CredentialComponentType.LM_HASH)
|
||||
lm_hash = fields.Str(validate=ntlm_hash_validator)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMHash(ICredentialComponent):
|
||||
credential_type: ClassVar[CredentialComponentType] = field(
|
||||
default=CredentialComponentType.LM_HASH, init=False
|
||||
)
|
||||
class LMHash(BaseModel):
|
||||
lm_hash: str
|
||||
|
||||
def __post_init__(self):
|
||||
credential_component_validator(LMHashSchema(), self)
|
||||
@validator("lm_hash")
|
||||
def validate_hash_format(cls, nt_hash):
|
||||
if not re.match(ntlm_hash_regex, nt_hash):
|
||||
raise ValueError(f"Invalid lm hash provided: {nt_hash}")
|
||||
return nt_hash
|
||||
|
|
|
@ -1,25 +1,16 @@
|
|||
from dataclasses import field
|
||||
from typing import ClassVar
|
||||
import re
|
||||
|
||||
from marshmallow import fields
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic import validator
|
||||
|
||||
from . import CredentialComponentType, ICredentialComponent
|
||||
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
|
||||
from .validators import credential_component_validator, ntlm_hash_validator
|
||||
from ..base_models import InfectionMonkeyBaseModel
|
||||
from .validators import ntlm_hash_regex
|
||||
|
||||
|
||||
class NTHashSchema(CredentialComponentSchema):
|
||||
credential_type = CredentialTypeField(CredentialComponentType.NT_HASH)
|
||||
nt_hash = fields.Str(validate=ntlm_hash_validator)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NTHash(ICredentialComponent):
|
||||
credential_type: ClassVar[CredentialComponentType] = field(
|
||||
default=CredentialComponentType.NT_HASH, init=False
|
||||
)
|
||||
class NTHash(InfectionMonkeyBaseModel):
|
||||
nt_hash: str
|
||||
|
||||
def __post_init__(self):
|
||||
credential_component_validator(NTHashSchema(), self)
|
||||
@validator("nt_hash")
|
||||
def validate_hash_format(cls, nt_hash):
|
||||
if not re.match(ntlm_hash_regex, nt_hash):
|
||||
raise ValueError(f"Invalid nt hash provided: {nt_hash}")
|
||||
return nt_hash
|
||||
|
|
|
@ -1,21 +1,5 @@
|
|||
from dataclasses import field
|
||||
from typing import ClassVar
|
||||
|
||||
from marshmallow import fields
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from . import CredentialComponentType, ICredentialComponent
|
||||
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
|
||||
from ..base_models import InfectionMonkeyBaseModel
|
||||
|
||||
|
||||
class PasswordSchema(CredentialComponentSchema):
|
||||
credential_type = CredentialTypeField(CredentialComponentType.PASSWORD)
|
||||
password = fields.Str()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Password(ICredentialComponent):
|
||||
credential_type: ClassVar[CredentialComponentType] = field(
|
||||
default=CredentialComponentType.PASSWORD, init=False
|
||||
)
|
||||
class Password(InfectionMonkeyBaseModel):
|
||||
password: str
|
||||
|
|
|
@ -1,27 +1,6 @@
|
|||
from dataclasses import field
|
||||
from typing import ClassVar
|
||||
|
||||
from marshmallow import fields
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from . import CredentialComponentType, ICredentialComponent
|
||||
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
|
||||
from ..base_models import InfectionMonkeyBaseModel
|
||||
|
||||
|
||||
class SSHKeypairSchema(CredentialComponentSchema):
|
||||
credential_type: ClassVar[CredentialComponentType] = CredentialTypeField(
|
||||
CredentialComponentType.SSH_KEYPAIR
|
||||
)
|
||||
# TODO: Find a list of valid formats for ssh keys and add validators.
|
||||
# See https://github.com/nemchik/ssh-key-regex
|
||||
private_key = fields.Str()
|
||||
public_key = fields.Str()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSHKeypair(ICredentialComponent):
|
||||
credential_type: ClassVar[CredentialComponentType] = field(
|
||||
default=CredentialComponentType.SSH_KEYPAIR, init=False
|
||||
)
|
||||
class SSHKeypair(InfectionMonkeyBaseModel):
|
||||
private_key: str
|
||||
public_key: str
|
||||
|
|
|
@ -1,21 +1,5 @@
|
|||
from dataclasses import field
|
||||
from typing import ClassVar
|
||||
|
||||
from marshmallow import fields
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from . import CredentialComponentType, ICredentialComponent
|
||||
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
|
||||
from ..base_models import InfectionMonkeyBaseModel
|
||||
|
||||
|
||||
class UsernameSchema(CredentialComponentSchema):
|
||||
credential_type = CredentialTypeField(CredentialComponentType.USERNAME)
|
||||
username = fields.Str()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Username(ICredentialComponent):
|
||||
credential_type: ClassVar[CredentialComponentType] = field(
|
||||
default=CredentialComponentType.USERNAME, init=False
|
||||
)
|
||||
class Username(InfectionMonkeyBaseModel):
|
||||
username: str
|
||||
|
|
|
@ -1,50 +1,3 @@
|
|||
import re
|
||||
from typing import Type
|
||||
|
||||
from marshmallow import Schema, validate
|
||||
|
||||
from . import ICredentialComponent
|
||||
|
||||
_ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
|
||||
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
|
||||
|
||||
|
||||
class InvalidCredentialComponentError(Exception):
|
||||
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
|
||||
self._credential_component_name = credential_component_class.__name__
|
||||
self._message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Cannot construct a {self._credential_component_name} object with the supplied, "
|
||||
f"invalid data: {self._message}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidCredentialsError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self._message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Cannot construct a Credentials object with the supplied, "
|
||||
f"invalid data: {self._message}"
|
||||
)
|
||||
|
||||
|
||||
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
|
||||
"""
|
||||
Validate a credential component
|
||||
|
||||
:param schema: A marshmallow schema used for validating the component
|
||||
:param credential_component: A credential component to be validated
|
||||
:raises InvalidCredentialComponent: if the credential_component contains invalid data
|
||||
"""
|
||||
try:
|
||||
serialized_data = schema.dump(credential_component)
|
||||
|
||||
# This will raise an exception if the object is invalid. Calling this in __post__init()
|
||||
# makes it impossible to construct an invalid object
|
||||
schema.load(serialized_data)
|
||||
except Exception as err:
|
||||
raise InvalidCredentialComponentError(credential_component.__class__, err)
|
||||
ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from itertools import product
|
||||
|
||||
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
|
||||
|
||||
USERNAME = "m0nk3y_user"
|
||||
|
@ -10,22 +12,47 @@ PASSWORD_3 = "rubberbabybuggybumpers"
|
|||
PUBLIC_KEY = "MY_PUBLIC_KEY"
|
||||
PRIVATE_KEY = "MY_PRIVATE_KEY"
|
||||
|
||||
PASSWORD_CREDENTIALS_1 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_1))
|
||||
PASSWORD_CREDENTIALS_2 = Credentials(identity=Username(USERNAME), secret=Password(PASSWORD_2))
|
||||
LM_HASH_CREDENTIALS = Credentials(identity=Username(SPECIAL_USERNAME), secret=LMHash(LM_HASH))
|
||||
NT_HASH_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=NTHash(NT_HASH))
|
||||
SSH_KEY_CREDENTIALS = Credentials(
|
||||
identity=Username(USERNAME), secret=SSHKeypair(PRIVATE_KEY, PUBLIC_KEY)
|
||||
)
|
||||
EMPTY_SECRET_CREDENTIALS = Credentials(identity=Username(USERNAME), secret=None)
|
||||
EMPTY_IDENTITY_CREDENTIALS = Credentials(identity=None, secret=Password(PASSWORD_3))
|
||||
IDENTITIES = [Username(username=USERNAME), None, Username(username=SPECIAL_USERNAME)]
|
||||
IDENTITY_DICTS = [{"username": USERNAME}, None]
|
||||
|
||||
PROPAGATION_CREDENTIALS = [
|
||||
PASSWORD_CREDENTIALS_1,
|
||||
LM_HASH_CREDENTIALS,
|
||||
NT_HASH_CREDENTIALS,
|
||||
PASSWORD_CREDENTIALS_2,
|
||||
SSH_KEY_CREDENTIALS,
|
||||
EMPTY_SECRET_CREDENTIALS,
|
||||
EMPTY_IDENTITY_CREDENTIALS,
|
||||
SECRETS = (
|
||||
Password(password=PASSWORD_1),
|
||||
Password(password=PASSWORD_2),
|
||||
Password(password=PASSWORD_3),
|
||||
LMHash(lm_hash=LM_HASH),
|
||||
NTHash(nt_hash=NT_HASH),
|
||||
SSHKeypair(private_key=PRIVATE_KEY, public_key=PUBLIC_KEY),
|
||||
None,
|
||||
)
|
||||
SECRET_DICTS = [
|
||||
{"password": PASSWORD_1},
|
||||
{"lm_hash": LM_HASH},
|
||||
{"nt_hash": NT_HASH},
|
||||
{
|
||||
"public_key": PUBLIC_KEY,
|
||||
"private_key": PRIVATE_KEY,
|
||||
},
|
||||
None,
|
||||
]
|
||||
|
||||
CREDENTIALS = [
|
||||
Credentials(identity=identity, secret=secret)
|
||||
for identity, secret in product(IDENTITIES, SECRETS)
|
||||
]
|
||||
|
||||
FULL_CREDENTIALS = [
|
||||
credentials
|
||||
for credentials in CREDENTIALS
|
||||
if not (credentials.identity is None and credentials.secret is None)
|
||||
]
|
||||
|
||||
CREDENTIALS_DICTS = [
|
||||
{"identity": identity, "secret": secret}
|
||||
for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS)
|
||||
]
|
||||
|
||||
FULL_CREDENTIALS_DICTS = [
|
||||
credentials
|
||||
for credentials in CREDENTIALS_DICTS
|
||||
if not (credentials["identity"] is None and credentials["secret"] is None)
|
||||
]
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
from typing import Any, Mapping
|
||||
|
||||
import pytest
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
from common.credentials import (
|
||||
CredentialComponentType,
|
||||
LMHash,
|
||||
NTHash,
|
||||
Password,
|
||||
SSHKeypair,
|
||||
Username,
|
||||
)
|
||||
from common.credentials.lm_hash import LMHashSchema
|
||||
from common.credentials.nt_hash import NTHashSchema
|
||||
from common.credentials.password import PasswordSchema
|
||||
from common.credentials.ssh_keypair import SSHKeypairSchema
|
||||
from common.credentials.username import UsernameSchema
|
||||
|
||||
PARAMETRIZED_PARAMETER_NAMES = (
|
||||
"credential_component_class, schema_class, credential_component_type, credential_component_data"
|
||||
)
|
||||
|
||||
PARAMETRIZED_PARAMETER_VALUES = [
|
||||
(Username, UsernameSchema, CredentialComponentType.USERNAME, {"username": "test_user"}),
|
||||
(Password, PasswordSchema, CredentialComponentType.PASSWORD, {"password": "123456"}),
|
||||
(
|
||||
LMHash,
|
||||
LMHashSchema,
|
||||
CredentialComponentType.LM_HASH,
|
||||
{"lm_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
|
||||
),
|
||||
(
|
||||
NTHash,
|
||||
NTHashSchema,
|
||||
CredentialComponentType.NT_HASH,
|
||||
{"nt_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
|
||||
),
|
||||
(
|
||||
SSHKeypair,
|
||||
SSHKeypairSchema,
|
||||
CredentialComponentType.SSH_KEYPAIR,
|
||||
{"public_key": "TEST_PUBLIC_KEY", "private_key": "TEST_PRIVATE_KEY"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
INVALID_COMPONENT_DATA = {
|
||||
CredentialComponentType.USERNAME: ({"username": None}, {"username": 1}, {"username": 2.0}),
|
||||
CredentialComponentType.PASSWORD: ({"password": None}, {"password": 1}, {"password": 2.0}),
|
||||
CredentialComponentType.LM_HASH: (
|
||||
{"lm_hash": None},
|
||||
{"lm_hash": 1},
|
||||
{"lm_hash": 2.0},
|
||||
{"lm_hash": "0123456789012345678901234568901"},
|
||||
{"lm_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
|
||||
),
|
||||
CredentialComponentType.NT_HASH: (
|
||||
{"nt_hash": None},
|
||||
{"nt_hash": 1},
|
||||
{"nt_hash": 2.0},
|
||||
{"nt_hash": "0123456789012345678901234568901"},
|
||||
{"nt_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
|
||||
),
|
||||
CredentialComponentType.SSH_KEYPAIR: (
|
||||
{"public_key": None, "private_key": "TEST_PRIVATE_KEY"},
|
||||
{"public_key": "TEST_PUBLIC_KEY", "private_key": None},
|
||||
{"public_key": 1, "private_key": "TEST_PRIVATE_KEY"},
|
||||
{"public_key": "TEST_PUBLIC_KEY", "private_key": 999},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def build_component_dict(
|
||||
credential_component_type: CredentialComponentType, credential_component_data: Mapping[str, Any]
|
||||
):
|
||||
return {"credential_type": credential_component_type.name, **credential_component_data}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
|
||||
def test_credential_component_serialize(
|
||||
credential_component_class, schema_class, credential_component_type, credential_component_data
|
||||
):
|
||||
schema = schema_class()
|
||||
constructed_object = credential_component_class(**credential_component_data)
|
||||
|
||||
serialized_object = schema.dump(constructed_object)
|
||||
|
||||
assert serialized_object == build_component_dict(
|
||||
credential_component_type, credential_component_data
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
|
||||
def test_credential_component_deserialize(
|
||||
credential_component_class, schema_class, credential_component_type, credential_component_data
|
||||
):
|
||||
schema = schema_class()
|
||||
credential_dict = build_component_dict(credential_component_type, credential_component_data)
|
||||
expected_deserialized_object = credential_component_class(**credential_component_data)
|
||||
|
||||
deserialized_object = credential_component_class(**schema.load(credential_dict))
|
||||
|
||||
assert deserialized_object == expected_deserialized_object
|
||||
|
||||
|
||||
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
|
||||
def test_invalid_credential_type(
|
||||
credential_component_class, schema_class, credential_component_type, credential_component_data
|
||||
):
|
||||
invalid_component_dict = build_component_dict(
|
||||
credential_component_type, credential_component_data
|
||||
)
|
||||
invalid_component_dict["credential_type"] = "INVALID"
|
||||
schema = schema_class()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
credential_component_class(**schema.load(invalid_component_dict))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
|
||||
def test_encorrect_credential_type(
|
||||
credential_component_class, schema_class, credential_component_type, credential_component_data
|
||||
):
|
||||
incorrect_component_dict = build_component_dict(
|
||||
credential_component_type, credential_component_data
|
||||
)
|
||||
incorrect_component_dict["credential_type"] = (
|
||||
CredentialComponentType.USERNAME.name
|
||||
if credential_component_type != CredentialComponentType.USERNAME
|
||||
else CredentialComponentType.PASSWORD
|
||||
)
|
||||
schema = schema_class()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
credential_component_class(**schema.load(incorrect_component_dict))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
|
||||
def test_invalid_values(
|
||||
credential_component_class, schema_class, credential_component_type, credential_component_data
|
||||
):
|
||||
schema = schema_class()
|
||||
|
||||
for invalid_component_data in INVALID_COMPONENT_DATA[credential_component_type]:
|
||||
component_dict = build_component_dict(credential_component_type, invalid_component_data)
|
||||
with pytest.raises(ValidationError):
|
||||
credential_component_class(**schema.load(component_dict))
|
|
@ -1,122 +1,14 @@
|
|||
import json
|
||||
from itertools import product
|
||||
|
||||
import pytest
|
||||
from tests.data_for_tests.propagation_credentials import (
|
||||
LM_HASH,
|
||||
NT_HASH,
|
||||
PASSWORD_1,
|
||||
PRIVATE_KEY,
|
||||
PUBLIC_KEY,
|
||||
USERNAME,
|
||||
)
|
||||
from tests.data_for_tests.propagation_credentials import CREDENTIALS, CREDENTIALS_DICTS
|
||||
|
||||
from common.credentials import (
|
||||
Credentials,
|
||||
InvalidCredentialComponentError,
|
||||
InvalidCredentialsError,
|
||||
LMHash,
|
||||
NTHash,
|
||||
Password,
|
||||
SSHKeypair,
|
||||
Username,
|
||||
)
|
||||
|
||||
IDENTITIES = [Username(USERNAME), None]
|
||||
IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}, None]
|
||||
|
||||
SECRETS = (
|
||||
Password(PASSWORD_1),
|
||||
LMHash(LM_HASH),
|
||||
NTHash(NT_HASH),
|
||||
SSHKeypair(PRIVATE_KEY, PUBLIC_KEY),
|
||||
None,
|
||||
)
|
||||
SECRET_DICTS = [
|
||||
{"credential_type": "PASSWORD", "password": PASSWORD_1},
|
||||
{"credential_type": "LM_HASH", "lm_hash": LM_HASH},
|
||||
{"credential_type": "NT_HASH", "nt_hash": NT_HASH},
|
||||
{
|
||||
"credential_type": "SSH_KEYPAIR",
|
||||
"public_key": PUBLIC_KEY,
|
||||
"private_key": PRIVATE_KEY,
|
||||
},
|
||||
None,
|
||||
]
|
||||
|
||||
CREDENTIALS = [
|
||||
Credentials(identity=identity, secret=secret)
|
||||
for identity, secret in product(IDENTITIES, SECRETS)
|
||||
if not (identity is None and secret is None)
|
||||
]
|
||||
|
||||
CREDENTIALS_DICTS = [
|
||||
{"identity": identity, "secret": secret}
|
||||
for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS)
|
||||
if not (identity is None and secret is None)
|
||||
]
|
||||
from common.credentials import Credentials
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
|
||||
)
|
||||
def test_credentials_serialization_json(credentials, expected_credentials_dict):
|
||||
serialized_credentials = Credentials.to_json(credentials)
|
||||
serialized_credentials = credentials.json()
|
||||
deserialized_credentials = Credentials.parse_raw(serialized_credentials)
|
||||
|
||||
assert json.loads(serialized_credentials) == expected_credentials_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"credentials, expected_credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
|
||||
)
|
||||
def test_credentials_serialization_mapping(credentials, expected_credentials_dict):
|
||||
serialized_credentials = Credentials.to_mapping(credentials)
|
||||
|
||||
assert serialized_credentials == expected_credentials_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
|
||||
)
|
||||
def test_credentials_deserialization__from_mapping(expected_credentials, credentials_dict):
|
||||
deserialized_credentials = Credentials.from_mapping(credentials_dict)
|
||||
|
||||
assert deserialized_credentials == expected_credentials
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_credentials, credentials_dict", zip(CREDENTIALS, CREDENTIALS_DICTS)
|
||||
)
|
||||
def test_credentials_deserialization__from_json(expected_credentials, credentials_dict):
|
||||
deserialized_credentials = Credentials.from_json(json.dumps(credentials_dict))
|
||||
|
||||
assert deserialized_credentials == expected_credentials
|
||||
|
||||
|
||||
def test_credentials_deserialization__invalid_credentials():
|
||||
invalid_data = {"secret": SECRET_DICTS[0], "unknown_key": []}
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
Credentials.from_mapping(invalid_data)
|
||||
|
||||
|
||||
def test_credentials_deserialization__invalid_component_type():
|
||||
invalid_data = {
|
||||
"secret": SECRET_DICTS[0],
|
||||
"identity": {"credential_type": "FAKE", "username": "user1"},
|
||||
}
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
Credentials.from_mapping(invalid_data)
|
||||
|
||||
|
||||
def test_credentials_deserialization__invalid_component():
|
||||
invalid_data = {
|
||||
"secret": SECRET_DICTS[0],
|
||||
"identity": {"credential_type": "USERNAME", "unknown_field": "user1"},
|
||||
}
|
||||
with pytest.raises(InvalidCredentialComponentError):
|
||||
Credentials.from_mapping(invalid_data)
|
||||
|
||||
|
||||
def test_create_credentials__none_none():
|
||||
with pytest.raises(InvalidCredentialsError):
|
||||
Credentials(None, None)
|
||||
assert credentials == deserialized_credentials
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import pytest
|
||||
|
||||
from common.credentials import LMHash
|
||||
|
||||
|
||||
def test_construct_valid_nt_hash(valid_ntlm_hash):
|
||||
# This test will fail if an exception is raised
|
||||
LMHash(lm_hash=valid_ntlm_hash)
|
||||
|
||||
|
||||
def test_construct_invalid_nt_hash(invalid_ntlm_hashes):
|
||||
for invalid_hash in invalid_ntlm_hashes:
|
||||
with pytest.raises(ValueError):
|
||||
LMHash(lm_hash=invalid_hash)
|
|
@ -7,8 +7,7 @@ from common.agent_configuration.agent_sub_configurations import (
|
|||
CustomPBAConfiguration,
|
||||
ScanTargetConfiguration,
|
||||
)
|
||||
from common.credentials import Credentials
|
||||
from common.utils import IJSONSerializable
|
||||
from common.credentials import LMHash, NTHash
|
||||
from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory
|
||||
from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue
|
||||
from monkey_island.cc.models import Report
|
||||
|
@ -166,6 +165,8 @@ LDAPServerFactory.buildProtocol
|
|||
get_file_sha256_hash
|
||||
strict_slashes # unused attribute (monkey/monkey_island/cc/app.py:96)
|
||||
post_breach_actions # unused variable (monkey\infection_monkey\config.py:95)
|
||||
LMHash.validate_hash_format
|
||||
NTHash.validate_hash_format
|
||||
|
||||
# Deployments
|
||||
DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5)
|
||||
|
@ -314,9 +315,6 @@ EXPLOITED
|
|||
CC
|
||||
CC_TUNNEL
|
||||
|
||||
Credentials.from_json
|
||||
IJSONSerializable.from_json
|
||||
|
||||
IslandEventTopic.AGENT_CONNECTED
|
||||
IslandEventTopic.CLEAR_SIMULATION_DATA
|
||||
IslandEventTopic.RESET_AGENT_CONFIGURATION
|
||||
|
|
Loading…
Reference in New Issue