Compare commits

...

13 Commits

Author SHA1 Message Date
Mike Salvatore 563957f9c2 Island: Add NodeUpdateFacade.get_event_source_machine() 2022-10-07 10:38:12 -04:00
Mike Salvatore b6a6295ae8 Island: Remove disused agent_event_handlers/utils.py
Replaced by NodeUpdateFacade
2022-10-07 10:38:12 -04:00
Mike Salvatore e876682d84 Island: Use NodeUpdateFacade in ScanEventHandler 2022-10-07 10:38:12 -04:00
Mike Salvatore e77932f7d6 Island: Add NodeUpdateFacade 2022-10-07 10:38:12 -04:00
Mike Salvatore e1f32177e9 Island: Call get_or_create_target_machine() from ScanEventHandler 2022-10-07 10:38:12 -04:00
Mike Salvatore c4052bc5ad Island: Add utils.get_or_create_target_machine()
Both ScanEventHandler and update_nodes_on_exploitation() will need this
functionality. Extracting it some place common.

I didn't put it into the MachineRepository because the semantics of
creating a machine if not found are likely specific to a small set of
use cases, rather than part of the general interface.
2022-10-07 10:38:12 -04:00
Mike Salvatore a7d7c1a787 UT: Add missing __init__.py 2022-10-07 10:38:12 -04:00
vakarisz e54c950dc3 Island: Upsert node on TCP scan event if source of event don't exist 2022-10-07 17:21:28 +03:00
vakarisz d3c2d95a69 Island: Handle network services in TCP scan events 2022-10-07 16:12:01 +03:00
vakarisz c5c8bc1d2f Island: Add mongo_dot_encoder to encode "." characters
This encoder will be needed in mongo repository, because mongodb can't handle keys with "." character (until version 5)
2022-10-07 16:12:01 +03:00
vakarisz a96b82fa0f Island: Don't raise errors if machine upsert did no changes
It doesn't make sense to raise an error if upsert did no changes, because the purpose of "upsert" method is to ensure that data is up-to-date. If no changes were made it means it's already up-to-date.
2022-10-07 16:12:01 +03:00
vakarisz a143d7206e Island: Reuse the same open port logic in scan_event_handler.py 2022-10-07 16:11:59 +03:00
vakarisz d0d37ce595 Island: Update machine services with TCP scan event results 2022-10-07 16:08:35 +03:00
15 changed files with 353 additions and 58 deletions

View File

@ -1,3 +1,4 @@
from .save_event_to_event_repository import save_event_to_event_repository
from .save_stolen_credentials_to_repository import save_stolen_credentials_to_repository
from .scan_event_handler import ScanEventHandler
from .update_nodes_on_exploitation import update_nodes_on_exploitation

View File

@ -0,0 +1,33 @@
from functools import lru_cache
from ipaddress import IPv4Address, IPv4Interface
from common.agent_events import AbstractAgentEvent
from common.types import AgentID, MachineID
from monkey_island.cc.models import Machine
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
class NodeUpdateFacade:
def __init__(self, agent_repository: IAgentRepository, machine_repository: IMachineRepository):
self._agent_repository = agent_repository
self._machine_repository = machine_repository
def get_or_create_target_machine(self, target: IPv4Address):
try:
target_machines = self._machine_repository.get_machines_by_ip(target)
return target_machines[0]
except UnknownRecordError:
machine = Machine(
id=self._machine_repository.get_new_id(),
network_interfaces=[IPv4Interface(target)],
)
self._machine_repository.upsert_machine(machine)
return machine
def get_event_source_machine(self, event: AbstractAgentEvent) -> Machine:
machine_id = self._get_machine_id_from_agent_id(event.source)
return self._machine_repository.get_machine_by_id(machine_id)
@lru_cache(maxsize=None)
def _get_machine_id_from_agent_id(self, agent_id: AgentID) -> MachineID:
return self._agent_repository.get_agent_by_id(agent_id).machine_id

View File

