diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index 0ef31129c..af263af6e 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -225,8 +225,13 @@ class AutomatedMaster(IMaster): interrupted_message = f"Received a stop signal, skipping remaining {plugin_type}s" for p in interruptible_iter(plugins, self._stop, interrupted_message): - # TODO: Catch exceptions to prevent thread from crashing - callback(p) + try: + callback(p) + except Exception: + logger.exception( + f"Got unhandled exception when running {plugin_type} plugin {p}. " + f"Plugin was passed to {callback}" + ) logger.info(f"Finished running {plugin_type}s") diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 2702642c9..8c0ea5caa 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -61,7 +61,6 @@ class IPScanner: address = addresses.get_nowait() logger.info(f"Scanning {address.ip}") - # TODO: Catch exceptions to prevent thread from crashing ping_scan_data = self._puppet.ping(address.ip, icmp_timeout) port_scan_data = self._puppet.scan_tcp_ports(address.ip, tcp_ports, tcp_timeout) diff --git a/monkey/infection_monkey/network_scanning/ping_scanner.py b/monkey/infection_monkey/network_scanning/ping_scanner.py index 66ae79b2a..16fb2df96 100644 --- a/monkey/infection_monkey/network_scanning/ping_scanner.py +++ b/monkey/infection_monkey/network_scanning/ping_scanner.py @@ -6,16 +6,26 @@ import subprocess import sys from infection_monkey.i_puppet import PingScanData +from infection_monkey.utils.environment import is_windows_os TTL_REGEX = re.compile(r"TTL=([0-9]+)\b", re.IGNORECASE) LINUX_TTL = 64 # Windows TTL is 128 PING_EXIT_TIMEOUT = 10 +EMPTY_PING_SCAN = PingScanData(False, None) logger = logging.getLogger(__name__) def ping(host: str, timeout: float) -> PingScanData: - if "win32" == sys.platform: + try: + return _ping(host, timeout) + except Exception: + logger.exception("Unhandled exception occurred while running ping") + return EMPTY_PING_SCAN + + +def _ping(host: str, timeout: float) -> PingScanData: + if is_windows_os(): timeout = math.floor(timeout * 1000) ping_command_output = _run_ping_command(host, timeout) diff --git a/monkey/infection_monkey/network_scanning/tcp_scanner.py b/monkey/infection_monkey/network_scanning/tcp_scanner.py index 6fdded293..d0c6e3e7a 100644 --- a/monkey/infection_monkey/network_scanning/tcp_scanner.py +++ b/monkey/infection_monkey/network_scanning/tcp_scanner.py @@ -11,11 +11,20 @@ from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) POLL_INTERVAL = 0.5 +EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)} def scan_tcp_ports( host: str, ports_to_scan: Iterable[int], timeout: float ) -> Mapping[int, PortScanData]: + try: + return _scan_tcp_ports(host, ports_to_scan, timeout) + except Exception: + logger.exception("Unhandled exception occurred while trying to scan tcp ports") + return EMPTY_PORT_SCAN + + +def _scan_tcp_ports(host: str, ports_to_scan: Iterable[int], timeout: float): open_ports = _check_tcp_ports(host, ports_to_scan, timeout) return _build_port_scan_data(ports_to_scan, open_ports) diff --git a/monkey/infection_monkey/puppet/puppet.py b/monkey/infection_monkey/puppet/puppet.py index 030008a04..c7953d83c 100644 --- a/monkey/infection_monkey/puppet/puppet.py +++ b/monkey/infection_monkey/puppet/puppet.py @@ -18,6 +18,8 @@ from infection_monkey.model import VictimHost from .plugin_registry import PluginRegistry +EMPTY_FINGERPRINT = PingScanData(False, None) + logger = logging.getLogger() @@ -54,8 +56,14 @@ class Puppet(IPuppet): port_scan_data: Dict[int, PortScanData], options: Dict, ) -> FingerprintData: - fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER) - return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options) + try: + fingerprinter = self._plugin_registry.get_plugin(name, PluginType.FINGERPRINTER) + return fingerprinter.get_host_fingerprint(host, ping_scan_data, port_scan_data, options) + except Exception: + logger.exception( + f"Unhandled exception occurred " f"while trying to run {name} fingerprinter" + ) + return EMPTY_FINGERPRINT def exploit_host( self, diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py similarity index 94% rename from monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping.py rename to monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py index 37e365682..88c9dbeca 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_ping_scanner.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest from infection_monkey.network_scanning import ping +from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN LINUX_SUCCESS_OUTPUT = """ PING 192.168.1.1 (192.168.1.1) 56(84) bytes of data. @@ -174,3 +175,10 @@ def test_linux_timeout(assert_expected_timeout): timeout = 1.42379 assert_expected_timeout(timeout_flag, timeout, str(math.ceil(timeout))) + + +def test_exception_handling(monkeypatch): + monkeypatch.setattr( + "infection_monkey.network_scanning.ping_scanner._ping", MagicMock(side_effect=Exception) + ) + assert ping("abc", 10) == EMPTY_PING_SCAN diff --git a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanning.py b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py similarity index 83% rename from monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanning.py rename to monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py index 725a3aaa0..837b3da0d 100644 --- a/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanning.py +++ b/monkey/tests/unit_tests/infection_monkey/network_scanning/test_tcp_scanner.py @@ -1,7 +1,10 @@ +from unittest.mock import MagicMock + import pytest from infection_monkey.i_puppet import PortStatus from infection_monkey.network_scanning import scan_tcp_ports +from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] @@ -36,7 +39,6 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): @pytest.mark.parametrize("open_ports_data", [{}]) def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): - port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) assert len(port_scan_data) == 6 @@ -48,7 +50,14 @@ def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): - port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) assert len(port_scan_data) == 0 + + +def test_exception_handling(monkeypatch): + monkeypatch.setattr( + "infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports", + MagicMock(side_effect=Exception), + ) + assert scan_tcp_ports("abc", [123], 123) == EMPTY_PORT_SCAN diff --git a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py index 70c98d252..39273faee 100644 --- a/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py +++ b/monkey/tests/unit_tests/infection_monkey/puppet/test_puppet.py @@ -1,8 +1,8 @@ import threading from unittest.mock import MagicMock -from infection_monkey.i_puppet import PluginType -from infection_monkey.puppet.puppet import Puppet +from infection_monkey.i_puppet import PingScanData, PluginType +from infection_monkey.puppet.puppet import EMPTY_FINGERPRINT, Puppet def test_puppet_run_payload_success(): @@ -41,3 +41,9 @@ def test_puppet_run_multiple_payloads(): p.run_payload(payload3_name, {}, threading.Event()) payload_3.run.assert_called_once() + + +def test_fingerprint_exception_handling(monkeypatch): + p = Puppet() + p._plugin_registry.get_plugin = MagicMock(side_effect=Exception) + assert p.fingerprint("", "", PingScanData("windows", False), {}, {}) == EMPTY_FINGERPRINT