Compare commits

...

20 Commits

Author SHA1 Message Date
Kekoa Kaaikala 9d37a38994 Agent: Fix FingerprintData mypy issues 2022-09-21 20:38:19 +00:00
Kekoa Kaaikala e40d061091 Agent: Fix PortScanData mypy issues 2022-09-21 20:30:41 +00:00
Kekoa Kaaikala d4f6c83f56 Agent: Fix VictimHost mypy issues 2022-09-21 20:19:20 +00:00
Kekoa Kaaikala d440d51f53 Island: Fix mypy issues in segmentation.py 2022-09-21 18:29:34 +00:00
Kekoa Kaaikala 218b341006 Island: Fix mypy issues in networkmap.py 2022-09-21 18:28:58 +00:00
Kekoa Kaaikala 24848e62df Island: Fix mypy issues in version.py 2022-09-21 18:28:17 +00:00
Kekoa Kaaikala a9e101fd04 Island: Fix mypy issues in T1065.py 2022-09-21 18:25:38 +00:00
Kekoa Kaaikala 1a6f48614e Island: Fix mypy issues in mongo_db_process.py 2022-09-21 18:24:47 +00:00
Kekoa Kaaikala a5b5449f73 Island: Fix mypy issues in finding_service.py 2022-09-21 18:23:33 +00:00
Kekoa Kaaikala 7013963d59 Island: Fix mypy issues in cred_exploit.py 2022-09-21 18:22:20 +00:00
Kekoa Kaaikala ed5773878e Island: Fix mypy issues in ransomware_report.py 2022-09-21 18:21:42 +00:00
Kekoa Kaaikala c870fde3cc Island: Fix mypy issues in exploit.py 2022-09-21 18:21:03 +00:00
Kekoa Kaaikala e595a70019 Island: Fix mypy issues for encryptors 2022-09-21 18:19:58 +00:00
Kekoa Kaaikala e8aa231f92 Island: Fix mypy issues in AbstractResource.py 2022-09-21 18:09:18 +00:00
Kekoa Kaaikala c0b2981150 Island: Fix mypy issues in i_log_repository.py 2022-09-21 18:08:25 +00:00
Kekoa Kaaikala 7801a98a15 Island: Fix mypy issues in server_setup.py 2022-09-21 18:06:53 +00:00
Kekoa Kaaikala 6209ec47cd Island: Fix mypy issues in app.py 2022-09-21 18:06:19 +00:00
Kekoa Kaaikala 795d9fe201 Agent: Fix mypy issues in ransomware_options.py 2022-09-21 17:58:41 +00:00
Kekoa Kaaikala ea67c07d70 Agent: Fix mypy issues in pba.py 2022-09-21 17:57:44 +00:00
Kekoa Kaaikala 70c74b87a9 Agent: Fix mypy issues in capture_output.py 2022-09-21 17:55:41 +00:00
43 changed files with 222 additions and 176 deletions

View File

@ -193,7 +193,7 @@ class SingleIpRange(NetworkRange):
:return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com) :return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com)
""" """
# The most common use case is to enter ip/range into "Scan IP/subnet list" # The most common use case is to enter ip/range into "Scan IP/subnet list"
domain_name = None domain_name = ""
if " " in string_: if " " in string_:
raise ValueError(f'"{string_}" is not a valid IP address or domain name.') raise ValueError(f'"{string_}" is not a valid IP address or domain name.')

View File

@ -1,10 +1,10 @@
import abc import abc
import threading import threading
from collections import namedtuple from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, Iterable, List, Mapping, Sequence from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from common import OperatingSystem
from common.credentials import Credentials from common.credentials import Credentials
from infection_monkey.model import VictimHost from infection_monkey.model import VictimHost
@ -26,15 +26,37 @@ class ExploiterResultData:
propagation_success: bool = False propagation_success: bool = False
interrupted: bool = False interrupted: bool = False
os: str = "" os: str = ""
info: Mapping = None info: Mapping = field(default_factory=lambda: {})
attempts: Iterable = None attempts: Iterable = field(default_factory=lambda: [])
error_message: str = "" error_message: str = ""
PingScanData = namedtuple("PingScanData", ["response_received", "os"]) @dataclass(frozen=True)
PortScanData = namedtuple("PortScanData", ["port", "status", "banner", "service"]) class FingerprintData:
FingerprintData = namedtuple("FingerprintData", ["os_type", "os_version", "services"]) os_type: Optional[OperatingSystem]
PostBreachData = namedtuple("PostBreachData", ["display_name", "command", "result"]) os_version: str
services: Mapping = field(default_factory=lambda: {})
@dataclass(frozen=True)
class PingScanData:
response_received: bool
os: Optional[OperatingSystem]
@dataclass(frozen=True)
class PortScanData:
port: int
status: PortStatus
banner: str
service: str
@dataclass(frozen=True)
class PostBreachData:
display_name: str
command: str
result: Tuple[str, bool]
class IPuppet(metaclass=abc.ABCMeta): class IPuppet(metaclass=abc.ABCMeta):
@ -84,7 +106,7 @@ class IPuppet(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3 self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]: ) -> Mapping[int, PortScanData]:
""" """
Scans a list of TCP ports on a remote host Scans a list of TCP ports on a remote host
@ -92,7 +114,7 @@ class IPuppet(metaclass=abc.ABCMeta):
:param int ports: List of TCP port numbers to scan :param int ports: List of TCP port numbers to scan
:param float timeout: The maximum amount of time (in seconds) to wait for a response :param float timeout: The maximum amount of time (in seconds) to wait for a response
:return: The data collected by scanning the provided host:ports combination :return: The data collected by scanning the provided host:ports combination
:rtype: Dict[int, PortScanData] :rtype: Mapping[int, PortScanData]
""" """
@abc.abstractmethod @abc.abstractmethod
@ -125,6 +147,7 @@ class IPuppet(metaclass=abc.ABCMeta):
name: str, name: str,
host: VictimHost, host: VictimHost,
current_depth: int, current_depth: int,
servers: Sequence[str],
options: Dict, options: Dict,
interrupt: threading.Event, interrupt: threading.Event,
) -> ExploiterResultData: ) -> ExploiterResultData:

