diff --git a/monkey/common/credentials/i_credential_component.py b/monkey/common/credentials/i_credential_component.py index 03dc75604..ba55b3cce 100644 --- a/monkey/common/credentials/i_credential_component.py +++ b/monkey/common/credentials/i_credential_component.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod +from pydantic.dataclasses import dataclass + from . import CredentialComponentType +@dataclass class ICredentialComponent(ABC): @property @abstractmethod diff --git a/monkey/common/credentials/lm_hash.py b/monkey/common/credentials/lm_hash.py index 5a04e8bae..df59beeb3 100644 --- a/monkey/common/credentials/lm_hash.py +++ b/monkey/common/credentials/lm_hash.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +from dataclasses import field +from typing import ClassVar from marshmallow import fields +from pydantic.dataclasses import dataclass from . import CredentialComponentType, ICredentialComponent from .credential_component_schema import CredentialComponentSchema, CredentialTypeField @@ -12,9 +14,9 @@ class LMHashSchema(CredentialComponentSchema): lm_hash = fields.Str(validate=ntlm_hash_validator) -@dataclass(frozen=True) +@dataclass class LMHash(ICredentialComponent): - credential_type: CredentialComponentType = field( + credential_type: ClassVar[CredentialComponentType] = field( default=CredentialComponentType.LM_HASH, init=False ) lm_hash: str diff --git a/monkey/common/credentials/nt_hash.py b/monkey/common/credentials/nt_hash.py index a7145a5a0..e7b84449a 100644 --- a/monkey/common/credentials/nt_hash.py +++ b/monkey/common/credentials/nt_hash.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +from dataclasses import field +from typing import ClassVar from marshmallow import fields +from pydantic.dataclasses import dataclass from . import CredentialComponentType, ICredentialComponent from .credential_component_schema import CredentialComponentSchema, CredentialTypeField @@ -12,9 +14,9 @@ class NTHashSchema(CredentialComponentSchema): nt_hash = fields.Str(validate=ntlm_hash_validator) -@dataclass(frozen=True) +@dataclass class NTHash(ICredentialComponent): - credential_type: CredentialComponentType = field( + credential_type: ClassVar[CredentialComponentType] = field( default=CredentialComponentType.NT_HASH, init=False ) nt_hash: str diff --git a/monkey/common/credentials/password.py b/monkey/common/credentials/password.py index b7bd1b84c..8fac2a37f 100644 --- a/monkey/common/credentials/password.py +++ b/monkey/common/credentials/password.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +from dataclasses import field +from typing import ClassVar from marshmallow import fields +from pydantic.dataclasses import dataclass from . import CredentialComponentType, ICredentialComponent from .credential_component_schema import CredentialComponentSchema, CredentialTypeField @@ -11,9 +13,9 @@ class PasswordSchema(CredentialComponentSchema): password = fields.Str() -@dataclass(frozen=True) +@dataclass class Password(ICredentialComponent): - credential_type: CredentialComponentType = field( + credential_type: ClassVar[CredentialComponentType] = field( default=CredentialComponentType.PASSWORD, init=False ) password: str diff --git a/monkey/common/credentials/ssh_keypair.py b/monkey/common/credentials/ssh_keypair.py index 6b8dcded2..3183adbd4 100644 --- a/monkey/common/credentials/ssh_keypair.py +++ b/monkey/common/credentials/ssh_keypair.py @@ -1,22 +1,26 @@ -from dataclasses import dataclass, field +from dataclasses import field +from typing import ClassVar from marshmallow import fields +from pydantic.dataclasses import dataclass from . import CredentialComponentType, ICredentialComponent from .credential_component_schema import CredentialComponentSchema, CredentialTypeField class SSHKeypairSchema(CredentialComponentSchema): - credential_type = CredentialTypeField(CredentialComponentType.SSH_KEYPAIR) + credential_type: ClassVar[CredentialComponentType] = CredentialTypeField( + CredentialComponentType.SSH_KEYPAIR + ) # TODO: Find a list of valid formats for ssh keys and add validators. # See https://github.com/nemchik/ssh-key-regex private_key = fields.Str() public_key = fields.Str() -@dataclass(frozen=True) +@dataclass class SSHKeypair(ICredentialComponent): - credential_type: CredentialComponentType = field( + credential_type: ClassVar[CredentialComponentType] = field( default=CredentialComponentType.SSH_KEYPAIR, init=False ) private_key: str diff --git a/monkey/common/credentials/username.py b/monkey/common/credentials/username.py index 86fde05ff..115b205a6 100644 --- a/monkey/common/credentials/username.py +++ b/monkey/common/credentials/username.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +from dataclasses import field +from typing import ClassVar from marshmallow import fields +from pydantic.dataclasses import dataclass from . import CredentialComponentType, ICredentialComponent from .credential_component_schema import CredentialComponentSchema, CredentialTypeField @@ -11,9 +13,9 @@ class UsernameSchema(CredentialComponentSchema): username = fields.Str() -@dataclass(frozen=True) +@dataclass class Username(ICredentialComponent): - credential_type: CredentialComponentType = field( + credential_type: ClassVar[CredentialComponentType] = field( default=CredentialComponentType.USERNAME, init=False ) username: str diff --git a/monkey/tests/unit_tests/common/credentials/test_credentials.py b/monkey/tests/unit_tests/common/credentials/test_credentials.py index 4eeb8647f..d4e14fbad 100644 --- a/monkey/tests/unit_tests/common/credentials/test_credentials.py +++ b/monkey/tests/unit_tests/common/credentials/test_credentials.py @@ -45,7 +45,7 @@ SECRET_DICTS = [ ] CREDENTIALS = [ - Credentials(identity, secret) + Credentials(identity=identity, secret=secret) for identity, secret in product(IDENTITIES, SECRETS) if not (identity is None and secret is None) ] diff --git a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_propagation_credentials_repository.py b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_propagation_credentials_repository.py index c992aaf85..074095cb6 100644 --- a/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_propagation_credentials_repository.py +++ b/monkey/tests/unit_tests/infection_monkey/credential_store/test_aggregating_propagation_credentials_repository.py @@ -60,8 +60,8 @@ STOLEN_CREDENTIALS = [ STOLEN_SSH_KEYS_CREDENTIALS = [ Credentials( - Username(USERNAME), - SSHKeypair(public_key=STOLEN_PUBLIC_KEY_2, private_key=STOLEN_PRIVATE_KEY_2), + identity=Username(USERNAME), + secret=SSHKeypair(public_key=STOLEN_PUBLIC_KEY_2, private_key=STOLEN_PRIVATE_KEY_2), ) ] diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_format_credentials.py b/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_format_credentials.py index bb51b89dd..079cbd4e3 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_format_credentials.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/reporting/test_format_credentials.py @@ -24,12 +24,12 @@ identities = (fake_username,) secrets = (fake_nt_hash, fake_lm_hash, fake_password, fake_ssh_key) fake_credentials = [ - Credentials(fake_username, fake_nt_hash), - Credentials(fake_username, fake_lm_hash), - Credentials(fake_username, fake_password), - Credentials(fake_username, fake_ssh_key), - Credentials(None, fake_ssh_key), - Credentials(fake_username, None), + Credentials(identity=fake_username, secret=fake_nt_hash), + Credentials(identity=fake_username, secret=fake_lm_hash), + Credentials(identity=fake_username, secret=fake_password), + Credentials(identity=fake_username, secret=fake_ssh_key), + Credentials(identity=None, secret=fake_ssh_key), + Credentials(identity=fake_username, secret=None), ]