Common: Add error trapping to Credentials deserialization

This commit is contained in:
Mike Salvatore 2022-07-07 08:31:28 -04:00
parent 3f61ddd584
commit a18eb1cb73
5 changed files with 68 additions and 12 deletions

View File

@ -1,7 +1,7 @@
from .credential_component_type import CredentialComponentType from .credential_component_type import CredentialComponentType
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
from .validators import InvalidCredentialComponent from .validators import InvalidCredentialComponentError, InvalidCredentialsError
from .lm_hash import LMHash from .lm_hash import LMHash
from .nt_hash import NTHash from .nt_hash import NTHash

View File

@ -5,7 +5,16 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple
from marshmallow import Schema, fields, post_load, pre_dump from marshmallow import Schema, fields, post_load, pre_dump
from . import CredentialComponentType, LMHash, NTHash, Password, SSHKeypair, Username from . import (
CredentialComponentType,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from .i_credential_component import ICredentialComponent from .i_credential_component import ICredentialComponent
from .lm_hash import LMHashSchema from .lm_hash import LMHashSchema
from .nt_hash import NTHashSchema from .nt_hash import NTHashSchema
@ -57,7 +66,11 @@ class CredentialsSchema(Schema):
@staticmethod @staticmethod
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent: def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
credential_component_type = CredentialComponentType[data["credential_type"]] try:
credential_component_type = CredentialComponentType[data["credential_type"]]
except KeyError as err:
raise InvalidCredentialsError(f"Unknown credential component type {err}")
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type] credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[ credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component_type credential_component_type
@ -104,13 +117,23 @@ class Credentials:
@staticmethod @staticmethod
def from_mapping(credentials: Mapping) -> Credentials: def from_mapping(credentials: Mapping) -> Credentials:
deserialized_data = CredentialsSchema().load(credentials) try:
return Credentials(**deserialized_data) deserialized_data = CredentialsSchema().load(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except Exception as err:
raise InvalidCredentialsError(str(err))
@staticmethod @staticmethod
def from_json(credentials: str) -> Credentials: def from_json(credentials: str) -> Credentials:
deserialized_data = CredentialsSchema().loads(credentials) try:
return Credentials(**deserialized_data) deserialized_data = CredentialsSchema().loads(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except Exception as err:
raise InvalidCredentialsError(str(err))
@staticmethod @staticmethod
def to_json(credentials: Credentials) -> str: def to_json(credentials: Credentials) -> str:

View File

@ -9,7 +9,7 @@ _ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex) ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
class InvalidCredentialComponent(Exception): class InvalidCredentialComponentError(Exception):
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str): def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
self._credential_component_name = credential_component_class.__name__ self._credential_component_name = credential_component_class.__name__
self._message = message self._message = message
@ -21,6 +21,17 @@ class InvalidCredentialComponent(Exception):
) )
class InvalidCredentialsError(Exception):
def __init__(self, message: str):
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a Credentials object with the supplied, "
f"invalid data: {self._message}"
)
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent): def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
""" """
Validate a credential component Validate a credential component
@ -36,4 +47,4 @@ def credential_component_validator(schema: Schema, credential_component: ICreden
# makes it impossible to construct an invalid object # makes it impossible to construct an invalid object
schema.load(serialized_data) schema.load(serialized_data)
except Exception as err: except Exception as err:
raise InvalidCredentialComponent(credential_component.__class__, err) raise InvalidCredentialComponentError(credential_component.__class__, err)

View File

@ -1,6 +1,16 @@
import json import json
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username import pytest
from common.credentials import (
Credentials,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
USER1 = "test_user_1" USER1 = "test_user_1"
USER2 = "test_user_2" USER2 = "test_user_2"
@ -55,3 +65,15 @@ def test_credentials_deserialization__from_json():
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON) deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
assert deserialized_credentials == CREDENTIALS_OBJECT assert deserialized_credentials == CREDENTIALS_OBJECT
def test_credentials_deserialization__invalid_credentials():
invalid_data = {"secrets": [], "unknown_key": []}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component_type():
invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)

View File

@ -1,6 +1,6 @@
import pytest import pytest
from common.credentials import InvalidCredentialComponent, LMHash, NTHash from common.credentials import InvalidCredentialComponentError, LMHash, NTHash
VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D" VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D"
INVALID_HASHES = ( INVALID_HASHES = (
@ -22,5 +22,5 @@ def test_construct_valid_ntlm_hash(ntlm_hash_class):
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash)) @pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
def test_construct_invalid_ntlm_hash(ntlm_hash_class): def test_construct_invalid_ntlm_hash(ntlm_hash_class):
for invalid_hash in INVALID_HASHES: for invalid_hash in INVALID_HASHES:
with pytest.raises(InvalidCredentialComponent): with pytest.raises(InvalidCredentialComponentError):
ntlm_hash_class(invalid_hash) ntlm_hash_class(invalid_hash)