View File

@ -2,7 +2,7 @@ import logging
from ipaddress import IPv4Interface from ipaddress import IPv4Interface
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from typing import List, Sequence from typing import Dict, List, Sequence
from common.agent_configuration import ( from common.agent_configuration import (
ExploitationConfiguration, ExploitationConfiguration,
@ -154,11 +154,12 @@ class Propagator:
victim_host.services[psd.service] = {} victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)" victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None: victim_host.services[psd.service]["banner"] = psd.banner
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod @staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData): def _process_fingerprinter_results(
victim_host: VictimHost, fingerprint_data: Dict[str, FingerprintData]
):
for fd in fingerprint_data.values(): for fd in fingerprint_data.values():
# TODO: This logic preserves the existing behavior prior to introducing IMaster and # TODO: This logic preserves the existing behavior prior to introducing IMaster and
# IPuppet, but it is possibly flawed. Different fingerprinters may detect # IPuppet, but it is possibly flawed. Different fingerprinters may detect
@ -167,7 +168,7 @@ class Propagator:
if fd.os_type is not None: if fd.os_type is not None:
victim_host.os["type"] = fd.os_type victim_host.os["type"] = fd.os_type
if ("version" not in victim_host.os) and (fd.os_version is not None): if ("version" not in victim_host.os) and (fd.os_version):
victim_host.os["version"] = fd.os_version victim_host.os["version"] = fd.os_version
for service, details in fd.services.items(): for service, details in fd.services.items():

View File

@ -10,7 +10,7 @@ class VictimHost(object):
self.os: Dict[str, Any] = {} self.os: Dict[str, Any] = {}
self.services: Dict[str, Any] = {} self.services: Dict[str, Any] = {}
self.icmp = False self.icmp = False
self.default_server = None self.default_server = ""
def as_dict(self): def as_dict(self):
return self.__dict__ return self.__dict__

View File

@ -31,10 +31,10 @@ class ElasticSearchFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
services = {} services: Dict[str, Any] = {}
if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN): if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
try: try:
elasticsearch_info = _query_elasticsearch(host) elasticsearch_info = _query_elasticsearch(host)
@ -42,7 +42,7 @@ class ElasticSearchFingerprinter(IFingerprinter):
except Exception as ex: except Exception as ex:
logger.debug(f"Did not detect an ElasticSearch cluster: {ex}") logger.debug(f"Did not detect an ElasticSearch cluster: {ex}")
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_elasticsearch(host: str) -> Dict[str, Any]: def _query_elasticsearch(host: str) -> Dict[str, Any]:

View File

@ -1,6 +1,6 @@
import logging import logging
from contextlib import closing from contextlib import closing
from typing import Any, Dict, Iterable, Optional, Set, Tuple from typing import Any, Dict, Iterable, Mapping, Optional, Set, Tuple
from requests import head from requests import head
from requests.exceptions import ConnectionError, Timeout from requests.exceptions import ConnectionError, Timeout
@ -46,7 +46,7 @@ class HTTPFingerprinter(IFingerprinter):
"data": (server_header_contents, ssl), "data": (server_header_contents, ssl),
} }
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]: def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]:
@ -71,7 +71,7 @@ def _get_server_from_headers(url: str) -> Optional[str]:
return None return None
def _get_http_headers(url: str) -> Optional[Dict[str, Any]]: def _get_http_headers(url: str) -> Optional[Mapping[str, Any]]:
try: try:
logger.debug(f"Sending request for headers to {url}") logger.debug(f"Sending request for headers to {url}")
with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123 with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123

View File

@ -1,7 +1,7 @@
import errno import errno
import logging import logging
import socket import socket
from typing import Any, Dict, Optional from typing import Any, Dict
from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData
@ -32,10 +32,10 @@ class MSSQLFingerprinter(IFingerprinter):
except Exception as ex: except Exception as ex:
logger.debug(f"Did not detect an MSSQL server: {ex}") logger.debug(f"Did not detect an MSSQL server: {ex}")
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
def _query_mssql_for_instance_data(host: str) -> Optional[bytes]: def _query_mssql_for_instance_data(host: str) -> bytes:
# Create a UDP socket and sets a timeout # Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(_MSSQL_SOCKET_TIMEOUT) sock.settimeout(_MSSQL_SOCKET_TIMEOUT)
@ -44,22 +44,20 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
# The message is a CLNT_UCAST_EX packet to get all instances # The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx # https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03" message_str = "\x03"
# Encode the message as a bytes array # Encode the message as a bytes array
message = message.encode() message = message_str.encode()
# send data and receive response # send data and receive response
try: try:
logger.info(f"Sending message to requested host: {host}, {message}") logger.info(f"Sending message to requested host: {host}, {message_str}")
sock.sendto(message, server_address) sock.sendto(message, server_address)
data, _ = sock.recvfrom(_BUFFER_SIZE) data, _ = sock.recvfrom(_BUFFER_SIZE)
return data return data
except socket.timeout as err: except socket.timeout as err:
logger.debug( logger.debug(f"Socket timeout reached, maybe browser service on host: {host} doesn't exist")
f"Socket timeout reached, maybe browser service on host: {host} doesnt " "exist"
)
raise err raise err
except socket.error as err: except socket.error as err:
if err.errno == errno.ECONNRESET: if err.errno == errno.ECONNRESET:
@ -78,7 +76,7 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _get_services_from_server_data(data: bytes) -> Dict[str, Any]: def _get_services_from_server_data(data: bytes) -> Dict[str, Any]:
services = {MSSQL_SERVICE: {}} services: Dict[str, Any] = {MSSQL_SERVICE: {}}
services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME
services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT

View File

