Merge pull request #2067 from guardicore/1965-credentials-serialization

1965 credentials serialization
This commit is contained in:
Mike Salvatore 2022-07-07 11:31:41 -04:00 committed by GitHub
commit 402a5f5860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 606 additions and 64 deletions

View File

@ -19,12 +19,14 @@ from .agent_sub_configurations import (
class InvalidConfigurationError(Exception): class InvalidConfigurationError(Exception):
pass def __init__(self, message: str):
self._message = message
def __str__(self) -> str:
INVALID_CONFIGURATION_ERROR_MESSAGE = ( return (
"Cannot construct an AgentConfiguration object with the supplied, invalid data:" f"Cannot construct an AgentConfiguration object with the supplied, invalid data: "
) f"{self._message}"
)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -42,7 +44,7 @@ class AgentConfiguration:
try: try:
AgentConfigurationSchema().dump(self) AgentConfigurationSchema().dump(self)
except Exception as err: except Exception as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") raise InvalidConfigurationError(str(err))
@staticmethod @staticmethod
def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration: def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration:
@ -59,7 +61,7 @@ class AgentConfiguration:
config_dict = AgentConfigurationSchema().load(config_mapping) config_dict = AgentConfigurationSchema().load(config_mapping)
return AgentConfiguration(**config_dict) return AgentConfiguration(**config_dict)
except MarshmallowError as err: except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") raise InvalidConfigurationError(str(err))
@staticmethod @staticmethod
def from_json(config_json: str) -> AgentConfiguration: def from_json(config_json: str) -> AgentConfiguration:
@ -75,7 +77,7 @@ class AgentConfiguration:
config_dict = AgentConfigurationSchema().loads(config_json) config_dict = AgentConfigurationSchema().loads(config_json)
return AgentConfiguration(**config_dict) return AgentConfiguration(**config_dict)
except MarshmallowError as err: except MarshmallowError as err:
raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") raise InvalidConfigurationError(str(err))
@staticmethod @staticmethod
def to_json(config: AgentConfiguration) -> str: def to_json(config: AgentConfiguration) -> str:

View File

@ -1,8 +1,12 @@
from .credential_component_type import CredentialComponentType from .credential_component_type import CredentialComponentType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
from .credentials import Credentials
from .validators import InvalidCredentialComponentError, InvalidCredentialsError
from .lm_hash import LMHash from .lm_hash import LMHash
from .nt_hash import NTHash from .nt_hash import NTHash
from .password import Password from .password import Password
from .ssh_keypair import SSHKeypair from .ssh_keypair import SSHKeypair
from .username import Username from .username import Username
from .credentials import Credentials

View File

@ -0,0 +1,20 @@
from marshmallow import Schema, post_load, validate
from marshmallow_enum import EnumField
from common.utils.code_utils import del_key
from . import CredentialComponentType
class CredentialTypeField(EnumField):
def __init__(self, credential_component_type: CredentialComponentType):
super().__init__(
CredentialComponentType, validate=validate.Equal(credential_component_type)
)
class CredentialComponentSchema(Schema):
@post_load
def _strip_credential_type(self, data, **kwargs):
del_key(data, "credential_type")
return data

View File

@ -1,10 +1,144 @@
from dataclasses import dataclass from __future__ import annotations
from typing import Tuple
from dataclasses import dataclass
from typing import Any, Mapping, MutableMapping, Sequence, Tuple
from marshmallow import Schema, fields, post_load, pre_dump
from marshmallow.exceptions import MarshmallowError
from . import (
CredentialComponentType,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
from .lm_hash import LMHashSchema
from .nt_hash import NTHashSchema
from .password import PasswordSchema
from .ssh_keypair import SSHKeypairSchema
from .username import UsernameSchema
CREDENTIAL_COMPONENT_TYPE_TO_CLASS = {
CredentialComponentType.LM_HASH: LMHash,
CredentialComponentType.NT_HASH: NTHash,
CredentialComponentType.PASSWORD: Password,
CredentialComponentType.SSH_KEYPAIR: SSHKeypair,
CredentialComponentType.USERNAME: Username,
}
CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA = {
CredentialComponentType.LM_HASH: LMHashSchema(),
CredentialComponentType.NT_HASH: NTHashSchema(),
CredentialComponentType.PASSWORD: PasswordSchema(),
CredentialComponentType.SSH_KEYPAIR: SSHKeypairSchema(),
CredentialComponentType.USERNAME: UsernameSchema(),
}
class CredentialsSchema(Schema):
# Use fields.List instead of fields.Tuple because marshmallow requires fields.Tuple to have a
# fixed length.
identities = fields.List(fields.Mapping())
secrets = fields.List(fields.Mapping())
@post_load
def _make_credentials(
self, data: MutableMapping, **kwargs: Mapping[str, Any]
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
data["identities"] = tuple(
[
CredentialsSchema._build_credential_component(component)
for component in data["identities"]
]
)
data["secrets"] = tuple(
[
CredentialsSchema._build_credential_component(component)
for component in data["secrets"]
]
)
return data
@staticmethod
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
try:
credential_component_type = CredentialComponentType[data["credential_type"]]
except KeyError as err:
raise InvalidCredentialsError(f"Unknown credential component type {err}")
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component_type
]
try:
return credential_component_class(**credential_component_schema.load(data))
except MarshmallowError as err:
raise InvalidCredentialComponentError(credential_component_class, str(err))
@pre_dump
def _serialize_credentials(
self, credentials: Credentials, **kwargs
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
data = {}
data["identities"] = tuple(
[
CredentialsSchema._serialize_credential_component(component)
for component in credentials.identities
]
)
data["secrets"] = tuple(
[
CredentialsSchema._serialize_credential_component(component)
for component in credentials.secrets
]
)
return data
@staticmethod
def _serialize_credential_component(
credential_component: ICredentialComponent,
) -> Mapping[str, Any]:
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component.credential_type
]
return credential_component_schema.dump(credential_component)
@dataclass(frozen=True) @dataclass(frozen=True)
class Credentials: class Credentials:
identities: Tuple[ICredentialComponent] identities: Tuple[ICredentialComponent]
secrets: Tuple[ICredentialComponent] secrets: Tuple[ICredentialComponent]
@staticmethod
def from_mapping(credentials: Mapping) -> Credentials:
try:
deserialized_data = CredentialsSchema().load(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except MarshmallowError as err:
raise InvalidCredentialsError(str(err))
@staticmethod
def from_json(credentials: str) -> Credentials:
try:
deserialized_data = CredentialsSchema().loads(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except MarshmallowError as err:
raise InvalidCredentialsError(str(err))
@staticmethod
def to_json(credentials: Credentials) -> str:
return CredentialsSchema().dumps(credentials)

View File

@ -1,6 +1,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from marshmallow import fields
from . import CredentialComponentType, ICredentialComponent from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
from .validators import credential_component_validator, ntlm_hash_validator
class LMHashSchema(CredentialComponentSchema):
credential_type = CredentialTypeField(CredentialComponentType.LM_HASH)
lm_hash = fields.Str(validate=ntlm_hash_validator)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -9,3 +18,6 @@ class LMHash(ICredentialComponent):
default=CredentialComponentType.LM_HASH, init=False default=CredentialComponentType.LM_HASH, init=False
) )
lm_hash: str lm_hash: str
def __post_init__(self):
credential_component_validator(LMHashSchema(), self)

View File

@ -1,6 +1,15 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from marshmallow import fields
from . import CredentialComponentType, ICredentialComponent from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
from .validators import credential_component_validator, ntlm_hash_validator
class NTHashSchema(CredentialComponentSchema):
credential_type = CredentialTypeField(CredentialComponentType.NT_HASH)
nt_hash = fields.Str(validate=ntlm_hash_validator)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -9,3 +18,6 @@ class NTHash(ICredentialComponent):
default=CredentialComponentType.NT_HASH, init=False default=CredentialComponentType.NT_HASH, init=False
) )
nt_hash: str nt_hash: str
def __post_init__(self):
credential_component_validator(NTHashSchema(), self)