@ -1,11 +1,10 @@
from ipaddress import IPv4Interface
from logging import getLogger
from typing import Union
from typing import List, Union
from typing_extensions import TypeAlias
from common.agent_events import PingScanEvent, TCPScanEvent
from common.types import PortStatus, SocketAddress
from common.types import NetworkService, PortStatus, SocketAddress
from monkey_island.cc.models import CommunicationType, Machine, Node
from monkey_island.cc.repository import (
IAgentRepository,
@ -16,6 +15,8 @@ from monkey_island.cc.repository import (
UnknownRecordError,
)
from .node_update_facade import NodeUpdateFacade
ScanEvent: TypeAlias = Union[PingScanEvent, TCPScanEvent]
logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ScanEventHandler:
machine_repository: IMachineRepository,
node_repository: INodeRepository,
):
self._node_update_facade = NodeUpdateFacade(agent_repository, machine_repository)
self._agent_repository = agent_repository
self._machine_repository = machine_repository
self._node_repository = node_repository
@ -49,7 +51,7 @@ class ScanEventHandler:
logger.exception("Unable to process ping scan data")
def handle_tcp_scan_event(self, event: TCPScanEvent):
num_open_ports = sum((1 for status in event.ports.values() if status == PortStatus.OPEN))
num_open_ports = len(self._get_open_ports(event))
if num_open_ports <= 0:
return
@ -60,24 +62,21 @@ class ScanEventHandler:
self._update_nodes(target_machine, event)
self._update_tcp_connections(source_node, target_machine, event)
self._update_network_services(target_machine, event)
except (RetrievalError, StorageError, UnknownRecordError):
logger.exception("Unable to process tcp scan data")
def _get_target_machine(self, event: ScanEvent) -> Machine:
try:
target_machines = self._machine_repository.get_machines_by_ip(event.target)
return target_machines[0]
except UnknownRecordError:
machine = Machine(
id=self._machine_repository.get_new_id(),
network_interfaces=[IPv4Interface(event.target)],
)
self._machine_repository.upsert_machine(machine)
return machine
return self._node_update_facade.get_or_create_target_machine(event.target)
def _get_source_node(self, event: ScanEvent) -> Node:
machine = self._get_source_machine(event)
return self._node_repository.get_node_by_machine_id(machine.id)
try:
node = self._node_repository.get_node_by_machine_id(machine.id)
except UnknownRecordError:
node = Node(machine_id=machine.id)
self._node_repository.upsert_node(node)
return node
def _get_source_machine(self, event: ScanEvent) -> Machine:
agent = self._agent_repository.get_agent_by_id(event.source)
@ -88,6 +87,17 @@ class ScanEventHandler:
machine.operating_system = event.os
self._machine_repository.upsert_machine(machine)
def _update_network_services(self, target: Machine, event: TCPScanEvent):
network_services = {
SocketAddress(ip=event.target, port=port): NetworkService.UNKNOWN
for port in self._get_open_ports(event)
}
self._machine_repository.upsert_network_services(target.id, network_services)
@staticmethod
def _get_open_ports(event: TCPScanEvent) -> List[int]:
return [port for port, status in event.ports.items() if status == PortStatus.OPEN]
def _update_nodes(self, target_machine: Machine, event: ScanEvent):
src_machine = self._get_source_machine(event)
@ -97,7 +107,7 @@ class ScanEventHandler:
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
tcp_connections = set()
open_ports = (port for port, status in event.ports.items() if status == PortStatus.OPEN)
open_ports = self._get_open_ports(event)
for open_port in open_ports:
socket_address = SocketAddress(ip=event.target, port=open_port)
tcp_connections.add(socket_address)

View File

