diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index c3249058e..837ba2b98 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,8 +1,25 @@ from dataclasses import dataclass, field +from marshmallow import Schema, fields, post_load, validate +from marshmallow_enum import EnumField + +from common.utils.code_utils import del_key + from . import CredentialComponentType, ICredentialComponent +class UsernameSchema(Schema): + credential_type = EnumField( + CredentialComponentType, validate=validate.Equal(CredentialComponentType.USERNAME) + ) + username = fields.Str() + + @post_load + def _strip_credential_type(self, data, **kwargs): + del_key(data, "credential_type") + return data + + @dataclass(frozen=True) class Username(ICredentialComponent): credential_type: CredentialComponentType = field( diff --git a/monkey/tests/unit_tests/common/credentials/test_username.py b/monkey/tests/unit_tests/common/credentials/test_username.py new file mode 100644 index 000000000..3bee18978 --- /dev/null +++ b/monkey/tests/unit_tests/common/credentials/test_username.py @@ -0,0 +1,49 @@ +from copy import deepcopy + +import pytest +from marshmallow.exceptions import ValidationError + +from common.credentials import CredentialComponentType, Username +from common.credentials.username import UsernameSchema + +USERNAME_VALUE = "test_user" +USERNAME_DICT = { + "credential_type": CredentialComponentType.USERNAME.name, + "username": USERNAME_VALUE, +} + + +def test_username_serialize(): + schema = UsernameSchema() + username = Username(USERNAME_VALUE) + + serialized_username = schema.dump(username) + + assert serialized_username == USERNAME_DICT + + +def test_username_deserialize(): + schema = UsernameSchema() + + username = Username(**schema.load(USERNAME_DICT)) + + assert username.credential_type == CredentialComponentType.USERNAME + assert username.username == USERNAME_VALUE + + +def test_invalid_credential_type(): + invalid_username_dict = deepcopy(USERNAME_DICT) + invalid_username_dict["credential_type"] = "INVALID" + schema = UsernameSchema() + + with pytest.raises(ValidationError): + Username(**schema.load(invalid_username_dict)) + + +def test_incorrect_credential_type(): + invalid_username_dict = deepcopy(USERNAME_DICT) + invalid_username_dict["credential_type"] = CredentialComponentType.PASSWORD.name + schema = UsernameSchema() + + with pytest.raises(ValidationError): + Username(**schema.load(invalid_username_dict))