View File

@ -1,6 +1,14 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from marshmallow import fields
from . import CredentialComponentType, ICredentialComponent from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class PasswordSchema(CredentialComponentSchema):
credential_type = CredentialTypeField(CredentialComponentType.PASSWORD)
password = fields.Str()
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -1,6 +1,17 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from marshmallow import fields
from . import CredentialComponentType, ICredentialComponent from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class SSHKeypairSchema(CredentialComponentSchema):
credential_type = CredentialTypeField(CredentialComponentType.SSH_KEYPAIR)
# TODO: Find a list of valid formats for ssh keys and add validators.
# See https://github.com/nemchik/ssh-key-regex
private_key = fields.Str()
public_key = fields.Str()
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -1,6 +1,14 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from marshmallow import fields
from . import CredentialComponentType, ICredentialComponent from . import CredentialComponentType, ICredentialComponent
from .credential_component_schema import CredentialComponentSchema, CredentialTypeField
class UsernameSchema(CredentialComponentSchema):
credential_type = CredentialTypeField(CredentialComponentType.USERNAME)
username = fields.Str()
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -0,0 +1,50 @@
import re
from typing import Type
from marshmallow import Schema, validate
from . import ICredentialComponent
_ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
class InvalidCredentialComponentError(Exception):
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
self._credential_component_name = credential_component_class.__name__
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a {self._credential_component_name} object with the supplied, "
f"invalid data: {self._message}"
)
class InvalidCredentialsError(Exception):
def __init__(self, message: str):
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a Credentials object with the supplied, "
f"invalid data: {self._message}"
)
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
"""
Validate a credential component
:param schema: A marshmallow schema used for validating the component
:param credential_component: A credential component to be validated
:raises InvalidCredentialComponent: if the credential_component contains invalid data
"""
try:
serialized_data = schema.dump(credential_component)
# This will raise an exception if the object is invalid. Calling this in __post__init()
# makes it impossible to construct an invalid object
schema.load(serialized_data)
except Exception as err:
raise InvalidCredentialComponentError(credential_component.__class__, err)

