Compare commits

...

20 Commits

Author SHA1 Message Date
Kekoa Kaaikala 9d37a38994 Agent: Fix FingerprintData mypy issues 2022-09-21 20:38:19 +00:00
Kekoa Kaaikala e40d061091 Agent: Fix PortScanData mypy issues 2022-09-21 20:30:41 +00:00
Kekoa Kaaikala d4f6c83f56 Agent: Fix VictimHost mypy issues 2022-09-21 20:19:20 +00:00
Kekoa Kaaikala d440d51f53 Island: Fix mypy issues in segmentation.py 2022-09-21 18:29:34 +00:00
Kekoa Kaaikala 218b341006 Island: Fix mypy issues in networkmap.py 2022-09-21 18:28:58 +00:00
Kekoa Kaaikala 24848e62df Island: Fix mypy issues in version.py 2022-09-21 18:28:17 +00:00
Kekoa Kaaikala a9e101fd04 Island: Fix mypy issues in T1065.py 2022-09-21 18:25:38 +00:00
Kekoa Kaaikala 1a6f48614e Island: Fix mypy issues in mongo_db_process.py 2022-09-21 18:24:47 +00:00
Kekoa Kaaikala a5b5449f73 Island: Fix mypy issues in finding_service.py 2022-09-21 18:23:33 +00:00
Kekoa Kaaikala 7013963d59 Island: Fix mypy issues in cred_exploit.py 2022-09-21 18:22:20 +00:00
Kekoa Kaaikala ed5773878e Island: Fix mypy issues in ransomware_report.py 2022-09-21 18:21:42 +00:00
Kekoa Kaaikala c870fde3cc Island: Fix mypy issues in exploit.py 2022-09-21 18:21:03 +00:00
Kekoa Kaaikala e595a70019 Island: Fix mypy issues for encryptors 2022-09-21 18:19:58 +00:00
Kekoa Kaaikala e8aa231f92 Island: Fix mypy issues in AbstractResource.py 2022-09-21 18:09:18 +00:00
Kekoa Kaaikala c0b2981150 Island: Fix mypy issues in i_log_repository.py 2022-09-21 18:08:25 +00:00
Kekoa Kaaikala 7801a98a15 Island: Fix mypy issues in server_setup.py 2022-09-21 18:06:53 +00:00
Kekoa Kaaikala 6209ec47cd Island: Fix mypy issues in app.py 2022-09-21 18:06:19 +00:00
Kekoa Kaaikala 795d9fe201 Agent: Fix mypy issues in ransomware_options.py 2022-09-21 17:58:41 +00:00
Kekoa Kaaikala ea67c07d70 Agent: Fix mypy issues in pba.py 2022-09-21 17:57:44 +00:00
Kekoa Kaaikala 70c74b87a9 Agent: Fix mypy issues in capture_output.py 2022-09-21 17:55:41 +00:00
43 changed files with 222 additions and 176 deletions

View File

