diff --git a/monkey/monkey_island/cc/repository/mongo_agent_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_repository.py index 5166e7168..64b6ebba9 100644 --- a/monkey/monkey_island/cc/repository/mongo_agent_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_agent_repository.py @@ -3,7 +3,12 @@ from typing import Any, MutableMapping, Sequence from pymongo import MongoClient from monkey_island.cc.models import Agent, AgentID -from monkey_island.cc.repository import IAgentRepository, RetrievalError, UnknownRecordError +from monkey_island.cc.repository import ( + IAgentRepository, + RetrievalError, + StorageError, + UnknownRecordError, +) from .consts import MONGO_OBJECT_ID_KEY @@ -13,7 +18,24 @@ class MongoAgentRepository(IAgentRepository): self._agents_collection = mongo_client.monkey_island.agents def upsert_agent(self, agent: Agent): - pass + try: + result = self._agents_collection.replace_one( + {"id": str(agent.id)}, agent.dict(simplify=True), upsert=True + ) + except Exception as err: + raise StorageError(f'Error updating agent with ID "{agent.id}": {err}') + + if result.matched_count != 0 and result.modified_count != 1: + raise StorageError( + f'Error updating agent with ID "{agent.id}": Expected to update 1 agent, ' + f"but {result.modified_count} were updated" + ) + + if result.matched_count == 0 and result.upserted_id is None: + raise StorageError( + f'Error inserting agent with ID "{agent.id}": Expected to insert 1 agent, ' + f"but no agents were inserted" + ) def get_agent_by_id(self, agent_id: AgentID) -> Agent: try: diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py index b4e556c77..18748b349 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py @@ -1,3 +1,4 @@ +from copy import deepcopy from datetime import datetime from unittest.mock import MagicMock from uuid import uuid4 @@ -10,6 +11,7 @@ from monkey_island.cc.repository import ( IAgentRepository, MongoAgentRepository, RetrievalError, + StorageError, UnknownRecordError, ) @@ -45,6 +47,12 @@ def agent_repository() -> IAgentRepository: return MongoAgentRepository(mongo_client) +@pytest.fixture +def empty_agent_repository() -> IAgentRepository: + mongo_client = mongomock.MongoClient() + return MongoAgentRepository(mongo_client) + + @pytest.fixture def error_raising_mock_mongo_client() -> mongomock.MongoClient: mongo_client = MagicMock(spec=mongomock.MongoClient) @@ -54,6 +62,9 @@ def error_raising_mock_mongo_client() -> mongomock.MongoClient: # The first call to find() must succeed mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception")) mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.replace_one = MagicMock( + side_effect=Exception("some exception") + ) return mongo_client @@ -63,6 +74,45 @@ def error_raising_agent_repository(error_raising_mock_mongo_client) -> IAgentRep return MongoAgentRepository(error_raising_mock_mongo_client) +def test_upsert_agent__insert(agent_repository): + new_id = uuid4() + new_agent = Agent( + id=new_id, + machine_id=2, + start_time=datetime.fromtimestamp(1661858139), + parent_id=VICTIM_ZERO_ID, + ) + + agent_repository.upsert_agent(new_agent) + + assert agent_repository.get_agent_by_id(new_id) == new_agent + + for agent in AGENTS: + assert agent_repository.get_agent_by_id(agent.id) == agent + + +def test_upsert_agent__insert_empty_repository(empty_agent_repository): + empty_agent_repository.upsert_agent(AGENTS[0]) + + assert empty_agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0] + + +def test_upsert_agent__update(agent_repository): + agents = deepcopy(AGENTS) + agents[0].stop_time = datetime.now() + agents[0].cc_server = "127.0.0.1:1984" + + agent_repository.upsert_agent(agents[0]) + + for agent in agents: + assert agent_repository.get_agent_by_id(agent.id) == agent + + +def test_upsert_agent__storage_error(error_raising_agent_repository): + with pytest.raises(StorageError): + error_raising_agent_repository.upsert_agent(AGENTS[0]) + + def test_get_agent_by_id(agent_repository): for i, expected_agent in enumerate(AGENTS): assert agent_repository.get_agent_by_id(expected_agent.id) == expected_agent