View File

@ -1,5 +1,7 @@
import inspect import inspect
from typing import Any, MutableMapping, Sequence, Type, TypeVar from typing import Any, Sequence, Type, TypeVar
from common.utils.code_utils import del_key
T = TypeVar("T") T = TypeVar("T")
@ -43,7 +45,7 @@ class DIContainer:
) )
self._type_registry[interface] = concrete_type self._type_registry[interface] = concrete_type
DIContainer._del_key(self._instance_registry, interface) del_key(self._instance_registry, interface)
def register_instance(self, interface: Type[T], instance: T): def register_instance(self, interface: Type[T], instance: T):
""" """
@ -59,7 +61,7 @@ class DIContainer:
) )
self._instance_registry[interface] = instance self._instance_registry[interface] = instance
DIContainer._del_key(self._type_registry, interface) del_key(self._type_registry, interface)
def register_convention(self, type_: Type[T], name: str, instance: T): def register_convention(self, type_: Type[T], name: str, instance: T):
""" """
@ -168,8 +170,8 @@ class DIContainer:
:param interface: The interface to release :param interface: The interface to release
""" """
DIContainer._del_key(self._type_registry, interface) del_key(self._type_registry, interface)
DIContainer._del_key(self._instance_registry, interface) del_key(self._instance_registry, interface)
def release_convention(self, type_: Type[T], name: str): def release_convention(self, type_: Type[T], name: str):
""" """
@ -179,18 +181,4 @@ class DIContainer:
:param name: The name of the dependency parameter :param name: The name of the dependency parameter
""" """
convention_identifier = (type_, name) convention_identifier = (type_, name)
DIContainer._del_key(self._convention_registry, convention_identifier) del_key(self._convention_registry, convention_identifier)
@staticmethod
def _del_key(mapping: MutableMapping[T, Any], key: T):
"""
Deletes key from mapping. Unlike the `del` keyword, this function does not raise a KeyError
if the key does not exist.
:param mapping: A mapping from which a key will be deleted
:param key: A key to delete from `mapping`
"""
try:
del mapping[key]
except KeyError:
pass

View File

@ -1,5 +1,7 @@
import queue import queue
from typing import Any, List from typing import Any, List, MutableMapping, TypeVar
T = TypeVar("T")
class abstractstatic(staticmethod): class abstractstatic(staticmethod):
@ -30,3 +32,19 @@ def queue_to_list(q: queue.Queue) -> List[Any]:
pass pass
return list_ return list_
def del_key(mapping: MutableMapping[T, Any], key: T):
"""
Delete a key from a mapping.
Unlike the `del` keyword, this function does not raise a KeyError
if the key does not exist.
:param mapping: A mapping from which a key will be deleted
:param key: A key to delete from `mapping`
"""
try:
del mapping[key]
except KeyError:
pass

View File

