Merge pull request #2400 from guardicore/2267-add-tcp-connections
2267 add tcp connections
This commit is contained in:
commit
4709ae771b
|
@ -4,9 +4,9 @@ from typing import Union
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from common.agent_events import PingScanEvent, TCPScanEvent
|
from common.agent_events import AbstractAgentEvent, PingScanEvent, TCPScanEvent
|
||||||
from common.types import PortStatus
|
from common.types import PortStatus, SocketAddress
|
||||||
from monkey_island.cc.models import CommunicationType, Machine
|
from monkey_island.cc.models import CommunicationType, Machine, Node
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
IAgentRepository,
|
IAgentRepository,
|
||||||
IMachineRepository,
|
IMachineRepository,
|
||||||
|
@ -56,11 +56,17 @@ class ScanEventHandler:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_machine = self._get_target_machine(event)
|
target_machine = self._get_target_machine(event)
|
||||||
|
source_node = self._get_source_node(event)
|
||||||
|
|
||||||
self._update_nodes(target_machine, event)
|
self._update_nodes(target_machine, event)
|
||||||
|
self._update_tcp_connections(source_node, target_machine, event)
|
||||||
except (RetrievalError, StorageError, UnknownRecordError):
|
except (RetrievalError, StorageError, UnknownRecordError):
|
||||||
logger.exception("Unable to process tcp scan data")
|
logger.exception("Unable to process tcp scan data")
|
||||||
|
|
||||||
|
def _get_source_node(self, event: AbstractAgentEvent) -> Node:
|
||||||
|
machine = self._get_source_machine(event)
|
||||||
|
return self._node_repository.get_node_by_machine_id(machine.id)
|
||||||
|
|
||||||
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
||||||
try:
|
try:
|
||||||
target_machines = self._machine_repository.get_machines_by_ip(event.target)
|
target_machines = self._machine_repository.get_machines_by_ip(event.target)
|
||||||
|
@ -85,6 +91,18 @@ class ScanEventHandler:
|
||||||
src_machine.id, target_machine.id, CommunicationType.SCANNED
|
src_machine.id, target_machine.id, CommunicationType.SCANNED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
|
||||||
|
tcp_connections = set()
|
||||||
|
open_ports = (port for port, status in event.ports.items() if status == PortStatus.OPEN)
|
||||||
|
for open_port in open_ports:
|
||||||
|
socket_address = SocketAddress(ip=event.target, port=open_port)
|
||||||
|
tcp_connections.add(socket_address)
|
||||||
|
|
||||||
|
if tcp_connections:
|
||||||
|
self._node_repository.upsert_tcp_connections(
|
||||||
|
src_node.machine_id, {target_machine.id: tcp_connections}
|
||||||
|
)
|
||||||
|
|
||||||
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
||||||
agent = self._agent_repository.get_agent_by_id(event.source)
|
agent = self._agent_repository.get_agent_by_id(event.source)
|
||||||
return self._machine_repository.get_machine_by_id(agent.machine_id)
|
return self._machine_repository.get_machine_by_id(agent.machine_id)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import FrozenSet, Mapping, Tuple
|
from typing import Dict, FrozenSet, Mapping, Tuple
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
@ -9,6 +9,7 @@ from common.types import SocketAddress
|
||||||
from . import CommunicationType, MachineID
|
from . import CommunicationType, MachineID
|
||||||
|
|
||||||
NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]]
|
NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]]
|
||||||
|
TCPConnections: TypeAlias = Dict[MachineID, Tuple[SocketAddress, ...]]
|
||||||
|
|
||||||
|
|
||||||
class Node(MutableInfectionMonkeyBaseModel):
|
class Node(MutableInfectionMonkeyBaseModel):
|
||||||
|
@ -26,5 +27,5 @@ class Node(MutableInfectionMonkeyBaseModel):
|
||||||
connections: NodeConnections
|
connections: NodeConnections
|
||||||
"""All outbound connections from this node to other machines"""
|
"""All outbound connections from this node to other machines"""
|
||||||
|
|
||||||
tcp_connections: Mapping[MachineID, Tuple[SocketAddress, ...]] = {}
|
tcp_connections: TCPConnections = {}
|
||||||
"""All successfull outbound TCP connections"""
|
"""All successfull outbound TCP connections"""
|
||||||
|
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||||
|
from monkey_island.cc.models.node import TCPConnections
|
||||||
|
|
||||||
|
|
||||||
class INodeRepository(ABC):
|
class INodeRepository(ABC):
|
||||||
|
@ -25,6 +26,15 @@ class INodeRepository(ABC):
|
||||||
:raises StorageError: If an error occurs while attempting to upsert the Node
|
:raises StorageError: If an error occurs while attempting to upsert the Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
||||||
|
"""
|
||||||
|
Add TCP connections to Node
|
||||||
|
:param machine_id: Machine ID of the Node that made the connections
|
||||||
|
:param tcp_connections: TCP connections made by node
|
||||||
|
:raises StorageError: If an error occurs while attempting to add connections
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_nodes(self) -> Sequence[Node]:
|
def get_nodes(self) -> Sequence[Node]:
|
||||||
"""
|
"""
|
||||||
|
@ -34,6 +44,15 @@ class INodeRepository(ABC):
|
||||||
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
|
"""
|
||||||
|
Fetches network Node from the database based on Machine id
|
||||||
|
:param machine_id: ID of a Machine that Node represents
|
||||||
|
:return: network Node that represents the Machine
|
||||||
|
:raises UnknownRecordError: If the Node does not exist
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,8 @@ from pymongo import MongoClient
|
||||||
|
|
||||||
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
from monkey_island.cc.models import CommunicationType, MachineID, Node
|
||||||
|
|
||||||
from . import INodeRepository, RemovalError, RetrievalError, StorageError
|
from ..models.node import TCPConnections
|
||||||
|
from . import INodeRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError
|
||||||
from .consts import MONGO_OBJECT_ID_KEY
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
|
||||||
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
|
UPSERT_ERROR_MESSAGE = "An error occurred while attempting to upsert a node"
|
||||||
|
@ -20,19 +21,14 @@ class MongoNodeRepository(INodeRepository):
|
||||||
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
|
self, src: MachineID, dst: MachineID, communication_type: CommunicationType
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
node_dict = self._nodes_collection.find_one(
|
node = self.get_node_by_machine_id(src)
|
||||||
{SRC_FIELD_NAME: src}, {MONGO_OBJECT_ID_KEY: False}
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
|
||||||
|
|
||||||
if node_dict is None:
|
|
||||||
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
|
|
||||||
else:
|
|
||||||
node = Node(**node_dict)
|
|
||||||
updated_node = MongoNodeRepository._add_connection_to_node(
|
updated_node = MongoNodeRepository._add_connection_to_node(
|
||||||
node, dst, communication_type
|
node, dst, communication_type
|
||||||
)
|
)
|
||||||
|
except UnknownRecordError:
|
||||||
|
updated_node = Node(machine_id=src, connections={dst: frozenset((communication_type,))})
|
||||||
|
except Exception as err:
|
||||||
|
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||||
|
|
||||||
self._upsert_node(updated_node)
|
self._upsert_node(updated_node)
|
||||||
|
|
||||||
|
@ -50,6 +46,19 @@ class MongoNodeRepository(INodeRepository):
|
||||||
|
|
||||||
return new_node
|
return new_node
|
||||||
|
|
||||||
|
def upsert_tcp_connections(self, machine_id: MachineID, tcp_connections: TCPConnections):
|
||||||
|
try:
|
||||||
|
node = self.get_node_by_machine_id(machine_id)
|
||||||
|
except UnknownRecordError:
|
||||||
|
node = Node(machine_id=machine_id, connections={})
|
||||||
|
|
||||||
|
for target, connections in tcp_connections.items():
|
||||||
|
if target in node.tcp_connections:
|
||||||
|
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
||||||
|
else:
|
||||||
|
node.tcp_connections[target] = connections
|
||||||
|
self._upsert_node(node)
|
||||||
|
|
||||||
def _upsert_node(self, node: Node):
|
def _upsert_node(self, node: Node):
|
||||||
try:
|
try:
|
||||||
result = self._nodes_collection.replace_one(
|
result = self._nodes_collection.replace_one(
|
||||||
|
@ -58,18 +67,20 @@ class MongoNodeRepository(INodeRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||||
|
|
||||||
if result.matched_count != 0 and result.modified_count != 1:
|
|
||||||
raise StorageError(
|
|
||||||
f'Error updating node with source ID "{node.machine_id}": Expected to update 1 '
|
|
||||||
f"node, but {result.modified_count} were updated"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0 and result.upserted_id is None:
|
if result.matched_count == 0 and result.upserted_id is None:
|
||||||
raise StorageError(
|
raise StorageError(
|
||||||
f'Error inserting node with source ID "{node.machine_id}": Expected to insert 1 '
|
f'Error inserting node with source ID "{node.machine_id}": Expected to insert 1 '
|
||||||
f"node, but no nodes were inserted"
|
f"node, but no nodes were inserted"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
|
node_dict = self._nodes_collection.find_one(
|
||||||
|
{SRC_FIELD_NAME: machine_id}, {MONGO_OBJECT_ID_KEY: False}
|
||||||
|
)
|
||||||
|
if not node_dict:
|
||||||
|
raise UnknownRecordError(f"Node with machine ID {machine_id}")
|
||||||
|
return Node(**node_dict)
|
||||||
|
|
||||||
def get_nodes(self) -> Sequence[Node]:
|
def get_nodes(self) -> Sequence[Node]:
|
||||||
try:
|
try:
|
||||||
cursor = self._nodes_collection.find({}, {MONGO_OBJECT_ID_KEY: False})
|
cursor = self._nodes_collection.find({}, {MONGO_OBJECT_ID_KEY: False})
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from copy import deepcopy
|
||||||
from ipaddress import IPv4Address, IPv4Interface
|
from ipaddress import IPv4Address, IPv4Interface
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
@ -9,7 +10,7 @@ from common import OperatingSystem
|
||||||
from common.agent_events import PingScanEvent, TCPScanEvent
|
from common.agent_events import PingScanEvent, TCPScanEvent
|
||||||
from common.types import PortStatus, SocketAddress
|
from common.types import PortStatus, SocketAddress
|
||||||
from monkey_island.cc.agent_event_handlers import ScanEventHandler
|
from monkey_island.cc.agent_event_handlers import ScanEventHandler
|
||||||
from monkey_island.cc.models import Agent, CommunicationType, Machine
|
from monkey_island.cc.models import Agent, CommunicationType, Machine, Node
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
IAgentRepository,
|
IAgentRepository,
|
||||||
IMachineRepository,
|
IMachineRepository,
|
||||||
|
@ -29,43 +30,60 @@ SOURCE_MACHINE = Machine(
|
||||||
hardware_id=5,
|
hardware_id=5,
|
||||||
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TARGET_MACHINE_ID = 33
|
||||||
|
TARGET_MACHINE_IP = "10.10.10.1"
|
||||||
TARGET_MACHINE = Machine(
|
TARGET_MACHINE = Machine(
|
||||||
id=33,
|
id=TARGET_MACHINE_ID,
|
||||||
hardware_id=9,
|
hardware_id=9,
|
||||||
network_interfaces=[IPv4Interface("10.10.10.1/24")],
|
network_interfaces=[IPv4Interface(f"{TARGET_MACHINE_IP}/24")],
|
||||||
|
)
|
||||||
|
|
||||||
|
SOURCE_NODE = Node(
|
||||||
|
machine_id=SOURCE_MACHINE.id,
|
||||||
|
connections=[],
|
||||||
|
tcp_connections={
|
||||||
|
44: (SocketAddress(ip="1.1.1.1", port=40), SocketAddress(ip="2.2.2.2", port=50))
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
PING_SCAN_EVENT = PingScanEvent(
|
PING_SCAN_EVENT = PingScanEvent(
|
||||||
source=AGENT_ID,
|
source=AGENT_ID,
|
||||||
target=IPv4Address("10.10.10.1"),
|
target=IPv4Address(TARGET_MACHINE_IP),
|
||||||
response_received=True,
|
response_received=True,
|
||||||
os=OperatingSystem.LINUX,
|
os=OperatingSystem.LINUX,
|
||||||
)
|
)
|
||||||
|
|
||||||
PING_SCAN_EVENT_NO_RESPONSE = PingScanEvent(
|
PING_SCAN_EVENT_NO_RESPONSE = PingScanEvent(
|
||||||
source=AGENT_ID,
|
source=AGENT_ID,
|
||||||
target=IPv4Address("10.10.10.1"),
|
target=IPv4Address(TARGET_MACHINE_IP),
|
||||||
response_received=False,
|
response_received=False,
|
||||||
os=OperatingSystem.LINUX,
|
os=OperatingSystem.LINUX,
|
||||||
)
|
)
|
||||||
|
|
||||||
PING_SCAN_EVENT_NO_OS = PingScanEvent(
|
PING_SCAN_EVENT_NO_OS = PingScanEvent(
|
||||||
source=AGENT_ID,
|
source=AGENT_ID,
|
||||||
target=IPv4Address("10.10.10.1"),
|
target=IPv4Address(TARGET_MACHINE_IP),
|
||||||
response_received=True,
|
response_received=True,
|
||||||
os=None,
|
os=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TCP_SCAN_EVENT = TCPScanEvent(
|
TCP_SCAN_EVENT = TCPScanEvent(
|
||||||
source=AGENT_ID,
|
source=AGENT_ID,
|
||||||
target=IPv4Address("10.10.10.1"),
|
target=IPv4Address(TARGET_MACHINE_IP),
|
||||||
ports={22: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TCP_CONNECTIONS = {
|
||||||
|
TARGET_MACHINE_ID: (
|
||||||
|
SocketAddress(ip=TARGET_MACHINE_IP, port=22),
|
||||||
|
SocketAddress(ip=TARGET_MACHINE_IP, port=80),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
TCP_SCAN_EVENT_CLOSED = TCPScanEvent(
|
TCP_SCAN_EVENT_CLOSED = TCPScanEvent(
|
||||||
source=AGENT_ID,
|
source=AGENT_ID,
|
||||||
target=IPv4Address("10.10.10.1"),
|
target=IPv4Address(TARGET_MACHINE_IP),
|
||||||
ports={145: PortStatus.CLOSED, 8080: PortStatus.CLOSED},
|
ports={145: PortStatus.CLOSED, 8080: PortStatus.CLOSED},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,6 +109,8 @@ def machine_repository() -> IMachineRepository:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def node_repository() -> INodeRepository:
|
def node_repository() -> INodeRepository:
|
||||||
node_repository = MagicMock(spec=INodeRepository)
|
node_repository = MagicMock(spec=INodeRepository)
|
||||||
|
node_repository.get_nodes.return_value = [deepcopy(SOURCE_NODE)]
|
||||||
|
node_repository.upsert_node = MagicMock()
|
||||||
node_repository.upsert_communication = MagicMock()
|
node_repository.upsert_communication = MagicMock()
|
||||||
return node_repository
|
return node_repository
|
||||||
|
|
||||||
|
@ -103,7 +123,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository):
|
||||||
MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
||||||
MACHINES_BY_IP = {
|
MACHINES_BY_IP = {
|
||||||
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
||||||
IPv4Address("10.10.10.1"): [TARGET_MACHINE],
|
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,6 +206,44 @@ def test_tcp_scan_event_target_machine_not_exists(
|
||||||
machine_repository.upsert_machine.assert_called_with(expected_machine)
|
machine_repository.upsert_machine.assert_called_with(expected_machine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_tcp_scan_event__no_open_ports(
|
||||||
|
scan_event_handler, machine_repository, node_repository
|
||||||
|
):
|
||||||
|
event = TCP_SCAN_EVENT_CLOSED
|
||||||
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
|
||||||
|
assert not node_repository.upsert_tcp_connections.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_tcp_scan_event__ports_found(
|
||||||
|
scan_event_handler, machine_repository, node_repository
|
||||||
|
):
|
||||||
|
event = TCP_SCAN_EVENT
|
||||||
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
node_repository.get_node_by_machine_id.return_value = SOURCE_NODE
|
||||||
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
|
||||||
|
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
||||||
|
assert call_args[0] == MACHINE_ID
|
||||||
|
assert TARGET_MACHINE_ID in call_args[1]
|
||||||
|
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
||||||
|
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
|
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_tcp_scan_event__no_source(
|
||||||
|
caplog, scan_event_handler, machine_repository, node_repository
|
||||||
|
):
|
||||||
|
event = TCP_SCAN_EVENT
|
||||||
|
node_repository.get_node_by_machine_id = MagicMock(side_effect=UnknownRecordError("no source"))
|
||||||
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
|
||||||
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
assert "ERROR" in caplog.text
|
||||||
|
assert "no source" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"event,handler",
|
"event,handler",
|
||||||
[(PING_SCAN_EVENT, HANDLE_PING_SCAN_METHOD), (TCP_SCAN_EVENT, HANDLE_TCP_SCAN_METHOD)],
|
[(PING_SCAN_EVENT, HANDLE_PING_SCAN_METHOD), (TCP_SCAN_EVENT, HANDLE_TCP_SCAN_METHOD)],
|
||||||
|
|
|
@ -3,6 +3,7 @@ from unittest.mock import MagicMock
|
||||||
import mongomock
|
import mongomock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from common.types import SocketAddress
|
||||||
from monkey_island.cc.models import CommunicationType, Node
|
from monkey_island.cc.models import CommunicationType, Node
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
INodeRepository,
|
INodeRepository,
|
||||||
|
@ -10,8 +11,17 @@ from monkey_island.cc.repository import (
|
||||||
RemovalError,
|
RemovalError,
|
||||||
RetrievalError,
|
RetrievalError,
|
||||||
StorageError,
|
StorageError,
|
||||||
|
UnknownRecordError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TARGET_MACHINE_IP = "2.2.2.2"
|
||||||
|
|
||||||
|
TCP_CONNECTION_PORT_22 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22),)}
|
||||||
|
TCP_CONNECTION_PORT_80 = {3: (SocketAddress(ip=TARGET_MACHINE_IP, port=80),)}
|
||||||
|
ALL_TCP_CONNECTIONS = {
|
||||||
|
3: (SocketAddress(ip=TARGET_MACHINE_IP, port=22), SocketAddress(ip=TARGET_MACHINE_IP, port=80))
|
||||||
|
}
|
||||||
|
|
||||||
NODES = (
|
NODES = (
|
||||||
Node(
|
Node(
|
||||||
machine_id=1,
|
machine_id=1,
|
||||||
|
@ -23,6 +33,7 @@ NODES = (
|
||||||
Node(
|
Node(
|
||||||
machine_id=2,
|
machine_id=2,
|
||||||
connections={1: frozenset((CommunicationType.CC,))},
|
connections={1: frozenset((CommunicationType.CC,))},
|
||||||
|
tcp_connections=TCP_CONNECTION_PORT_22,
|
||||||
),
|
),
|
||||||
Node(
|
Node(
|
||||||
machine_id=3,
|
machine_id=3,
|
||||||
|
@ -32,10 +43,7 @@ NODES = (
|
||||||
5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
5: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Node(
|
Node(machine_id=4, connections={}, tcp_connections=ALL_TCP_CONNECTIONS),
|
||||||
machine_id=4,
|
|
||||||
connections={},
|
|
||||||
),
|
|
||||||
Node(
|
Node(
|
||||||
machine_id=5,
|
machine_id=5,
|
||||||
connections={
|
connections={
|
||||||
|
@ -163,21 +171,6 @@ def test_upsert_communication__replace_one_fails(
|
||||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
||||||
|
|
||||||
|
|
||||||
def test_upsert_communication__replace_one_matched_without_modify(
|
|
||||||
error_raising_mock_mongo_client, error_raising_node_repository
|
|
||||||
):
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.matched_count = 1
|
|
||||||
mock_result.modified_count = 0
|
|
||||||
error_raising_mock_mongo_client.monkey_island.nodes.find_one = MagicMock(return_value=None)
|
|
||||||
error_raising_mock_mongo_client.monkey_island.nodes.replace_one = MagicMock(
|
|
||||||
return_value=mock_result
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(StorageError):
|
|
||||||
error_raising_node_repository.upsert_communication(1, 2, CommunicationType.SCANNED)
|
|
||||||
|
|
||||||
|
|
||||||
def test_upsert_communication__replace_one_insert_fails(
|
def test_upsert_communication__replace_one_insert_fails(
|
||||||
error_raising_mock_mongo_client, error_raising_node_repository
|
error_raising_mock_mongo_client, error_raising_node_repository
|
||||||
):
|
):
|
||||||
|
@ -216,3 +209,43 @@ def test_reset(node_repository):
|
||||||
def test_reset__removal_error(error_raising_node_repository):
|
def test_reset__removal_error(error_raising_node_repository):
|
||||||
with pytest.raises(RemovalError):
|
with pytest.raises(RemovalError):
|
||||||
error_raising_node_repository.reset()
|
error_raising_node_repository.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_tcp_connections__empty_connections(node_repository):
|
||||||
|
node_repository.upsert_tcp_connections(1, TCP_CONNECTION_PORT_22)
|
||||||
|
nodes = node_repository.get_nodes()
|
||||||
|
for node in nodes:
|
||||||
|
if node.machine_id == 1:
|
||||||
|
assert node.tcp_connections == TCP_CONNECTION_PORT_22
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_tcp_connections__upsert_new_port(node_repository):
|
||||||
|
node_repository.upsert_tcp_connections(2, TCP_CONNECTION_PORT_80)
|
||||||
|
nodes = node_repository.get_nodes()
|
||||||
|
modified_node = [node for node in nodes if node.machine_id == 2][0]
|
||||||
|
assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS)
|
||||||
|
assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_tcp_connections__port_already_present(node_repository):
|
||||||
|
node_repository.upsert_tcp_connections(4, TCP_CONNECTION_PORT_80)
|
||||||
|
nodes = node_repository.get_nodes()
|
||||||
|
modified_node = [node for node in nodes if node.machine_id == 4][0]
|
||||||
|
assert set(modified_node.tcp_connections) == set(ALL_TCP_CONNECTIONS)
|
||||||
|
assert len(modified_node.tcp_connections) == len(ALL_TCP_CONNECTIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_tcp_connections__node_missing(node_repository):
|
||||||
|
node_repository.upsert_tcp_connections(999, TCP_CONNECTION_PORT_80)
|
||||||
|
nodes = node_repository.get_nodes()
|
||||||
|
modified_node = [node for node in nodes if node.machine_id == 999][0]
|
||||||
|
assert set(modified_node.tcp_connections) == set(TCP_CONNECTION_PORT_80)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_by_machine_id(node_repository):
|
||||||
|
assert node_repository.get_node_by_machine_id(1) == NODES[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_by_machine_id__no_node(node_repository):
|
||||||
|
with pytest.raises(UnknownRecordError):
|
||||||
|
node_repository.get_node_by_machine_id(999)
|
||||||
|
|
Loading…
Reference in New Issue