Common: Add error trapping to Credentials deserialization
This commit is contained in:
parent
3f61ddd584
commit
a18eb1cb73
|
@ -1,7 +1,7 @@
|
||||||
from .credential_component_type import CredentialComponentType
|
from .credential_component_type import CredentialComponentType
|
||||||
from .i_credential_component import ICredentialComponent
|
from .i_credential_component import ICredentialComponent
|
||||||
|
|
||||||
from .validators import InvalidCredentialComponent
|
from .validators import InvalidCredentialComponentError, InvalidCredentialsError
|
||||||
|
|
||||||
from .lm_hash import LMHash
|
from .lm_hash import LMHash
|
||||||
from .nt_hash import NTHash
|
from .nt_hash import NTHash
|
||||||
|
|
|
@ -5,7 +5,16 @@ from typing import Any, Mapping, MutableMapping, Sequence, Tuple
|
||||||
|
|
||||||
from marshmallow import Schema, fields, post_load, pre_dump
|
from marshmallow import Schema, fields, post_load, pre_dump
|
||||||
|
|
||||||
from . import CredentialComponentType, LMHash, NTHash, Password, SSHKeypair, Username
|
from . import (
|
||||||
|
CredentialComponentType,
|
||||||
|
InvalidCredentialComponentError,
|
||||||
|
InvalidCredentialsError,
|
||||||
|
LMHash,
|
||||||
|
NTHash,
|
||||||
|
Password,
|
||||||
|
SSHKeypair,
|
||||||
|
Username,
|
||||||
|
)
|
||||||
from .i_credential_component import ICredentialComponent
|
from .i_credential_component import ICredentialComponent
|
||||||
from .lm_hash import LMHashSchema
|
from .lm_hash import LMHashSchema
|
||||||
from .nt_hash import NTHashSchema
|
from .nt_hash import NTHashSchema
|
||||||
|
@ -57,7 +66,11 @@ class CredentialsSchema(Schema):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
|
def _build_credential_component(data: Mapping[str, Any]) -> ICredentialComponent:
|
||||||
credential_component_type = CredentialComponentType[data["credential_type"]]
|
try:
|
||||||
|
credential_component_type = CredentialComponentType[data["credential_type"]]
|
||||||
|
except KeyError as err:
|
||||||
|
raise InvalidCredentialsError(f"Unknown credential component type {err}")
|
||||||
|
|
||||||
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
|
credential_component_class = CREDENTIAL_COMPONENT_TYPE_TO_CLASS[credential_component_type]
|
||||||
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
|
credential_component_schema = CREDENTIAL_COMPONENT_TYPE_TO_CLASS_SCHEMA[
|
||||||
credential_component_type
|
credential_component_type
|
||||||
|
@ -104,13 +117,23 @@ class Credentials:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_mapping(credentials: Mapping) -> Credentials:
|
def from_mapping(credentials: Mapping) -> Credentials:
|
||||||
deserialized_data = CredentialsSchema().load(credentials)
|
try:
|
||||||
return Credentials(**deserialized_data)
|
deserialized_data = CredentialsSchema().load(credentials)
|
||||||
|
return Credentials(**deserialized_data)
|
||||||
|
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
|
||||||
|
raise err
|
||||||
|
except Exception as err:
|
||||||
|
raise InvalidCredentialsError(str(err))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(credentials: str) -> Credentials:
|
def from_json(credentials: str) -> Credentials:
|
||||||
deserialized_data = CredentialsSchema().loads(credentials)
|
try:
|
||||||
return Credentials(**deserialized_data)
|
deserialized_data = CredentialsSchema().loads(credentials)
|
||||||
|
return Credentials(**deserialized_data)
|
||||||
|
except (InvalidCredentialsError, InvalidCredentialComponentError) as err:
|
||||||
|
raise err
|
||||||
|
except Exception as err:
|
||||||
|
raise InvalidCredentialsError(str(err))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_json(credentials: Credentials) -> str:
|
def to_json(credentials: Credentials) -> str:
|
||||||
|
|
|
@ -9,7 +9,7 @@ _ntlm_hash_regex = re.compile(r"^[a-fA-F0-9]{32}$")
|
||||||
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
|
ntlm_hash_validator = validate.Regexp(regex=_ntlm_hash_regex)
|
||||||
|
|
||||||
|
|
||||||
class InvalidCredentialComponent(Exception):
|
class InvalidCredentialComponentError(Exception):
|
||||||
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
|
def __init__(self, credential_component_class: Type[ICredentialComponent], message: str):
|
||||||
self._credential_component_name = credential_component_class.__name__
|
self._credential_component_name = credential_component_class.__name__
|
||||||
self._message = message
|
self._message = message
|
||||||
|
@ -21,6 +21,17 @@ class InvalidCredentialComponent(Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCredentialsError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self._message = message
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Cannot construct a Credentials object with the supplied, "
|
||||||
|
f"invalid data: {self._message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
|
def credential_component_validator(schema: Schema, credential_component: ICredentialComponent):
|
||||||
"""
|
"""
|
||||||
Validate a credential component
|
Validate a credential component
|
||||||
|
@ -36,4 +47,4 @@ def credential_component_validator(schema: Schema, credential_component: ICreden
|
||||||
# makes it impossible to construct an invalid object
|
# makes it impossible to construct an invalid object
|
||||||
schema.load(serialized_data)
|
schema.load(serialized_data)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise InvalidCredentialComponent(credential_component.__class__, err)
|
raise InvalidCredentialComponentError(credential_component.__class__, err)
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
|
import pytest
|
||||||
|
|
||||||
|
from common.credentials import (
|
||||||
|
Credentials,
|
||||||
|
InvalidCredentialsError,
|
||||||
|
LMHash,
|
||||||
|
NTHash,
|
||||||
|
Password,
|
||||||
|
SSHKeypair,
|
||||||
|
Username,
|
||||||
|
)
|
||||||
|
|
||||||
USER1 = "test_user_1"
|
USER1 = "test_user_1"
|
||||||
USER2 = "test_user_2"
|
USER2 = "test_user_2"
|
||||||
|
@ -55,3 +65,15 @@ def test_credentials_deserialization__from_json():
|
||||||
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
|
deserialized_credentials = Credentials.from_json(CREDENTIALS_JSON)
|
||||||
|
|
||||||
assert deserialized_credentials == CREDENTIALS_OBJECT
|
assert deserialized_credentials == CREDENTIALS_OBJECT
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_deserialization__invalid_credentials():
|
||||||
|
invalid_data = {"secrets": [], "unknown_key": []}
|
||||||
|
with pytest.raises(InvalidCredentialsError):
|
||||||
|
Credentials.from_mapping(invalid_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_deserialization__invalid_component_type():
|
||||||
|
invalid_data = {"secrets": [], "identities": [{"credential_type": "FAKE", "username": "user1"}]}
|
||||||
|
with pytest.raises(InvalidCredentialsError):
|
||||||
|
Credentials.from_mapping(invalid_data)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from common.credentials import InvalidCredentialComponent, LMHash, NTHash
|
from common.credentials import InvalidCredentialComponentError, LMHash, NTHash
|
||||||
|
|
||||||
VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D"
|
VALID_HASH = "E520AC67419A9A224A3B108F3FA6CB6D"
|
||||||
INVALID_HASHES = (
|
INVALID_HASHES = (
|
||||||
|
@ -22,5 +22,5 @@ def test_construct_valid_ntlm_hash(ntlm_hash_class):
|
||||||
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
|
@pytest.mark.parametrize("ntlm_hash_class", (LMHash, NTHash))
|
||||||
def test_construct_invalid_ntlm_hash(ntlm_hash_class):
|
def test_construct_invalid_ntlm_hash(ntlm_hash_class):
|
||||||
for invalid_hash in INVALID_HASHES:
|
for invalid_hash in INVALID_HASHES:
|
||||||
with pytest.raises(InvalidCredentialComponent):
|
with pytest.raises(InvalidCredentialComponentError):
|
||||||
ntlm_hash_class(invalid_hash)
|
ntlm_hash_class(invalid_hash)
|
||||||
|
|
Loading…
Reference in New Issue