From a18eb1cb73e48194fb18ecd832b00394ac38388d Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 7 Jul 2022 08:31:28 -0400 Subject: [PATCH] Common: Add error trapping to Credentials deserialization --- monkey/common/credentials/__init__.py | 2 +- monkey/common/credentials/credentials.py | 35 +++++++++++++++---- monkey/common/credentials/validators.py | 15 ++++++-- .../common/credentials/test_credentials.py | 24 ++++++++++++- .../common/credentials/test_ntlm_hash.py | 4 +-- 5 files changed, 68 insertions(+), 12 deletions(-) diff --git a/monkey/common/credentials/__init__.py b/monkey/common/credentials/__init__.py index f49f6af03..6275e0985 100644 --- a/monkey/common/credentials/__init__.py +++ b/monkey/common/credentials/__init__.py @@ -1,7 +1,7 @@ from .credential_component_type import CredentialComponentType from .i_credential_component import ICredentialComponent -from .validators import InvalidCredentialComponent +from .validators import InvalidCredentialComponentError, InvalidCredentialsError from .lm_hash import LMHash from .nt_hash import NTHash diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index 75775f22d..1837fd8b6 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -5,7 +5,16 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple from marshmallow import Schema, fields, post_load, pre_dump -from . import CredentialComponentType, LMHash, NTHash, Password, SSHKeypair, Username +from . import ( + CredentialComponentType, + InvalidCredentialComponentError, + InvalidCredentialsError, + LMHash, + NTHash, + Password, + SSHKeypair, + Username, +) from .i_credential_component import ICredentialComponent from .lm_hash import LMHashSchema from .nt_hash import NTHashSchema @@ -57,7 +66,11 @@ class CredentialsSchema(Schema): @staticmethod def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent: - credential_component_type = CredentialComponentType[data["credential_type"]] + try: + credential_component_type = CredentialComponentType[data["credential_type"]] + except KeyError as err: + raise InvalidCredentialsError(f"Unknown credential component type {err}") + credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type] credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ credential_component_type @@ -104,13 +117,23 @@ class Credentials: @staticmethod def from_mapping(credentials: Mapping) -> Credentials: - deserialized_data = CredentialsSchema().load(credentials) - return Credentials(**deserialized_data) + try: + deserialized_data = CredentialsSchema().load(credentials) + return Credentials(**deserialized_data) + except (InvalidCredentialsError, InvalidCredentialComponentError) as err: + raise err + except Exception as err: + raise InvalidCredentialsError(str(err)) @staticmethod def from_json(credentials: str) -> Credentials: - deserialized_data = CredentialsSchema().loads(credentials) - return Credentials(**deserialized_data) + try: + deserialized_data = CredentialsSchema().loads(credentials) + return Credentials(**deserialized_data) + except (InvalidCredentialsError, InvalidCredentialComponentError) as err: + raise err + except Exception as err: + raise InvalidCredentialsError(str(err)) @staticmethod def to_json(credentials: Credentials) -> str: diff --git a/monkey/common/credentials/validators.py b/monkey/common/credentials/validators.py index aa5fc7735..2e0e2e93c 100644 --- a/monkey/common/credentials/validators.py +++ b/monkey/common/credentials/validators.py @@ -9,7 +9,7 @@ _ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$") ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex) -class InvalidCredentialComponent(Exception): +class InvalidCredentialComponentError(Exception): def __init__(self, credential_component_class: Type[ICredentialComponent], message: str): self._credential_component_name = credential_component_class.__name__ self._message = message @@ -21,6 +21,17 @@ class InvalidCredentialComponent(Exception): ) +class InvalidCredentialsError(Exception): + def __init__(self, message: str): + self._message = message + + def __str__(self) -> str: + return ( + f"Cannot construct a Credentials object with the supplied, " + f"invalid data: {self._message}" + ) + + def credential_component_validator(schema: Schema, credential_component: ICredentialComponent): """ Validate a credential component @@ -36,4 +47,4 @@ def credential_component_validator(schema: Schema, credential_component: ICreden # makes it impossible to construct an invalid object schema.load(serialized_data) except Exception as err: - raise InvalidCredentialComponent(credential_component.__class__, err) + raise InvalidCredentialComponentError(credential_component.__class__, err) diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py index 4496a2f05..d6fd55ef5 100644 --- a/monkey/tests/unit_tests/common/credentials/test_credentials.py +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -1,6 +1,16 @@ import json -from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username +import pytest + +from common.credentials import ( + Credentials, + InvalidCredentialsError, + LMHash, + NTHash, + Password, + SSHKeypair, + Username, +) USER1 = "test_user_1" USER2 = "test_user_2" @@ -55,3 +65,15 @@ def test_credentials_deserialization__from_json(): deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON) assert deserialized_credentials == CREDENTIALS_OBJECT + + +def test_credentials_deserialization__invalid_credentials(): + invalid_data = {"secrets": [], "unknown_key": []} + with pytest.raises(InvalidCredentialsError): + Credentials.from_mapping(invalid_data) + + +def test_credentials_deserialization__invalid_component_type(): + invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]} + with pytest.raises(InvalidCredentialsError): + Credentials.from_mapping(invalid_data) diff --git a/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py b/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py index 28f7bcaae..5f50110e8 100644 --- a/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py +++ b/monkey/tests/unit_tests/common/credentials/test_ntlm_hash.py @@ -1,6 +1,6 @@ import pytest -from common.credentials import InvalidCredentialComponent, LMHash, NTHash +from common.credentials import InvalidCredentialComponentError, LMHash, NTHash VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D" INVALID_HASHES = ( @@ -22,5 +22,5 @@ def test_construct_valid_ntlm_hash(ntlm_hash_class): @pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash)) def test_construct_invalid_ntlm_hash(ntlm_hash_class): for invalid_hash in INVALID_HASHES: - with pytest.raises(InvalidCredentialComponent): + with pytest.raises(InvalidCredentialComponentError): ntlm_hash_class(invalid_hash)