diff --git a/monkey/common/events/abstract_agent_event.py b/monkey/common/events/abstract_agent_event.py index 89f515187..760061aa1 100644 --- a/monkey/common/events/abstract_agent_event.py +++ b/monkey/common/events/abstract_agent_event.py @@ -1,13 +1,15 @@ import time from abc import ABC -from dataclasses import dataclass, field from ipaddress import IPv4Address from typing import FrozenSet, Union -from uuid import UUID, getnode + +from pydantic import Field + +from common.base_models import InfectionMonkeyBaseModel +from common.types import AgentID -@dataclass(frozen=True) -class AbstractAgentEvent(ABC): +class AbstractAgentEvent(InfectionMonkeyBaseModel, ABC): """ An event that was initiated or observed by an agent @@ -22,7 +24,7 @@ class AbstractAgentEvent(ABC): :param tags: The set of tags associated with the event """ - source: UUID = field(default_factory=getnode) - target: Union[UUID, IPv4Address, None] = field(default=None) - timestamp: float = field(default_factory=time.time) - tags: FrozenSet[str] = field(default_factory=frozenset) + source: AgentID + target: Union[int, IPv4Address, None] = Field(default=None) + timestamp: float = Field(default_factory=time.time) + tags: FrozenSet[str] = Field(default_factory=frozenset) diff --git a/monkey/common/events/credentials_stolen_events.py b/monkey/common/events/credentials_stolen_events.py index 184ee81c3..06d2f967e 100644 --- a/monkey/common/events/credentials_stolen_events.py +++ b/monkey/common/events/credentials_stolen_events.py @@ -1,12 +1,12 @@ -from dataclasses import dataclass, field from typing import Sequence +from pydantic import Field + from common.credentials import Credentials from . import AbstractAgentEvent -@dataclass(frozen=True) class CredentialsStolenEvent(AbstractAgentEvent): """ An event that occurs when an agent collects credentials from the victim @@ -15,4 +15,4 @@ class CredentialsStolenEvent(AbstractAgentEvent): :param stolen_credentials: The credentials that were stolen by an agent """ - stolen_credentials: Sequence[Credentials] = field(default_factory=list) + stolen_credentials: Sequence[Credentials] = Field(default_factory=list) diff --git a/monkey/common/types.py b/monkey/common/types.py index d1808d762..56dee91a5 100644 --- a/monkey/common/types.py +++ b/monkey/common/types.py @@ -1,4 +1,7 @@ +from uuid import UUID + from pydantic import PositiveInt from typing_extensions import TypeAlias +AgentID: TypeAlias = UUID HardwareID: TypeAlias = PositiveInt diff --git a/monkey/infection_monkey/credential_collectors/mimikatz_collector/mimikatz_credential_collector.py b/monkey/infection_monkey/credential_collectors/mimikatz_collector/mimikatz_credential_collector.py index 10bccade3..6a9995692 100644 --- a/monkey/infection_monkey/credential_collectors/mimikatz_collector/mimikatz_credential_collector.py +++ b/monkey/infection_monkey/credential_collectors/mimikatz_collector/mimikatz_credential_collector.py @@ -6,6 +6,7 @@ from common.event_queue import IAgentEventQueue from common.events import CredentialsStolenEvent from infection_monkey.i_puppet import ICredentialCollector from infection_monkey.model import USERNAME_PREFIX +from infection_monkey.utils.ids import get_agent_id from . import pypykatz_handler from .windows_credentials import WindowsCredentials @@ -76,6 +77,7 @@ class MimikatzCredentialCollector(ICredentialCollector): def _publish_credentials_stolen_event(self, collected_credentials: Sequence[Credentials]): credentials_stolen_event = CredentialsStolenEvent( + source=get_agent_id(), tags=MIMIKATZ_EVENT_TAGS, stolen_credentials=collected_credentials, ) diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py index 8c1129455..33d2c67d8 100644 --- a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py +++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py @@ -11,6 +11,7 @@ from infection_monkey.telemetry.attack.t1005_telem import T1005Telem from infection_monkey.telemetry.attack.t1145_telem import T1145Telem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.utils.environment import is_windows_os +from infection_monkey.utils.ids import get_agent_id logger = logging.getLogger(__name__) @@ -172,6 +173,7 @@ def _publish_credentials_stolen_event( collected_credentials: Credentials, event_queue: IAgentEventQueue ): credentials_stolen_event = CredentialsStolenEvent( + source=get_agent_id(), tags=SSH_COLLECTOR_EVENT_TAGS, stolen_credentials=[collected_credentials], ) diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index 29c38674a..30cb5ad95 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -12,4 +12,5 @@ from .user_credentials import UserCredentials from .machine import Machine, MachineID from .communication_type import CommunicationType from .node import Node -from .agent import Agent, AgentID +from common.types import AgentID +from .agent import Agent diff --git a/monkey/monkey_island/cc/models/agent.py b/monkey/monkey_island/cc/models/agent.py index 66bca7b54..c5ae10d22 100644 --- a/monkey/monkey_island/cc/models/agent.py +++ b/monkey/monkey_island/cc/models/agent.py @@ -1,15 +1,11 @@ from datetime import datetime from typing import Optional -from uuid import UUID from pydantic import Field -from typing_extensions import TypeAlias from common.base_models import MutableInfectionMonkeyBaseModel -from . import MachineID - -AgentID: TypeAlias = UUID +from . import AgentID, MachineID class Agent(MutableInfectionMonkeyBaseModel): diff --git a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_agent_event_queue.py b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_agent_event_queue.py index 3ee7ba482..891fa445f 100644 --- a/monkey/tests/unit_tests/common/event_queue/test_pypubsub_agent_event_queue.py +++ b/monkey/tests/unit_tests/common/event_queue/test_pypubsub_agent_event_queue.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from ipaddress import IPv4Address from typing import Callable, FrozenSet, Union from uuid import UUID @@ -13,7 +12,6 @@ EVENT_TAG_1 = "event tag 1" EVENT_TAG_2 = "event tag 2" -@dataclass(frozen=True) class TestEvent1(AbstractAgentEvent): __test__ = False source: UUID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5") @@ -22,7 +20,6 @@ class TestEvent1(AbstractAgentEvent): tags: FrozenSet = frozenset() -@dataclass(frozen=True) class TestEvent2(AbstractAgentEvent): __test__ = False source: UUID = UUID("e810ad01-6b67-9446-fc58-9b8d717653f7") diff --git a/monkey/tests/unit_tests/common/event_serializers/test_event_serializer_registry.py b/monkey/tests/unit_tests/common/event_serializers/test_event_serializer_registry.py index b4ddd6612..f83c54ba4 100644 --- a/monkey/tests/unit_tests/common/event_serializers/test_event_serializer_registry.py +++ b/monkey/tests/unit_tests/common/event_serializers/test_event_serializer_registry.py @@ -1,25 +1,22 @@ -from dataclasses import dataclass, field from unittest.mock import MagicMock import pytest +from pydantic import Field from common.event_serializers import EventSerializerRegistry, IEventSerializer from common.events import AbstractAgentEvent -@dataclass(frozen=True) class SomeEvent(AbstractAgentEvent): - some_param: int = field(default=435) + some_param: int = Field(default=435) -@dataclass(frozen=True) class OtherEvent(AbstractAgentEvent): - other_param: float = field(default=123.456) + other_param: float = Field(default=123.456) -@dataclass(frozen=True) class NoneEvent(AbstractAgentEvent): - none_param: float = field(default=1.0) + none_param: float = Field(default=1.0) SOME_SERIALIZER = MagicMock(spec=IEventSerializer) diff --git a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py index e3cb7a561..94bf3fc14 100644 --- a/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py +++ b/monkey/tests/unit_tests/common/event_serializers/test_pydantic_event_serializer.py @@ -1,12 +1,15 @@ from abc import ABC -from dataclasses import dataclass, field +from dataclasses import dataclass +from uuid import UUID import pytest +from pydantic import Field -from common.base_models import InfectionMonkeyBaseModel from common.event_serializers import IEventSerializer, PydanticEventSerializer from common.events import AbstractAgentEvent +AGENT_ID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5") + @dataclass(frozen=True) class NotAgentEvent(ABC): @@ -14,12 +17,11 @@ class NotAgentEvent(ABC): other_field: float -@dataclass(frozen=True) class SomeAgentEvent(AbstractAgentEvent): - bogus: int = field(default_factory=int) + bogus: int = Field(default_factory=int) -class PydanticEvent(InfectionMonkeyBaseModel): +class PydanticEvent(AbstractAgentEvent): some_field: str @@ -28,7 +30,10 @@ def pydantic_event_serializer() -> IEventSerializer: return PydanticEventSerializer(PydanticEvent) -@pytest.mark.parametrize("event", [NotAgentEvent(1, 2.0), SomeAgentEvent(2)]) +@pytest.mark.parametrize( + "event", + [NotAgentEvent(some_field=1, other_field=2.0), SomeAgentEvent(source=AGENT_ID, bogus=2)], +) def test_pydantic_event_serializer__serialize_wrong_type(pydantic_event_serializer, event): with pytest.raises(TypeError): pydantic_event_serializer.serialize(event) @@ -40,7 +45,7 @@ def test_pydantic_event_serializer__deserialize_wrong_type(pydantic_event_serial def test_pydanitc_event_serializer__de_serialize(pydantic_event_serializer): - pydantic_event = PydanticEvent(some_field="some_field") + pydantic_event = PydanticEvent(source=AGENT_ID, some_field="some_field") serialized_event = pydantic_event_serializer.serialize(pydantic_event) deserialized_object = pydantic_event_serializer.deserialize(serialized_event)