Merge pull request #2283 from guardicore/2255-mongo-node-repository

2255 mongo node repository
This commit is contained in:
Mike Salvatore 2022-09-14 09:35:16 -04:00 committed by GitHub
commit 4bb914316f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 345 additions and 29 deletions

View File

@ -1,4 +1,4 @@
from typing import Mapping, Tuple from typing import FrozenSet, Mapping
from pydantic import Field from pydantic import Field
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -7,7 +7,7 @@ from common.base_models import MutableInfectionMonkeyBaseModel
from . import CommunicationType, MachineID from . import CommunicationType, MachineID
NodeConnections: TypeAlias = Mapping[MachineID, Tuple[CommunicationType, ...]] NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]]
class Node(MutableInfectionMonkeyBaseModel): class Node(MutableInfectionMonkeyBaseModel):

View File

@ -24,3 +24,4 @@ from .json_file_user_repository import JSONFileUserRepository
from .mongo_credentials_repository import MongoCredentialsRepository from .mongo_credentials_repository import MongoCredentialsRepository
from .mongo_machine_repository import MongoMachineRepository from .mongo_machine_repository import MongoMachineRepository
from .mongo_agent_repository import MongoAgentRepository from .mongo_agent_repository import MongoAgentRepository
from .mongo_node_repository import MongoNodeRepository

View File

@ -33,3 +33,12 @@ class INodeRepository(ABC):
:return: All known Nodes :return: All known Nodes
:raises RetrievalError: If an error occurred while attempting to retrieve the nodes :raises RetrievalError: If an error occurred while attempting to retrieve the nodes
""" """
@abstractmethod
def reset(self):
"""
Removes all data from the repository
:raises RemovalError: If an error occurred while attempting to remove all `Nodes` from the
repository
"""

View File

@ -1,4 +1,4 @@
from typing import Any, MutableMapping, Sequence from typing import Sequence
from pymongo import MongoClient from pymongo import MongoClient
@ -40,27 +40,24 @@ class MongoAgentRepository(IAgentRepository):
def get_agent_by_id(self, agent_id: AgentID) -> Agent: def get_agent_by_id(self, agent_id: AgentID) -> Agent:
try: try:
agent_dict = self._agents_collection.find_one({"id": str(agent_id)}) agent_dict = self._agents_collection.find_one(
{"id": str(agent_id)}, {MONGO_OBJECT_ID_KEY: False}
)
except Exception as err: except Exception as err:
raise RetrievalError(f'Error retrieving agent with "id == {agent_id}": {err}') raise RetrievalError(f'Error retrieving agent with "id == {agent_id}": {err}')
if agent_dict is None: if agent_dict is None:
raise UnknownRecordError(f'Unknown ID "{agent_id}"') raise UnknownRecordError(f'Unknown ID "{agent_id}"')
return MongoAgentRepository._mongo_record_to_agent(agent_dict) return Agent(**agent_dict)
def get_running_agents(self) -> Sequence[Agent]: def get_running_agents(self) -> Sequence[Agent]:
try: try:
cursor = self._agents_collection.find({"stop_time": None}) cursor = self._agents_collection.find({"stop_time": None}, {MONGO_OBJECT_ID_KEY: False})
return list(map(MongoAgentRepository._mongo_record_to_agent, cursor)) return list(map(lambda a: Agent(**a), cursor))
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving running agents: {err}") raise RetrievalError(f"Error retrieving running agents: {err}")
@staticmethod
def _mongo_record_to_agent(mongo_record: MutableMapping[str, Any]) -> Agent:
del mongo_record[MONGO_OBJECT_ID_KEY]
return Agent(**mongo_record)
def reset(self): def reset(self):
try: try:
self._agents_collection.drop() self._agents_collection.drop()

View File

@ -55,9 +55,8 @@ class MongoCredentialsRepository(ICredentialsRepository):
def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]: def _get_credentials_from_collection(self, collection) -> Sequence[Credentials]:
try: try:
collection_result = [] collection_result = []
list_collection_result = list(collection.find({})) list_collection_result = list(collection.find({}, {MONGO_OBJECT_ID_KEY: False}))
for encrypted_credentials in list_collection_result: for encrypted_credentials in list_collection_result:
del encrypted_credentials[MONGO_OBJECT_ID_KEY]
plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials) plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials)
collection_result.append(Credentials(**plaintext_credentials)) collection_result.append(Credentials(**plaintext_credentials))

