Merge pull request #2280 from guardicore/2179-credentialsstolenevent-pydantic

2179 credentialsstolenevent pydantic
This commit is contained in:
Mike Salvatore 2022-09-13 14:47:10 -04:00 committed by GitHub
commit 1c742c3f96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 39 additions and 34 deletions

View File

@ -1,13 +1,15 @@
import time import time
from abc import ABC from abc import ABC
from dataclasses import dataclass, field
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import FrozenSet, Union from typing import FrozenSet, Union
from uuid import UUID, getnode
from pydantic import Field
from common.base_models import InfectionMonkeyBaseModel
from common.types import AgentID
@dataclass(frozen=True) class AbstractAgentEvent(InfectionMonkeyBaseModel, ABC):
class AbstractAgentEvent(ABC):
""" """
An event that was initiated or observed by an agent An event that was initiated or observed by an agent
@ -22,7 +24,7 @@ class AbstractAgentEvent(ABC):
:param tags: The set of tags associated with the event :param tags: The set of tags associated with the event
""" """
source: UUID = field(default_factory=getnode) source: AgentID
target: Union[UUID, IPv4Address, None] = field(default=None) target: Union[int, IPv4Address, None] = Field(default=None)
timestamp: float = field(default_factory=time.time) timestamp: float = Field(default_factory=time.time)
tags: FrozenSet[str] = field(default_factory=frozenset) tags: FrozenSet[str] = Field(default_factory=frozenset)

View File

@ -1,12 +1,12 @@
from dataclasses import dataclass, field
from typing import Sequence from typing import Sequence
from pydantic import Field
from common.credentials import Credentials from common.credentials import Credentials
from . import AbstractAgentEvent from . import AbstractAgentEvent
@dataclass(frozen=True)
class CredentialsStolenEvent(AbstractAgentEvent): class CredentialsStolenEvent(AbstractAgentEvent):
""" """
An event that occurs when an agent collects credentials from the victim An event that occurs when an agent collects credentials from the victim
@ -15,4 +15,4 @@ class CredentialsStolenEvent(AbstractAgentEvent):
:param stolen_credentials: The credentials that were stolen by an agent :param stolen_credentials: The credentials that were stolen by an agent
""" """
stolen_credentials: Sequence[Credentials] = field(default_factory=list) stolen_credentials: Sequence[Credentials] = Field(default_factory=list)

View File

@ -1,4 +1,7 @@
from uuid import UUID
from pydantic import PositiveInt from pydantic import PositiveInt
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
AgentID: TypeAlias = UUID
HardwareID: TypeAlias = PositiveInt HardwareID: TypeAlias = PositiveInt

View File

@ -6,6 +6,7 @@ from common.event_queue import IAgentEventQueue
from common.events import CredentialsStolenEvent from common.events import CredentialsStolenEvent
from infection_monkey.i_puppet import ICredentialCollector from infection_monkey.i_puppet import ICredentialCollector
from infection_monkey.model import USERNAME_PREFIX from infection_monkey.model import USERNAME_PREFIX
from infection_monkey.utils.ids import get_agent_id
from . import pypykatz_handler from . import pypykatz_handler
from .windows_credentials import WindowsCredentials from .windows_credentials import WindowsCredentials
@ -76,6 +77,7 @@ class MimikatzCredentialCollector(ICredentialCollector):
def _publish_credentials_stolen_event(self, collected_credentials: Sequence[Credentials]): def _publish_credentials_stolen_event(self, collected_credentials: Sequence[Credentials]):
credentials_stolen_event = CredentialsStolenEvent( credentials_stolen_event = CredentialsStolenEvent(
source=get_agent_id(),
tags=MIMIKATZ_EVENT_TAGS, tags=MIMIKATZ_EVENT_TAGS,
stolen_credentials=collected_credentials, stolen_credentials=collected_credentials,
) )

View File

@ -11,6 +11,7 @@ from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
from infection_monkey.telemetry.attack.t1145_telem import T1145Telem from infection_monkey.telemetry.attack.t1145_telem import T1145Telem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.ids import get_agent_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -172,6 +173,7 @@ def _publish_credentials_stolen_event(
collected_credentials: Credentials, event_queue: IAgentEventQueue collected_credentials: Credentials, event_queue: IAgentEventQueue
): ):
credentials_stolen_event = CredentialsStolenEvent( credentials_stolen_event = CredentialsStolenEvent(
source=get_agent_id(),
tags=SSH_COLLECTOR_EVENT_TAGS, tags=SSH_COLLECTOR_EVENT_TAGS,
stolen_credentials=[collected_credentials], stolen_credentials=[collected_credentials],
) )

View File

@ -12,4 +12,5 @@ from .user_credentials import UserCredentials
from .machine import Machine, MachineID from .machine import Machine, MachineID
from .communication_type import CommunicationType from .communication_type import CommunicationType
from .node import Node from .node import Node
from .agent import Agent, AgentID from common.types import AgentID
from .agent import Agent

View File