@ -3,6 +3,7 @@ from ipaddress import IPv4Interface
from typing import Any, Dict, Mapping, Optional, Sequence
from pydantic import Field, validator
from typing_extensions import TypeAlias
from common import OperatingSystem
from common.base_models import MutableInfectionMonkeyBaseModel, MutableInfectionMonkeyModelConfig
@ -11,6 +12,8 @@ from common.types import HardwareID, NetworkService, SocketAddress
from . import MachineID
NetworkServices: TypeAlias = Dict[SocketAddress, NetworkService]
def _serialize_network_services(machine_dict: Dict, *, default):
machine_dict["network_services"] = {
@ -61,7 +64,7 @@ class Machine(MutableInfectionMonkeyBaseModel):
hostname: str = ""
"""The hostname of the machine"""
network_services: Mapping[SocketAddress, NetworkService] = Field(default_factory=dict)
network_services: NetworkServices = Field(default_factory=dict)
"""All network services found running on the machine"""
_make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)(

View File

@ -24,7 +24,7 @@ class Node(MutableInfectionMonkeyBaseModel):
machine_id: MachineID = Field(..., allow_mutation=False)
"""The MachineID of the node (source)"""
connections: NodeConnections
connections: NodeConnections = {}
"""All outbound connections from this node to other machines"""
tcp_connections: TCPConnections = {}

View File

@ -4,6 +4,7 @@ from typing import Sequence
from common.types import HardwareID
from monkey_island.cc.models import Machine, MachineID
from monkey_island.cc.models.machine import NetworkServices
class IMachineRepository(ABC):
@ -29,6 +30,16 @@ class IMachineRepository(ABC):
:raises StorageError: If an error occurs while attempting to store the `Machine`
"""
@abstractmethod
def upsert_network_services(self, machine_id: MachineID, services: NetworkServices):
"""
Add/update network services on the machine
:param machine_id: ID of machine with services to be updated
:param services: Network services to be added to machine model
:raises UnknownRecordError: If the Machine is not found
:raises StorageError: If an error occurs while attempting to add/store the services
"""
@abstractmethod
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
"""

View File

@ -44,6 +44,14 @@ class INodeRepository(ABC):
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
"""
@abstractmethod
def upsert_node(self, node: Node):
"""
Update or insert Node model into the database
:param node: Node model to be added to the repository
:raises StorageError: If something went wrong when upserting the Node
"""
@abstractmethod
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
"""

View File

@ -7,8 +7,10 @@ from pymongo import MongoClient
from common.types import HardwareID
from monkey_island.cc.models import Machine, MachineID
from ..models.machine import NetworkServices
from . import IMachineRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError
from .consts import MONGO_OBJECT_ID_KEY
from .utils import DOT_REPLACEMENT, mongo_dot_decoder, mongo_dot_encoder
class MongoMachineRepository(IMachineRepository):
@ -32,26 +34,32 @@ class MongoMachineRepository(IMachineRepository):
def upsert_machine(self, machine: Machine):
try:
machine_dict = mongo_dot_encoder(machine.dict(simplify=True))
result = self._machines_collection.replace_one(
{"id": machine.id}, machine.dict(simplify=True), upsert=True
{"id": machine.id}, machine_dict, upsert=True
)
except Exception as err:
raise StorageError(f'Error updating machine with ID "{machine.id}": {err}')
if result.matched_count != 0 and result.modified_count != 1:
raise StorageError(
f'Error updating machine with ID "{machine.id}": Expected to update 1 machine, '
f"but {result.modified_count} were updated"
)
if result.matched_count == 0 and result.upserted_id is None:
raise StorageError(
f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, '
f"but no machines were inserted"
)
def upsert_network_services(self, machine_id: MachineID, services: NetworkServices):
machine = self.get_machine_by_id(machine_id)
try:
machine.network_services.update(services)
self.upsert_machine(machine)
except Exception as err:
raise StorageError(f"Failed upserting the machine or adding services") from err
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
return self._find_one("id", machine_id)
machine = self._find_one("id", machine_id)
if not machine:
raise UnknownRecordError(f"Machine with id {machine_id} not found")
return machine
def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine:
return self._find_one("hardware_id", hardware_id)
@ -67,6 +75,7 @@ class MongoMachineRepository(IMachineRepository):
if machine_dict is None:
raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"')
machine_dict = mongo_dot_decoder(machine_dict)
return Machine(**machine_dict)
def get_machines(self) -> Sequence[Machine]:
@ -75,10 +84,10 @@ class MongoMachineRepository(IMachineRepository):
except Exception as err:
raise RetrievalError(f"Error retrieving machines: {err}")
return [Machine(**m) for m in cursor]
return [Machine(**mongo_dot_decoder(m)) for m in cursor]
def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]:
ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$"
ip_regex = "^" + str(ip).replace(".", DOT_REPLACEMENT) + "\\/.*$"
query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}}
try:
@ -86,7 +95,7 @@ class MongoMachineRepository(IMachineRepository):
except Exception as err:
raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}')
machines = [Machine(**m) for m in cursor]
machines = [Machine(**mongo_dot_decoder(m)) for m in cursor]
if len(machines) == 0:
raise UnknownRecordError(f'No machines found with IP "{ip}"')

View File

