diff --git a/monkey/monkey_island/cc/repository/mongo_agent_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_repository.py index 64b6ebba9..3b1b60f0a 100644 --- a/monkey/monkey_island/cc/repository/mongo_agent_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_agent_repository.py @@ -5,6 +5,7 @@ from pymongo import MongoClient from monkey_island.cc.models import Agent, AgentID from monkey_island.cc.repository import ( IAgentRepository, + RemovalError, RetrievalError, StorageError, UnknownRecordError, @@ -61,4 +62,7 @@ class MongoAgentRepository(IAgentRepository): return Agent(**mongo_record) def reset(self): - pass + try: + self._agents_collection.drop() + except Exception as err: + raise RemovalError(f"Error resetting the repository: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py index 18748b349..4d3fa0d14 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py @@ -10,6 +10,7 @@ from monkey_island.cc.models import Agent from monkey_island.cc.repository import ( IAgentRepository, MongoAgentRepository, + RemovalError, RetrievalError, StorageError, UnknownRecordError, @@ -59,9 +60,9 @@ def error_raising_mock_mongo_client() -> mongomock.MongoClient: mongo_client.monkey_island = MagicMock(spec=mongomock.Database) mongo_client.monkey_island.agents = MagicMock(spec=mongomock.Collection) - # The first call to find() must succeed - mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.drop = MagicMock(side_effect=Exception("some exception")) mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception")) + mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception")) mongo_client.monkey_island.agents.replace_one = MagicMock( side_effect=Exception("some exception") ) @@ -139,3 +140,29 @@ def test_get_running_agents(agent_repository): def test_get_running_agents__retrieval_error(error_raising_agent_repository): with pytest.raises(RetrievalError): error_raising_agent_repository.get_running_agents() + + +def test_reset(agent_repository): + # Ensure the repository is not empty + for agent in AGENTS: + preexisting_agent = agent_repository.get_agent_by_id(agent.id) + assert isinstance(preexisting_agent, Agent) + + agent_repository.reset() + + for agent in AGENTS: + with pytest.raises(UnknownRecordError): + agent_repository.get_agent_by_id(agent.id) + + +def test_usable_after_reset(agent_repository): + agent_repository.reset() + + agent_repository.upsert_agent(AGENTS[0]) + + assert agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0] + + +def test_reset__removal_error(error_raising_agent_repository): + with pytest.raises(RemovalError): + error_raising_agent_repository.reset()