Merge pull request #1859 from guardicore/1826-catch-exceptions-todos

Resolve catching exceptions TODOs
This commit is contained in:
VakarisZ 2022-04-08 13:40:31 +03:00 committed by GitHub
commit f9a6d13f3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 10 deletions

View File

@ -225,8 +225,13 @@ class AutomatedMaster(IMaster):
interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s" interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s"
for p in interruptible_iter(plugins, self._stop, interrupted_message): for p in interruptible_iter(plugins, self._stop, interrupted_message):
# TODO: Catch exceptions to prevent thread from crashing try:
callback(p) callback(p)
except Exception:
logger.exception(
f"Got unhandled exception when running {plugin_type} plugin {p}. "
f"Plugin was passed to {callback}"
)
logger.info(f"Finished running {plugin_type}s") logger.info(f"Finished running {plugin_type}s")

View File

@ -61,7 +61,6 @@ class IPScanner:
address = addresses.get_nowait() address = addresses.get_nowait()
logger.info(f"Scanning {address.ip}") logger.info(f"Scanning {address.ip}")
# TODO: Catch exceptions to prevent thread from crashing
ping_scan_data = self._puppet.ping(address.ip, icmp_timeout) ping_scan_data = self._puppet.ping(address.ip, icmp_timeout)
port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout) port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout)

View File

@ -6,16 +6,26 @@ import subprocess
import sys import sys
from infection_monkey.i_puppet import PingScanData from infection_monkey.i_puppet import PingScanData
from infection_monkey.utils.environment import is_windows_os
TTL_REGEX = re.compile(r"TTL=([0-9]+)\b", re.IGNORECASE) TTL_REGEX = re.compile(r"TTL=([0-9]+)\b", re.IGNORECASE)
LINUX_TTL = 64 # Windows TTL is 128 LINUX_TTL = 64 # Windows TTL is 128
PING_EXIT_TIMEOUT = 10 PING_EXIT_TIMEOUT = 10
EMPTY_PING_SCAN = PingScanData(False, None)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def ping(host: str, timeout: float) -> PingScanData: def ping(host: str, timeout: float) -> PingScanData:
if "win32" == sys.platform: try:
return _ping(host, timeout)
except Exception:
logger.exception("Unhandled exception occurred while running ping")
return EMPTY_PING_SCAN
def _ping(host: str, timeout: float) -> PingScanData:
if is_windows_os():
timeout = math.floor(timeout * 1000) timeout = math.floor(timeout * 1000)
ping_command_output = _run_ping_command(host, timeout) ping_command_output = _run_ping_command(host, timeout)

View File

@ -11,11 +11,20 @@ from infection_monkey.utils.timer import Timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POLL_INTERVAL = 0.5 POLL_INTERVAL = 0.5
EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)}
def scan_tcp_ports( def scan_tcp_ports(
host: str, ports_to_scan: Iterable[int], timeout: float host: str, ports_to_scan: Iterable[int], timeout: float
) -> Mapping[int, PortScanData]: ) -> Mapping[int, PortScanData]:
try:
return _scan_tcp_ports(host, ports_to_scan, timeout)
except Exception:
logger.exception("Unhandled exception occurred while trying to scan tcp ports")
return EMPTY_PORT_SCAN
def _scan_tcp_ports(host: str, ports_to_scan: Iterable[int], timeout: float):
open_ports = _check_tcp_ports(host, ports_to_scan, timeout) open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
return _build_port_scan_data(ports_to_scan, open_ports) return _build_port_scan_data(ports_to_scan, open_ports)

View File

@ -18,6 +18,8 @@ from infection_monkey.model import VictimHost
from .plugin_registry import PluginRegistry from .plugin_registry import PluginRegistry
EMPTY_FINGERPRINT = PingScanData(False, None)
logger = logging.getLogger() logger = logging.getLogger()
@ -54,8 +56,14 @@ class Puppet(IPuppet):
port_scan_data: Dict[int, PortScanData], port_scan_data: Dict[int, PortScanData],
options: Dict, options: Dict,
) -> FingerprintData: ) -> FingerprintData:
fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER) try:
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options) fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER)
return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options)
except Exception:
logger.exception(
f"Unhandled exception occurred " f"while trying to run {name} fingerprinter"
)
return EMPTY_FINGERPRINT
def exploit_host( def exploit_host(
self, self,

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.network_scanning import ping from infection_monkey.network_scanning import ping
from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN
LINUX_SUCCESS_OUTPUT = """ LINUX_SUCCESS_OUTPUT = """
PING 192.168.1.1 (192.168.1.1) 56(84) bytes of data. PING 192.168.1.1 (192.168.1.1) 56(84) bytes of data.
@ -174,3 +175,10 @@ def test_linux_timeout(assert_expected_timeout):
timeout = 1.42379 timeout = 1.42379
assert_expected_timeout(timeout_flag, timeout, str(math.ceil(timeout))) assert_expected_timeout(timeout_flag, timeout, str(math.ceil(timeout)))
def test_exception_handling(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.ping_scanner._ping", MagicMock(side_effect=Exception)
)
assert ping("abc", 10) == EMPTY_PING_SCAN

View File

@ -1,7 +1,10 @@
from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.i_puppet import PortStatus from infection_monkey.i_puppet import PortStatus
from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning import scan_tcp_ports
from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
@ -36,7 +39,6 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
@ -48,7 +50,14 @@ def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) port_scan_data = scan_tcp_ports("127.0.0.1", [], 0)
assert len(port_scan_data) == 0 assert len(port_scan_data) == 0
def test_exception_handling(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports",
MagicMock(side_effect=Exception),
)
assert scan_tcp_ports("abc", [123], 123) == EMPTY_PORT_SCAN

View File

@ -1,8 +1,8 @@
import threading import threading
from unittest.mock import MagicMock from unittest.mock import MagicMock
from infection_monkey.i_puppet import PluginType from infection_monkey.i_puppet import PingScanData, PluginType
from infection_monkey.puppet.puppet import Puppet from infection_monkey.puppet.puppet import EMPTY_FINGERPRINT, Puppet
def test_puppet_run_payload_success(): def test_puppet_run_payload_success():
@ -41,3 +41,9 @@ def test_puppet_run_multiple_payloads():
p.run_payload(payload3_name, {}, threading.Event()) p.run_payload(payload3_name, {}, threading.Event())
payload_3.run.assert_called_once() payload_3.run.assert_called_once()
def test_fingerprint_exception_handling(monkeypatch):
p = Puppet()
p._plugin_registry.get_plugin = MagicMock(side_effect=Exception)
assert p.fingerprint("", "", PingScanData("windows", False), {}, {}) == EMPTY_FINGERPRINT