@ -30,7 +30,7 @@ class MongoNodeRepository(INodeRepository):
except Exception as err:
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
self._upsert_node(updated_node)
self.upsert_node(updated_node)
@staticmethod
def _add_connection_to_node(
@ -57,9 +57,9 @@ class MongoNodeRepository(INodeRepository):
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
else:
node.tcp_connections[target] = connections
self._upsert_node(node)
self.upsert_node(node)
def _upsert_node(self, node: Node):
def upsert_node(self, node: Node):
try:
result = self._nodes_collection.replace_one(
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True

View File

@ -1,12 +1,14 @@
import json
import platform
from socket import gethostname
from typing import Any, Mapping
from uuid import getnode
from common import OperatingSystem
from common.network.network_utils import get_network_interfaces
from monkey_island.cc.models import Machine
from . import IMachineRepository, UnknownRecordError
from . import IMachineRepository, StorageError, UnknownRecordError
def initialize_machine_repository(machine_repository: IMachineRepository):
@ -33,3 +35,34 @@ def initialize_machine_repository(machine_repository: IMachineRepository):
hostname=gethostname(),
)
machine_repository.upsert_machine(machine)
DOT_REPLACEMENT = ",,,"
def mongo_dot_encoder(mapping: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Mongo can't store keys with "." symbols (like IP's and filenames). This method
replaces all occurances of "." with ",,,"
:param mapping: Mapping to be converted to mongo compatible mapping
:return: Mongo compatible mapping
"""
mapping_json = json.dumps(mapping)
if DOT_REPLACEMENT in mapping_json:
raise StorageError(
f"Mapping {mapping} already contains {DOT_REPLACEMENT}."
f" Aborting the encoding procedure"
)
encoded_json = mapping_json.replace(".", DOT_REPLACEMENT)
return json.loads(encoded_json)
def mongo_dot_decoder(mapping: Mapping[str, Any]):
"""
Mongo can't store keys with "." symbols (like IP's and filenames). This method
reverts changes made by "mongo_dot_encoder" by replacing all occurances of ",,," with "."
:param mapping: Mapping to be converted from mongo compatible mapping to original mapping
:return: Original mapping
"""
report_as_json = json.dumps(mapping).replace(DOT_REPLACEMENT, ".")
return json.loads(report_as_json)

View File

@ -0,0 +1,97 @@
from ipaddress import IPv4Address, IPv4Interface
from unittest.mock import MagicMock
from uuid import UUID
import pytest
from common.agent_events import AbstractAgentEvent
from common.types import AgentID, MachineID, SocketAddress
from monkey_island.cc.agent_event_handlers.node_update_facade import NodeUpdateFacade
from monkey_island.cc.models import Agent, Machine
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
class TestEvent(AbstractAgentEvent):
success: bool
SEED_ID = 99
IP_ADDRESS = IPv4Address("10.10.10.99")
SOURCE_MACHINE_ID = 1
SOURCE_MACHINE = Machine(
id=SOURCE_MACHINE_ID,
hardware_id=5,
network_interfaces=[IPv4Interface(IP_ADDRESS)],
)
SOURCE_AGENT_ID = UUID("655fd01c-5eec-4e42-b6e3-1fb738c2978d")
SOURCE_AGENT = Agent(
id=SOURCE_AGENT_ID,
machine_id=SOURCE_MACHINE_ID,
start_time=0,
parent_id=None,
cc_server=(SocketAddress(ip="10.10.10.10", port=5000)),
)
EXPECTED_CREATED_MACHINE = Machine(
id=SEED_ID,
network_interfaces=[IPv4Interface(IP_ADDRESS)],
)
TEST_EVENT = TestEvent(source=SOURCE_AGENT_ID, success=True)
@pytest.fixture
def agent_repository() -> IAgentRepository:
def get_agent_by_id(agent_id: AgentID) -> Agent:
if agent_id == SOURCE_AGENT_ID:
return SOURCE_AGENT
raise UnknownRecordError()
agent_repository = MagicMock(spec=IAgentRepository)
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
return agent_repository
@pytest.fixture
def machine_repository() -> IMachineRepository:
def get_machine_by_id(machine_id: MachineID) -> Machine:
if machine_id == SOURCE_MACHINE_ID:
return SOURCE_MACHINE
raise UnknownRecordError()
machine_repository = MagicMock(spec=IMachineRepository)
machine_repository.get_new_id = MagicMock(return_value=SEED_ID)
machine_repository.get_machine_by_id = MagicMock(side_effect=get_machine_by_id)
return machine_repository
@pytest.fixture
def node_update_facade(
agent_repository: IAgentRepository, machine_repository: IMachineRepository
) -> NodeUpdateFacade:
return NodeUpdateFacade(agent_repository, machine_repository)
def test_return_existing_machine(node_update_facade, machine_repository):
machine_repository.get_machines_by_ip = MagicMock(return_value=[SOURCE_MACHINE])
target_machine = node_update_facade.get_or_create_target_machine(IP_ADDRESS)
assert target_machine == SOURCE_MACHINE
def test_create_new_machine(node_update_facade, machine_repository):
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
target_machine = node_update_facade.get_or_create_target_machine(IP_ADDRESS)
assert target_machine == EXPECTED_CREATED_MACHINE
assert machine_repository.upsert_machine.called_once_with(target_machine)
def test_get_event_source_machine(node_update_facade):
assert node_update_facade.get_event_source_machine(TEST_EVENT) == SOURCE_MACHINE

View File

@ -8,7 +8,7 @@ import pytest
from common import OperatingSystem
from common.agent_events import PingScanEvent, TCPScanEvent
from common.types import PortStatus, SocketAddress
from common.types import NetworkService, PortStatus, SocketAddress
from monkey_island.cc.agent_event_handlers import ScanEventHandler
from monkey_island.cc.models import Agent, CommunicationType, Machine, Node
from monkey_island.cc.repository import (
@ -22,11 +22,13 @@ from monkey_island.cc.repository import (
SEED_ID = 99
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
MACHINE_ID = 11
SOURCE_MACHINE_ID = 11
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
AGENT = Agent(id=AGENT_ID, machine_id=MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER)
AGENT = Agent(
id=AGENT_ID, machine_id=SOURCE_MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER
)
SOURCE_MACHINE = Machine(
id=MACHINE_ID,
id=SOURCE_MACHINE_ID,
hardware_id=5,
network_interfaces=[IPv4Interface("10.10.10.99/24")],
)
@ -74,6 +76,11 @@ TCP_SCAN_EVENT = TCPScanEvent(
ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED},
)
EXPECTED_NETWORK_SERVICES = {
SocketAddress(ip=TARGET_MACHINE_IP, port=22): NetworkService.UNKNOWN,
SocketAddress(ip=TARGET_MACHINE_IP, port=80): NetworkService.UNKNOWN,
}
TCP_CONNECTIONS = {
TARGET_MACHINE_ID: (
SocketAddress(ip=TARGET_MACHINE_IP, port=22),
@ -120,7 +127,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository):
return ScanEventHandler(agent_repository, machine_repository, node_repository)
MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
MACHINES_BY_ID = {SOURCE_MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
MACHINES_BY_IP = {
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
@ -225,14 +232,14 @@ def test_handle_tcp_scan_event__ports_found(
scan_event_handler.handle_tcp_scan_event(event)
call_args = node_repository.upsert_tcp_connections.call_args[0]
assert call_args[0] == MACHINE_ID
assert call_args[0] == SOURCE_MACHINE_ID
assert TARGET_MACHINE_ID in call_args[1]
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
def test_handle_tcp_scan_event__no_source(
def test_handle_tcp_scan_event__no_source_node(
caplog, scan_event_handler, machine_repository, node_repository
):
event = TCP_SCAN_EVENT
@ -240,8 +247,11 @@ def test_handle_tcp_scan_event__no_source(
scan_event_handler._update_nodes = MagicMock()
scan_event_handler.handle_tcp_scan_event(event)
assert "ERROR" in caplog.text
assert "no source" in caplog.text
expected_node = Node(machine_id=SOURCE_MACHINE_ID)
node_called = node_repository.upsert_node.call_args[0][0]
assert expected_node.machine_id == node_called.machine_id
assert expected_node.connections == node_called.connections
assert expected_node.tcp_connections == node_called.tcp_connections
@pytest.mark.parametrize(
@ -382,3 +392,11 @@ def test_failed_scan(
assert not node_repository.upsert_communication.called
assert not machine_repository.upsert_machine.called
def test_network_services_handling(scan_event_handler, machine_repository):
scan_event_handler.handle_tcp_scan_event(TCP_SCAN_EVENT)
machine_repository.upsert_network_services.assert_called_with(
TARGET_MACHINE_ID, EXPECTED_NETWORK_SERVICES
)

View File

@ -0,0 +1,40 @@
import pytest
from monkey_island.cc.repository import StorageError
from monkey_island.cc.repository.utils import DOT_REPLACEMENT, mongo_dot_decoder, mongo_dot_encoder
DATASET = [
({"no:changes;expectes": "Nothing'$ changed"}, {"no:changes;expectes": "Nothing'$ changed"}),
(
{"192.168.56.1": "monkeys-running-wild.com"},
{
f"192{DOT_REPLACEMENT}168{DOT_REPLACEMENT}56{DOT_REPLACEMENT}1": f"monkeys-running-wild{DOT_REPLACEMENT}com"
},
),
(
{"...dots...": ",comma,comma,,comedy"},
{
f"{DOT_REPLACEMENT}{DOT_REPLACEMENT}{DOT_REPLACEMENT}dots"
f"{DOT_REPLACEMENT}{DOT_REPLACEMENT}{DOT_REPLACEMENT}": ",comma,comma,,comedy"
},
),
(
{"one": {"two": {"three": "this.is.nested"}}},
{"one": {"two": {"three": f"this{DOT_REPLACEMENT}is{DOT_REPLACEMENT}nested"}}},
),
]
# This dict already contains the replacement used, encoding procedure would lose data
FLAWED_DICT = {"one": {".two": {"three": f"this is with {DOT_REPLACEMENT} already!!!!"}}}
@pytest.mark.parametrize("input, expected_output", DATASET)
def test_mongo_dot_encoding_and_decoding(input, expected_output):
encoded = mongo_dot_encoder(input)
assert encoded == expected_output
assert mongo_dot_decoder(encoded) == input
def test_mongo_dot_encoding__data_loss():
with pytest.raises(StorageError):
mongo_dot_encoder(FLAWED_DICT)

View File

@ -6,6 +6,7 @@ import mongomock
import pytest
from common import OperatingSystem
from common.types import NetworkService, SocketAddress
from monkey_island.cc.models import Machine
from monkey_island.cc.repository import (
IMachineRepository,
@ -15,6 +16,7 @@ from monkey_island.cc.repository import (
StorageError,
UnknownRecordError,
)
from monkey_island.cc.repository.utils import mongo_dot_encoder
MACHINES = (
Machine(
@ -32,6 +34,10 @@ MACHINES = (
operating_system=OperatingSystem.WINDOWS,
operating_system_version="eXtra Problems",
hostname="hal",
network_services={
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN,
},
),
Machine(
id=3,
@ -40,6 +46,10 @@ MACHINES = (
operating_system=OperatingSystem.WINDOWS,
operating_system_version="Vista",
hostname="smith",
network_services={
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN,
},
),
Machine(
id=4,
@ -51,11 +61,24 @@ MACHINES = (
),
)
SERVICES_TO_ADD = {
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN,
}
EXPECTED_SERVICES_1 = EXPECTED_SERVICES_3 = SERVICES_TO_ADD
EXPECTED_SERVICES_2 = {
**SERVICES_TO_ADD,
SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN,
}
@pytest.fixture
def mongo_client() -> mongomock.MongoClient:
client = mongomock.MongoClient()
client.monkey_island.machines.insert_many((m.dict(simplify=True) for m in MACHINES))
client.monkey_island.machines.insert_many(
(mongo_dot_encoder(m.dict(simplify=True)) for m in MACHINES)
)
return client
@ -146,21 +169,6 @@ def test_upsert_machine__storage_error_exception(error_raising_machine_repositor
error_raising_machine_repository.upsert_machine(machine)
def test_upsert_machine__storage_error_update_failed(error_raising_mock_mongo_client):
mock_result = MagicMock()
mock_result.matched_count = 1
mock_result.modified_count = 0
error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock(
return_value=mock_result
)
machine_repository = MongoMachineRepository(error_raising_mock_mongo_client)
machine = MACHINES[0]
with pytest.raises(StorageError):
machine_repository.upsert_machine(machine)
def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client):
mock_result = MagicMock()
mock_result.matched_count = 0
@ -279,3 +287,27 @@ def test_usable_after_reset(machine_repository):
def test_reset__removal_error(error_raising_machine_repository):
with pytest.raises(RemovalError):
error_raising_machine_repository.reset()
@pytest.mark.parametrize(
"machine_id, expected_services",
[
(MACHINES[0].id, EXPECTED_SERVICES_1),
(MACHINES[1].id, EXPECTED_SERVICES_2),
(MACHINES[2].id, EXPECTED_SERVICES_3),
],
)
def test_service_upsert(machine_id, expected_services, machine_repository):
machine_repository.upsert_network_services(machine_id, SERVICES_TO_ADD)
assert machine_repository.get_machine_by_id(machine_id).network_services == expected_services
def test_service_upsert__machine_not_found(machine_repository):
with pytest.raises(UnknownRecordError):
machine_repository.upsert_network_services(machine_id=999, services=SERVICES_TO_ADD)
def test_service_upsert__error_on_storage(machine_repository):
malformed_services = 3
with pytest.raises(StorageError):
machine_repository.upsert_network_services(MACHINES[0].id, malformed_services)