diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index f1c08beb1..676e77275 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Mapping, MutableMapping, Sequence, Type +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Type from marshmallow import Schema, fields, post_load, pre_dump from marshmallow.exceptions import MarshmallowError @@ -42,8 +42,8 @@ CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA: Mapping[CredentialComponentType, Sche class CredentialsSchema(Schema): - identity = fields.Mapping() - secret = fields.Mapping() + identity = fields.Mapping(allow_none=True) + secret = fields.Mapping(allow_none=True) @post_load def _make_credentials( @@ -55,9 +55,16 @@ class CredentialsSchema(Schema): return data @staticmethod - def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent: + def _build_credential_component( + credential_component: Optional[Mapping[str, Any]] + ) -> Optional[ICredentialComponent]: + if credential_component is None: + return None + try: - credential_component_type = CredentialComponentType[data["credential_type"]] + credential_component_type = CredentialComponentType[ + credential_component["credential_type"] + ] except KeyError as err: raise InvalidCredentialsError(f"Unknown credential component type {err}") @@ -67,7 +74,9 @@ class CredentialsSchema(Schema): ] try: - return credential_component_class(**credential_component_schema.load(data)) + return credential_component_class( + **credential_component_schema.load(credential_component) + ) except MarshmallowError as err: raise InvalidCredentialComponentError(credential_component_class, str(err)) @@ -84,8 +93,11 @@ class CredentialsSchema(Schema): @staticmethod def _serialize_credential_component( - credential_component: ICredentialComponent, - ) -> Mapping[str, Any]: + credential_component: Optional[ICredentialComponent], + ) -> Optional[Mapping[str, Any]]: + if credential_component is None: + return None + credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ credential_component.credential_type ] @@ -95,8 +107,8 @@ class CredentialsSchema(Schema): @dataclass(frozen=True) class Credentials(IJSONSerializable): - identity: ICredentialComponent - secret: ICredentialComponent + identity: Optional[ICredentialComponent] + secret: Optional[ICredentialComponent] @staticmethod def from_mapping(credentials: Mapping) -> Credentials: diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py index fe3e919ae..d26cde8fa 100644 --- a/monkey/tests/unit_tests/common/credentials/test_credentials.py +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -22,14 +22,15 @@ from common.credentials import ( Username, ) -IDENTITIES = [Username(USERNAME)] -IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}] +IDENTITIES = [Username(USERNAME), None] +IDENTITY_DICTS = [{"credential_type": "USERNAME", "username": USERNAME}, None] SECRETS = ( Password(PASSWORD_1), LMHash(LM_HASH), NTHash(NT_HASH), SSHKeypair(PRIVATE_KEY, PUBLIC_KEY), + None, ) SECRET_DICTS = [ {"credential_type": "PASSWORD", "password": PASSWORD_1}, @@ -40,13 +41,19 @@ SECRET_DICTS = [ "public_key": PUBLIC_KEY, "private_key": PRIVATE_KEY, }, + None, ] -CREDENTIALS = [Credentials(identity, secret) for identity, secret in product(IDENTITIES, SECRETS)] +CREDENTIALS = [ + Credentials(identity, secret) + for identity, secret in product(IDENTITIES, SECRETS) + if not (identity is None and secret is None) +] CREDENTIALS_DICTS = [ {"identity": identity, "secret": secret} for identity, secret in product(IDENTITY_DICTS, SECRET_DICTS) + if not (identity is None and secret is None) ]