forked from p15670423/monkey
Island: Add MongoNodeRepository
This commit is contained in:
parent
76b51d25b9
commit
cd6a46a304
|
@ -24,3 +24,4 @@ from .json_file_user_repository import JSONFileUserRepository
|
|||
from .mongo_credentials_repository import MongoCredentialsRepository
|
||||
from .mongo_machine_repository import MongoMachineRepository
|
||||
from .mongo_agent_repository import MongoAgentRepository
|
||||
from .mongo_node_repository import MongoNodeRepository
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any, MutableMapping, Sequence
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||
|
||||
from . import INodeRepository, RemovalError, RetrievalError, StorageError
|
||||
from .consts import MONGO_OBJECT_ID_KEY
|
||||
|
||||
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
|
||||
SRC_FIELD_NAME = "machine_id"
|
||||
|
||||
|
||||
class MongoNodeRepository(INodeRepository):
|
||||
def __init__(self, mongo_client: MongoClient):
|
||||
self._nodes_collection = mongo_client.monkey_island.nodes
|
||||
|
||||
def upsert_communication(
|
||||
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
|
||||
):
|
||||
try:
|
||||
node_dict = self._nodes_collection.find_one({SRC_FIELD_NAME: src})
|
||||
except Exception as err:
|
||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||
|
||||
if node_dict is None:
|
||||
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
|
||||
else:
|
||||
node = MongoNodeRepository._mongo_record_to_node(node_dict)
|
||||
updated_node = MongoNodeRepository._add_connection_to_node(
|
||||
node, dst, communication_type
|
||||
)
|
||||
|
||||
self._upsert_node(updated_node)
|
||||
|
||||
@staticmethod
|
||||
def _mongo_record_to_node(mongo_record: MutableMapping[str, Any]) -> Node:
|
||||
del mongo_record[MONGO_OBJECT_ID_KEY]
|
||||
return Node(**mongo_record)
|
||||
|
||||
@staticmethod
|
||||
def _add_connection_to_node(
|
||||
node: Node, dst: MachineID, communication_type: CommunicationType
|
||||
) -> Node:
|
||||
connections = dict(deepcopy(node.connections))
|
||||
communications = set(connections.get(dst, set()))
|
||||
communications.add(communication_type)
|
||||
connections[dst] = frozenset(communications)
|
||||
|
||||
new_node = node.copy()
|
||||
new_node.connections = connections
|
||||
|
||||
return new_node
|
||||
|
||||
def _upsert_node(self, node: Node):
|
||||
try:
|
||||
result = self._nodes_collection.replace_one(
|
||||
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
||||
)
|
||||
except Exception as err:
|
||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||
|
||||
if result.matched_count != 0 and result.modified_count != 1:
|
||||
raise StorageError(
|
||||
f'Error updating node with source ID "{node.machine_id}": Expected to update 1 '
|
||||
f"node, but {result.modified_count} were updated"
|
||||
)
|
||||
|
||||
if result.matched_count == 0 and result.upserted_id is None:
|
||||
raise StorageError(
|
||||
f'Error inserting node with source ID "{node.machine_id}": Expected to insert 1 '
|
||||
f"node, but no nodes were inserted"
|
||||
)
|
||||
|
||||
def get_nodes(self) -> Sequence[Node]:
|
||||
try:
|
||||
cursor = self._nodes_collection.find()
|
||||
return list(map(MongoNodeRepository._mongo_record_to_node, cursor))
|
||||
except Exception as err:
|
||||
raise RetrievalError(f"Error retrieving nodes from the repository: {err}")
|
||||
|
||||
def reset(self):
|
||||
try:
|
||||
self._nodes_collection.drop()
|
||||
except Exception as err:
|
||||
raise RemovalError(f"Error resetting the repository: {err}")
|
|
@ -0,0 +1,214 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
|
||||
from monkey_island.cc.models import CommunicationType, Node
|
||||
from monkey_island.cc.repository import (
|
||||
INodeRepository,
|
||||
MongoNodeRepository,
|
||||
RemovalError,
|
||||
RetrievalError,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
NODES = (
|
||||
Node(
|
||||
machine_id=1,
|
||||
connections={
|
||||
2: frozenset((CommunicationType.SCANNED,)),
|
||||
3: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
||||
},
|
||||
),
|
||||
Node(
|
||||
machine_id=2,
|
||||
connections={1: frozenset((CommunicationType.CC,))},
|
||||
),
|
||||
Node(
|
||||
machine_id=3,
|
||||
connections={
|
||||
1: frozenset((CommunicationType.CC,)),
|
||||
4: frozenset((CommunicationType.SCANNED,)),
|
||||
5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
||||
},
|
||||
),
|
||||
Node(
|
||||
machine_id=4,
|
||||
connections={},
|
||||
),
|
||||
Node(
|
||||
machine_id=5,
|
||||
connections={
|
||||
2: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
||||
3: frozenset((CommunicationType.CC,)),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_node_repository() -> INodeRepository:
|
||||
return MongoNodeRepository(mongomock.MongoClient())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mongo_client() -> mongomock.MongoClient:
|
||||
client = mongomock.MongoClient()
|
||||
client.monkey_island.nodes.insert_many((n.dict(simplify=True) for n in NODES))
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_repository(mongo_client) -> INodeRepository:
|
||||
return MongoNodeRepository(mongo_client)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
|
||||
mongo_client = MagicMock(spec=mongomock.MongoClient)
|
||||
mongo_client.monkey_island = MagicMock(spec=mongomock.Database)
|
||||
mongo_client.monkey_island.nodes = MagicMock(spec=mongomock.Collection)
|
||||
|
||||
mongo_client.monkey_island.nodes.find = MagicMock(side_effect=Exception("some exception"))
|
||||
mongo_client.monkey_island.nodes.find_one = MagicMock(side_effect=Exception("some exception"))
|
||||
mongo_client.monkey_island.nodes.replace_one = MagicMock(
|
||||
side_effect=Exception("some exception")
|
||||
)
|
||||
mongo_client.monkey_island.nodes.drop = MagicMock(side_effect=Exception("some exception"))
|
||||
|
||||
return mongo_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def error_raising_node_repository(error_raising_mock_mongo_client) -> INodeRepository:
|
||||
return MongoNodeRepository(error_raising_mock_mongo_client)
|
||||
|
||||
|
||||
def test_upsert_communication__empty_repository(empty_node_repository):
|
||||
src_machine_id = 1
|
||||
dst_machine_id = 2
|
||||
expected_node = Node(
|
||||
machine_id=src_machine_id,
|
||||
connections={dst_machine_id: frozenset((CommunicationType.SCANNED,))},
|
||||
)
|
||||
|
||||
empty_node_repository.upsert_communication(
|
||||
src_machine_id, dst_machine_id, CommunicationType.SCANNED
|
||||
)
|
||||
nodes = empty_node_repository.get_nodes()
|
||||
|
||||
assert len(nodes) == 1
|
||||
assert nodes[0] == expected_node
|
||||
|
||||
|
||||
def test_upsert_communication__new_node(node_repository):
|
||||
src_machine_id = NODES[-1].machine_id + 100
|
||||
dst_machine_id = 1
|
||||
expected_nodes = NODES + (
|
||||
Node(
|
||||
machine_id=src_machine_id,
|
||||
connections={dst_machine_id: frozenset((CommunicationType.CC,))},
|
||||
),
|
||||
)
|
||||
node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.CC)
|
||||
nodes = node_repository.get_nodes()
|
||||
|
||||
assert len(nodes) == len(expected_nodes)
|
||||
for en in expected_nodes:
|
||||
assert en in nodes
|
||||
|
||||
|
||||
def test_upsert_communication__update_existing_connection(node_repository):
|
||||
src_machine_id = 1
|
||||
dst_machine_id = 2
|
||||
expected_node = NODES[0].copy(deep=True)
|
||||
expected_node.connections[2] = frozenset(
|
||||
(*expected_node.connections[2], CommunicationType.EXPLOITED)
|
||||
)
|
||||
node_repository.upsert_communication(
|
||||
src_machine_id, dst_machine_id, CommunicationType.EXPLOITED
|
||||
)
|
||||
nodes = node_repository.get_nodes()
|
||||
|
||||
for node in nodes:
|
||||
if node.machine_id == src_machine_id:
|
||||
assert node == expected_node
|
||||
break
|
||||
|
||||
|
||||
def test_upsert_communication__update_existing_node_add_connection(node_repository):
|
||||
src_machine_id = 2
|
||||
dst_machine_id = 5
|
||||
expected_node = NODES[1].copy(deep=True)
|
||||
expected_node.connections[5] = frozenset((CommunicationType.SCANNED,))
|
||||
node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.SCANNED)
|
||||
nodes = node_repository.get_nodes()
|
||||
|
||||
for node in nodes:
|
||||
if node.machine_id == src_machine_id:
|
||||
assert node == expected_node
|
||||
break
|
||||
|
||||
|
||||
def test_upsert_communication__find_one_fails(error_raising_node_repository):
|
||||
with pytest.raises(StorageError):
|
||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
||||
|
||||
|
||||
def test_upsert_communication__replace_one_fails(
|
||||
error_raising_mock_mongo_client, error_raising_node_repository
|
||||
):
|
||||
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
|
||||
with pytest.raises(StorageError):
|
||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
||||
|
||||
|
||||
def test_upsert_communication__replace_one_matched_without_modify(
|
||||
error_raising_mock_mongo_client, error_raising_node_repository
|
||||
):
|
||||
mock_result = MagicMock()
|
||||
mock_result.matched_count = 1
|
||||
mock_result.modified_count = 0
|
||||
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
|
||||
error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
||||
|
||||
|
||||
def test_upsert_communication__replace_one_insert_fails(
|
||||
error_raising_mock_mongo_client, error_raising_node_repository
|
||||
):
|
||||
mock_result = MagicMock()
|
||||
mock_result.matched_count = 0
|
||||
mock_result.upserted_id = None
|
||||
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
|
||||
error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
||||
|
||||
|
||||
def test_get_nodes(node_repository):
|
||||
nodes = node_repository.get_nodes()
|
||||
assert len(nodes) == len(nodes)
|
||||
for n in nodes:
|
||||
assert n in NODES
|
||||
|
||||
|
||||
def test_get_nodes__find_fails(error_raising_node_repository):
|
||||
with pytest.raises(RetrievalError):
|
||||
error_raising_node_repository.get_nodes()
|
||||
|
||||
|
||||
def test_reset(node_repository):
|
||||
assert len(node_repository.get_nodes()) > 0
|
||||
|
||||
node_repository.reset()
|
||||
|
||||
assert len(node_repository.get_nodes()) == 0
|
||||
|
||||
|
||||
def test_reset__removal_error(error_raising_node_repository):
|
||||
with pytest.raises(RemovalError):
|
||||
error_raising_node_repository.reset()
|
Loading…
Reference in New Issue