forked from p15670423/monkey
Island: Implement MongoAgentRepository.get_agent_by_id()
This commit is contained in:
parent
6f285ba80c
commit
ba228e56b6
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import Any, MutableMapping, Sequence
|
||||||
|
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
from monkey_island.cc.models import Agent, AgentID
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, RetrievalError, UnknownRecordError
|
||||||
|
|
||||||
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
|
||||||
|
|
||||||
|
class MongoAgentRepository(IAgentRepository):
|
||||||
|
def __init__(self, mongo_client: MongoClient):
|
||||||
|
self._agents_collection = mongo_client.monkey_island.agents
|
||||||
|
|
||||||
|
def upsert_agent(self, agent: Agent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_agent_by_id(self, agent_id: AgentID) -> Agent:
|
||||||
|
try:
|
||||||
|
agent_dict = self._agents_collection.find_one({"id": str(agent_id)})
|
||||||
|
except Exception as err:
|
||||||
|
raise RetrievalError(f'Error retrieving agent with "id == {agent_id}": {err}')
|
||||||
|
|
||||||
|
if agent_dict is None:
|
||||||
|
raise UnknownRecordError(f'Unknown ID "{agent_id}"')
|
||||||
|
|
||||||
|
return MongoAgentRepository._mongo_record_to_agent(agent_dict)
|
||||||
|
|
||||||
|
def get_running_agents(self) -> Sequence[Agent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mongo_record_to_agent(mongo_record: MutableMapping[str, Any]) -> Agent:
|
||||||
|
del mongo_record[MONGO_OBJECT_ID_KEY]
|
||||||
|
return Agent(**mongo_record)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
|
@ -0,0 +1,71 @@
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import mongomock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from monkey_island.cc.models import Agent
|
||||||
|
from monkey_island.cc.repository import (
|
||||||
|
IAgentRepository,
|
||||||
|
MongoAgentRepository,
|
||||||
|
RetrievalError,
|
||||||
|
UnknownRecordError,
|
||||||
|
)
|
||||||
|
|
||||||
|
VICTIM_ZERO_ID = uuid4()
|
||||||
|
AGENTS = (
|
||||||
|
Agent(id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)),
|
||||||
|
Agent(
|
||||||
|
id=uuid4(),
|
||||||
|
machine_id=2,
|
||||||
|
start_time=datetime.fromtimestamp(1661856818),
|
||||||
|
parent_id=VICTIM_ZERO_ID,
|
||||||
|
),
|
||||||
|
Agent(
|
||||||
|
id=uuid4(),
|
||||||
|
machine_id=3,
|
||||||
|
start_time=datetime.fromtimestamp(1661856758),
|
||||||
|
parent_id=VICTIM_ZERO_ID,
|
||||||
|
stop_time=datetime.fromtimestamp(1661856773),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_repository() -> IAgentRepository:
|
||||||
|
mongo_client = mongomock.MongoClient()
|
||||||
|
mongo_client.monkey_island.agents.insert_many((a.dict(simplify=True) for a in AGENTS))
|
||||||
|
return MongoAgentRepository(mongo_client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
|
||||||
|
mongo_client = MagicMock(spec=mongomock.MongoClient)
|
||||||
|
mongo_client.monkey_island = MagicMock(spec=mongomock.Database)
|
||||||
|
mongo_client.monkey_island.agents = MagicMock(spec=mongomock.Collection)
|
||||||
|
|
||||||
|
# The first call to find() must succeed
|
||||||
|
mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception"))
|
||||||
|
|
||||||
|
return mongo_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def error_raising_agent_repository(error_raising_mock_mongo_client) -> IAgentRepository:
|
||||||
|
return MongoAgentRepository(error_raising_mock_mongo_client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_agent_by_id(agent_repository):
|
||||||
|
for i, expected_agent in enumerate(AGENTS):
|
||||||
|
assert agent_repository.get_agent_by_id(expected_agent.id) == expected_agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_agent_by_id__not_found(agent_repository):
|
||||||
|
with pytest.raises(UnknownRecordError):
|
||||||
|
agent_repository.get_agent_by_id(uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_agent_by_id__retrieval_error(error_raising_agent_repository):
|
||||||
|
with pytest.raises(RetrievalError):
|
||||||
|
error_raising_agent_repository.get_agent_by_id(AGENTS[0].id)
|
Loading…
Reference in New Issue