@ -1,9 +1,8 @@
import enum
import json import json
from typing import Dict, Iterable from typing import Iterable
from common.common_consts.telem_categories import TelemCategoryEnum from common.common_consts.telem_categories import TelemCategoryEnum
from common.credentials import Credentials, ICredentialComponent from common.credentials import Credentials
from infection_monkey.telemetry.base_telem import BaseTelem from infection_monkey.telemetry.base_telem import BaseTelem
@ -24,24 +23,5 @@ class CredentialsTelem(BaseTelem):
def send(self, log_data=True): def send(self, log_data=True):
super().send(log_data=False) super().send(log_data=False)
def get_data(self) -> Dict: def get_data(self):
# TODO: At a later time we can consider factoring this into a Serializer class or similar. return [json.loads(Credentials.to_json(c)) for c in self._credentials]
return json.loads(json.dumps(self._credentials, default=_serialize))
def _serialize(obj):
if isinstance(obj, enum.Enum):
return obj.name
if isinstance(obj, ICredentialComponent):
# This is a workaround for ICredentialComponents that are implemented as dataclasses. If the
# credential_type attribute is populated with `field(init=False, ...)`, then credential_type
# is not added to the object's __dict__ attribute. The biggest risk of this workaround is
# that we might change the name of the credential_type field in ICredentialComponents, but
# automated refactoring tools would not detect that this string needs to change. This is
# mittigated by the call to getattr() below, which will raise an AttributeException if the
# attribute name changes and a unit test will fail under these conditions.
credential_type = getattr(obj, "credential_type")
return dict(obj.__dict__, **{"credential_type": credential_type})
return getattr(obj, "__dict__", str(obj))

View File