@ -26,10 +26,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces)) scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces))
if inaccessible_subnets: if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets( inaccessible_subnet_addresses = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces inaccessible_subnets, local_network_interfaces
) )
scan_targets.extend(inaccessible_subnets) scan_targets.extend(inaccessible_subnet_addresses)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces) scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips) scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -44,7 +44,7 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
for target in targets: for target in targets:
domain_name = target.domain domain_name = target.domain
ip = target.ip ip = target.ip
if ip not in reverse_dns or (reverse_dns[ip] is None and domain_name is not None): if ip not in reverse_dns or (not reverse_dns[ip] and domain_name):
reverse_dns[ip] = domain_name reverse_dns[ip] = domain_name
return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()] return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()]
@ -55,7 +55,7 @@ def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
if hasattr(range_obj, "domain_name"): if hasattr(range_obj, "domain_name"):
addresses.append(NetworkAddress(address, range_obj.domain_name)) addresses.append(NetworkAddress(address, range_obj.domain_name))
else: else:
addresses.append(NetworkAddress(address, None)) addresses.append(NetworkAddress(address, ""))
return addresses return addresses

View File

@ -144,16 +144,16 @@ class SMBFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
services = {} services: Dict = {}
smb_service = { smb_service = {
"display_name": DISPLAY_NAME, "display_name": DISPLAY_NAME,
"port": SMB_PORT, "port": SMB_PORT,
} }
os_type = None os_type = None
os_version = None os_version = ""
if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN): if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services) return FingerprintData(None, "", services)
logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}") logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}")

View File

@ -21,7 +21,7 @@ class SSHFingerprinter(IFingerprinter):
_options: Dict, _options: Dict,
) -> FingerprintData: ) -> FingerprintData:
os_type = None os_type = None
os_version = None os_version = ""
services = {} services = {}
for ps_data in port_scan_data.values(): for ps_data in port_scan_data.values():
@ -35,9 +35,9 @@ class SSHFingerprinter(IFingerprinter):
return FingerprintData(os_type, os_version, services) return FingerprintData(os_type, os_version, services)
@staticmethod @staticmethod
def _get_host_os(banner) -> Tuple[Optional[str], Optional[str]]: def _get_host_os(banner) -> Tuple[Optional[OperatingSystem], str]:
os = None os = None
os_version = None os_version = ""
for dist in LINUX_DIST_SSH: for dist in LINUX_DIST_SSH:
if banner.lower().find(dist) != -1: if banner.lower().find(dist) != -1:
os_version = banner.split(" ").pop().strip() os_version = banner.split(" ").pop().strip()

View File

