From 62ce91b59b0eb83d0cea4693e4adf935392c1473 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Fri, 15 Jul 2022 12:19:12 -0400 Subject: [PATCH] Common: Prevent invalid Credentials objects from being constructed --- monkey/common/credentials/credentials.py | 14 ++++++++++++++ .../common/credentials/test_credentials.py | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py index 676e77275..c65ffab55 100644 --- a/monkey/common/credentials/credentials.py +++ b/monkey/common/credentials/credentials.py @@ -49,6 +49,9 @@ class CredentialsSchema(Schema): def _make_credentials( self, data: MutableMapping, **kwargs: Mapping[str, Any] ) -> Mapping[str, Sequence[Mapping[str, Any]]]: + if not any(data.values()): + raise InvalidCredentialsError("At least one credentials component must be defined") + data["identity"] = CredentialsSchema._build_credential_component(data["identity"]) data["secret"] = CredentialsSchema._build_credential_component(data["secret"]) @@ -110,6 +113,17 @@ class Credentials(IJSONSerializable): identity: Optional[ICredentialComponent] secret: Optional[ICredentialComponent] + def __post_init__(self): + schema = CredentialsSchema() + try: + serialized_data = schema.dump(self) + + # This will raise an exception if the object is invalid. Calling this in __post__init() + # makes it impossible to construct an invalid object + schema.load(serialized_data) + except Exception as err: + raise InvalidCredentialsError(err) + @staticmethod def from_mapping(credentials: Mapping) -> Credentials: """ diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py index d26cde8fa..4eeb8647f 100644 --- a/monkey/tests/unit_tests/common/credentials/test_credentials.py +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -115,3 +115,8 @@ def test_credentials_deserialization__invalid_component(): } with pytest.raises(InvalidCredentialComponentError): Credentials.from_mapping(invalid_data) + + +def test_create_credentials__none_none(): + with pytest.raises(InvalidCredentialsError): + Credentials(None, None)