UT: Add some tests for exception handling

This commit is contained in:
vakarisz 2022-04-08 12:00:06 +03:00
parent 45c6cac60c
commit e1b52428d1
4 changed files with 36 additions and 54 deletions

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet
from infection_monkey.i_puppet import FingerprintData, PingScanData, PortScanData, PortStatus from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus
from infection_monkey.master import IPScanner from infection_monkey.master import IPScanner
from infection_monkey.network import NetworkAddress from infection_monkey.network import NetworkAddress
@ -202,47 +202,6 @@ def test_scan_lots_of_ips(callback, scan_config, stop):
assert callback.call_count == 255 assert callback.call_count == 255
def test_exception_in_pinging(callback, scan_config, stop):
addresses = [NetworkAddress("10.0.0.1", "d1")]
puppet = MockPuppet()
puppet.ping = MagicMock(side_effect=Exception("Exception raised during pinging."))
ns = IPScanner(puppet, num_workers=1)
ns.scan(addresses, scan_config, callback, stop)
(_, scan_results) = callback.call_args_list[0][0]
assert scan_results.ping_scan_data == PingScanData(False, None)
def test_exception_in_port_scanning(callback, scan_config, stop):
addresses = [NetworkAddress("10.0.0.1", "d1")]
puppet = MockPuppet()
puppet.scan_tcp_ports = MagicMock(
side_effect=Exception("Exception raised when scanning TCP ports.")
)
ns = IPScanner(puppet, num_workers=1)
ns.scan(addresses, scan_config, callback, stop)
(_, scan_results) = callback.call_args_list[0][0]
assert scan_results.port_scan_data == {-1: PortScanData(-1, PortStatus.CLOSED, None, None)}
def test_exception_in_fingerprinting(callback, scan_config, stop):
addresses = [NetworkAddress("10.0.0.1", "d1")]
puppet = MockPuppet()
puppet.fingerprint = MagicMock(side_effect=Exception("Exception raised during fingerprinting."))
ns = IPScanner(puppet, num_workers=1)
ns.scan(addresses, scan_config, callback, stop)
(_, scan_results) = callback.call_args_list[0][0]
assert scan_results.fingerprint_data == FingerprintData(None, None, {})
def test_stop_after_callback(scan_config, stop): def test_stop_after_callback(scan_config, stop):
def _callback(*_): def _callback(*_):
# Block all threads here until 2 threads reach this barrier, then set stop # Block all threads here until 2 threads reach this barrier, then set stop

View File