@ -0,0 +1,148 @@
from typing import Any, Mapping
import pytest
from marshmallow.exceptions import ValidationError
from common.credentials import (
CredentialComponentType,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from common.credentials.lm_hash import LMHashSchema
from common.credentials.nt_hash import NTHashSchema
from common.credentials.password import PasswordSchema
from common.credentials.ssh_keypair import SSHKeypairSchema
from common.credentials.username import UsernameSchema
PARAMETRIZED_PARAMETER_NAMES = (
"credential_component_class, schema_class, credential_component_type, credential_component_data"
)
PARAMETRIZED_PARAMETER_VALUES = [
(Username, UsernameSchema, CredentialComponentType.USERNAME, {"username": "test_user"}),
(Password, PasswordSchema, CredentialComponentType.PASSWORD, {"password": "123456"}),
(
LMHash,
LMHashSchema,
CredentialComponentType.LM_HASH,
{"lm_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
),
(
NTHash,
NTHashSchema,
CredentialComponentType.NT_HASH,
{"nt_hash": "E52CAC67419A9A224A3B108F3FA6CB6D"},
),
(
SSHKeypair,
SSHKeypairSchema,
CredentialComponentType.SSH_KEYPAIR,
{"public_key": "TEST_PUBLIC_KEY", "private_key": "TEST_PRIVATE_KEY"},
),
]
INVALID_COMPONENT_DATA = {
CredentialComponentType.USERNAME: ({"username": None}, {"username": 1}, {"username": 2.0}),
CredentialComponentType.PASSWORD: ({"password": None}, {"password": 1}, {"password": 2.0}),
CredentialComponentType.LM_HASH: (
{"lm_hash": None},
{"lm_hash": 1},
{"lm_hash": 2.0},
{"lm_hash": "0123456789012345678901234568901"},
{"lm_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
),
CredentialComponentType.NT_HASH: (
{"nt_hash": None},
{"nt_hash": 1},
{"nt_hash": 2.0},
{"nt_hash": "0123456789012345678901234568901"},
{"nt_hash": "E52GAC67419A9A224A3B108F3FA6CB6D"},
),
CredentialComponentType.SSH_KEYPAIR: (
{"public_key": None, "private_key": "TEST_PRIVATE_KEY"},
{"public_key": "TEST_PUBLIC_KEY", "private_key": None},
{"public_key": 1, "private_key": "TEST_PRIVATE_KEY"},
{"public_key": "TEST_PUBLIC_KEY", "private_key": 999},
),
}
def build_component_dict(
credential_component_type: CredentialComponentType, credential_component_data: Mapping[str, Any]
):
return {"credential_type": credential_component_type.name, **credential_component_data}
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_credential_component_serialize(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
constructed_object = credential_component_class(**credential_component_data)
serialized_object = schema.dump(constructed_object)
assert serialized_object == build_component_dict(
credential_component_type, credential_component_data
)
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_credential_component_deserialize(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
credential_dict = build_component_dict(credential_component_type, credential_component_data)
expected_deserialized_object = credential_component_class(**credential_component_data)
deserialized_object = credential_component_class(**schema.load(credential_dict))
assert deserialized_object == expected_deserialized_object
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_invalid_credential_type(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
invalid_component_dict = build_component_dict(
credential_component_type, credential_component_data
)
invalid_component_dict["credential_type"] = "INVALID"
schema = schema_class()
with pytest.raises(ValidationError):
credential_component_class(**schema.load(invalid_component_dict))
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_encorrect_credential_type(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
incorrect_component_dict = build_component_dict(
credential_component_type, credential_component_data
)
incorrect_component_dict["credential_type"] = (
CredentialComponentType.USERNAME.name
if credential_component_type != CredentialComponentType.USERNAME
else CredentialComponentType.PASSWORD
)
schema = schema_class()
with pytest.raises(ValidationError):
credential_component_class(**schema.load(incorrect_component_dict))
@pytest.mark.parametrize(PARAMETRIZED_PARAMETER_NAMES, PARAMETRIZED_PARAMETER_VALUES)
def test_invalid_values(
credential_component_class, schema_class, credential_component_type, credential_component_data
):
schema = schema_class()
for invalid_component_data in INVALID_COMPONENT_DATA[credential_component_type]:
component_dict = build_component_dict(credential_component_type, invalid_component_data)
with pytest.raises(ValidationError):
credential_component_class(**schema.load(component_dict))

View File

@ -0,0 +1,89 @@
import json
import pytest
from common.credentials import (
Credentials,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
USER1 = "test_user_1"
USER2 = "test_user_2"
PASSWORD = "12435"
LM_HASH = "AEBD4DE384C7EC43AAD3B435B51404EE"
NT_HASH = "7A21990FCD3D759941E45C490F143D5F"
PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY"
CREDENTIALS_DICT = {
"identities": [
{"credential_type": "USERNAME", "username": USER1},
{"credential_type": "USERNAME", "username": USER2},
],
"secrets": [
{"credential_type": "PASSWORD", "password": PASSWORD},
{"credential_type": "LM_HASH", "lm_hash": LM_HASH},
{"credential_type": "NT_HASH", "nt_hash": NT_HASH},
{
"credential_type": "SSH_KEYPAIR",
"public_key": PUBLIC_KEY,
"private_key": PRIVATE_KEY,
},
],
}
CREDENTIALS_JSON = json.dumps(CREDENTIALS_DICT)
IDENTITIES = (Username(USER1), Username(USER2))
SECRETS = (
Password(PASSWORD),
LMHash(LM_HASH),
NTHash(NT_HASH),
SSHKeypair(PRIVATE_KEY, PUBLIC_KEY),
)
CREDENTIALS_OBJECT = Credentials(IDENTITIES, SECRETS)
def test_credentials_serialization_json():
serialized_credentials = Credentials.to_json(CREDENTIALS_OBJECT)
assert json.loads(serialized_credentials) == CREDENTIALS_DICT
def test_credentials_deserialization__from_mapping():
deserialized_credentials = Credentials.from_mapping(CREDENTIALS_DICT)
assert deserialized_credentials == CREDENTIALS_OBJECT
def test_credentials_deserialization__from_json():
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
assert deserialized_credentials == CREDENTIALS_OBJECT
def test_credentials_deserialization__invalid_credentials():
invalid_data = {"secrets": [], "unknown_key": []}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component_type():
invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component():
invalid_data = {
"secrets": [],
"identities": [{"credential_type": "USERNAME", "unknown_field": "user1"}],
}
with pytest.raises(InvalidCredentialComponentError):
Credentials.from_mapping(invalid_data)

View File

@ -0,0 +1,26 @@
import pytest
from common.credentials import InvalidCredentialComponentError, LMHash, NTHash
VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D"
INVALID_HASHES = (
0,
1,
2.0,
"invalid",
"0123456789012345678901234568901",
"E52GAC67419A9A224A3B108F3FA6CB6D",
)
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
def test_construct_valid_ntlm_hash(ntlm_hash_class):
# This test will fail if an exception is raised
ntlm_hash_class(VALID_HASH)
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
def test_construct_invalid_ntlm_hash(ntlm_hash_class):
for invalid_hash in INVALID_HASHES:
with pytest.raises(InvalidCredentialComponentError):
ntlm_hash_class(invalid_hash)

View File

@ -1,6 +1,6 @@
from queue import Queue from queue import Queue
from common.utils.code_utils import queue_to_list from common.utils.code_utils import del_key, queue_to_list
def test_empty_queue_to_empty_list(): def test_empty_queue_to_empty_list():
@ -20,3 +20,23 @@ def test_queue_to_list():
list_ = queue_to_list(q) list_ = queue_to_list(q)
assert list_ == expected_list assert list_ == expected_list
def test_del_key__deletes_key():
key_to_delete = "a"
my_dict = {"a": 1, "b": 2}
expected_dict = {k: v for k, v in my_dict.items() if k != key_to_delete}
del_key(my_dict, key_to_delete)
assert my_dict == expected_dict
def test_del_key__nonexistant_key():
key_to_delete = "a"
my_dict = {"a": 1, "b": 2}
del_key(my_dict, key_to_delete)
# This test passes if the following call does not raise an error
del_key(my_dict, key_to_delete)

View File

@ -56,14 +56,16 @@ def test_pypykatz_result_parsing_duplicates(monkeypatch):
def test_pypykatz_result_parsing_defaults(monkeypatch): def test_pypykatz_result_parsing_defaults(monkeypatch):
win_creds = [ win_creds = [
WindowsCredentials(username="user2", password="secret2", lm_hash="lm_hash"), WindowsCredentials(
username="user2", password="secret2", lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E"
),
] ]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
# Expected credentials # Expected credentials
username = Username("user2") username = Username("user2")
password = Password("secret2") password = Password("secret2")
lm_hash = LMHash("lm_hash") lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
expected_credentials = Credentials([username], [password, lm_hash]) expected_credentials = Credentials([username], [password, lm_hash])
collected_credentials = collect_credentials() collected_credentials = collect_credentials()
@ -73,12 +75,17 @@ def test_pypykatz_result_parsing_defaults(monkeypatch):
def test_pypykatz_result_parsing_no_identities(monkeypatch): def test_pypykatz_result_parsing_no_identities(monkeypatch):
win_creds = [ win_creds = [
WindowsCredentials(username="", password="", ntlm_hash="ntlm_hash", lm_hash="lm_hash"), WindowsCredentials(
username="",
password="",
ntlm_hash="E9F85516721DDC218359AD5280DB4450",
lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E",
),
] ]
patch_pypykatz(win_creds, monkeypatch) patch_pypykatz(win_creds, monkeypatch)
lm_hash = LMHash("lm_hash") lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
nt_hash = NTHash("ntlm_hash") nt_hash = NTHash("E9F85516721DDC218359AD5280DB4450")
expected_credentials = Credentials([], [lm_hash, nt_hash]) expected_credentials = Credentials([], [lm_hash, nt_hash])
collected_credentials = collect_credentials() collected_credentials = collect_credentials()

View File

@ -38,8 +38,7 @@ def test_credential_telem_send(spy_send_telemetry, credentials_for_test):
telem = CredentialsTelem([credentials_for_test]) telem = CredentialsTelem([credentials_for_test])
telem.send() telem.send()
expected_data = json.dumps(expected_data, cls=telem.json_encoder) assert json.loads(spy_send_telemetry.data) == expected_data
assert spy_send_telemetry.data == expected_data
assert spy_send_telemetry.telem_category == "credentials" assert spy_send_telemetry.telem_category == "credentials"

View File

@ -196,6 +196,12 @@ _make_tcp_scan_configuration # unused method (monkey/common/configuration/agent
_make_network_scan_configuration # unused method (monkey/common/configuration/agent_configuration.py:110) _make_network_scan_configuration # unused method (monkey/common/configuration/agent_configuration.py:110)
_make_propagation_configuration # unused method (monkey/common/configuration/agent_configuration.py:167) _make_propagation_configuration # unused method (monkey/common/configuration/agent_configuration.py:167)
# Credentials
_strip_credential_type # unused method (monkey/common/credentials/password.py:18)
_make_credentials # unused method (monkey/common/credentials/credentials:39)
_serialize_credentials # unused method (monkey/common/credentials/credentials:67)
# Models # Models
_make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py:19 _make_simulation # unused method (monkey/monkey_island/cc/models/simulation.py:19