From 62ce91b59b0eb83d0cea4693e4adf935392c1473 Mon Sep 17 00:00:00 2001
From: Mike Salvatore <mike.s.salvatore@gmail.com>
Date: Fri, 15 Jul 2022 12:19:12 -0400
Subject: [PATCH] Common: Prevent invalid Credentials objects from being
 constructed

---
 monkey/common/credentials/credentials.py           | 14 ++++++++++++++
 .../common/credentials/test_credentials.py         |  5 +++++
 2 files changed, 19 insertions(+)

diff --git a/monkey/common/credentials/credentials.py b/monkey/common/credentials/credentials.py
index 676e77275..c65ffab55 100644
--- a/monkey/common/credentials/credentials.py
+++ b/monkey/common/credentials/credentials.py
@@ -49,6 +49,9 @@ class CredentialsSchema(Schema):
     def _make_credentials(
         self, data: MutableMapping, **kwargs: Mapping[str, Any]
     ) -> Mapping[str, Sequence[Mapping[str, Any]]]:
+        if not any(data.values()):
+            raise InvalidCredentialsError("At least one credentials component must be defined")
+
         data["identity"] = CredentialsSchema._build_credential_component(data["identity"])
         data["secret"] = CredentialsSchema._build_credential_component(data["secret"])
 
@@ -110,6 +113,17 @@ class Credentials(IJSONSerializable):
     identity: Optional[ICredentialComponent]
     secret: Optional[ICredentialComponent]
 
+    def __post_init__(self):
+        schema = CredentialsSchema()
+        try:
+            serialized_data = schema.dump(self)
+
+            # This will raise an exception if the object is invalid. Calling this in __post__init()
+            # makes it impossible to construct an invalid object
+            schema.load(serialized_data)
+        except Exception as err:
+            raise InvalidCredentialsError(err)
+
     @staticmethod
     def from_mapping(credentials: Mapping) -> Credentials:
         """
diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py
index d26cde8fa..4eeb8647f 100644
--- a/monkey/tests/unit_tests/common/credentials/test_credentials.py
+++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py
@@ -115,3 +115,8 @@ def test_credentials_deserialization__invalid_component():
     }
     with pytest.raises(InvalidCredentialComponentError):
         Credentials.from_mapping(invalid_data)
+
+
+def test_create_credentials__none_none():
+    with pytest.raises(InvalidCredentialsError):
+        Credentials(None, None)