forked from p15670423/monkey
Island: Move tcp_connection addition to node repository
This commit is contained in:
parent
c90044074d
commit
b0ec035909
|
@ -1,4 +1,3 @@
|
|||
from copy import deepcopy
|
||||
from ipaddress import IPv4Interface
|
||||
from logging import getLogger
|
||||
from typing import Union
|
||||
|
@ -98,16 +97,16 @@ class ScanEventHandler:
|
|||
)
|
||||
|
||||
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
|
||||
node_connections = dict(deepcopy(src_node.tcp_connections))
|
||||
machine_connections = set(node_connections.get(target_machine.id, set()))
|
||||
open_ports = [port for port, status in event.ports.items() if status == PortStatus.OPEN]
|
||||
tcp_connections = set()
|
||||
open_ports = (port for port, status in event.ports.items() if status == PortStatus.OPEN)
|
||||
for open_port in open_ports:
|
||||
socket_address = SocketAddress(ip=event.target, port=open_port)
|
||||
machine_connections.add(socket_address)
|
||||
tcp_connections.add(socket_address)
|
||||
|
||||
node_connections[target_machine.id] = tuple(machine_connections)
|
||||
src_node.tcp_connections = node_connections
|
||||
self._node_repository.upsert_node(src_node)
|
||||
if tcp_connections:
|
||||
self._node_repository.add_tcp_connections(
|
||||
src_node.machine_id, {target_machine.id: tcp_connections}
|
||||
)
|
||||
|
||||
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
||||
agent = self._agent_repository.get_agent_by_id(event.source)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import FrozenSet, Mapping, Tuple
|
||||
from typing import Dict, FrozenSet, Mapping, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import TypeAlias
|
||||
|
@ -9,6 +9,7 @@ from common.types import SocketAddress
|
|||
from . import CommunicationType, MachineID
|
||||
|
||||
NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]]
|
||||
TCPConnections: TypeAlias = Dict[MachineID, Tuple[SocketAddress, ...]]
|
||||
|
||||
|
||||
class Node(MutableInfectionMonkeyBaseModel):
|
||||
|
@ -26,5 +27,5 @@ class Node(MutableInfectionMonkeyBaseModel):
|
|||
connections: NodeConnections
|
||||
"""All outbound connections from this node to other machines"""
|
||||
|
||||
tcp_connections: Mapping[MachineID, Tuple[SocketAddress, ...]] = {}
|
||||
tcp_connections: TCPConnections = {}
|
||||
"""All successfull outbound TCP connections"""
|
||||
|
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Sequence
|
||||
|
||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||
from monkey_island.cc.models.node import TCPConnections
|
||||
|
||||
|
||||
class INodeRepository(ABC):
|
||||
|
@ -26,11 +27,12 @@ class INodeRepository(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_node(self, node: Node):
|
||||
def add_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
||||
"""
|
||||
Store the Node object in the repository by creating a new one or updating an existing one.
|
||||
:param node: Node that will be saved
|
||||
:raises StorageError: If an error occurs while attempting to upsert the Node
|
||||
Add TCP connections to Node
|
||||
:param machine_id: Machine ID of the Node that made the connections
|
||||
:param tcp_connections: TCP connections made by node
|
||||
:raises StorageError: If an error occurs while attempting to add connections
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -5,6 +5,7 @@ from pymongo import MongoClient
|
|||
|
||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||
|
||||
from ..models.node import TCPConnections
|
||||
from . import INodeRepository, RemovalError, RetrievalError, StorageError
|
||||
from .consts import MONGO_OBJECT_ID_KEY
|
||||
|
||||
|
@ -47,7 +48,17 @@ class MongoNodeRepository(INodeRepository):
|
|||
|
||||
return new_node
|
||||
|
||||
def upsert_node(self, node: Node):
|
||||
def add_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
||||
node = self._get_node_by_id(machine_id)
|
||||
|
||||
for target, connections in tcp_connections.items():
|
||||
if target in node.tcp_connections:
|
||||
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
||||
else:
|
||||
node.tcp_connections[target] = connections
|
||||
self._upsert_node(node)
|
||||
|
||||
def _upsert_node(self, node: Node):
|
||||
try:
|
||||
result = self._nodes_collection.replace_one(
|
||||
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
||||
|
|
|
@ -95,6 +95,13 @@ TCP_SCAN_EVENT = TCPScanEvent(
|
|||
ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
||||
)
|
||||
|
||||
TCP_CONNECTIONS = {
|
||||
TARGET_MACHINE_ID: (
|
||||
SocketAddress(ip=TARGET_MACHINE_IP, port=22),
|
||||
SocketAddress(ip=TARGET_MACHINE_IP, port=80),
|
||||
)
|
||||
}
|
||||
|
||||
TCP_SCAN_EVENT_CLOSED = TCPScanEvent(
|
||||
source=AGENT_ID,
|
||||
target=IPv4Address(TARGET_MACHINE_IP),
|
||||
|
@ -220,31 +227,29 @@ def test_tcp_scan_event_target_machine_not_exists(
|
|||
machine_repository.upsert_machine.assert_called_with(expected_machine)
|
||||
|
||||
|
||||
def test_handle_tcp_scan_event__tcp_connections(
|
||||
def test_handle_tcp_scan_event__no_open_ports(
|
||||
scan_event_handler, machine_repository, node_repository
|
||||
):
|
||||
event = TCP_SCAN_EVENT_CLOSED
|
||||
scan_event_handler._update_nodes = MagicMock()
|
||||
scan_event_handler.handle_tcp_scan_event(event)
|
||||
|
||||
assert not node_repository.add_tcp_connections.called
|
||||
|
||||
|
||||
def test_handle_tcp_scan_event__ports_found(
|
||||
scan_event_handler, machine_repository, node_repository
|
||||
):
|
||||
event = TCP_SCAN_EVENT
|
||||
scan_event_handler._update_nodes = MagicMock()
|
||||
scan_event_handler.handle_tcp_scan_event(event)
|
||||
|
||||
node_passed = node_repository.upsert_node.call_args[0][0]
|
||||
assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set(
|
||||
EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID]
|
||||
)
|
||||
|
||||
|
||||
def test_handle_tcp_scan_event__tcp_connections_upsert(
|
||||
scan_event_handler, machine_repository, node_repository
|
||||
):
|
||||
event = TCP_SCAN_EVENT
|
||||
node_repository.get_nodes.return_value = [deepcopy(SOURCE_NODE_2)]
|
||||
scan_event_handler._update_nodes = MagicMock()
|
||||
scan_event_handler.handle_tcp_scan_event(event)
|
||||
|
||||
node_passed = node_repository.upsert_node.call_args[0][0]
|
||||
assert set(node_passed.tcp_connections[TARGET_MACHINE_ID]) == set(
|
||||
EXPECTED_NODE.tcp_connections[TARGET_MACHINE_ID]
|
||||
)
|
||||
call_args = node_repository.add_tcp_connections.call_args[0]
|
||||
assert call_args[0] == MACHINE_ID
|
||||
assert TARGET_MACHINE_ID in call_args[1]
|
||||
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
||||
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||
|
||||
|
||||
def test_handle_tcp_scan_event__no_source(
|
||||
|
|
|
@ -3,6 +3,7 @@ from unittest.mock import MagicMock
|
|||
import mongomock
|
||||
import pytest
|
||||
|
||||
from common.types import SocketAddress
|
||||
from monkey_island.cc.models import CommunicationType, Node
|
||||
from monkey_island.cc.repository import (
|
||||
INodeRepository,
|
||||
|
@ -12,6 +13,14 @@ from monkey_island.cc.repository import (
|
|||
StorageError,
|
||||
)
|
||||
|
||||
TARGET_MACHINE_IP = "2.2.2.2"
|
||||
|
||||
TCP_CONNECTION_PORT_22 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22),)}
|
||||
TCP_CONNECTION_PORT_80 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=80),)}
|
||||
ALL_TCP_CONNECTIONS = {
|
||||
3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22), SocketAddress(ip=TARGET_MACHINE_IP, port=80))
|
||||
}
|
||||
|
||||
NODES = (
|
||||
Node(
|
||||
machine_id=1,
|
||||
|
@ -23,6 +32,7 @@ NODES = (
|
|||
Node(
|
||||
machine_id=2,
|
||||
connections={1: frozenset((CommunicationType.CC,))},
|
||||
tcp_connections=TCP_CONNECTION_PORT_22,
|
||||
),
|
||||
Node(
|
||||
machine_id=3,
|
||||
|
@ -32,10 +42,7 @@ NODES = (
|
|||
5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
||||
},
|
||||
),
|
||||
Node(
|
||||
machine_id=4,
|
||||
connections={},
|
||||
),
|
||||
Node(machine_id=4, connections={}, tcp_connections=ALL_TCP_CONNECTIONS),
|
||||
Node(
|
||||
machine_id=5,
|
||||
connections={
|
||||
|
@ -201,3 +208,27 @@ def test_reset(node_repository):
|
|||
def test_reset__removal_error(error_raising_node_repository):
|
||||
with pytest.raises(RemovalError):
|
||||
error_raising_node_repository.reset()
|
||||
|
||||
|
||||
def test_upsert_tcp_connections__empty_connections(node_repository):
|
||||
node_repository.add_tcp_connections(1, TCP_CONNECTION_PORT_22)
|
||||
nodes = node_repository.get_nodes()
|
||||
for node in nodes:
|
||||
if node.machine_id == 1:
|
||||
assert node.tcp_connections == TCP_CONNECTION_PORT_22
|
||||
|
||||
|
||||
def test_upsert_tcp_connections__upsert_new_port(node_repository):
|
||||
node_repository.add_tcp_connections(2, TCP_CONNECTION_PORT_80)
|
||||
nodes = node_repository.get_nodes()
|
||||
modified_node = [node for node in nodes if node.machine_id == 2][0]
|
||||
assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS)
|
||||
assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS)
|
||||
|
||||
|
||||
def test_upsert_tcp_connections__port_already_present(node_repository):
|
||||
node_repository.add_tcp_connections(4, TCP_CONNECTION_PORT_80)
|
||||
nodes = node_repository.get_nodes()
|
||||
modified_node = [node for node in nodes if node.machine_id == 4][0]
|
||||
assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS)
|
||||
assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS)
|
||||
|
|
Loading…
Reference in New Issue