@ -193,7 +193,7 @@ class SingleIpRange(NetworkRange):
:return: A tuple in format (IP, domain_name). Eg. (192.168.55.1, www.google.com)
"""
# The most common use case is to enter ip/range into "Scan IP/subnet list"
domain_name = None
domain_name = ""
if " " in string_:
raise ValueError(f'"{string_}" is not a valid IP address or domain name.')

View File

@ -1,10 +1,10 @@
import abc
import threading
from collections import namedtuple
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Iterable, List, Mapping, Sequence
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from common import OperatingSystem
from common.credentials import Credentials
from infection_monkey.model import VictimHost
@ -26,15 +26,37 @@ class ExploiterResultData:
propagation_success: bool = False
interrupted: bool = False
os: str = ""
info: Mapping = None
attempts: Iterable = None
info: Mapping = field(default_factory=lambda: {})
attempts: Iterable = field(default_factory=lambda: [])
error_message: str = ""
PingScanData = namedtuple("PingScanData", ["response_received", "os"])
PortScanData = namedtuple("PortScanData", ["port", "status", "banner", "service"])
FingerprintData = namedtuple("FingerprintData", ["os_type", "os_version", "services"])
PostBreachData = namedtuple("PostBreachData", ["display_name", "command", "result"])
@dataclass(frozen=True)
class FingerprintData:
os_type: Optional[OperatingSystem]
os_version: str
services: Mapping = field(default_factory=lambda: {})
@dataclass(frozen=True)
class PingScanData:
response_received: bool
os: Optional[OperatingSystem]
@dataclass(frozen=True)
class PortScanData:
port: int
status: PortStatus
banner: str
service: str
@dataclass(frozen=True)
class PostBreachData:
display_name: str
command: str
result: Tuple[str, bool]
class IPuppet(metaclass=abc.ABCMeta):
@ -84,7 +106,7 @@ class IPuppet(metaclass=abc.ABCMeta):
@abc.abstractmethod
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = 3
) -> Dict[int, PortScanData]:
) -> Mapping[int, PortScanData]:
"""
Scans a list of TCP ports on a remote host
@ -92,7 +114,7 @@ class IPuppet(metaclass=abc.ABCMeta):
:param int ports: List of TCP port numbers to scan
:param float timeout: The maximum amount of time (in seconds) to wait for a response
:return: The data collected by scanning the provided host:ports combination
:rtype: Dict[int, PortScanData]
:rtype: Mapping[int, PortScanData]
"""
@abc.abstractmethod
@ -125,6 +147,7 @@ class IPuppet(metaclass=abc.ABCMeta):
name: str,
host: VictimHost,
current_depth: int,
servers: Sequence[str],
options: Dict,
interrupt: threading.Event,
) -> ExploiterResultData:

View File

@ -2,7 +2,7 @@ import logging
from ipaddress import IPv4Interface
from queue import Queue
from threading import Event
from typing import List, Sequence
from typing import Dict, List, Sequence
from common.agent_configuration import (
ExploitationConfiguration,
@ -154,11 +154,12 @@ class Propagator:
victim_host.services[psd.service] = {}
victim_host.services[psd.service]["display_name"] = "unknown(TCP)"
victim_host.services[psd.service]["port"] = psd.port
if psd.banner is not None:
victim_host.services[psd.service]["banner"] = psd.banner
victim_host.services[psd.service]["banner"] = psd.banner
@staticmethod
def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: FingerprintData):
def _process_fingerprinter_results(
victim_host: VictimHost, fingerprint_data: Dict[str, FingerprintData]
):
for fd in fingerprint_data.values():
# TODO: This logic preserves the existing behavior prior to introducing IMaster and
# IPuppet, but it is possibly flawed. Different fingerprinters may detect
@ -167,7 +168,7 @@ class Propagator:
if fd.os_type is not None:
victim_host.os["type"] = fd.os_type
if ("version" not in victim_host.os) and (fd.os_version is not None):
if ("version" not in victim_host.os) and (fd.os_version):
victim_host.os["version"] = fd.os_version
for service, details in fd.services.items():

View File

@ -10,7 +10,7 @@ class VictimHost(object):
self.os: Dict[str, Any] = {}
self.services: Dict[str, Any] = {}
self.icmp = False
self.default_server = None
self.default_server = ""
def as_dict(self):
return self.__dict__

View File

@ -31,10 +31,10 @@ class ElasticSearchFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData],
_options: Dict,
) -> FingerprintData:
services = {}
services: Dict[str, Any] = {}
if (ES_PORT not in port_scan_data) or (port_scan_data[ES_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
try:
elasticsearch_info = _query_elasticsearch(host)
@ -42,7 +42,7 @@ class ElasticSearchFingerprinter(IFingerprinter):
except Exception as ex:
logger.debug(f"Did not detect an ElasticSearch cluster: {ex}")
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_elasticsearch(host: str) -> Dict[str, Any]:

View File

@ -1,6 +1,6 @@
import logging
from contextlib import closing
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import Any, Dict, Iterable, Mapping, Optional, Set, Tuple
from requests import head
from requests.exceptions import ConnectionError, Timeout
@ -46,7 +46,7 @@ class HTTPFingerprinter(IFingerprinter):
"data": (server_header_contents, ssl),
}
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_potential_http_server(host: str, port: int) -> Tuple[Optional[str], Optional[bool]]:
@ -71,7 +71,7 @@ def _get_server_from_headers(url: str) -> Optional[str]:
return None
def _get_http_headers(url: str) -> Optional[Dict[str, Any]]:
def _get_http_headers(url: str) -> Optional[Mapping[str, Any]]:
try:
logger.debug(f"Sending request for headers to {url}")
with closing(head(url, verify=False, timeout=1)) as response: # noqa: DUO123

View File

@ -1,7 +1,7 @@
import errno
import logging
import socket
from typing import Any, Dict, Optional
from typing import Any, Dict
from infection_monkey.i_puppet import FingerprintData, IFingerprinter, PingScanData, PortScanData
@ -32,10 +32,10 @@ class MSSQLFingerprinter(IFingerprinter):
except Exception as ex:
logger.debug(f"Did not detect an MSSQL server: {ex}")
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _query_mssql_for_instance_data(host: str) -> bytes:
# Create a UDP socket and sets a timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(_MSSQL_SOCKET_TIMEOUT)
@ -44,22 +44,20 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
# The message is a CLNT_UCAST_EX packet to get all instances
# https://msdn.microsoft.com/en-us/library/cc219745.aspx
message = "\x03"
message_str = "\x03"
# Encode the message as a bytes array
message = message.encode()
message = message_str.encode()
# send data and receive response
try:
logger.info(f"Sending message to requested host: {host}, {message}")
logger.info(f"Sending message to requested host: {host}, {message_str}")
sock.sendto(message, server_address)
data, _ = sock.recvfrom(_BUFFER_SIZE)
return data
except socket.timeout as err:
logger.debug(
f"Socket timeout reached, maybe browser service on host: {host} doesnt " "exist"
)
logger.debug(f"Socket timeout reached, maybe browser service on host: {host} doesn't exist")
raise err
except socket.error as err:
if err.errno == errno.ECONNRESET:
@ -78,7 +76,7 @@ def _query_mssql_for_instance_data(host: str) -> Optional[bytes]:
def _get_services_from_server_data(data: bytes) -> Dict[str, Any]:
services = {MSSQL_SERVICE: {}}
services: Dict[str, Any] = {MSSQL_SERVICE: {}}
services[MSSQL_SERVICE]["display_name"] = DISPLAY_NAME
services[MSSQL_SERVICE]["port"] = SQL_BROWSER_DEFAULT_PORT

View File

@ -26,10 +26,10 @@ def compile_scan_target_list(
scan_targets.extend(_get_ips_to_scan_from_local_interface(local_network_interfaces))
if inaccessible_subnets:
inaccessible_subnets = _get_segmentation_check_targets(
inaccessible_subnet_addresses = _get_segmentation_check_targets(
inaccessible_subnets, local_network_interfaces
)
scan_targets.extend(inaccessible_subnets)
scan_targets.extend(inaccessible_subnet_addresses)
scan_targets = _remove_interface_ips(scan_targets, local_network_interfaces)
scan_targets = _remove_blocklisted_ips(scan_targets, blocklisted_ips)
@ -44,7 +44,7 @@ def _remove_redundant_targets(targets: List[NetworkAddress]) -> List[NetworkAddr
for target in targets:
domain_name = target.domain
ip = target.ip
if ip not in reverse_dns or (reverse_dns[ip] is None and domain_name is not None):
if ip not in reverse_dns or (not reverse_dns[ip] and domain_name):
reverse_dns[ip] = domain_name
return [NetworkAddress(key, value) for (key, value) in reverse_dns.items()]
@ -55,7 +55,7 @@ def _range_to_addresses(range_obj: NetworkRange) -> List[NetworkAddress]:
if hasattr(range_obj, "domain_name"):
addresses.append(NetworkAddress(address, range_obj.domain_name))
else:
addresses.append(NetworkAddress(address, None))
addresses.append(NetworkAddress(address, ""))
return addresses

View File

@ -144,16 +144,16 @@ class SMBFingerprinter(IFingerprinter):
port_scan_data: Dict[int, PortScanData],
_options: Dict,
) -> FingerprintData:
services = {}
services: Dict = {}
smb_service = {
"display_name": DISPLAY_NAME,
"port": SMB_PORT,
}
os_type = None
os_version = None
os_version = ""
if (SMB_PORT not in port_scan_data) or (port_scan_data[SMB_PORT].status != PortStatus.OPEN):
return FingerprintData(None, None, services)
return FingerprintData(None, "", services)
logger.debug(f"Fingerprinting potential SMB port {SMB_PORT} on {host}")

View File

@ -21,7 +21,7 @@ class SSHFingerprinter(IFingerprinter):
_options: Dict,
) -> FingerprintData:
os_type = None
os_version = None
os_version = ""
services = {}
for ps_data in port_scan_data.values():
@ -35,9 +35,9 @@ class SSHFingerprinter(IFingerprinter):
return FingerprintData(os_type, os_version, services)
@staticmethod
def _get_host_os(banner) -> Tuple[Optional[str], Optional[str]]:
def _get_host_os(banner) -> Tuple[Optional[OperatingSystem], str]:
os = None
os_version = None
os_version = ""
for dist in LINUX_DIST_SSH:
if banner.lower().find(dist) != -1:
os_version = banner.split(" ").pop().strip()

View File

@ -12,7 +12,7 @@ from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_por
logger = logging.getLogger(__name__)
POLL_INTERVAL = 0.5
EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)}
EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, "", "")}
def scan_tcp_ports(
@ -48,7 +48,7 @@ def _build_port_scan_data(
def _get_closed_port_data(port: int) -> PortScanData:
return PortScanData(port, PortStatus.CLOSED, None, None)
return PortScanData(port, PortStatus.CLOSED, "", "")
def _check_tcp_ports(

View File

@ -1,4 +1,6 @@
import logging
from pathlib import Path
from typing import Optional
from common.utils.file_utils import InvalidPath, expand_path
from infection_monkey.utils.environment import is_windows_os
@ -12,7 +14,7 @@ class RansomwareOptions:
self.file_extension = options["encryption"]["file_extension"]
self.readme_enabled = options["other_behaviors"]["readme"]
self.target_directory = None
self.target_directory: Optional[Path] = None
self._set_target_directory(options["encryption"]["directories"])
def _set_target_directory(self, os_target_directories: dict):

View File

@ -1,6 +1,6 @@
import logging
import subprocess
from typing import Dict, Iterable
from typing import Dict, Iterable, List, Tuple
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.utils.attack_utils import ScanStatus
@ -33,7 +33,7 @@ class PBA:
"""
self.command = PBA.choose_command(linux_cmd, windows_cmd)
self.name = name
self.pba_data = []
self.pba_data: List[PostBreachData] = []
self.telemetry_messenger = telemetry_messenger
self.timeout = timeout
@ -73,7 +73,7 @@ class PBA:
pba_execution_succeeded = pba_execution_result[1]
return pba_execution_succeeded and self.is_script()
def _execute_default(self):
def _execute_default(self) -> Tuple[str, bool]:
"""
Default post breach command execution routine
:return: Tuple of command's output string and boolean, indicating if it succeeded
@ -84,7 +84,7 @@ class PBA:
).decode()
return output, True
except subprocess.CalledProcessError as err:
return err.output.decode(), False
return bytes(err.output).decode(), False
except subprocess.TimeoutExpired as err:
return str(err), False

View File

@ -1,6 +1,6 @@
import logging
import threading
from typing import Dict, Iterable, List, Sequence
from typing import Dict, Iterable, List, Mapping, Sequence
from common.common_consts.timeouts import CONNECTION_TIMEOUT
from common.credentials import Credentials
@ -8,6 +8,7 @@ from infection_monkey import network_scanning
from infection_monkey.i_puppet import (
ExploiterResultData,
FingerprintData,
IFingerprinter,
IPuppet,
PingScanData,
PluginType,
@ -18,7 +19,7 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None)
EMPTY_FINGERPRINT = FingerprintData(None, "", {})
logger = logging.getLogger()
@ -45,7 +46,7 @@ class Puppet(IPuppet):
def scan_tcp_ports(
self, host: str, ports: List[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]:
) -> Mapping[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout)
def fingerprint(
@ -57,7 +58,9 @@ class Puppet(IPuppet):
options: Dict,
) -> FingerprintData:
try:
fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER)
fingerprinter: IFingerprinter = self._plugin_registry.get_plugin(
name, PluginType.FINGERPRINTER
)
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options)
except Exception:
logger.exception(

View File

@ -7,7 +7,6 @@ class StdoutCapture:
self._orig_stdout = sys.stdout
self._new_stdout = io.StringIO()
sys.stdout = self._new_stdout
return self
def get_captured_stdout_output(self) -> str:
self._new_stdout.seek(0)

View File

@ -2,7 +2,7 @@ import os
import re
import uuid
from datetime import timedelta
from typing import Iterable, Type
from typing import Iterable, Set, Type
import flask_restful
from flask import Flask, Response, send_from_directory
@ -120,7 +120,7 @@ class FlaskDIWrapper:
def __init__(self, api: flask_restful.Api, container: DIContainer):
self._api = api
self._container = container
self._reserved_urls = set()
self._reserved_urls: Set[str] = set()
def add_resource(self, resource: Type[AbstractResource]):
if len(resource.urls) == 0:

View File

@ -1,6 +1,14 @@
from dataclasses import dataclass
from typing import Mapping, Sequence
from monkey_island.cc.models import Machine
@dataclass
class Arc:
dst_machine: Machine # noqa: F821
status: str
# This is the most concise way to represent a graph:
# Machine id as key, Arch list as a value
@ -9,9 +17,3 @@ from typing import Mapping, Sequence
@dataclass
class NetworkMap:
nodes: Mapping[str, Sequence[Arc]] # noqa: F821
@dataclass
class Arc:
dst_machine: Machine # noqa: F821
status: str

View File

@ -2,8 +2,12 @@ from abc import ABC
from typing import Optional, Sequence
# TODO: Actually define the Log class
class Log:
pass
class ILogRepository(ABC):
# Define log object
def get_logs(self, agent_id: Optional[str] = None) -> Sequence[Log]: # noqa: F821
pass

View File

@ -1,6 +1,8 @@
from typing import List
import flask_restful
# The purpose of this class is to decouple resources from flask
class AbstractResource(flask_restful.Resource):
urls = []
urls: List[str] = []

View File

@ -1,6 +1,6 @@
import logging
from monkey_island.cc import Version
from monkey_island.cc import Version as IslandVersion
from monkey_island.cc.resources.AbstractResource import AbstractResource
from monkey_island.cc.resources.request_authentication import jwt_required
@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
class Version(AbstractResource):
urls = ["/api/island/version"]
def __init__(self, version: Version):
def __init__(self, version: IslandVersion):
self._version = version
@jwt_required

View File

@ -4,7 +4,7 @@ import logging
import sys
from pathlib import Path
from threading import Thread
from typing import Sequence, Tuple
from typing import Optional, Sequence, Tuple
import gevent.hub
import requests
@ -151,7 +151,7 @@ def _start_mongodb(data_dir: Path) -> MongoDbProcess:
return mongo_db_process
def _connect_to_mongodb(mongo_db_process: MongoDbProcess):
def _connect_to_mongodb(mongo_db_process: Optional[MongoDbProcess]):
try:
mongo_setup.connect_to_mongodb(MONGO_CONNECTION_TIMEOUT)
except mongo_setup.MongoDBTimeOutError as err:

View File

@ -1,7 +1,7 @@
import os
import secrets
from pathlib import Path
from typing import Union
from typing import Optional
from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes
from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file
@ -12,7 +12,7 @@ from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
_KEY_FILE_NAME = "mongo_key.bin"
_encryptor: Union[None, IEncryptor] = None
_encryptor: Optional[IEncryptor] = None
# NOTE: This class is being replaced by RepositoryEncryptor
@ -73,5 +73,5 @@ def _initialize_datastore_encryptor(key_file: Path, secret: str):
_encryptor = DataStoreEncryptor(secret, key_file)
def get_datastore_encryptor() -> IEncryptor:
def get_datastore_encryptor() -> Optional[IEncryptor]:
return _encryptor

View File

@ -1,10 +1,11 @@
import secrets
from pathlib import Path
from typing import Optional
from monkey_island.cc.server_utils.encryption.encryption_key_types import EncryptionKey32Bytes
from monkey_island.cc.server_utils.file_utils import open_new_securely_permissioned_file
from . import ILockableEncryptor, LockedKeyError, ResetKeyError, UnlockError
from . import IEncryptor, ILockableEncryptor, LockedKeyError, ResetKeyError, UnlockError
from .key_based_encryptor import KeyBasedEncryptor
from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
@ -12,33 +13,32 @@ from .password_based_bytes_encryptor import PasswordBasedBytesEncryptor
class RepositoryEncryptor(ILockableEncryptor):
def __init__(self, key_file: Path):
self._key_file = key_file
self._password_based_encryptor = None
self._key_based_encryptor = None
self._key_based_encryptor: Optional[IEncryptor] = None
def unlock(self, secret: bytes):
try:
self._password_based_encryptor = PasswordBasedBytesEncryptor(secret.decode())
self._key_based_encryptor = self._initialize_key_based_encryptor()
encryptor = PasswordBasedBytesEncryptor(secret.decode())
self._key_based_encryptor = self._initialize_key_based_encryptor(encryptor)
except Exception as err:
raise UnlockError(err)
def _initialize_key_based_encryptor(self):
def _initialize_key_based_encryptor(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
if self._key_file.is_file():
return self._load_key()
return self._load_key(encryptor)
return self._create_key()
return self._create_key(encryptor)
def _load_key(self) -> KeyBasedEncryptor:
def _load_key(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
with open(self._key_file, "rb") as f:
encrypted_key = f.read()
plaintext_key = EncryptionKey32Bytes(self._password_based_encryptor.decrypt(encrypted_key))
plaintext_key = EncryptionKey32Bytes(encryptor.decrypt(encrypted_key))
return KeyBasedEncryptor(plaintext_key)
def _create_key(self) -> KeyBasedEncryptor:
def _create_key(self, encryptor: IEncryptor) -> KeyBasedEncryptor:
plaintext_key = EncryptionKey32Bytes(secrets.token_bytes(32))
encrypted_key = self._password_based_encryptor.encrypt(plaintext_key)
encrypted_key = encryptor.encrypt(plaintext_key)
with open_new_securely_permissioned_file(str(self._key_file), "wb") as f:
f.write(encrypted_key)
@ -54,7 +54,6 @@ class RepositoryEncryptor(ILockableEncryptor):
except Exception as err:
raise ResetKeyError(err)
self._password_based_encryptor = None
self._key_based_encryptor = None
def encrypt(self, plaintext: bytes) -> bytes:

View File

@ -25,4 +25,8 @@ class T1065(AttackTechnique):
@staticmethod
def get_tunnel_ports() -> Sequence[str]:
telems = Telemetry.objects(telem_category="tunnel", data__proxy__ne=None)
return [address_to_ip_port(telem["data"]["proxy"])[1] for telem in telems]
return [
p
for p in (address_to_ip_port(telem["data"]["proxy"])[1] for telem in telems)
if p is not None
]

View File

@ -18,8 +18,8 @@ def get_propagation_stats() -> Dict:
}
def _get_exploit_counts(exploited: List[MonkeyExploitation]) -> Dict:
exploit_counts = {}
def _get_exploit_counts(exploited: List[MonkeyExploitation]) -> Dict[str, int]:
exploit_counts: Dict[str, int] = {}
for node in exploited:
for exploit in node.exploits:

View File

@ -16,12 +16,12 @@ class CredExploitProcessor:
if attempt["result"]:
exploit_info.username = attempt["user"]
if attempt["password"]:
exploit_info.credential_type = CredentialType.PASSWORD.value
exploit_info.credential_type = CredentialType.PASSWORD
exploit_info.password = attempt["password"]
elif attempt["ssh_key"]:
exploit_info.credential_type = CredentialType.KEY.value
exploit_info.credential_type = CredentialType.KEY
exploit_info.ssh_key = attempt["ssh_key"]
else:
exploit_info.credential_type = CredentialType.HASH.value
exploit_info.credential_type = CredentialType.HASH
return exploit_info
return exploit_info

View File

@ -1,6 +1,6 @@
import copy
import dateutil
from dateutil import parser as dateutil_parser
from monkey_island.cc.models import Monkey
from monkey_island.cc.server_utils.encryption import get_datastore_encryptor
@ -29,10 +29,10 @@ def process_exploit_telemetry(telemetry_json, _):
def update_network_with_exploit(edge: EdgeService, telemetry_json):
telemetry_json["data"]["info"]["started"] = dateutil.parser.parse(
telemetry_json["data"]["info"]["started"] = dateutil_parser.parse(
telemetry_json["data"]["info"]["started"]
)
telemetry_json["data"]["info"]["finished"] = dateutil.parser.parse(
telemetry_json["data"]["info"]["finished"] = dateutil_parser.parse(
telemetry_json["data"]["info"]["finished"]
)
new_exploit = copy.deepcopy(telemetry_json["data"])

View File

@ -67,6 +67,8 @@ def is_segmentation_violation(
return cross_segment_ip is not None
return False
def get_segmentation_violation_event(current_monkey, source_subnet, target_ip, target_subnet):
return Event.create_event(

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import Dict, List, Union, cast
from bson import SON
@ -30,21 +30,26 @@ class FindingService:
@staticmethod
def get_all_findings_for_ui() -> List[EnrichedFinding]:
findings = FindingService.get_all_findings_from_db()
for i in range(len(findings)):
details = FindingService._get_finding_details(findings[i])
findings[i] = findings[i].to_mongo()
findings[i] = FindingService._get_enriched_finding(findings[i])
findings[i].details = details
return findings
enriched_findings: List[EnrichedFinding] = []
for finding in findings:
finding_data = finding.to_mongo()
enriched_finding = FindingService._get_enriched_finding(finding_data)
details = FindingService._get_finding_details(finding)
enriched_finding.details = details
enriched_findings.append(enriched_finding)
return enriched_findings
@staticmethod
def _get_enriched_finding(finding: Finding) -> EnrichedFinding:
def _get_enriched_finding(finding: SON) -> EnrichedFinding:
test_info = zero_trust_consts.TESTS_MAP[finding["test"]]
enriched_finding = EnrichedFinding(
finding_id=str(finding["_id"]),
test=test_info[zero_trust_consts.FINDING_EXPLANATION_BY_STATUS_KEY][finding["status"]],
test=cast(
Dict[str, str], test_info[zero_trust_consts.FINDING_EXPLANATION_BY_STATUS_KEY]
)[finding["status"]],
test_key=finding["test"],
pillars=test_info[zero_trust_consts.PILLARS_KEY],
pillars=cast(List[str], test_info[zero_trust_consts.PILLARS_KEY]),
status=finding["status"],
details=None,
)

View File

@ -49,7 +49,7 @@ class MongoDbProcess:
self._process.kill()
def is_running(self) -> bool:
if self._process.poll() is None:
if self._process and self._process.poll() is None:
return True
return False

View File

@ -77,14 +77,14 @@ class MockPuppet(IPuppet):
) -> Dict[int, PortScanData]:
logger.debug(f"run_scan_tcp_port({host}, {ports}, {timeout})")
dot_1_results = {
22: PortScanData(22, PortStatus.CLOSED, None, None),
22: PortScanData(22, PortStatus.CLOSED, "", ""),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
}
dot_3_results = {
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
}
if host == DOT_1:
@ -104,7 +104,7 @@ class MockPuppet(IPuppet):
options: Dict,
) -> FingerprintData:
logger.debug(f"fingerprint({name}, {host})")
empty_fingerprint_data = FingerprintData(None, None, {})
empty_fingerprint_data = FingerprintData(None, "", {})
dot_1_results = {
"SMBFinger": FingerprintData(
@ -118,7 +118,7 @@ class MockPuppet(IPuppet):
),
"HTTPFinger": FingerprintData(
None,
None,
"",
{
"tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)},
"tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)},
@ -248,4 +248,4 @@ class MockPuppet(IPuppet):
def _get_empty_results(port: int):
return PortScanData(port, PortStatus.CLOSED, None, None)
return PortScanData(port, PortStatus.CLOSED, "", "")

View File

@ -154,7 +154,7 @@ def assert_fingerprint_results_no_3(fingerprint_data):
def assert_scan_results_host_down(address, ping_scan_data, port_scan_data, fingerprint_data):
assert address.ip not in {"10.0.0.1", "10.0.0.3"}
assert address.domain is None
assert not address.domain
assert ping_scan_data.response_received is False
assert len(port_scan_data.keys()) == 6
@ -178,9 +178,9 @@ def test_scan_single_ip(callback, scan_config, stop):
def test_scan_multiple_ips(callback, scan_config, stop):
addresses = [
NetworkAddress("10.0.0.1", "d1"),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", "d3"),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(MockPuppet(), num_workers=4)
@ -203,7 +203,7 @@ def test_scan_multiple_ips(callback, scan_config, stop):
@pytest.mark.slow
def test_scan_lots_of_ips(callback, scan_config, stop):
addresses = [NetworkAddress(f"10.0.0.{i}", None) for i in range(0, 255)]
addresses = [NetworkAddress(f"10.0.0.{i}", "") for i in range(0, 255)]
ns = IPScanner(MockPuppet(), num_workers=4)
ns.scan(addresses, scan_config, callback, stop)
@ -223,10 +223,10 @@ def test_stop_after_callback(scan_config, stop):
stoppable_callback = MagicMock(side_effect=_callback)
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(MockPuppet(), num_workers=2)
@ -251,10 +251,10 @@ def test_interrupt_before_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock()
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(puppet, num_workers=2)
@ -270,7 +270,7 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
stoppable_fingerprint.barrier.wait()
stop.set()
return FingerprintData(None, None, {})
return FingerprintData(None, "", {})
stoppable_fingerprint.barrier = Barrier(2)
@ -278,10 +278,10 @@ def test_interrupt_fingerprinting(callback, scan_config, stop):
puppet.fingerprint = MagicMock(side_effect=stoppable_fingerprint)
addresses = [
NetworkAddress("10.0.0.1", None),
NetworkAddress("10.0.0.2", None),
NetworkAddress("10.0.0.3", None),
NetworkAddress("10.0.0.4", None),
NetworkAddress("10.0.0.1", ""),
NetworkAddress("10.0.0.2", ""),
NetworkAddress("10.0.0.3", ""),
NetworkAddress("10.0.0.4", ""),
]
ns = IPScanner(puppet, num_workers=2)

View File

@ -35,12 +35,12 @@ def mock_victim_host_factory():
return MockVictimHostFactory()
empty_fingerprint_data = FingerprintData(None, None, {})
empty_fingerprint_data = FingerprintData(None, "", {})
dot_1_scan_results = IPScanResults(
PingScanData(True, "windows"),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
22: PortScanData(22, PortStatus.CLOSED, "", ""),
445: PortScanData(445, PortStatus.OPEN, "SMB BANNER", "tcp-445"),
3389: PortScanData(3389, PortStatus.OPEN, "", "tcp-3389"),
},
@ -56,7 +56,7 @@ dot_3_scan_results = IPScanResults(
{
22: PortScanData(22, PortStatus.OPEN, "SSH BANNER", "tcp-22"),
443: PortScanData(443, PortStatus.OPEN, "HTTPS BANNER", "tcp-443"),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
},
{
"SSHFinger": FingerprintData(
@ -64,7 +64,7 @@ dot_3_scan_results = IPScanResults(
),
"HTTPFinger": FingerprintData(
None,
None,
"",
{
"tcp-80": {"name": "http", "data": ("SERVER_HEADERS", False)},
"tcp-443": {"name": "http", "data": ("SERVER_HEADERS_2", True)},
@ -77,9 +77,9 @@ dot_3_scan_results = IPScanResults(
dead_host_scan_results = IPScanResults(
PingScanData(False, None),
{
22: PortScanData(22, PortStatus.CLOSED, None, None),
443: PortScanData(443, PortStatus.CLOSED, None, None),
3389: PortScanData(3389, PortStatus.CLOSED, "", None),
22: PortScanData(22, PortStatus.CLOSED, "", ""),
443: PortScanData(443, PortStatus.CLOSED, "", ""),
3389: PortScanData(3389, PortStatus.CLOSED, "", ""),
},
{},
)

View File

@ -13,7 +13,7 @@ def mock_get_interface_to_target(monkeypatch):
def test_factory_no_tunnel():
factory = VictimHostFactory(island_ip="192.168.56.1", island_port="5000", on_island=False)
network_address = NetworkAddress("192.168.56.2", None)
network_address = NetworkAddress("192.168.56.2", "")
victim = factory.build_victim_host(network_address)
@ -49,4 +49,4 @@ def test_factory_no_default_server():
victim = factory.build_victim_host(network_address)
assert victim.default_server is None
assert not victim.default_server

View File

@ -38,7 +38,7 @@ def test_successful(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
es_service = fingerprint_data.services[ES_SERVICE]
@ -60,7 +60,7 @@ def test_fingerprinting_skipped_if_port_closed(monkeypatch, fingerprinter, port_
assert not mock_query_elasticsearch.called
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -82,5 +82,5 @@ def test_no_response_from_server(monkeypatch, fingerprinter, mock_query_function
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0

View File

@ -63,7 +63,7 @@ def test_fingerprint_only_port_443(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_called_with("https://127.0.0.1:443")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -86,7 +86,7 @@ def test_open_port_no_http_server(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:9200")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -106,7 +106,7 @@ def test_multiple_open_ports(mock_get_http_headers, http_fingerprinter):
mock_get_http_headers.assert_any_call("http://127.0.0.1:8080")
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 2
assert fingerprint_data.services["tcp-443"]["data"][0] == PYTHON_SERVER_HEADER["Server"]
@ -126,7 +126,7 @@ def test_server_missing_from_http_headers(mock_get_http_headers, http_fingerprin
assert mock_get_http_headers.call_count == 2
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
assert fingerprint_data.services["tcp-1080"]["data"][0] == ""

View File

@ -45,7 +45,7 @@ def test_mssql_fingerprint_successful(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 1
# Each mssql instance is under his name
@ -78,7 +78,7 @@ def test_mssql_no_response_from_server(monkeypatch, fingerprinter, mock_query_fu
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0
@ -98,5 +98,5 @@ def test_mssql_wrong_response_from_server(monkeypatch, fingerprinter):
)
assert fingerprint_data.os_type is None
assert fingerprint_data.os_version is None
assert not fingerprint_data.os_version
assert len(fingerprint_data.services.keys()) == 0

View File

@ -24,7 +24,7 @@ def test_single_subnet():
assert len(scan_targets) == 255
for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
@pytest.mark.parametrize("single_ip", ["10.0.0.2", "10.0.0.2/32", "10.0.0.2-10.0.0.2"])
@ -33,8 +33,8 @@ def test_single_ip(single_ip):
scan_targets = compile_ranges_only([single_ip])
assert len(scan_targets) == 1
assert NetworkAddress("10.0.0.2", None) in scan_targets
assert NetworkAddress("10.0.0.2", None) == scan_targets[0]
assert NetworkAddress("10.0.0.2", "") in scan_targets
assert NetworkAddress("10.0.0.2", "") == scan_targets[0]
def test_multiple_subnet():
@ -43,10 +43,10 @@ def test_multiple_subnet():
assert len(scan_targets) == 262
for i in range(0, 255):
assert NetworkAddress(f"10.0.0.{i}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{i}", "") in scan_targets
for i in range(8, 15):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_middle_of_range_subnet():
@ -55,7 +55,7 @@ def test_middle_of_range_subnet():
assert len(scan_targets) == 7
for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
@pytest.mark.parametrize(
@ -68,7 +68,7 @@ def test_ip_range(ip_range):
assert len(scan_targets) == 9
for i in range(25, 34):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_no_duplicates():
@ -77,7 +77,7 @@ def test_no_duplicates():
assert len(scan_targets) == 7
for i in range(0, 7):
assert NetworkAddress(f"192.168.56.{i}", None) in scan_targets
assert NetworkAddress(f"192.168.56.{i}", "") in scan_targets
def test_blocklisted_ips():
@ -146,7 +146,7 @@ def test_no_redundant_targets():
)
assert len(scan_targets) == 2
assert NetworkAddress(ip="127.0.0.0", domain=None) in scan_targets
assert NetworkAddress(ip="127.0.0.0", domain="") in scan_targets
assert NetworkAddress(ip="127.0.0.1", domain="localhost") in scan_targets
@ -212,7 +212,7 @@ def test_local_subnet_added():
assert len(scan_targets) == 254
for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
def test_multiple_local_subnets_added():
@ -232,10 +232,10 @@ def test_multiple_local_subnets_added():
assert len(scan_targets) == 2 * (255 - 1)
for ip in chain(range(0, 5), range(6, 255)):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in chain(range(0, 99), range(100, 255)):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_blocklisted_ips_missing_from_local_subnets():
@ -273,12 +273,12 @@ def test_local_subnets_and_ranges_added():
assert len(scan_targets) == 254 + 3
for ip in range(0, 5):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(6, 255):
assert NetworkAddress(f"10.0.0.{ip}", None) in scan_targets
assert NetworkAddress(f"10.0.0.{ip}", "") in scan_targets
for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_specified_but_disabled():
@ -295,7 +295,7 @@ def test_local_network_interfaces_specified_but_disabled():
assert len(scan_targets) == 3
for ip in range(40, 43):
assert NetworkAddress(f"172.33.66.{ip}", None) in scan_targets
assert NetworkAddress(f"172.33.66.{ip}", "") in scan_targets
def test_local_network_interfaces_subnet_masks():
@ -315,7 +315,7 @@ def test_local_network_interfaces_subnet_masks():
assert len(scan_targets) == 4
for ip in [108, 110, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_targets():
@ -334,7 +334,7 @@ def test_segmentation_targets():
assert len(scan_targets) == 3
for ip in [144, 145, 146]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_clash_with_blocked():
@ -377,7 +377,7 @@ def test_segmentation_clash_with_targets():
assert len(scan_targets) == 3
for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_segmentation_one_network():
@ -443,7 +443,7 @@ def test_invalid_inputs():
assert len(scan_targets) == 3
for ip in [148, 149, 150]:
assert NetworkAddress(f"172.60.145.{ip}", None) in scan_targets
assert NetworkAddress(f"172.60.145.{ip}", "") in scan_targets
def test_invalid_blocklisted_ip():

View File

@ -18,7 +18,7 @@ def test_no_ssh_ports_open(ssh_fingerprinter):
}
results = ssh_fingerprinter.get_host_fingerprint("127.0.0.1", None, port_scan_data, None)
assert results == FingerprintData(None, None, {})
assert results == FingerprintData(None, "", {})
def test_no_os(ssh_fingerprinter):
@ -32,7 +32,7 @@ def test_no_os(ssh_fingerprinter):
assert results == FingerprintData(
None,
None,
"",
{
"tcp-22": {
"display_name": "SSH",

View File

@ -34,7 +34,7 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
for port in closed_ports:
assert port_scan_data[port].port == port
assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None
assert not port_scan_data[port].banner
@pytest.mark.parametrize("open_ports_data", [{}])

View File

@ -1,11 +1,12 @@
import json
from typing import Iterable
import pytest
from infection_monkey.exploit.sshexec import SSHExploiter
from infection_monkey.i_puppet.i_puppet import ExploiterResultData
from infection_monkey.model.host import VictimHost
from infection_monkey.telemetry.exploit_telem import ExploitTelem
from monkey.infection_monkey.i_puppet.i_puppet import ExploiterResultData
DOMAIN_NAME = "domain-name"
IP = "0.0.0.0"
@ -16,7 +17,7 @@ HOST_AS_DICT = {
"os": {},
"services": {},
"icmp": False,
"default_server": None,
"default_server": "",
}
EXPLOITER_NAME = "SSHExploiter"
EXPLOITER_INFO = {
@ -27,7 +28,7 @@ EXPLOITER_INFO = {
"vulnerable_ports": [],
"executed_cmds": [],
}
EXPLOITER_ATTEMPTS = []
EXPLOITER_ATTEMPTS: Iterable = []
RESULT = False
OS_LINUX = "linux"
ERROR_MSG = "failed because yolo"

View File

@ -1,4 +1,5 @@
import json
from typing import Any, Dict
import pytest
@ -14,9 +15,9 @@ HOST_AS_DICT = {
"os": {},
"services": {},
"icmp": False,
"default_server": None,
"default_server": "",
}
HOST_SERVICES = {}
HOST_SERVICES: Dict[str, Any] = {}
@pytest.fixture

View File

@ -18,7 +18,7 @@ SCAN_DATA_MOCK = [
},
},
"monkey_exe": None,
"default_server": None,
"default_server": "",
},
}
]