diff --git a/monkey/common/credentials/credential_component_schema.py b/monkey/common/credentials/credential_component_schema.py new file mode 100644 index 000000000..97f8c10b0 --- /dev/null +++ b/monkey/common/credentials/credential_component_schema.py @@ -0,0 +1,10 @@ +from marshmallow import Schema, post_load + +from common.utils.code_utils import del_key + + +class CredentialComponentSchema(Schema): + @post_load + def _strip_credential_type(self, data, **kwargs): + del_key(data, "credential_type") + return data diff --git a/monkey/common/credentials/password.py b/monkey/common/credentials/password.py index 1fee478aa..7b4d2178f 100644 --- a/monkey/common/credentials/password.py +++ b/monkey/common/credentials/password.py @@ -1,24 +1,18 @@ from dataclasses import dataclass, field -from marshmallow import Schema, fields, post_load, validate +from marshmallow import fields, validate from marshmallow_enum import EnumField -from common.utils.code_utils import del_key - from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema -class PasswordSchema(Schema): +class PasswordSchema(CredentialComponentSchema): credential_type = EnumField( CredentialComponentType, validate=validate.Equal(CredentialComponentType.PASSWORD) ) password = fields.Str() - @post_load - def _strip_credential_type(self, data, **kwargs): - del_key(data, "credential_type") - return data - @dataclass(frozen=True) class Password(ICredentialComponent): diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index 837ba2b98..ffa863e15 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,24 +1,18 @@ from dataclasses import dataclass, field -from marshmallow import Schema, fields, post_load, validate +from marshmallow import fields, validate from marshmallow_enum import EnumField -from common.utils.code_utils import del_key - from . import CredentialComponentType, ICredentialComponent +from .credential_component_schema import CredentialComponentSchema -class UsernameSchema(Schema): +class UsernameSchema(CredentialComponentSchema): credential_type = EnumField( CredentialComponentType, validate=validate.Equal(CredentialComponentType.USERNAME) ) username = fields.Str() - @post_load - def _strip_credential_type(self, data, **kwargs): - del_key(data, "credential_type") - return data - @dataclass(frozen=True) class Username(ICredentialComponent):