View File

@ -1,6 +1,6 @@
from ipaddress import IPv4Address from ipaddress import IPv4Address
from threading import Lock from threading import Lock
from typing import Any, MutableMapping, Sequence from typing import Any, Sequence
from pymongo import MongoClient from pymongo import MongoClient
@ -58,36 +58,33 @@ class MongoMachineRepository(IMachineRepository):
def _find_one(self, key: str, search_value: Any) -> Machine: def _find_one(self, key: str, search_value: Any) -> Machine:
try: try:
machine_dict = self._machines_collection.find_one({key: search_value}) machine_dict = self._machines_collection.find_one(
{key: search_value}, {MONGO_OBJECT_ID_KEY: False}
)
except Exception as err: except Exception as err:
raise RetrievalError(f'Error retrieving machine with "{key} == {search_value}": {err}') raise RetrievalError(f'Error retrieving machine with "{key} == {search_value}": {err}')
if machine_dict is None: if machine_dict is None:
raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"') raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"')
return MongoMachineRepository._mongo_record_to_machine(machine_dict) return Machine(**machine_dict)
def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]: def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]:
ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$" ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$"
query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}} query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}}
try: try:
cursor = self._machines_collection.find(query) cursor = self._machines_collection.find(query, {MONGO_OBJECT_ID_KEY: False})
except Exception as err: except Exception as err:
raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}') raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}')
machines = list(map(MongoMachineRepository._mongo_record_to_machine, cursor)) machines = list(map(lambda m: Machine(**m), cursor))
if len(machines) == 0: if len(machines) == 0:
raise UnknownRecordError(f'No machines found with IP "{ip}"') raise UnknownRecordError(f'No machines found with IP "{ip}"')
return machines return machines
@staticmethod
def _mongo_record_to_machine(mongo_record: MutableMapping[str, Any]) -> Machine:
del mongo_record[MONGO_OBJECT_ID_KEY]
return Machine(**mongo_record)
def reset(self): def reset(self):
try: try:
self._machines_collection.drop() self._machines_collection.drop()

View File

@ -0,0 +1,84 @@
from copy import deepcopy
from typing import Sequence
from pymongo import MongoClient
from monkey_island.cc.models import CommunicationType, MachineID, Node
from . import INodeRepository, RemovalError, RetrievalError, StorageError
from .consts import MONGO_OBJECT_ID_KEY
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
SRC_FIELD_NAME = "machine_id"
class MongoNodeRepository(INodeRepository):
def __init__(self, mongo_client: MongoClient):
self._nodes_collection = mongo_client.monkey_island.nodes
def upsert_communication(
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
):
try:
node_dict = self._nodes_collection.find_one(
{SRC_FIELD_NAME: src}, {MONGO_OBJECT_ID_KEY: False}
)
except Exception as err:
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
if node_dict is None:
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
else:
node = Node(**node_dict)
updated_node = MongoNodeRepository._add_connection_to_node(
node, dst, communication_type
)
self._upsert_node(updated_node)
@staticmethod
def _add_connection_to_node(
node: Node, dst: MachineID, communication_type: CommunicationType
) -> Node:
connections = dict(deepcopy(node.connections))
communications = set(connections.get(dst, set()))
communications.add(communication_type)
connections[dst] = frozenset(communications)
new_node = node.copy()
new_node.connections = connections
return new_node
def _upsert_node(self, node: Node):
try:
result = self._nodes_collection.replace_one(
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
)
except Exception as err:
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
if result.matched_count != 0 and result.modified_count != 1:
raise StorageError(
f'Error updating node with source ID "{node.machine_id}": Expected to update 1 '
f"node, but {result.modified_count} were updated"
)
if result.matched_count == 0 and result.upserted_id is None:
raise StorageError(
f'Error inserting node with source ID "{node.machine_id}": Expected to insert 1 '
f"node, but no nodes were inserted"
)
def get_nodes(self) -> Sequence[Node]:
try:
cursor = self._nodes_collection.find({}, {MONGO_OBJECT_ID_KEY: False})
return list(map(lambda n: Node(**n), cursor))
except Exception as err:
raise RetrievalError(f"Error retrieving nodes from the repository: {err}")
def reset(self):
try:
self._nodes_collection.drop()
except Exception as err:
raise RemovalError(f"Error resetting the repository: {err}")