@ -1,15 +1,11 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from uuid import UUID
from pydantic import Field from pydantic import Field
from typing_extensions import TypeAlias
from common.base_models import MutableInfectionMonkeyBaseModel from common.base_models import MutableInfectionMonkeyBaseModel
from . import MachineID from . import AgentID, MachineID
AgentID: TypeAlias = UUID
class Agent(MutableInfectionMonkeyBaseModel): class Agent(MutableInfectionMonkeyBaseModel):

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Callable, FrozenSet, Union from typing import Callable, FrozenSet, Union
from uuid import UUID from uuid import UUID
@ -13,7 +12,6 @@ EVENT_TAG_1 = "event tag 1"
EVENT_TAG_2 = "event tag 2" EVENT_TAG_2 = "event tag 2"
@dataclass(frozen=True)
class TestEvent1(AbstractAgentEvent): class TestEvent1(AbstractAgentEvent):
__test__ = False __test__ = False
source: UUID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5") source: UUID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5")
@ -22,7 +20,6 @@ class TestEvent1(AbstractAgentEvent):
tags: FrozenSet = frozenset() tags: FrozenSet = frozenset()
@dataclass(frozen=True)
class TestEvent2(AbstractAgentEvent): class TestEvent2(AbstractAgentEvent):
__test__ = False __test__ = False
source: UUID = UUID("e810ad01-6b67-9446-fc58-9b8d717653f7") source: UUID = UUID("e810ad01-6b67-9446-fc58-9b8d717653f7")

View File

@ -1,25 +1,22 @@
from dataclasses import dataclass, field
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pydantic import Field
from common.event_serializers import EventSerializerRegistry, IEventSerializer from common.event_serializers import EventSerializerRegistry, IEventSerializer
from common.events import AbstractAgentEvent from common.events import AbstractAgentEvent
@dataclass(frozen=True)
class SomeEvent(AbstractAgentEvent): class SomeEvent(AbstractAgentEvent):
some_param: int = field(default=435) some_param: int = Field(default=435)
@dataclass(frozen=True)
class OtherEvent(AbstractAgentEvent): class OtherEvent(AbstractAgentEvent):
other_param: float = field(default=123.456) other_param: float = Field(default=123.456)
@dataclass(frozen=True)
class NoneEvent(AbstractAgentEvent): class NoneEvent(AbstractAgentEvent):
none_param: float = field(default=1.0) none_param: float = Field(default=1.0)
SOME_SERIALIZER = MagicMock(spec=IEventSerializer) SOME_SERIALIZER = MagicMock(spec=IEventSerializer)

View File

@ -1,12 +1,15 @@
from abc import ABC from abc import ABC
from dataclasses import dataclass, field from dataclasses import dataclass
from uuid import UUID
import pytest import pytest
from pydantic import Field
from common.base_models import InfectionMonkeyBaseModel
from common.event_serializers import IEventSerializer, PydanticEventSerializer from common.event_serializers import IEventSerializer, PydanticEventSerializer
from common.events import AbstractAgentEvent from common.events import AbstractAgentEvent
AGENT_ID = UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5")
@dataclass(frozen=True) @dataclass(frozen=True)
class NotAgentEvent(ABC): class NotAgentEvent(ABC):
@ -14,12 +17,11 @@ class NotAgentEvent(ABC):
other_field: float other_field: float
@dataclass(frozen=True)
class SomeAgentEvent(AbstractAgentEvent): class SomeAgentEvent(AbstractAgentEvent):
bogus: int = field(default_factory=int) bogus: int = Field(default_factory=int)
class PydanticEvent(InfectionMonkeyBaseModel): class PydanticEvent(AbstractAgentEvent):
some_field: str some_field: str
@ -28,7 +30,10 @@ def pydantic_event_serializer() -> IEventSerializer:
return PydanticEventSerializer(PydanticEvent) return PydanticEventSerializer(PydanticEvent)
@pytest.mark.parametrize("event", [NotAgentEvent(1, 2.0), SomeAgentEvent(2)]) @pytest.mark.parametrize(
"event",
[NotAgentEvent(some_field=1, other_field=2.0), SomeAgentEvent(source=AGENT_ID, bogus=2)],
)
def test_pydantic_event_serializer__serialize_wrong_type(pydantic_event_serializer, event): def test_pydantic_event_serializer__serialize_wrong_type(pydantic_event_serializer, event):
with pytest.raises(TypeError): with pytest.raises(TypeError):
pydantic_event_serializer.serialize(event) pydantic_event_serializer.serialize(event)
@ -40,7 +45,7 @@ def test_pydantic_event_serializer__deserialize_wrong_type(pydantic_event_serial
def test_pydanitc_event_serializer__de_serialize(pydantic_event_serializer): def test_pydanitc_event_serializer__de_serialize(pydantic_event_serializer):
pydantic_event = PydanticEvent(some_field="some_field") pydantic_event = PydanticEvent(source=AGENT_ID, some_field="some_field")
serialized_event = pydantic_event_serializer.serialize(pydantic_event) serialized_event = pydantic_event_serializer.serialize(pydantic_event)
deserialized_object = pydantic_event_serializer.deserialize(serialized_event) deserialized_object = pydantic_event_serializer.deserialize(serialized_event)