Island: Implement MongoAgentRepository.reset()

This commit is contained in:
Mike Salvatore 2022-09-01 12:24:37 -04:00
parent d854eb7576
commit 693ce9e486
2 changed files with 34 additions and 3 deletions

View File

@ -5,6 +5,7 @@ from pymongo import MongoClient
from monkey_island.cc.models import Agent, AgentID
from monkey_island.cc.repository import (
IAgentRepository,
RemovalError,
RetrievalError,
StorageError,
UnknownRecordError,
@ -61,4 +62,7 @@ class MongoAgentRepository(IAgentRepository):
return Agent(**mongo_record)
def reset(self):
pass
try:
self._agents_collection.drop()
except Exception as err:
raise RemovalError(f"Error resetting the repository: {err}")

View File

@ -10,6 +10,7 @@ from monkey_island.cc.models import Agent
from monkey_island.cc.repository import (
IAgentRepository,
MongoAgentRepository,
RemovalError,
RetrievalError,
StorageError,
UnknownRecordError,
@ -59,9 +60,9 @@ def error_raising_mock_mongo_client() -> mongomock.MongoClient:
mongo_client.monkey_island = MagicMock(spec=mongomock.Database)
mongo_client.monkey_island.agents = MagicMock(spec=mongomock.Collection)
# The first call to find() must succeed
mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.agents.drop = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.agents.replace_one = MagicMock(
side_effect=Exception("some exception")
)
@ -139,3 +140,29 @@ def test_get_running_agents(agent_repository):
def test_get_running_agents__retrieval_error(error_raising_agent_repository):
with pytest.raises(RetrievalError):
error_raising_agent_repository.get_running_agents()
def test_reset(agent_repository):
# Ensure the repository is not empty
for agent in AGENTS:
preexisting_agent = agent_repository.get_agent_by_id(agent.id)
assert isinstance(preexisting_agent, Agent)
agent_repository.reset()
for agent in AGENTS:
with pytest.raises(UnknownRecordError):
agent_repository.get_agent_by_id(agent.id)
def test_usable_after_reset(agent_repository):
agent_repository.reset()
agent_repository.upsert_agent(AGENTS[0])
assert agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0]
def test_reset__removal_error(error_raising_agent_repository):
with pytest.raises(RemovalError):
error_raising_agent_repository.reset()