@ -12,7 +12,7 @@ from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_por
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POLL_INTERVAL = 0.5 POLL_INTERVAL = 0.5
EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)} EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, "", "")}
def scan_tcp_ports( def scan_tcp_ports(
@ -48,7 +48,7 @@ def _build_port_scan_data(
def _get_closed_port_data(port: int) -> PortScanData: def _get_closed_port_data(port: int) -> PortScanData:
return PortScanData(port, PortStatus.CLOSED, None, None) return PortScanData(port, PortStatus.CLOSED, "", "")
def _check_tcp_ports( def _check_tcp_ports(

View File

@ -1,4 +1,6 @@
import logging import logging
from pathlib import Path
from typing import Optional
from common.utils.file_utils import InvalidPath, expand_path from common.utils.file_utils import InvalidPath, expand_path
from infection_monkey.utils.environment import is_windows_os from infection_monkey.utils.environment import is_windows_os
@ -12,7 +14,7 @@ class RansomwareOptions:
self.file_extension = options["encryption"]["file_extension"] self.file_extension = options["encryption"]["file_extension"]
self.readme_enabled = options["other_behaviors"]["readme"] self.readme_enabled = options["other_behaviors"]["readme"]
self.target_directory = None self.target_directory: Optional[Path] = None
self._set_target_directory(options["encryption"]["directories"]) self._set_target_directory(options["encryption"]["directories"])
def _set_target_directory(self, os_target_directories: dict): def _set_target_directory(self, os_target_directories: dict):

View File

@ -1,6 +1,6 @@
import logging import logging
import subprocess import subprocess
from typing import Dict, Iterable from typing import Dict, Iterable, List, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
@ -33,7 +33,7 @@ class PBA:
""" """
self.command = PBA.choose_command(linux_cmd, windows_cmd) self.command = PBA.choose_command(linux_cmd, windows_cmd)
self.name = name self.name = name
self.pba_data = [] self.pba_data: List[PostBreachData] = []
self.telemetry_messenger = telemetry_messenger self.telemetry_messenger = telemetry_messenger
self.timeout = timeout self.timeout = timeout
@ -73,7 +73,7 @@ class PBA:
pba_execution_succeeded = pba_execution_result[1] pba_execution_succeeded = pba_execution_result[1]
return pba_execution_succeeded and self.is_script() return pba_execution_succeeded and self.is_script()
def _execute_default(self): def _execute_default(self) -> Tuple[str, bool]:
""" """
Default post breach command execution routine Default post breach command execution routine
:return: Tuple of command's output string and boolean, indicating if it succeeded :return: Tuple of command's output string and boolean, indicating if it succeeded
@ -84,7 +84,7 @@ class PBA:
).decode() ).decode()
return output, True return output, True
except subprocess.CalledProcessError as err: except subprocess.CalledProcessError as err:
return err.output.decode(), False return bytes(err.output).decode(), False
except subprocess.TimeoutExpired as err: except subprocess.TimeoutExpired as err:
return str(err), False return str(err), False

View File

@ -1,6 +1,6 @@
import logging import logging
import threading import threading
from typing import Dict, Iterable, List, Sequence from typing import Dict, Iterable, List, Mapping, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT from common.common_consts.timeouts import CONNECTION_TIMEOUT
from common.credentials import Credentials from common.credentials import Credentials
@ -8,6 +8,7 @@ from infection_monkey import network_scanning
from infection_monkey.i_puppet import ( from infection_monkey.i_puppet import (
ExploiterResultData, ExploiterResultData,
FingerprintData, FingerprintData,
IFingerprinter,
IPuppet, IPuppet,
PingScanData, PingScanData,
PluginType, PluginType,
@ -18,7 +19,7 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None) EMPTY_FINGERPRINT = FingerprintData(None, "", {})
logger = logging.getLogger() logger = logging.getLogger()
@ -45,7 +46,7 @@ class Puppet(IPuppet):
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]: ) -> Mapping[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout) return network_scanning.scan_tcp_ports(host, ports, timeout)
def fingerprint( def fingerprint(
@ -57,7 +58,9 @@ class Puppet(IPuppet):
options: Dict, options: Dict,
) -> FingerprintData: ) -> FingerprintData:
try: try:
fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER) fingerprinter: IFingerprinter = self._plugin_registry.get_plugin(
name, PluginType.FINGERPRINTER
)
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options) return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options)
except Exception: except Exception:
logger.exception( logger.exception(

View File

@ -7,7 +7,6 @@ class StdoutCapture:
self._orig_stdout = sys.stdout self._orig_stdout = sys.stdout
self._new_stdout = io.StringIO() self._new_stdout = io.StringIO()
sys.stdout = self._new_stdout sys.stdout = self._new_stdout
return self
def get_captured_stdout_output(self) -> str: def get_captured_stdout_output(self) -> str:
self._new_stdout.seek(0) self._new_stdout.seek(0)

View File

@ -2,7 +2,7 @@ import os
import re import re
import uuid import uuid
from datetime import timedelta from datetime import timedelta
from typing import Iterable, Type from typing import Iterable, Set, Type
import flask_restful import flask_restful
from flask import Flask, Response, send_from_directory from flask import Flask, Response, send_from_directory
@ -120,7 +120,7 @@ class FlaskDIWrapper:
def __init__(self, api: flask_restful.Api, container: DIContainer): def __init__(self, api: flask_restful.Api, container: DIContainer):
self._api = api self._api = api
self._container = container self._container = container
self._reserved_urls = set() self._reserved_urls: Set[str] = set()
def add_resource(self, resource: Type[AbstractResource]): def add_resource(self, resource: Type[AbstractResource]):
if len(resource.urls) == 0: if len(resource.urls) == 0:

View File

@ -1,6 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Mapping, Sequence from typing import Mapping, Sequence
from monkey_island.cc.models import Machine
@dataclass
class Arc:
dst_machine: Machine # noqa: F821
status: str
# This is the most concise way to represent a graph: # This is the most concise way to represent a graph:
# Machine id as key, Arch list as a value # Machine id as key, Arch list as a value
@ -9,9 +17,3 @@ from typing import Mapping, Sequence
@dataclass @dataclass
class NetworkMap: class NetworkMap:
nodes: Mapping[str, Sequence[Arc]] # noqa: F821 nodes: Mapping[str, Sequence[Arc]] # noqa: F821
@dataclass
class Arc:
dst_machine: Machine # noqa: F821
status: str

View File

@ -2,8 +2,12 @@ from abc import ABC
from typing import Optional, Sequence from typing import Optional, Sequence
# TODO: Actually define the Log class
class Log:
pass
class ILogRepository(ABC): class ILogRepository(ABC):
# Define log object
def get_logs(self, agent_id: Optional[str] = None) -> Sequence[Log]: # noqa: F821 def get_logs(self, agent_id: Optional[str] = None) -> Sequence[Log]: # noqa: F821
pass pass

View File

@ -1,6 +1,8 @@
from typing import List
import flask_restful import flask_restful
# The purpose of this class is to decouple resources from flask # The purpose of this class is to decouple resources from flask
class AbstractResource(flask_restful.Resource): class AbstractResource(flask_restful.Resource):
urls = [] urls: List[str] = []

View File

@ -1,6 +1,6 @@
import logging import logging
from monkey_island.cc import Version from monkey_island.cc import Version as IslandVersion
from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required from monkey_island.cc.resources.request_authentication import jwt_required
@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class Version(AbstractResource): class Version(AbstractResource):
urls = ["/api/island/version"] urls = ["/api/island/version"]
def __init__(self, version: Version): def __init__(self, version: IslandVersion):
self._version = version self._version = version
@jwt_required @jwt_required

View File

@ -4,7 +4,7 @@ import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Sequence, Tuple from typing import Optional, Sequence, Tuple
import gevent.hub import gevent.hub
import requests import requests
@ -151,7 +151,7 @@ def _start_mongodb(data_dir: Path) -> MongoDbProcess:
return mongo_db_process return mongo_db_process
def _connect_to_mongodb(mongo_db_process: MongoDbProcess): def _connect_to_mongodb(mongo_db_process: Optional[MongoDbProcess]):
try: try:
mongo_setup.connect_to_mongodb(MONGO_CONNECTION_TIMEOUT) mongo_setup.connect_to_mongodb(MONGO_CONNECTION_TIMEOUT)
except mongo_setup.MongoDBTimeOutError as err: except mongo_setup.MongoDBTimeOutError as err:

View File

@ -1,7 +1,7 @@
import os import os
import secrets import secrets
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional
from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes
from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file
@ -12,7 +12,7 @@ from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
_KEY_FILE_NAME = "mongo_key.bin" _KEY_FILE_NAME = "mongo_key.bin"
_encryptor: Union[None, IEncryptor] = None _encryptor: Optional[IEncryptor] = None
# NOTE: This class is being replaced by RepositoryEncryptor # NOTE: This class is being replaced by RepositoryEncryptor
@ -73,5 +73,5 @@ def _initialize_datastore_encryptor(key_file: Path, secret: str):
_encryptor = DataStoreEncryptor(secret, key_file) _encryptor = DataStoreEncryptor(secret, key_file)
def get_datastore_encryptor() -> IEncryptor: def get_datastore_encryptor() -> Optional[IEncryptor]:
return _encryptor return _encryptor

View File

@ -1,10 +1,11 @@
import secrets import secrets
from pathlib import Path from pathlib import Path
from typing import Optional
from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes
from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file
from . import ILockableEncryptor, LockedKeyError, ResetKeyError, UnlockError from . import IEncryptor, ILockableEncryptor, LockedKeyError, ResetKeyError, UnlockError
from .key_based_encryptor import KeyBasedEncryptor from .key_based_encryptor import KeyBasedEncryptor
from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
@ -12,33 +13,32 @@ from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
class RepositoryEncryptor(ILockableEncryptor): class RepositoryEncryptor(ILockableEncryptor):
def __init__(self, key_file: Path): def __init__(self, key_file: Path):
self._key_file = key_file self._key_file = key_file
self._password_based_encryptor = None self._key_based_encryptor: Optional[IEncryptor] = None
self._key_based_encryptor = None
def unlock(self, secret: bytes): def unlock(self, secret: bytes):
try: try:
self._password_based_encryptor = PasswordBasedBytesEncryptor(secret.decode()) encryptor = PasswordBasedBytesEncryptor(secret.decode())
self._key_based_encryptor = self._initialize_key_based_encryptor() self._key_based_encryptor = self._initialize_key_based_encryptor(encryptor)
except Exception as err: except Exception as err:
raise UnlockError(err) raise UnlockError(err)
def _initialize_key_based_encryptor(self): def _initialize_key_based_encryptor(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
if self._key_file.is_file(): if self._key_file.is_file():
return self._load_key() return self._load_key(encryptor)
return self._create_key() return self._create_key(encryptor)
def _load_key(self) -> KeyBasedEncryptor: def _load_key(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
with open(self._key_file, "rb") as f: with open(self._key_file, "rb") as f:
encrypted_key = f.read() encrypted_key = f.read()
plaintext_key = EncryptionKey32Bytes(self._password_based_encryptor.decrypt(encrypted_key)) plaintext_key = EncryptionKey32Bytes(encryptor.decrypt(encrypted_key))
return KeyBasedEncryptor(plaintext_key) return KeyBasedEncryptor(plaintext_key)
def _create_key(self) -> KeyBasedEncryptor: def _create_key(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
plaintext_key = EncryptionKey32Bytes(secrets.token_bytes(32)) plaintext_key = EncryptionKey32Bytes(secrets.token_bytes(32))
encrypted_key = self._password_based_encryptor.encrypt(plaintext_key) encrypted_key = encryptor.encrypt(plaintext_key)
with open_new_securely_permissioned_file(str(self._key_file), "wb") as f: with open_new_securely_permissioned_file(str(self._key_file), "wb") as f:
f.write(encrypted_key) f.write(encrypted_key)
@ -54,7 +54,6 @@ class RepositoryEncryptor(ILockableEncryptor):
except Exception as err: except Exception as err:
raise ResetKeyError(err) raise ResetKeyError(err)
self._password_based_encryptor = None
self._key_based_encryptor = None self._key_based_encryptor = None
def encrypt(self, plaintext: bytes) -> bytes: def encrypt(self, plaintext: bytes) -> bytes:

View File

@ -25,4 +25,8 @@ class T1065(AttackTechnique):
@staticmethod @staticmethod
def get_tunnel_ports() -> Sequence[str]: def get_tunnel_ports() -> Sequence[str]:
telems = Telemetry.objects(telem_category="tunnel", data__proxy__ne=None) telems = Telemetry.objects(telem_category="tunnel", data__proxy__ne=None)
return [address_to_ip_port(telem["data"]["proxy"])[1] for telem in telems] return [
p
for p in (address_to_ip_port(telem["data"]["proxy"])[1] for telem in telems)
if p is not None
]

View File

@ -18,8 +18,8 @@ def get_propagation_stats() -> Dict:
} }
def _get_exploit_counts(exploited: List[MonkeyExploitation]) -> Dict: def _get_exploit_counts(exploited: List[MonkeyExploitation]) -> Dict[str, int]:
exploit_counts = {} exploit_counts: Dict[str, int] = {}
for node in exploited: for node in exploited:
for exploit in node.exploits: for exploit in node.exploits:

View File

@ -16,12 +16,12 @@ class CredExploitProcessor:
if attempt["result"]: if attempt["result"]:
exploit_info.username = attempt["user"] exploit_info.username = attempt["user"]
if attempt["password"]: if attempt["password"]:
exploit_info.credential_type = CredentialType.PASSWORD.value exploit_info.credential_type = CredentialType.PASSWORD
exploit_info.password = attempt["password"] exploit_info.password = attempt["password"]
elif attempt["ssh_key"]: elif attempt["ssh_key"]:
exploit_info.credential_type = CredentialType.KEY.value exploit_info.credential_type = CredentialType.KEY
exploit_info.ssh_key = attempt["ssh_key"] exploit_info.ssh_key = attempt["ssh_key"]
else: else:
exploit_info.credential_type = CredentialType.HASH.value exploit_info.credential_type = CredentialType.HASH
return exploit_info return exploit_info
return exploit_info return exploit_info

View File

@ -1,6 +1,6 @@
import copy import copy
import dateutil from dateutil import parser as dateutil_parser
from monkey_island.cc.models import Monkey from monkey_island.cc.models import Monkey
from monkey_island.cc.server_utils.encryption import get_datastore_encryptor from monkey_island.cc.server_utils.encryption import get_datastore_encryptor
@ -29,10 +29,10 @@ def process_exploit_telemetry(telemetry_json, _):
def update_network_with_exploit(edge: EdgeService, telemetry_json): def update_network_with_exploit(edge: EdgeService, telemetry_json):
telemetry_json["data"]["info"]["started"] = dateutil.parser.parse( telemetry_json["data"]["info"]["started"] = dateutil_parser.parse(
telemetry_json["data"]["info"]["started"] telemetry_json["data"]["info"]["started"]
) )
telemetry_json["data"]["info"]["finished"] = dateutil.parser.parse( telemetry_json["data"]["info"]["finished"] = dateutil_parser.parse(
telemetry_json["data"]["info"]["finished"] telemetry_json["data"]["info"]["finished"]
) )
new_exploit = copy.deepcopy(telemetry_json["data"]) new_exploit = copy.deepcopy(telemetry_json["data"])

View File

@ -67,6 +67,8 @@ def is_segmentation_violation(
return cross_segment_ip is not None return cross_segment_ip is not None
return False
def get_segmentation_violation_event(current_monkey, source_subnet, target_ip, target_subnet): def get_segmentation_violation_event(current_monkey, source_subnet, target_ip, target_subnet):
return Event.create_event( return Event.create_event(

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import Dict, List, Union, cast
from bson import SON from bson import SON
@ -30,21 +30,26 @@ class FindingService:
@staticmethod @staticmethod
def get_all_findings_for_ui() -> List[EnrichedFinding]: def get_all_findings_for_ui() -> List[EnrichedFinding]:
findings = FindingService.get_all_findings_from_db() findings = FindingService.get_all_findings_from_db()
for i in range(len(findings)): enriched_findings: List[EnrichedFinding] = []
details = FindingService._get_finding_details(findings[i]) for finding in findings:
findings[i] = findings[i].to_mongo() finding_data = finding.to_mongo()
findings[i] = FindingService._get_enriched_finding(findings[i]) enriched_finding = FindingService._get_enriched_finding(finding_data)
findings[i].details = details details = FindingService._get_finding_details(finding)
return findings enriched_finding.details = details
enriched_findings.append(enriched_finding)
return enriched_findings
@staticmethod @staticmethod
def _get_enriched_finding(finding: Finding) -> EnrichedFinding: def _get_enriched_finding(finding: SON) -> EnrichedFinding:
test_info = zero_trust_consts.TESTS_MAP[finding["test"]] test_info = zero_trust_consts.TESTS_MAP[finding["test"]]
enriched_finding = EnrichedFinding( enriched_finding = EnrichedFinding(
finding_id=str(finding["_id"]), finding_id=str(finding["_id"]),
test=test_info[zero_trust_consts.FINDING_EXPLANATION_BY_STATUS_KEY][finding["status"]], test=cast(
Dict[str, str], test_info[zero_trust_consts.FINDING_EXPLANATION_BY_STATUS_KEY]
)[finding["status"]],
test_key=finding["test"], test_key=finding["test"],
pillars=test_info[zero_trust_consts.PILLARS_KEY], pillars=cast(List[str], test_info[zero_trust_consts.PILLARS_KEY]),
status=finding["status"], status=finding["status"],
details=None, details=None,
) )

View File

@ -49,7 +49,7 @@ class MongoDbProcess:
self._process.kill() self._process.kill()
def is_running(self) -> bool: def is_running(self) -> bool:
if self._process.poll() is None: if self._process and self._process.poll() is None:
return True return True
return False return False

View File

@ -77,14 +77,14 @@ class MockPuppet(IPuppet):
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})") logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})")
dot_1_results = { dot_1_results = {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, "", ""),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"), 445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"), 3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
} }
dot_3_results = { dot_3_results = {
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"), 22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"), 443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None), 3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
} }
if host == DOT_1: if host == DOT_1:
@ -104,7 +104,7 @@ class MockPuppet(IPuppet):
options: Dict, options: Dict,
) -> FingerprintData: ) -> FingerprintData:
logger.debug(f"fingerprint({name}, {host})") logger.debug(f"fingerprint({name}, {host})")
empty_fingerprint_data = FingerprintData(None, None, {}) empty_fingerprint_data = FingerprintData(None, "", {})
dot_1_results = { dot_1_results = {
"SMBFinger": FingerprintData( "SMBFinger": FingerprintData(
@ -118,7 +118,7 @@ class MockPuppet(IPuppet):
), ),
"HTTPFinger": FingerprintData( "HTTPFinger": FingerprintData(
None, None,
None, "",
{ {
"tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)}, "tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)},
"tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)}, "tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)},
@ -248,4 +248,4 @@ class MockPuppet(IPuppet):
def _get_empty_results(port: int): def _get_empty_results(port: int):
return PortScanData(port, PortStatus.CLOSED, None, None) return PortScanData(port, PortStatus.CLOSED, "", "")

View File

@ -154,7 +154,7 @@ def assert_fingerprint_results_no_3(fingerprint_data):
def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data): def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert address.ip not in {"10.0.0.1", "10.0.0.3"} assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None assert not address.domain
assert ping_scan_data.response_received is False assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6 assert len(port_scan_data.keys()) == 6
@ -178,9 +178,9 @@ def test_scan_single_ip(callback, scan_config, stop):
def test_scan_multiple_ips(callback, scan_config, stop): def test_scan_multiple_ips(callback, scan_config, stop):
addresses = [ addresses = [
NetworkAddress("10.0.0.1", "d1"), NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", "d3"), NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
@ -203,7 +203,7 @@ def test_scan_multiple_ips(callback, scan_config, stop):
@pytest.mark.slow @pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop): def test_scan_lots_of_ips(callback, scan_config, stop):
addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)] addresses = [NetworkAddress(f"10.0.0.{i}", "") for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4) ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(addresses, scan_config, callback, stop) ns.scan(addresses, scan_config, callback, stop)
@ -223,10 +223,10 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback) stoppable_callback = MagicMock(side_effect=_callback)
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(MockPuppet(), num_workers=2) ns = IPScanner(MockPuppet(), num_workers=2)
@ -251,10 +251,10 @@ def test_interrupt_before_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock() puppet.fingerprint = MagicMock()
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)
@ -270,7 +270,7 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
stoppable_fingerprint.barrier.wait() stoppable_fingerprint.barrier.wait()
stop.set() stop.set()
return FingerprintData(None, None, {}) return FingerprintData(None, "", {})
stoppable_fingerprint.barrier = Barrier(2) stoppable_fingerprint.barrier = Barrier(2)
@ -278,10 +278,10 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint) puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
addresses = [ addresses = [
NetworkAddress("10.0.0.1", None), NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", None), NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", None), NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", None), NetworkAddress("10.0.0.4", ""),
] ]
ns = IPScanner(puppet, num_workers=2) ns = IPScanner(puppet, num_workers=2)

