forked from p15670423/monkey
Island: Upsert node on TCP scan event if source of event don't exist
This commit is contained in:
parent
d3c2d95a69
commit
e54c950dc3
|
@ -78,7 +78,12 @@ class ScanEventHandler:
|
||||||
|
|
||||||
def _get_source_node(self, event: ScanEvent) -> Node:
|
def _get_source_node(self, event: ScanEvent) -> Node:
|
||||||
machine = self._get_source_machine(event)
|
machine = self._get_source_machine(event)
|
||||||
return self._node_repository.get_node_by_machine_id(machine.id)
|
try:
|
||||||
|
node = self._node_repository.get_node_by_machine_id(machine.id)
|
||||||
|
except UnknownRecordError:
|
||||||
|
node = Node(machine_id=machine.id)
|
||||||
|
self._node_repository.upsert_node(node)
|
||||||
|
return node
|
||||||
|
|
||||||
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
||||||
agent = self._agent_repository.get_agent_by_id(event.source)
|
agent = self._agent_repository.get_agent_by_id(event.source)
|
||||||
|
|
|
@ -24,7 +24,7 @@ class Node(MutableInfectionMonkeyBaseModel):
|
||||||
machine_id: MachineID = Field(..., allow_mutation=False)
|
machine_id: MachineID = Field(..., allow_mutation=False)
|
||||||
"""The MachineID of the node (source)"""
|
"""The MachineID of the node (source)"""
|
||||||
|
|
||||||
connections: NodeConnections
|
connections: NodeConnections = {}
|
||||||
"""All outbound connections from this node to other machines"""
|
"""All outbound connections from this node to other machines"""
|
||||||
|
|
||||||
tcp_connections: TCPConnections = {}
|
tcp_connections: TCPConnections = {}
|
||||||
|
|
|
@ -44,6 +44,14 @@ class INodeRepository(ABC):
|
||||||
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert_node(self, node: Node):
|
||||||
|
"""
|
||||||
|
Update or insert Node model into the database
|
||||||
|
:param node: Node model to be added to the repository
|
||||||
|
:raises StorageError: If something went wrong when upserting the Node
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -30,7 +30,7 @@ class MongoNodeRepository(INodeRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||||
|
|
||||||
self._upsert_node(updated_node)
|
self.upsert_node(updated_node)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_connection_to_node(
|
def _add_connection_to_node(
|
||||||
|
@ -57,9 +57,9 @@ class MongoNodeRepository(INodeRepository):
|
||||||
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
||||||
else:
|
else:
|
||||||
node.tcp_connections[target] = connections
|
node.tcp_connections[target] = connections
|
||||||
self._upsert_node(node)
|
self.upsert_node(node)
|
||||||
|
|
||||||
def _upsert_node(self, node: Node):
|
def upsert_node(self, node: Node):
|
||||||
try:
|
try:
|
||||||
result = self._nodes_collection.replace_one(
|
result = self._nodes_collection.replace_one(
|
||||||
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
||||||
|
|
|
@ -22,11 +22,13 @@ from monkey_island.cc.repository import (
|
||||||
|
|
||||||
SEED_ID = 99
|
SEED_ID = 99
|
||||||
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
|
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
|
||||||
MACHINE_ID = 11
|
SOURCE_MACHINE_ID = 11
|
||||||
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
|
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
|
||||||
AGENT = Agent(id=AGENT_ID, machine_id=MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER)
|
AGENT = Agent(
|
||||||
|
id=AGENT_ID, machine_id=SOURCE_MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER
|
||||||
|
)
|
||||||
SOURCE_MACHINE = Machine(
|
SOURCE_MACHINE = Machine(
|
||||||
id=MACHINE_ID,
|
id=SOURCE_MACHINE_ID,
|
||||||
hardware_id=5,
|
hardware_id=5,
|
||||||
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
||||||
)
|
)
|
||||||
|
@ -125,7 +127,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository):
|
||||||
return ScanEventHandler(agent_repository, machine_repository, node_repository)
|
return ScanEventHandler(agent_repository, machine_repository, node_repository)
|
||||||
|
|
||||||
|
|
||||||
MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
MACHINES_BY_ID = {SOURCE_MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
||||||
MACHINES_BY_IP = {
|
MACHINES_BY_IP = {
|
||||||
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
||||||
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
|
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
|
||||||
|
@ -230,14 +232,14 @@ def test_handle_tcp_scan_event__ports_found(
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
|
||||||
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
||||||
assert call_args[0] == MACHINE_ID
|
assert call_args[0] == SOURCE_MACHINE_ID
|
||||||
assert TARGET_MACHINE_ID in call_args[1]
|
assert TARGET_MACHINE_ID in call_args[1]
|
||||||
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
||||||
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
|
|
||||||
|
|
||||||
def test_handle_tcp_scan_event__no_source(
|
def test_handle_tcp_scan_event__no_source_node(
|
||||||
caplog, scan_event_handler, machine_repository, node_repository
|
caplog, scan_event_handler, machine_repository, node_repository
|
||||||
):
|
):
|
||||||
event = TCP_SCAN_EVENT
|
event = TCP_SCAN_EVENT
|
||||||
|
@ -245,8 +247,11 @@ def test_handle_tcp_scan_event__no_source(
|
||||||
scan_event_handler._update_nodes = MagicMock()
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
assert "ERROR" in caplog.text
|
expected_node = Node(machine_id=SOURCE_MACHINE_ID)
|
||||||
assert "no source" in caplog.text
|
node_called = node_repository.upsert_node.call_args[0][0]
|
||||||
|
assert expected_node.machine_id == node_called.machine_id
|
||||||
|
assert expected_node.connections == node_called.connections
|
||||||
|
assert expected_node.tcp_connections == node_called.tcp_connections
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Reference in New Issue