diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 713ac3aac..28b6d7533 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -32,7 +32,8 @@ class ControlChannel(IControlChannel): id=get_agent_id(), machine_hardware_id=get_machine_id(), start_time=agent_process.get_start_time(), - parent_id=parent, + # parent_id=parent, + parent_id=None, # None for now, until we change GUID to UUID cc_server=self._control_channel_server, network_interfaces=get_network_interfaces(), ) diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index 94e063a81..521e7f720 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -12,4 +12,4 @@ from .user_credentials import UserCredentials from .machine import Machine, MachineID from .communication_type import CommunicationType from .node import Node -from .agent import Agent +from .agent import Agent, AgentID diff --git a/monkey/monkey_island/cc/models/agent.py b/monkey/monkey_island/cc/models/agent.py index 1740e264d..66bca7b54 100644 --- a/monkey/monkey_island/cc/models/agent.py +++ b/monkey/monkey_island/cc/models/agent.py @@ -3,17 +3,35 @@ from typing import Optional from uuid import UUID from pydantic import Field +from typing_extensions import TypeAlias from common.base_models import MutableInfectionMonkeyBaseModel from . import MachineID +AgentID: TypeAlias = UUID + class Agent(MutableInfectionMonkeyBaseModel): - id: UUID = Field(..., allow_mutation=False) + """Represents an agent that has run on a victim machine""" + + id: AgentID = Field(..., allow_mutation=False) + """Uniquely identifies an instance of an agent""" + machine_id: MachineID = Field(..., allow_mutation=False) + """The machine that the agent ran on""" + start_time: datetime = Field(..., allow_mutation=False) + """The time the agent process started""" + stop_time: Optional[datetime] - parent_id: UUID = Field(..., allow_mutation=False) + """The time the agent process exited""" + + parent_id: Optional[AgentID] = Field(allow_mutation=False) + """The ID of the parent agent that spawned this agent""" + cc_server: str = Field(default="") + """The address that the agent used to communicate with the island""" + log_contents: str = Field(default="") + """The contents of the agent's log (empty until the agent shuts down)""" diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index caef77b9b..e1d5fd47b 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -8,6 +8,7 @@ from .i_simulation_repository import ISimulationRepository from .i_credentials_repository import ICredentialsRepository from .i_user_repository import IUserRepository from .i_machine_repository import IMachineRepository +from .i_agent_repository import IAgentRepository from .local_storage_file_repository import LocalStorageFileRepository @@ -21,3 +22,4 @@ from .file_simulation_repository import FileSimulationRepository from .json_file_user_repository import JSONFileUserRepository from .mongo_credentials_repository import MongoCredentialsRepository from .mongo_machine_repository import MongoMachineRepository +from .mongo_agent_repository import MongoAgentRepository diff --git a/monkey/monkey_island/cc/repository/i_agent_repository.py b/monkey/monkey_island/cc/repository/i_agent_repository.py index 5a784e5d2..880b4d7e7 100644 --- a/monkey/monkey_island/cc/repository/i_agent_repository.py +++ b/monkey/monkey_island/cc/repository/i_agent_repository.py @@ -1,15 +1,47 @@ -from abc import ABC -from typing import Optional, Sequence +from abc import ABC, abstractmethod +from typing import Sequence -from monkey_island.cc.models import Monkey +from monkey_island.cc.models import Agent, AgentID class IAgentRepository(ABC): - # TODO rename Monkey document to Agent - def save_agent(self, agent: Monkey): - pass + """A repository used to store and retrieve `Agent` objects""" - def get_agents( - self, id: Optional[str] = None, running: Optional[bool] = None - ) -> Sequence[Monkey]: - pass + @abstractmethod + def upsert_agent(self, agent: Agent): + """ + Upsert (insert or update) an `Agent` + + Insert the `Agent` if no `Agent` with a matching ID exists in the repository. If the agent + already exists, update it. + + :param agent: The `agent` to be inserted or updated + :raises StorageError: If an error occurred while attempting to store the `Agent` + """ + + @abstractmethod + def get_agent_by_id(self, agent_id: AgentID) -> Agent: + """ + Get an `Agent` by ID + + :param agent_id: The ID of the `Agent` to be retrieved + :return: An `Agent` with a matching `id` + :raises UnknownRecordError: If an `Agent` with the specified `id` does not exist in the + repository + :raises RetrievalError: If an error occurred while attempting to retrieve the `Agent` + """ + + @abstractmethod + def get_running_agents(self) -> Sequence[Agent]: + """ + Get all `Agents` that are currently running + + :return: All `Agents` that are currently running + :raises RetrievalError: If an error occurred while attempting to retrieve the `Agents` + """ + + @abstractmethod + def reset(self): + """ + Removes all data from the repository + """ diff --git a/monkey/monkey_island/cc/repository/mongo_agent_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_repository.py new file mode 100644 index 000000000..3b1b60f0a --- /dev/null +++ b/monkey/monkey_island/cc/repository/mongo_agent_repository.py @@ -0,0 +1,68 @@ +from typing import Any, MutableMapping, Sequence + +from pymongo import MongoClient + +from monkey_island.cc.models import Agent, AgentID +from monkey_island.cc.repository import ( + IAgentRepository, + RemovalError, + RetrievalError, + StorageError, + UnknownRecordError, +) + +from .consts import MONGO_OBJECT_ID_KEY + + +class MongoAgentRepository(IAgentRepository): + def __init__(self, mongo_client: MongoClient): + self._agents_collection = mongo_client.monkey_island.agents + + def upsert_agent(self, agent: Agent): + try: + result = self._agents_collection.replace_one( + {"id": str(agent.id)}, agent.dict(simplify=True), upsert=True + ) + except Exception as err: + raise StorageError(f'Error updating agent with ID "{agent.id}": {err}') + + if result.matched_count != 0 and result.modified_count != 1: + raise StorageError( + f'Error updating agent with ID "{agent.id}": Expected to update 1 agent, ' + f"but {result.modified_count} were updated" + ) + + if result.matched_count == 0 and result.upserted_id is None: + raise StorageError( + f'Error inserting agent with ID "{agent.id}": Expected to insert 1 agent, ' + f"but no agents were inserted" + ) + + def get_agent_by_id(self, agent_id: AgentID) -> Agent: + try: + agent_dict = self._agents_collection.find_one({"id": str(agent_id)}) + except Exception as err: + raise RetrievalError(f'Error retrieving agent with "id == {agent_id}": {err}') + + if agent_dict is None: + raise UnknownRecordError(f'Unknown ID "{agent_id}"') + + return MongoAgentRepository._mongo_record_to_agent(agent_dict) + + def get_running_agents(self) -> Sequence[Agent]: + try: + cursor = self._agents_collection.find({"stop_time": None}) + return list(map(MongoAgentRepository._mongo_record_to_agent, cursor)) + except Exception as err: + raise RetrievalError(f"Error retrieving running agents: {err}") + + @staticmethod + def _mongo_record_to_agent(mongo_record: MutableMapping[str, Any]) -> Agent: + del mongo_record[MONGO_OBJECT_ID_KEY] + return Agent(**mongo_record) + + def reset(self): + try: + self._agents_collection.drop() + except Exception as err: + raise RemovalError(f"Error resetting the repository: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_agent.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_agent.py index 1851a4873..fb528fd51 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/models/test_agent.py +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_agent.py @@ -31,8 +31,11 @@ def test_constructor__defaults_from_objects(): def test_constructor__defaults_from_simple_dict(): - a = Agent(**AGENT_SIMPLE_DICT) + agent_simple_dict = AGENT_SIMPLE_DICT.copy() + del agent_simple_dict["parent_id"] + a = Agent(**agent_simple_dict) + assert a.parent_id is None assert a.stop_time is None assert a.cc_server == "" assert a.log_contents == "" diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py new file mode 100644 index 000000000..4d3fa0d14 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py @@ -0,0 +1,168 @@ +from copy import deepcopy +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import mongomock +import pytest + +from monkey_island.cc.models import Agent +from monkey_island.cc.repository import ( + IAgentRepository, + MongoAgentRepository, + RemovalError, + RetrievalError, + StorageError, + UnknownRecordError, +) + +VICTIM_ZERO_ID = uuid4() +RUNNING_AGENTS = ( + Agent(id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)), + Agent( + id=uuid4(), + machine_id=2, + start_time=datetime.fromtimestamp(1661856818), + parent_id=VICTIM_ZERO_ID, + ), +) +STOPPED_AGENTS = ( + Agent( + id=uuid4(), + machine_id=3, + start_time=datetime.fromtimestamp(1661856758), + parent_id=VICTIM_ZERO_ID, + stop_time=datetime.fromtimestamp(1661856773), + ), +) +AGENTS = ( + *RUNNING_AGENTS, + *STOPPED_AGENTS, +) + + +@pytest.fixture +def agent_repository() -> IAgentRepository: + mongo_client = mongomock.MongoClient() + mongo_client.monkey_island.agents.insert_many((a.dict(simplify=True) for a in AGENTS)) + return MongoAgentRepository(mongo_client) + + +@pytest.fixture +def empty_agent_repository() -> IAgentRepository: + mongo_client = mongomock.MongoClient() + return MongoAgentRepository(mongo_client) + + +@pytest.fixture +def error_raising_mock_mongo_client() -> mongomock.MongoClient: + mongo_client = MagicMock(spec=mongomock.MongoClient) + mongo_client.monkey_island = MagicMock(spec=mongomock.Database) + mongo_client.monkey_island.agents = MagicMock(spec=mongomock.Collection) + + mongo_client.monkey_island.agents.drop = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.replace_one = MagicMock( + side_effect=Exception("some exception") + ) + + return mongo_client + + +@pytest.fixture +def error_raising_agent_repository(error_raising_mock_mongo_client) -> IAgentRepository: + return MongoAgentRepository(error_raising_mock_mongo_client) + + +def test_upsert_agent__insert(agent_repository): + new_id = uuid4() + new_agent = Agent( + id=new_id, + machine_id=2, + start_time=datetime.fromtimestamp(1661858139), + parent_id=VICTIM_ZERO_ID, + ) + + agent_repository.upsert_agent(new_agent) + + assert agent_repository.get_agent_by_id(new_id) == new_agent + + for agent in AGENTS: + assert agent_repository.get_agent_by_id(agent.id) == agent + + +def test_upsert_agent__insert_empty_repository(empty_agent_repository): + empty_agent_repository.upsert_agent(AGENTS[0]) + + assert empty_agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0] + + +def test_upsert_agent__update(agent_repository): + agents = deepcopy(AGENTS) + agents[0].stop_time = datetime.now() + agents[0].cc_server = "127.0.0.1:1984" + + agent_repository.upsert_agent(agents[0]) + + for agent in agents: + assert agent_repository.get_agent_by_id(agent.id) == agent + + +def test_upsert_agent__storage_error(error_raising_agent_repository): + with pytest.raises(StorageError): + error_raising_agent_repository.upsert_agent(AGENTS[0]) + + +def test_get_agent_by_id(agent_repository): + for i, expected_agent in enumerate(AGENTS): + assert agent_repository.get_agent_by_id(expected_agent.id) == expected_agent + + +def test_get_agent_by_id__not_found(agent_repository): + with pytest.raises(UnknownRecordError): + agent_repository.get_agent_by_id(uuid4()) + + +def test_get_agent_by_id__retrieval_error(error_raising_agent_repository): + with pytest.raises(RetrievalError): + error_raising_agent_repository.get_agent_by_id(AGENTS[0].id) + + +def test_get_running_agents(agent_repository): + running_agents = agent_repository.get_running_agents() + + assert len(running_agents) == len(RUNNING_AGENTS) + for a in running_agents: + assert a in RUNNING_AGENTS + + +def test_get_running_agents__retrieval_error(error_raising_agent_repository): + with pytest.raises(RetrievalError): + error_raising_agent_repository.get_running_agents() + + +def test_reset(agent_repository): + # Ensure the repository is not empty + for agent in AGENTS: + preexisting_agent = agent_repository.get_agent_by_id(agent.id) + assert isinstance(preexisting_agent, Agent) + + agent_repository.reset() + + for agent in AGENTS: + with pytest.raises(UnknownRecordError): + agent_repository.get_agent_by_id(agent.id) + + +def test_usable_after_reset(agent_repository): + agent_repository.reset() + + agent_repository.upsert_agent(AGENTS[0]) + + assert agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0] + + +def test_reset__removal_error(error_raising_agent_repository): + with pytest.raises(RemovalError): + error_raising_agent_repository.reset() diff --git a/vulture_allowlist.py b/vulture_allowlist.py index 68a94a2db..ea46d5b2b 100644 --- a/vulture_allowlist.py +++ b/vulture_allowlist.py @@ -238,8 +238,9 @@ NetworkMap Arc.dst_machine IMitigationsRepository.get_mitigations IMitigationsRepository.save_mitigations -IAgentRepository.save_agent -IAgentRepository.get_agents +IAgentRepository.upsert_agent +IAgentRepository.get_agent_by_id +IAgentRepository.get_running_agents agent IAttackRepository.get_attack_report IAttackRepository.save_attack_report