diff --git a/monkey/common/credentials/credential_component_schema.py b/monkey/common/credentials/credential_component_schema.py index 97f8c10b0..ff3e657be 100644 --- a/monkey/common/credentials/credential_component_schema.py +++ b/monkey/common/credentials/credential_component_schema.py @@ -1,7 +1,17 @@ -from marshmallow import Schema, post_load +from marshmallow import Schema, post_load, validate +from marshmallow_enum import EnumField from common.utils.code_utils import del_key +from . import CredentialComponentType + + +class CredentialTypeField(EnumField): + def __init__(self, credential_component_type: CredentialComponentType): + super().__init__( + CredentialComponentType, validate=validate.Equal(credential_component_type) + ) + class CredentialComponentSchema(Schema): @post_load diff --git a/monkey/common/credentials/password.py b/monkey/common/credentials/password.py index 7b4d2178f..b7bd1b84c 100644 --- a/monkey/common/credentials/password.py +++ b/monkey/common/credentials/password.py @@ -1,16 +1,13 @@ from dataclasses import dataclass, field -from marshmallow import fields, validate -from marshmallow_enum import EnumField +from marshmallow import fields from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField class PasswordSchema(CredentialComponentSchema): - credential_type = EnumField( - CredentialComponentType, validate=validate.Equal(CredentialComponentType.PASSWORD) - ) + credential_type = CredentialTypeField(CredentialComponentType.PASSWORD) password = fields.Str() diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index ffa863e15..86fde05ff 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,16 +1,13 @@ from dataclasses import dataclass, field -from marshmallow import fields, validate -from marshmallow_enum import EnumField +from marshmallow import fields from . import CredentialComponentType, ICredentialComponent -from .credential_component_schema import CredentialComponentSchema +from .credential_component_schema import CredentialComponentSchema, CredentialTypeField class UsernameSchema(CredentialComponentSchema): - credential_type = EnumField( - CredentialComponentType, validate=validate.Equal(CredentialComponentType.USERNAME) - ) + credential_type = CredentialTypeField(CredentialComponentType.USERNAME) username = fields.Str()