Common: Add error trapping to Credentials deserialization

This commit is contained in:
Mike Salvatore 2022-07-07 08:31:28 -04:00
parent 3f61ddd584
commit a18eb1cb73
5 changed files with 68 additions and 12 deletions

View File

@ -1,7 +1,7 @@
from .credential_component_type import CredentialComponentType
from .i_credential_component import ICredentialComponent
from .validators import InvalidCredentialComponent
from .validators import InvalidCredentialComponentError, InvalidCredentialsError
from .lm_hash import LMHash
from .nt_hash import NTHash

View File

@ -5,7 +5,16 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple
from marshmallow import Schema, fields, post_load, pre_dump
from . import CredentialComponentType, LMHash, NTHash, Password, SSHKeypair, Username
from . import (
CredentialComponentType,
InvalidCredentialComponentError,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
from .i_credential_component import ICredentialComponent
from .lm_hash import LMHashSchema
from .nt_hash import NTHashSchema
@ -57,7 +66,11 @@ class CredentialsSchema(Schema):
@staticmethod
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
credential_component_type = CredentialComponentType[data["credential_type"]]
try:
credential_component_type = CredentialComponentType[data["credential_type"]]
except KeyError as err:
raise InvalidCredentialsError(f"Unknown credential component type {err}")
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
credential_component_type
@ -104,13 +117,23 @@ class Credentials:
@staticmethod
def from_mapping(credentials: Mapping) -> Credentials:
deserialized_data = CredentialsSchema().load(credentials)
return Credentials(**deserialized_data)
try:
deserialized_data = CredentialsSchema().load(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except Exception as err:
raise InvalidCredentialsError(str(err))
@staticmethod
def from_json(credentials: str) -> Credentials:
deserialized_data = CredentialsSchema().loads(credentials)
return Credentials(**deserialized_data)
try:
deserialized_data = CredentialsSchema().loads(credentials)
return Credentials(**deserialized_data)
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
raise err
except Exception as err:
raise InvalidCredentialsError(str(err))
@staticmethod
def to_json(credentials: Credentials) -> str:

View File

@ -9,7 +9,7 @@ _ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
class InvalidCredentialComponent(Exception):
class InvalidCredentialComponentError(Exception):
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
self._credential_component_name = credential_component_class.__name__
self._message = message
@ -21,6 +21,17 @@ class InvalidCredentialComponent(Exception):
)
class InvalidCredentialsError(Exception):
def __init__(self, message: str):
self._message = message
def __str__(self) -> str:
return (
f"Cannot construct a Credentials object with the supplied, "
f"invalid data: {self._message}"
)
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
"""
Validate a credential component
@ -36,4 +47,4 @@ def credential_component_validator(schema: Schema, credential_component: ICreden
# makes it impossible to construct an invalid object
schema.load(serialized_data)
except Exception as err:
raise InvalidCredentialComponent(credential_component.__class__, err)
raise InvalidCredentialComponentError(credential_component.__class__, err)

View File

@ -1,6 +1,16 @@
import json
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
import pytest
from common.credentials import (
Credentials,
InvalidCredentialsError,
LMHash,
NTHash,
Password,
SSHKeypair,
Username,
)
USER1 = "test_user_1"
USER2 = "test_user_2"
@ -55,3 +65,15 @@ def test_credentials_deserialization__from_json():
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
assert deserialized_credentials == CREDENTIALS_OBJECT
def test_credentials_deserialization__invalid_credentials():
invalid_data = {"secrets": [], "unknown_key": []}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)
def test_credentials_deserialization__invalid_component_type():
invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]}
with pytest.raises(InvalidCredentialsError):
Credentials.from_mapping(invalid_data)

View File

@ -1,6 +1,6 @@
import pytest
from common.credentials import InvalidCredentialComponent, LMHash, NTHash
from common.credentials import InvalidCredentialComponentError, LMHash, NTHash
VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D"
INVALID_HASHES = (
@ -22,5 +22,5 @@ def test_construct_valid_ntlm_hash(ntlm_hash_class):
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
def test_construct_invalid_ntlm_hash(ntlm_hash_class):
for invalid_hash in INVALID_HASHES:
with pytest.raises(InvalidCredentialComponent):
with pytest.raises(InvalidCredentialComponentError):
ntlm_hash_class(invalid_hash)