forked from p15670423/monkey
Island: Implement MongoAgentRepository.upsert_agent()
This commit is contained in:
parent
1745b76122
commit
d854eb7576
|
@ -3,7 +3,12 @@ from typing import Any, MutableMapping, Sequence
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
from monkey_island.cc.models import Agent, AgentID
|
from monkey_island.cc.models import Agent, AgentID
|
||||||
from monkey_island.cc.repository import IAgentRepository, RetrievalError, UnknownRecordError
|
from monkey_island.cc.repository import (
|
||||||
|
IAgentRepository,
|
||||||
|
RetrievalError,
|
||||||
|
StorageError,
|
||||||
|
UnknownRecordError,
|
||||||
|
)
|
||||||
|
|
||||||
from .consts import MONGO_OBJECT_ID_KEY
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
|
||||||
|
@ -13,7 +18,24 @@ class MongoAgentRepository(IAgentRepository):
|
||||||
self._agents_collection = mongo_client.monkey_island.agents
|
self._agents_collection = mongo_client.monkey_island.agents
|
||||||
|
|
||||||
def upsert_agent(self, agent: Agent):
|
def upsert_agent(self, agent: Agent):
|
||||||
pass
|
try:
|
||||||
|
result = self._agents_collection.replace_one(
|
||||||
|
{"id": str(agent.id)}, agent.dict(simplify=True), upsert=True
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
raise StorageError(f'Error updating agent with ID "{agent.id}": {err}')
|
||||||
|
|
||||||
|
if result.matched_count != 0 and result.modified_count != 1:
|
||||||
|
raise StorageError(
|
||||||
|
f'Error updating agent with ID "{agent.id}": Expected to update 1 agent, '
|
||||||
|
f"but {result.modified_count} were updated"
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0 and result.upserted_id is None:
|
||||||
|
raise StorageError(
|
||||||
|
f'Error inserting agent with ID "{agent.id}": Expected to insert 1 agent, '
|
||||||
|
f"but no agents were inserted"
|
||||||
|
)
|
||||||
|
|
||||||
def get_agent_by_id(self, agent_id: AgentID) -> Agent:
|
def get_agent_by_id(self, agent_id: AgentID) -> Agent:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
@ -10,6 +11,7 @@ from monkey_island.cc.repository import (
|
||||||
IAgentRepository,
|
IAgentRepository,
|
||||||
MongoAgentRepository,
|
MongoAgentRepository,
|
||||||
RetrievalError,
|
RetrievalError,
|
||||||
|
StorageError,
|
||||||
UnknownRecordError,
|
UnknownRecordError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,6 +47,12 @@ def agent_repository() -> IAgentRepository:
|
||||||
return MongoAgentRepository(mongo_client)
|
return MongoAgentRepository(mongo_client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def empty_agent_repository() -> IAgentRepository:
|
||||||
|
mongo_client = mongomock.MongoClient()
|
||||||
|
return MongoAgentRepository(mongo_client)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
|
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
|
||||||
mongo_client = MagicMock(spec=mongomock.MongoClient)
|
mongo_client = MagicMock(spec=mongomock.MongoClient)
|
||||||
|
@ -54,6 +62,9 @@ def error_raising_mock_mongo_client() -> mongomock.MongoClient:
|
||||||
# The first call to find() must succeed
|
# The first call to find() must succeed
|
||||||
mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception"))
|
mongo_client.monkey_island.agents.find_one = MagicMock(side_effect=Exception("some exception"))
|
||||||
mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception"))
|
mongo_client.monkey_island.agents.find = MagicMock(side_effect=Exception("some exception"))
|
||||||
|
mongo_client.monkey_island.agents.replace_one = MagicMock(
|
||||||
|
side_effect=Exception("some exception")
|
||||||
|
)
|
||||||
|
|
||||||
return mongo_client
|
return mongo_client
|
||||||
|
|
||||||
|
@ -63,6 +74,45 @@ def error_raising_agent_repository(error_raising_mock_mongo_client) -> IAgentRep
|
||||||
return MongoAgentRepository(error_raising_mock_mongo_client)
|
return MongoAgentRepository(error_raising_mock_mongo_client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_agent__insert(agent_repository):
|
||||||
|
new_id = uuid4()
|
||||||
|
new_agent = Agent(
|
||||||
|
id=new_id,
|
||||||
|
machine_id=2,
|
||||||
|
start_time=datetime.fromtimestamp(1661858139),
|
||||||
|
parent_id=VICTIM_ZERO_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_repository.upsert_agent(new_agent)
|
||||||
|
|
||||||
|
assert agent_repository.get_agent_by_id(new_id) == new_agent
|
||||||
|
|
||||||
|
for agent in AGENTS:
|
||||||
|
assert agent_repository.get_agent_by_id(agent.id) == agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_agent__insert_empty_repository(empty_agent_repository):
|
||||||
|
empty_agent_repository.upsert_agent(AGENTS[0])
|
||||||
|
|
||||||
|
assert empty_agent_repository.get_agent_by_id(VICTIM_ZERO_ID) == AGENTS[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_agent__update(agent_repository):
|
||||||
|
agents = deepcopy(AGENTS)
|
||||||
|
agents[0].stop_time = datetime.now()
|
||||||
|
agents[0].cc_server = "127.0.0.1:1984"
|
||||||
|
|
||||||
|
agent_repository.upsert_agent(agents[0])
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
assert agent_repository.get_agent_by_id(agent.id) == agent
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_agent__storage_error(error_raising_agent_repository):
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
error_raising_agent_repository.upsert_agent(AGENTS[0])
|
||||||
|
|
||||||
|
|
||||||
def test_get_agent_by_id(agent_repository):
|
def test_get_agent_by_id(agent_repository):
|
||||||
for i, expected_agent in enumerate(AGENTS):
|
for i, expected_agent in enumerate(AGENTS):
|
||||||
assert agent_repository.get_agent_by_id(expected_agent.id) == expected_agent
|
assert agent_repository.get_agent_by_id(expected_agent.id) == expected_agent
|
||||||
|
|
Loading…
Reference in New Issue