@ -4,7 +4,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.network_scanning import _ping from infection_monkey.network_scanning import ping
from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN
LINUX_SUCCESS_OUTPUT = """ LINUX_SUCCESS_OUTPUT = """
PING 192.168.1.1 (192.168.1.1) 56(84) bytes of data. PING 192.168.1.1 (192.168.1.1) 56(84) bytes of data.
@ -87,7 +88,7 @@ def set_os_windows(monkeypatch):
@pytest.mark.usefixtures("set_os_linux") @pytest.mark.usefixtures("set_os_linux")
def test_linux_ping_success(patch_subprocess_running_ping_with_ping_output): def test_linux_ping_success(patch_subprocess_running_ping_with_ping_output):
patch_subprocess_running_ping_with_ping_output(LINUX_SUCCESS_OUTPUT) patch_subprocess_running_ping_with_ping_output(LINUX_SUCCESS_OUTPUT)
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert result.response_received assert result.response_received
assert result.os == "linux" assert result.os == "linux"
@ -96,7 +97,7 @@ def test_linux_ping_success(patch_subprocess_running_ping_with_ping_output):
@pytest.mark.usefixtures("set_os_linux") @pytest.mark.usefixtures("set_os_linux")
def test_linux_ping_no_response(patch_subprocess_running_ping_with_ping_output): def test_linux_ping_no_response(patch_subprocess_running_ping_with_ping_output):
patch_subprocess_running_ping_with_ping_output(LINUX_NO_RESPONSE_OUTPUT) patch_subprocess_running_ping_with_ping_output(LINUX_NO_RESPONSE_OUTPUT)
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert not result.response_received assert not result.response_received
assert result.os is None assert result.os is None
@ -105,7 +106,7 @@ def test_linux_ping_no_response(patch_subprocess_running_ping_with_ping_output):
@pytest.mark.usefixtures("set_os_windows") @pytest.mark.usefixtures("set_os_windows")
def test_windows_ping_success(patch_subprocess_running_ping_with_ping_output): def test_windows_ping_success(patch_subprocess_running_ping_with_ping_output):
patch_subprocess_running_ping_with_ping_output(WINDOWS_SUCCESS_OUTPUT) patch_subprocess_running_ping_with_ping_output(WINDOWS_SUCCESS_OUTPUT)
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert result.response_received assert result.response_received
assert result.os == "windows" assert result.os == "windows"
@ -114,7 +115,7 @@ def test_windows_ping_success(patch_subprocess_running_ping_with_ping_output):
@pytest.mark.usefixtures("set_os_windows") @pytest.mark.usefixtures("set_os_windows")
def test_windows_ping_no_response(patch_subprocess_running_ping_with_ping_output): def test_windows_ping_no_response(patch_subprocess_running_ping_with_ping_output):
patch_subprocess_running_ping_with_ping_output(WINDOWS_NO_RESPONSE_OUTPUT) patch_subprocess_running_ping_with_ping_output(WINDOWS_NO_RESPONSE_OUTPUT)
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert not result.response_received assert not result.response_received
assert result.os is None assert result.os is None
@ -122,7 +123,7 @@ def test_windows_ping_no_response(patch_subprocess_running_ping_with_ping_output
def test_malformed_ping_command_response(patch_subprocess_running_ping_with_ping_output): def test_malformed_ping_command_response(patch_subprocess_running_ping_with_ping_output):
patch_subprocess_running_ping_with_ping_output(MALFORMED_OUTPUT) patch_subprocess_running_ping_with_ping_output(MALFORMED_OUTPUT)
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert not result.response_received assert not result.response_received
assert result.os is None assert result.os is None
@ -130,7 +131,7 @@ def test_malformed_ping_command_response(patch_subprocess_running_ping_with_ping
@pytest.mark.usefixtures("patch_subprocess_running_ping_to_raise_timeout_expired") @pytest.mark.usefixtures("patch_subprocess_running_ping_to_raise_timeout_expired")
def test_timeout_expired(): def test_timeout_expired():
result = _ping("192.168.1.1", 1.0) result = ping("192.168.1.1", 1.0)
assert not result.response_received assert not result.response_received
assert result.os is None assert result.os is None
@ -147,7 +148,7 @@ def ping_command_spy(monkeypatch):
@pytest.fixture @pytest.fixture
def assert_expected_timeout(ping_command_spy): def assert_expected_timeout(ping_command_spy):
def inner(timeout_flag, timeout_input, expected_timeout): def inner(timeout_flag, timeout_input, expected_timeout):
_ping("192.168.1.1", timeout_input) ping("192.168.1.1", timeout_input)
assert ping_command_spy.call_args is not None assert ping_command_spy.call_args is not None
@ -174,3 +175,10 @@ def test_linux_timeout(assert_expected_timeout):
timeout = 1.42379 timeout = 1.42379
assert_expected_timeout(timeout_flag, timeout, str(math.ceil(timeout))) assert_expected_timeout(timeout_flag, timeout, str(math.ceil(timeout)))
def test_exception_handling(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.ping_scanner._ping", MagicMock(side_effect=Exception)
)
assert ping("abc", 10) == EMPTY_PING_SCAN

View File

@ -1,7 +1,10 @@
from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.i_puppet import PortStatus from infection_monkey.i_puppet import PortStatus
from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning import scan_tcp_ports
from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
@ -36,7 +39,6 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
@ -48,7 +50,14 @@ def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) port_scan_data = scan_tcp_ports("127.0.0.1", [], 0)
assert len(port_scan_data) == 0 assert len(port_scan_data) == 0
def test_exception_handling(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports",
MagicMock(side_effect=Exception),
)
assert scan_tcp_ports("abc", [123], 123) == EMPTY_PORT_SCAN

View File

@ -1,8 +1,8 @@
import threading import threading
from unittest.mock import MagicMock from unittest.mock import MagicMock
from infection_monkey.i_puppet import PluginType from infection_monkey.i_puppet import PingScanData, PluginType
from infection_monkey.puppet.puppet import Puppet from infection_monkey.puppet.puppet import EMPTY_FINGERPRINT, Puppet
def test_puppet_run_payload_success(): def test_puppet_run_payload_success():
@ -41,3 +41,9 @@ def test_puppet_run_multiple_payloads():
p.run_payload(payload3_name, {}, threading.Event()) p.run_payload(payload3_name, {}, threading.Event())
payload_3.run.assert_called_once() payload_3.run.assert_called_once()
def test_fingerprint_exception_handling(monkeypatch):
p = Puppet()
p._plugin_registry.get_plugin = MagicMock(side_effect=Exception)
assert p.fingerprint("", "", PingScanData("windows", False), {}, {}) == EMPTY_FINGERPRINT