View File

@ -35,12 +35,12 @@ def mock_victim_host_factory():
return MockVictimHostFactory() return MockVictimHostFactory()
empty_fingerprint_data = FingerprintData(None, None, {}) empty_fingerprint_data = FingerprintData(None, "", {})
dot_1_scan_results = IPScanResults( dot_1_scan_results = IPScanResults(
PingScanData(True, "windows"), PingScanData(True, "windows"),
{ {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, "", ""),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"), 445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"), 3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
}, },
@ -56,7 +56,7 @@ dot_3_scan_results = IPScanResults(
{ {
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"), 22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"), 443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None), 3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
}, },
{ {
"SSHFinger": FingerprintData( "SSHFinger": FingerprintData(
@ -64,7 +64,7 @@ dot_3_scan_results = IPScanResults(
), ),
"HTTPFinger": FingerprintData( "HTTPFinger": FingerprintData(
None, None,
None, "",
{ {
"tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)}, "tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)},
"tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)}, "tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)},
@ -77,9 +77,9 @@ dot_3_scan_results = IPScanResults(
dead_host_scan_results = IPScanResults( dead_host_scan_results = IPScanResults(
PingScanData(False, None), PingScanData(False, None),
{ {
22: PortScanData(22, PortStatus.CLOSED, None, None), 22: PortScanData(22, PortStatus.CLOSED, "", ""),
443: PortScanData(443, PortStatus.CLOSED, None, None), 443: PortScanData(443, PortStatus.CLOSED, "", ""),
3389: PortScanData(3389, PortStatus.CLOSED, "", None), 3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
}, },
{}, {},
) )

View File

@ -13,7 +13,7 @@ def mock_get_interface_to_target(monkeypatch):
def test_factory_no_tunnel(): def test_factory_no_tunnel():
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="5000", on_island=False) factory = VictimHostFactory(island_ip="192.168.56.1", island_port="5000", on_island=False)
network_address = NetworkAddress("192.168.56.2", None) network_address = NetworkAddress("192.168.56.2", "")
victim = factory.build_victim_host(network_address) victim = factory.build_victim_host(network_address)
@ -49,4 +49,4 @@ def test_factory_no_default_server():
victim = factory.build_victim_host(network_address) victim = factory.build_victim_host(network_address)
assert victim.default_server is None assert not victim.default_server

View File

@ -38,7 +38,7 @@ def test_successful(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
es_service = fingerprint_data.services[ES_SERVICE] es_service = fingerprint_data.services[ES_SERVICE]
@ -60,7 +60,7 @@ def test_fingerprinting_skipped_if_port_closed(monkeypatch, fingerprinter, port_
assert not mock_query_elasticsearch.called assert not mock_query_elasticsearch.called
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -82,5 +82,5 @@ def test_no_response_from_server(monkeypatch, fingerprinter, mock_query_function
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0

View File

@ -63,7 +63,7 @@ def test_fingerprint_only_port_443(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_called_with("https://127.0.0.1:443") mock_get_http_headers.assert_called_with("https://127.0.0.1:443")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"] assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -86,7 +86,7 @@ def test_open_port_no_http_server(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:9200") mock_get_http_headers.assert_any_call("http://127.0.0.1:9200")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -106,7 +106,7 @@ def test_multiple_open_ports(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:8080") mock_get_http_headers.assert_any_call("http://127.0.0.1:8080")
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 2 assert len(fingerprint_data.services.keys()) == 2
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"] assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -126,7 +126,7 @@ def test_server_missing_from_http_headers(mock_get_http_headers, http_fingerprin
assert mock_get_http_headers.call_count == 2 assert mock_get_http_headers.call_count == 2
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-1080"]["data"][0] == "" assert fingerprint_data.services["tcp-1080"]["data"][0] == ""

View File

@ -45,7 +45,7 @@ def test_mssql_fingerprint_successful(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1 assert len(fingerprint_data.services.keys()) == 1
# Each mssql instance is under his name # Each mssql instance is under his name
@ -78,7 +78,7 @@ def test_mssql_no_response_from_server(monkeypatch, fingerprinter, mock_query_fu
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0
@ -98,5 +98,5 @@ def test_mssql_wrong_response_from_server(monkeypatch, fingerprinter):
) )
assert fingerprint_data.os_type is None assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0 assert len(fingerprint_data.services.keys()) == 0

View File

@ -24,7 +24,7 @@ def test_single_subnet():
assert len(scan_targets) == 255 assert len(scan_targets) == 255
for i in range(0, 255): for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
@pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"]) @pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"])
@ -33,8 +33,8 @@ def test_single_ip(single_ip):
scan_targets = compile_ranges_only([single_ip]) scan_targets = compile_ranges_only([single_ip])
assert len(scan_targets) == 1 assert len(scan_targets) == 1
assert NetworkAddress("10.0.0.2", None) in scan_targets assert NetworkAddress("10.0.0.2", "") in scan_targets
assert NetworkAddress("10.0.0.2", None) == scan_targets[0] assert NetworkAddress("10.0.0.2", "") == scan_targets[0]
def test_multiple_subnet(): def test_multiple_subnet():
@ -43,10 +43,10 @@ def test_multiple_subnet():
assert len(scan_targets) == 262 assert len(scan_targets) == 262
for i in range(0, 255): for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
for i in range(8, 15): for i in range(8, 15):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_middle_of_range_subnet(): def test_middle_of_range_subnet():
@ -55,7 +55,7 @@ def test_middle_of_range_subnet():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -68,7 +68,7 @@ def test_ip_range(ip_range):
assert len(scan_targets) == 9 assert len(scan_targets) == 9
for i in range(25, 34): for i in range(25, 34):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_no_duplicates(): def test_no_duplicates():
@ -77,7 +77,7 @@ def test_no_duplicates():
assert len(scan_targets) == 7 assert len(scan_targets) == 7
for i in range(0, 7): for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_blocklisted_ips(): def test_blocklisted_ips():
@ -146,7 +146,7 @@ def test_no_redundant_targets():
) )
assert len(scan_targets) == 2 assert len(scan_targets) == 2
assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets assert NetworkAddress(ip="127.0.0.0", domain="") in scan_targets
assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets
@ -212,7 +212,7 @@ def test_local_subnet_added():
assert len(scan_targets) == 254 assert len(scan_targets) == 254
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
def test_multiple_local_subnets_added(): def test_multiple_local_subnets_added():
@ -232,10 +232,10 @@ def test_multiple_local_subnets_added():
assert len(scan_targets) == 2 * (255 - 1) assert len(scan_targets) == 2 * (255 - 1)
for ip in chain(range(0, 5), range(6, 255)): for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in chain(range(0, 99), range(100, 255)): for ip in chain(range(0, 99), range(100, 255)):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_blocklisted_ips_missing_from_local_subnets(): def test_blocklisted_ips_missing_from_local_subnets():
@ -273,12 +273,12 @@ def test_local_subnets_and_ranges_added():
assert len(scan_targets) == 254 + 3 assert len(scan_targets) == 254 + 3
for ip in range(0, 5): for ip in range(0, 5):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(6, 255): for ip in range(6, 255):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(40, 43): for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_specified_but_disabled(): def test_local_network_interfaces_specified_but_disabled():
@ -295,7 +295,7 @@ def test_local_network_interfaces_specified_but_disabled():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in range(40, 43): for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_subnet_masks(): def test_local_network_interfaces_subnet_masks():
@ -315,7 +315,7 @@ def test_local_network_interfaces_subnet_masks():
assert len(scan_targets) == 4 assert len(scan_targets) == 4
for ip in [108, 110, 145, 146]: for ip in [108, 110, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_targets(): def test_segmentation_targets():
@ -334,7 +334,7 @@ def test_segmentation_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [144, 145, 146]: for ip in [144, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_clash_with_blocked(): def test_segmentation_clash_with_blocked():
@ -377,7 +377,7 @@ def test_segmentation_clash_with_targets():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_one_network(): def test_segmentation_one_network():
@ -443,7 +443,7 @@ def test_invalid_inputs():
assert len(scan_targets) == 3 assert len(scan_targets) == 3
for ip in [148, 149, 150]: for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_invalid_blocklisted_ip(): def test_invalid_blocklisted_ip():

View File

@ -18,7 +18,7 @@ def test_no_ssh_ports_open(ssh_fingerprinter):
} }
results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None) results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None)
assert results == FingerprintData(None, None, {}) assert results == FingerprintData(None, "", {})
def test_no_os(ssh_fingerprinter): def test_no_os(ssh_fingerprinter):
@ -32,7 +32,7 @@ def test_no_os(ssh_fingerprinter):
assert results == FingerprintData( assert results == FingerprintData(
None, None,
None, "",
{ {
"tcp-22": { "tcp-22": {
"display_name": "SSH", "display_name": "SSH",

View File

@ -34,7 +34,7 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
for port in closed_ports: for port in closed_ports:
assert port_scan_data[port].port == port assert port_scan_data[port].port == port
assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None assert not port_scan_data[port].banner
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])

View File

@ -1,11 +1,12 @@
import json import json
from typing import Iterable
import pytest import pytest
from infection_monkey.exploit.sshexec import SSHExploiter from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.i_puppet.i_puppet import ExploiterResultData
from infection_monkey.model.host import VictimHost from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.exploit_telem import ExploitTelem from infection_monkey.telemetry.exploit_telem import ExploitTelem
from monkey.infection_monkey.i_puppet.i_puppet import ExploiterResultData
DOMAIN_NAME = "domain-name" DOMAIN_NAME = "domain-name"
IP = "0.0.0.0" IP = "0.0.0.0"
@ -16,7 +17,7 @@ HOST_AS_DICT = {
"os": {}, "os": {},
"services": {}, "services": {},
"icmp": False, "icmp": False,
"default_server": None, "default_server": "",
} }
EXPLOITER_NAME = "SSHExploiter" EXPLOITER_NAME = "SSHExploiter"
EXPLOITER_INFO = { EXPLOITER_INFO = {
@ -27,7 +28,7 @@ EXPLOITER_INFO = {
"vulnerable_ports": [], "vulnerable_ports": [],
"executed_cmds": [], "executed_cmds": [],
} }
EXPLOITER_ATTEMPTS = [] EXPLOITER_ATTEMPTS: Iterable = []
RESULT = False RESULT = False
OS_LINUX = "linux" OS_LINUX = "linux"
ERROR_MSG = "failed because yolo" ERROR_MSG = "failed because yolo"

View File

@ -1,4 +1,5 @@
import json import json
from typing import Any, Dict
import pytest import pytest
@ -14,9 +15,9 @@ HOST_AS_DICT = {
"os": {}, "os": {},
"services": {}, "services": {},
"icmp": False, "icmp": False,
"default_server": None, "default_server": "",
} }
HOST_SERVICES = {} HOST_SERVICES: Dict[str, Any] = {}
@pytest.fixture @pytest.fixture

View File

@ -18,7 +18,7 @@ SCAN_DATA_MOCK = [
}, },
}, },
"monkey_exe": None, "monkey_exe": None,
"default_server": None, "default_server": "",
}, },
} }
] ]