forked from p15670423/monkey
Island: Refactor get_node_by_id to raise UnknownRecordError
This commit is contained in:
parent
8503e0f499
commit
be4ecccdcd
|
@ -65,12 +65,7 @@ class ScanEventHandler:
|
||||||
|
|
||||||
def _get_source_node(self, event: AbstractAgentEvent) -> Node:
|
def _get_source_node(self, event: AbstractAgentEvent) -> Node:
|
||||||
machine = self._get_source_machine(event)
|
machine = self._get_source_machine(event)
|
||||||
try:
|
return self._node_repository.get_node_by_machine_id(machine.id)
|
||||||
return [
|
|
||||||
node for node in self._node_repository.get_nodes() if node.machine_id == machine.id
|
|
||||||
][0]
|
|
||||||
except IndexError:
|
|
||||||
raise UnknownRecordError(f"Source node for event {event} does not exist")
|
|
||||||
|
|
||||||
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -44,6 +44,15 @@ class INodeRepository(ABC):
|
||||||
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
|
"""
|
||||||
|
Fetches network Node from the database based on Machine id
|
||||||
|
:param machine_id: ID of a Machine that Node represents
|
||||||
|
:return: network Node that represents the Machine
|
||||||
|
:raises UnknownRecordError: If the Node does not exist
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||||
|
|
||||||
from ..models.node import TCPConnections
|
from ..models.node import TCPConnections
|
||||||
from . import INodeRepository, RemovalError, RetrievalError, StorageError
|
from . import INodeRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError
|
||||||
from .consts import MONGO_OBJECT_ID_KEY
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
|
||||||
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
|
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
|
||||||
|
@ -21,16 +21,14 @@ class MongoNodeRepository(INodeRepository):
|
||||||
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
|
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
node = self._get_node_by_id(src)
|
node = self.get_node_by_machine_id(src)
|
||||||
except Exception as err:
|
|
||||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
|
||||||
|
|
||||||
if node is None:
|
|
||||||
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
|
|
||||||
else:
|
|
||||||
updated_node = MongoNodeRepository._add_connection_to_node(
|
updated_node = MongoNodeRepository._add_connection_to_node(
|
||||||
node, dst, communication_type
|
node, dst, communication_type
|
||||||
)
|
)
|
||||||
|
except UnknownRecordError:
|
||||||
|
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
|
||||||
|
except Exception as err:
|
||||||
|
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||||
|
|
||||||
self._upsert_node(updated_node)
|
self._upsert_node(updated_node)
|
||||||
|
|
||||||
|
@ -49,9 +47,9 @@ class MongoNodeRepository(INodeRepository):
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
def upsert_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
def upsert_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
||||||
node = self._get_node_by_id(machine_id)
|
try:
|
||||||
|
node = self.get_node_by_machine_id(machine_id)
|
||||||
if node is None:
|
except UnknownRecordError:
|
||||||
node = Node(machine_id=machine_id, connections={})
|
node = Node(machine_id=machine_id, connections={})
|
||||||
|
|
||||||
for target, connections in tcp_connections.items():
|
for target, connections in tcp_connections.items():
|
||||||
|
@ -75,11 +73,13 @@ class MongoNodeRepository(INodeRepository):
|
||||||
f"node, but no nodes were inserted"
|
f"node, but no nodes were inserted"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_node_by_id(self, node_id: MachineID) -> Optional[Node]:
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
node_dict = self._nodes_collection.find_one(
|
node_dict = self._nodes_collection.find_one(
|
||||||
{SRC_FIELD_NAME: node_id}, {MONGO_OBJECT_ID_KEY: False}
|
{SRC_FIELD_NAME: machine_id}, {MONGO_OBJECT_ID_KEY: False}
|
||||||
)
|
)
|
||||||
return Node(**node_dict) if node_dict else None
|
if not node_dict:
|
||||||
|
raise UnknownRecordError(f"Node with machine ID {machine_id}")
|
||||||
|
return Node(**node_dict)
|
||||||
|
|
||||||
def get_nodes(self) -> Sequence[Node]:
|
def get_nodes(self) -> Sequence[Node]:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -221,6 +221,7 @@ def test_handle_tcp_scan_event__ports_found(
|
||||||
):
|
):
|
||||||
event = TCP_SCAN_EVENT
|
event = TCP_SCAN_EVENT
|
||||||
scan_event_handler._update_nodes = MagicMock()
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
node_repository.get_node_by_machine_id.return_value = SOURCE_NODE
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
|
||||||
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
||||||
|
@ -235,12 +236,12 @@ def test_handle_tcp_scan_event__no_source(
|
||||||
caplog, scan_event_handler, machine_repository, node_repository
|
caplog, scan_event_handler, machine_repository, node_repository
|
||||||
):
|
):
|
||||||
event = TCP_SCAN_EVENT
|
event = TCP_SCAN_EVENT
|
||||||
node_repository.get_nodes.return_value = []
|
node_repository.get_node_by_machine_id = MagicMock(side_effect=UnknownRecordError("no source"))
|
||||||
scan_event_handler._update_nodes = MagicMock()
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
assert "ERROR" in caplog.text
|
assert "ERROR" in caplog.text
|
||||||
assert f"Source node for event {event} does not exist" in caplog.text
|
assert "no source" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from monkey_island.cc.repository import (
|
||||||
RemovalError,
|
RemovalError,
|
||||||
RetrievalError,
|
RetrievalError,
|
||||||
StorageError,
|
StorageError,
|
||||||
|
UnknownRecordError,
|
||||||
)
|
)
|
||||||
|
|
||||||
TARGET_MACHINE_IP = "2.2.2.2"
|
TARGET_MACHINE_IP = "2.2.2.2"
|
||||||
|
@ -239,3 +240,12 @@ def test_upsert_tcp_connections__node_missing(node_repository):
|
||||||
nodes = node_repository.get_nodes()
|
nodes = node_repository.get_nodes()
|
||||||
modified_node = [node for node in nodes if node.machine_id == 999][0]
|
modified_node = [node for node in nodes if node.machine_id == 999][0]
|
||||||
assert set(modified_node.tcp_connections) == set(TCP_CONNECTION_PORT_80)
|
assert set(modified_node.tcp_connections) == set(TCP_CONNECTION_PORT_80)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_by_machine_id(node_repository):
|
||||||
|
assert node_repository.get_node_by_machine_id(1) == NODES[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_by_machine_id__no_node(node_repository):
|
||||||
|
with pytest.raises(UnknownRecordError):
|
||||||
|
node_repository.get_node_by_machine_id(999)
|
||||||
|
|
Loading…
Reference in New Issue