diff --git a/monkey/infection_monkey/i_puppet.py b/monkey/infection_monkey/i_puppet.py index 49040dd9f..f158be08c 100644 --- a/monkey/infection_monkey/i_puppet.py +++ b/monkey/infection_monkey/i_puppet.py @@ -36,22 +36,22 @@ class IPuppet(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def ping(self, host: str, options: Dict) -> PingScanData: + def ping(self, host: str, timeout: float) -> PingScanData: """ Sends a ping (ICMP packet) to a remote host :param str host: The domain name or IP address of a host - :return: A tuple that contains whether or not the host responded and the host's inferred - operating system - :rtype: Tuple[bool, Optional[str]] + :param float timeout: The maximum amount of time (in seconds) to wait for a response + :return: The data collected by attempting to ping the target host + :rtype: PingScanData """ @abc.abstractmethod - def scan_tcp_port(self, host: str, port: int, timeout: int) -> PortScanData: + def scan_tcp_port(self, host: str, port: int, timeout: float) -> PortScanData: """ Scans a TCP port on a remote host :param str host: The domain name or IP address of a host :param int port: A TCP port number to scan - :param int timeout: The maximum amount of time (in seconds) to wait for a response + :param float timeout: The maximum amount of time (in seconds) to wait for a response :return: The data collected by scanning the provided host:port combination :rtype: PortScanData """ diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 3e469ee9c..7933202f6 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -44,7 +44,8 @@ class IPScanner: ip = ips.get_nowait() logger.info(f"Scanning {ip}") - ping_scan_data = self._puppet.ping(ip, options["icmp"]) + icmp_timeout = options["icmp"]["timeout_ms"] / 1000 + ping_scan_data = self._puppet.ping(ip, icmp_timeout) port_scan_data = self._scan_tcp_ports(ip, options["tcp"], stop) results_callback(ip, ping_scan_data, port_scan_data) @@ -59,11 +60,13 @@ class IPScanner: ) def _scan_tcp_ports(self, ip: str, options: Dict, stop: Event): + tcp_timeout = options["timeout_ms"] / 1000 port_scan_data = {} + for p in options["ports"]: if stop.is_set(): break - port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, options["timeout_ms"]) + port_scan_data[p] = self._puppet.scan_tcp_port(ip, p, tcp_timeout) return port_scan_data diff --git a/monkey/infection_monkey/master/mock_master.py b/monkey/infection_monkey/master/mock_master.py index 8c8ecebdd..551ff886c 100644 --- a/monkey/infection_monkey/master/mock_master.py +++ b/monkey/infection_monkey/master/mock_master.py @@ -66,7 +66,7 @@ class MockMaster(IMaster): for ip in ips: h = self._hosts[ip] - ping_scan_data = self._puppet.ping(ip, {}) + ping_scan_data = self._puppet.ping(ip, 1) h.icmp = ping_scan_data.response_received if ping_scan_data.os is not None: h.os["type"] = ping_scan_data.os diff --git a/monkey/infection_monkey/puppet/mock_puppet.py b/monkey/infection_monkey/puppet/mock_puppet.py index de89db172..8c6a39c65 100644 --- a/monkey/infection_monkey/puppet/mock_puppet.py +++ b/monkey/infection_monkey/puppet/mock_puppet.py @@ -156,8 +156,8 @@ class MockPuppet(IPuppet): else: return PostBreachData("pba command 2", ["pba result 2", False]) - def ping(self, host: str, options: Dict) -> PingScanData: - logger.debug(f"run_ping({host})") + def ping(self, host: str, timeout: float = 1) -> PingScanData: + logger.debug(f"run_ping({host}, {timeout})") if host == DOT_1: return PingScanData(True, "windows") diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py index 9447bdfc1..d302cbbfb 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_network_scanner.py @@ -12,12 +12,6 @@ WINDOWS_OS = "windows" LINUX_OS = "linux" -class MockPuppet(MockPuppet): - def __init__(self): - self.ping = MagicMock(side_effect=super().ping) - self.scan_tcp_port = MagicMock(side_effect=super().scan_tcp_port) - - @pytest.fixture def scan_config(): return {