View File

@ -8,8 +8,8 @@ from monkey_island.cc.models import CommunicationType, Node
def test_constructor(): def test_constructor():
machine_id = 1 machine_id = 1
connections = { connections = {
6: (CommunicationType.SCANNED,), 6: frozenset((CommunicationType.SCANNED,)),
7: (CommunicationType.SCANNED, CommunicationType.EXPLOITED), 7: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
} }
n = Node( n = Node(
machine_id=1, machine_id=1,
@ -24,13 +24,25 @@ def test_serialization():
node_dict = { node_dict = {
"machine_id": 1, "machine_id": 1,
"connections": { "connections": {
"6": ["cc", "scanned"], "6": [CommunicationType.CC.value, CommunicationType.SCANNED.value],
"7": ["exploited", "cc"], "7": [CommunicationType.EXPLOITED.value, CommunicationType.CC.value],
}, },
} }
# "6": frozenset((CommunicationType.CC, CommunicationType.SCANNED)),
# "7": frozenset((CommunicationType.EXPLOITED, CommunicationType.CC)),
n = Node(**node_dict) n = Node(**node_dict)
assert n.dict(simplify=True) == node_dict serialized_node = n.dict(simplify=True)
# NOTE: Comparing these nodes is difficult because sets are not ordered
assert len(serialized_node) == len(node_dict)
for key in serialized_node.keys():
assert key in node_dict
assert len(serialized_node["connections"]) == len(node_dict["connections"])
for key, value in serialized_node["connections"].items():
assert set(value) == set(node_dict["connections"][key])
def test_machine_id_immutable(): def test_machine_id_immutable():

View File

@ -0,0 +1,214 @@
from unittest.mock import MagicMock
import mongomock
import pytest
from monkey_island.cc.models import CommunicationType, Node
from monkey_island.cc.repository import (
INodeRepository,
MongoNodeRepository,
RemovalError,
RetrievalError,
StorageError,
)
NODES = (
Node(
machine_id=1,
connections={
2: frozenset((CommunicationType.SCANNED,)),
3: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
},
),
Node(
machine_id=2,
connections={1: frozenset((CommunicationType.CC,))},
),
Node(
machine_id=3,
connections={
1: frozenset((CommunicationType.CC,)),
4: frozenset((CommunicationType.SCANNED,)),
5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
},
),
Node(
machine_id=4,
connections={},
),
Node(
machine_id=5,
connections={
2: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
3: frozenset((CommunicationType.CC,)),
},
),
)
@pytest.fixture
def empty_node_repository() -> INodeRepository:
return MongoNodeRepository(mongomock.MongoClient())
@pytest.fixture
def mongo_client() -> mongomock.MongoClient:
client = mongomock.MongoClient()
client.monkey_island.nodes.insert_many((n.dict(simplify=True) for n in NODES))
return client
@pytest.fixture
def node_repository(mongo_client) -> INodeRepository:
return MongoNodeRepository(mongo_client)
@pytest.fixture
def error_raising_mock_mongo_client() -> mongomock.MongoClient:
mongo_client = MagicMock(spec=mongomock.MongoClient)
mongo_client.monkey_island = MagicMock(spec=mongomock.Database)
mongo_client.monkey_island.nodes = MagicMock(spec=mongomock.Collection)
mongo_client.monkey_island.nodes.find = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.nodes.find_one = MagicMock(side_effect=Exception("some exception"))
mongo_client.monkey_island.nodes.replace_one = MagicMock(
side_effect=Exception("some exception")
)
mongo_client.monkey_island.nodes.drop = MagicMock(side_effect=Exception("some exception"))
return mongo_client
@pytest.fixture
def error_raising_node_repository(error_raising_mock_mongo_client) -> INodeRepository:
return MongoNodeRepository(error_raising_mock_mongo_client)
def test_upsert_communication__empty_repository(empty_node_repository):
src_machine_id = 1
dst_machine_id = 2
expected_node = Node(
machine_id=src_machine_id,
connections={dst_machine_id: frozenset((CommunicationType.SCANNED,))},
)
empty_node_repository.upsert_communication(
src_machine_id, dst_machine_id, CommunicationType.SCANNED
)
nodes = empty_node_repository.get_nodes()
assert len(nodes) == 1
assert nodes[0] == expected_node
def test_upsert_communication__new_node(node_repository):
src_machine_id = NODES[-1].machine_id + 100
dst_machine_id = 1
expected_nodes = NODES + (
Node(
machine_id=src_machine_id,
connections={dst_machine_id: frozenset((CommunicationType.CC,))},
),
)
node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.CC)
nodes = node_repository.get_nodes()
assert len(nodes) == len(expected_nodes)
for en in expected_nodes:
assert en in nodes
def test_upsert_communication__update_existing_connection(node_repository):
src_machine_id = 1
dst_machine_id = 2
expected_node = NODES[0].copy(deep=True)
expected_node.connections[2] = frozenset(
(*expected_node.connections[2], CommunicationType.EXPLOITED)
)
node_repository.upsert_communication(
src_machine_id, dst_machine_id, CommunicationType.EXPLOITED
)
nodes = node_repository.get_nodes()
for node in nodes:
if node.machine_id == src_machine_id:
assert node == expected_node
break
def test_upsert_communication__update_existing_node_add_connection(node_repository):
src_machine_id = 2
dst_machine_id = 5
expected_node = NODES[1].copy(deep=True)
expected_node.connections[5] = frozenset((CommunicationType.SCANNED,))
node_repository.upsert_communication(src_machine_id, dst_machine_id, CommunicationType.SCANNED)
nodes = node_repository.get_nodes()
for node in nodes:
if node.machine_id == src_machine_id:
assert node == expected_node
break
def test_upsert_communication__find_one_fails(error_raising_node_repository):
with pytest.raises(StorageError):
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
def test_upsert_communication__replace_one_fails(
error_raising_mock_mongo_client, error_raising_node_repository
):
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
with pytest.raises(StorageError):
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
def test_upsert_communication__replace_one_matched_without_modify(
error_raising_mock_mongo_client, error_raising_node_repository
):
mock_result = MagicMock()
mock_result.matched_count = 1
mock_result.modified_count = 0
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result
with pytest.raises(StorageError):
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
def test_upsert_communication__replace_one_insert_fails(
error_raising_mock_mongo_client, error_raising_node_repository
):
mock_result = MagicMock()
mock_result.matched_count = 0
mock_result.upserted_id = None
error_raising_mock_mongo_client.monkey_island.nodes.find_one = lambda _: None
error_raising_mock_mongo_client.monkey_island.nodes.replace_one = lambda *_, **__: mock_result
with pytest.raises(StorageError):
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
def test_get_nodes(node_repository):
nodes = node_repository.get_nodes()
assert len(nodes) == len(nodes)
for n in nodes:
assert n in NODES
def test_get_nodes__find_fails(error_raising_node_repository):
with pytest.raises(RetrievalError):
error_raising_node_repository.get_nodes()
def test_reset(node_repository):
assert len(node_repository.get_nodes()) > 0
node_repository.reset()
assert len(node_repository.get_nodes()) == 0
def test_reset__removal_error(error_raising_node_repository):
with pytest.raises(RemovalError):
error_raising_node_repository.reset()

View File

@ -12,6 +12,7 @@ from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFacto
from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue
from monkey_island.cc.models import Report from monkey_island.cc.models import Report
from monkey_island.cc.models.networkmap import Arc, NetworkMap from monkey_island.cc.models.networkmap import Arc, NetworkMap
from monkey_island.cc.repository import MongoAgentRepository, MongoMachineRepository
from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository from monkey_island.cc.repository.attack.IMitigationsRepository import IMitigationsRepository
from monkey_island.cc.repository.i_agent_repository import IAgentRepository from monkey_island.cc.repository.i_agent_repository import IAgentRepository
from monkey_island.cc.repository.i_attack_repository import IAttackRepository from monkey_island.cc.repository.i_attack_repository import IAttackRepository
@ -275,6 +276,8 @@ ICredentialsRepository.save_stolen_credentials
ICredentialsRepository.save_configured_credentials ICredentialsRepository.save_configured_credentials
IEventRepository.get_events IEventRepository.get_events
IFindingRepository.get_findings IFindingRepository.get_findings
MongoAgentRepository
MongoMachineRepository
key_list key_list
simulation simulation
netmap netmap