From 31fd24f0770ab8b43515ff28d299704f5ae7a6f3 Mon Sep 17 00:00:00 2001 From: Shreya Malviya Date: Wed, 9 Feb 2022 21:24:55 +0530 Subject: [PATCH 01/10] Agent: Address CR comments + minor changes in tcp_scanner.py --- .../infection_monkey/network/tcp_scanner.py | 47 ++++++++----------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 3d3f66a14..84453cfa7 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -46,8 +46,9 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT :return: List of open ports. """ sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))] - # CR: Don't use list comprehensions if you don't need a list - [s.setblocking(False) for s in sockets] + for s in sockets: + s.setblocking(False) + possible_ports = [] connected_ports_sockets = [] try: @@ -58,13 +59,10 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT connected_ports_sockets.append((port, sock)) possible_ports.append((port, sock)) continue - # BUG: I don't think a socket will ever connect successfully if this error is raised. - # From the documentation: "Resource temporarily unavailable... It is a nonfatal - # error, **and the operation should be retried later**." (emphasis mine). If the - # operation is not retried later, I don't see the point in appending this to - # possible_ports. - if err == 10035: # WSAEWOULDBLOCK is valid, see - # https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668%28v=vs.85%29.aspx?f=255&MSPPError=-2147217396 + if err == 10035: # WSAEWOULDBLOCK is valid. + # https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect + # says, "Use the select function to determine the completion of the connection + # request by checking to see if the socket is writable," which is being done below. possible_ports.append((port, sock)) continue if err == 115: # EINPROGRESS 115 /* Operation now in progress */ @@ -74,15 +72,11 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT if len(possible_ports) != 0: timeout = int(round(timeout)) # clamp to integer, to avoid checking input - sockets_to_try = possible_ports[:] - # BUG: If any sockets were added to connected_ports_sockets on line 94, this would - # remove them. - connected_ports_sockets = [] + sockets_to_try = possible_ports.copy() while (timeout >= 0) and sockets_to_try: sock_objects = [s[1] for s in sockets_to_try] - # BUG: Since timeout is 0, this could block indefinitely - _, writeable_sockets, _ = select.select(sock_objects, sock_objects, sock_objects, 0) + _, writeable_sockets, _ = select.select([], sock_objects, [], 0) for s in writeable_sockets: try: # actual test connected_ports_sockets.append((s.getpeername()[1], s)) @@ -97,6 +91,7 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT "On host %s discovered the following ports %s" % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])) ) + banners = [] if len(connected_ports_sockets) != 0: readable_sockets, _, _ = select.select( @@ -104,20 +99,18 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT ) # read first BANNER_READ bytes. We ignore errors because service might not send a # decodable byte string. - # CR: Because of how black formats this, it is difficult to parse. Refactor to be - # easier to read. + for port, sock in connected_ports_sockets: + if sock in readable_sockets: + banners.append(sock.recv(BANNER_READ).decode(errors="ignore")) + else: + banners.append("") - # TODO: Rework the return of this function. Consider using dictionary - banners = [ - sock.recv(BANNER_READ).decode(errors="ignore") - if sock in readable_sockets - else "" - for port, sock in connected_ports_sockets - ] - pass # try to cleanup - # CR: Evaluate whether or not we should call shutdown() before close() on each socket. - [s[1].close() for s in possible_ports] + for s in possible_ports: + s[1].shutdown(socket.SHUT_RDWR) + s[1].close() + + # TODO: Rework the return of this function. Consider using dictionary return [port for port, sock in connected_ports_sockets], banners else: return [], [] From e981ead1500e58892a9317c07e456f13c01dc708 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Wed, 9 Feb 2022 18:18:15 -0500 Subject: [PATCH 02/10] Agent: Add new time_remaining() method to Timer --- monkey/infection_monkey/utils/timer.py | 13 +++++++++- .../infection_monkey/utils/test_timer.py | 25 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/monkey/infection_monkey/utils/timer.py b/monkey/infection_monkey/utils/timer.py index 366a10b20..597eb6020 100644 --- a/monkey/infection_monkey/utils/timer.py +++ b/monkey/infection_monkey/utils/timer.py @@ -25,7 +25,18 @@ class Timer: TIMEOUT_SEC, False otherwise :rtype: bool """ - return (time.time() - self._start_time) >= self._timeout_sec + return self.time_remaining == 0 + + @property + def time_remaining(self) -> float: + """ + Return the amount of time remaining until the timer expires. + :return: The number of seconds until the timer expires. If the timer is expired, this + function returns 0 (it will never return a negative number). + :rtype: float + """ + time_remaining = self._timeout_sec - (time.time() - self._start_time) + return time_remaining if time_remaining > 0 else 0 def reset(self): """ diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py b/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py index 5359b8c79..b5291cc0e 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_timer.py @@ -67,3 +67,28 @@ def test_timer_reset(start_time, set_current_time, timeout): set_current_time(start_time + (2 * timeout)) assert t.is_expired() + + +def test_time_remaining(start_time, set_current_time): + timeout = 5 + + t = Timer() + t.set(timeout) + + assert t.time_remaining == timeout + + set_current_time(start_time + 2) + assert t.time_remaining == 3 + + +def test_time_remaining_is_zero(start_time, set_current_time): + timeout = 5 + + t = Timer() + t.set(timeout) + + set_current_time(start_time + timeout) + assert t.time_remaining == 0 + + set_current_time(start_time + (2 * timeout)) + assert t.time_remaining == 0 From 0e7f171c4a1e1843883c8fcf60d7d590030baec3 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Wed, 9 Feb 2022 19:21:45 -0500 Subject: [PATCH 03/10] Agent: Use a Timer in _check_tcp_ports() to simplify logic --- monkey/infection_monkey/network/tcp_scanner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 84453cfa7..9e69d9351 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,14 +1,12 @@ import logging import select import socket -import time from itertools import zip_longest from typing import Dict, List, Set from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service - -SLEEP_BETWEEN_POLL = 0.5 +from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) @@ -71,21 +69,21 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT logger.warning("Failed to connect to port %s, error code is %d", port, err) if len(possible_ports) != 0: - timeout = int(round(timeout)) # clamp to integer, to avoid checking input sockets_to_try = possible_ports.copy() - while (timeout >= 0) and sockets_to_try: + + timer = Timer() + timer.set(timeout) + + while (not timer.is_expired()) and sockets_to_try: sock_objects = [s[1] for s in sockets_to_try] - _, writeable_sockets, _ = select.select([], sock_objects, [], 0) + _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining) for s in writeable_sockets: try: # actual test connected_ports_sockets.append((s.getpeername()[1], s)) except socket.error: # bad socket, select didn't filter it properly pass sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] - if sockets_to_try: - time.sleep(SLEEP_BETWEEN_POLL) - timeout -= SLEEP_BETWEEN_POLL logger.debug( "On host %s discovered the following ports %s" @@ -95,7 +93,7 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT banners = [] if len(connected_ports_sockets) != 0: readable_sockets, _, _ = select.select( - [s[1] for s in connected_ports_sockets], [], [], 0 + [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining ) # read first BANNER_READ bytes. We ignore errors because service might not send a # decodable byte string. From eb1a322ff83607fcded8e67df6a46a220985f33e Mon Sep 17 00:00:00 2001 From: Shreya Malviya Date: Thu, 10 Feb 2022 15:37:00 +0530 Subject: [PATCH 04/10] Agent: Rework return value in _check_tcp_ports in tcp_scanner.py --- .../infection_monkey/network/tcp_scanner.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 9e69d9351..a8ec9751b 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -14,8 +14,10 @@ logger = logging.getLogger(__name__) def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]: ports_scan = {} - open_ports, banners = _check_tcp_ports(host, ports, timeout) - open_ports = set(open_ports) + open_ports_data = _check_tcp_ports(host, ports, timeout) + + open_ports = set(open_ports_data["open_ports"]) + banners = open_ports_data["banners"] for port, banner in zip_longest(ports, banners, fillvalue=None): ports_scan[port] = _build_port_scan_data(port, open_ports, banner) @@ -35,14 +37,18 @@ def _get_closed_port_data(port: int) -> PortScanData: return PortScanData(port, PortStatus.CLOSED, None, None) -def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT): +def _check_tcp_ports( + ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT +) -> Dict[str, List]: """ Checks whether any of the given ports are open on a target IP. :param ip: IP of host to attack :param ports: List of ports to attack. Must not be empty. :param timeout: Amount of time to wait for connection - :return: List of open ports. + :return: Dict with list of open ports and list of banners. """ + open_ports_data = {"open_ports": [], "banners": []} + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))] for s in sockets: s.setblocking(False) @@ -108,11 +114,11 @@ def _check_tcp_ports(ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT s[1].shutdown(socket.SHUT_RDWR) s[1].close() - # TODO: Rework the return of this function. Consider using dictionary - return [port for port, sock in connected_ports_sockets], banners - else: - return [], [] + open_ports_data["open_ports"] = [port for port, _ in connected_ports_sockets] + open_ports_data["banners"] = banners except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) - return [], [] + + finally: + return open_ports_data From d3dd6ffeb0e3bc98cc46fa21430421cb2666983b Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 10 Feb 2022 09:13:09 -0500 Subject: [PATCH 05/10] Agent: Simplify logic in Timer.time_remaining --- monkey/infection_monkey/utils/timer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monkey/infection_monkey/utils/timer.py b/monkey/infection_monkey/utils/timer.py index 597eb6020..2ed17d551 100644 --- a/monkey/infection_monkey/utils/timer.py +++ b/monkey/infection_monkey/utils/timer.py @@ -36,7 +36,7 @@ class Timer: :rtype: float """ time_remaining = self._timeout_sec - (time.time() - self._start_time) - return time_remaining if time_remaining > 0 else 0 + return max(time_remaining, 0) def reset(self): """ From a53b61175951dadcb3c8f018c01214e67435b46b Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 10 Feb 2022 09:32:14 -0500 Subject: [PATCH 06/10] Agent: Change _check_tcp_ports() to return Mapping[int, str] --- .../infection_monkey/network/tcp_scanner.py | 71 ++++++++++--------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index a8ec9751b..62af733e7 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,8 +1,7 @@ import logging import select import socket -from itertools import zip_longest -from typing import Dict, List, Set +from typing import Iterable, Mapping from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service @@ -11,26 +10,28 @@ from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) -def scan_tcp_ports(host: str, ports: List[int], timeout: float) -> Dict[int, PortScanData]: - ports_scan = {} +def scan_tcp_ports( + host: str, ports_to_scan: Iterable[int], timeout: float +) -> Mapping[int, PortScanData]: + open_ports = _check_tcp_ports(host, ports_to_scan, timeout) - open_ports_data = _check_tcp_ports(host, ports, timeout) - - open_ports = set(open_ports_data["open_ports"]) - banners = open_ports_data["banners"] - - for port, banner in zip_longest(ports, banners, fillvalue=None): - ports_scan[port] = _build_port_scan_data(port, open_ports, banner) - - return ports_scan + return _build_port_scan_data(ports_to_scan, open_ports) -def _build_port_scan_data(port: int, open_ports: Set[int], banner: str) -> PortScanData: - if port in open_ports: - service = tcp_port_to_service(port) - return PortScanData(port, PortStatus.OPEN, banner, service) - else: - return _get_closed_port_data(port) +def _build_port_scan_data( + ports_to_scan: Iterable[int], open_ports: Mapping[int, str] +) -> Mapping[int, PortScanData]: + port_scan_data = {} + for port in ports_to_scan: + if port in open_ports: + service = tcp_port_to_service(port) + banner = open_ports[port] + + port_scan_data[port] = PortScanData(port, PortStatus.OPEN, banner, service) + else: + port_scan_data[port] = _get_closed_port_data(port) + + return port_scan_data def _get_closed_port_data(port: int) -> PortScanData: @@ -38,26 +39,29 @@ def _get_closed_port_data(port: int) -> PortScanData: def _check_tcp_ports( - ip: str, ports: List[int], timeout: float = DEFAULT_TIMEOUT -) -> Dict[str, List]: + ip: str, ports_to_scan: Iterable[int], timeout: float = DEFAULT_TIMEOUT +) -> Mapping[int, str]: """ Checks whether any of the given ports are open on a target IP. :param ip: IP of host to attack - :param ports: List of ports to attack. Must not be empty. + :param ports_to_scan: An iterable of ports to scan. Must not be empty. :param timeout: Amount of time to wait for connection - :return: Dict with list of open ports and list of banners. + :return: Mapping where the key is an open port and the value is the banner + :rtype: Mapping """ - open_ports_data = {"open_ports": [], "banners": []} - - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports))] + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(len(ports_to_scan))] for s in sockets: s.setblocking(False) possible_ports = [] connected_ports_sockets = [] + open_ports = {} + try: - logger.debug("Connecting to the following ports %s" % ",".join((str(x) for x in ports))) - for sock, port in zip(sockets, ports): + logger.debug( + "Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan)) + ) + for sock, port in zip(sockets, ports_to_scan): err = sock.connect_ex((ip, port)) if err == 0: # immediate connect connected_ports_sockets.append((port, sock)) @@ -96,7 +100,7 @@ def _check_tcp_ports( % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])) ) - banners = [] + open_ports = {port: "" for port, _ in connected_ports_sockets} if len(connected_ports_sockets) != 0: readable_sockets, _, _ = select.select( [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining @@ -105,20 +109,17 @@ def _check_tcp_ports( # decodable byte string. for port, sock in connected_ports_sockets: if sock in readable_sockets: - banners.append(sock.recv(BANNER_READ).decode(errors="ignore")) + open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore") else: - banners.append("") + open_ports[port] = "" # try to cleanup for s in possible_ports: s[1].shutdown(socket.SHUT_RDWR) s[1].close() - open_ports_data["open_ports"] = [port for port, _ in connected_ports_sockets] - open_ports_data["banners"] = banners - except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) finally: - return open_ports_data + return open_ports From 2ae77ce897c56ed678e119ba5d8dc6ae620cb493 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 10 Feb 2022 10:02:36 -0500 Subject: [PATCH 07/10] Agent: Fix error when shutting down sockets in _check_tcp_ports() An error is raised if shutdown() is called on a socket that has not successfully connected. This commit modifies the cleanup logic so that shutdown() is only called on sockets that are known to be connected and close() is called on all sockets. --- .../infection_monkey/network/tcp_scanner.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 62af733e7..45de37d30 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,7 +1,7 @@ import logging import select import socket -from typing import Iterable, Mapping +from typing import Iterable, Mapping, Tuple from infection_monkey.i_puppet import PortScanData, PortStatus from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service @@ -113,13 +113,28 @@ def _check_tcp_ports( else: open_ports[port] = "" - # try to cleanup - for s in possible_ports: - s[1].shutdown(socket.SHUT_RDWR) - s[1].close() - except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) - finally: - return open_ports + _clean_up_sockets(possible_ports, connected_ports_sockets) + + return open_ports + + +def _clean_up_sockets( + possible_ports: Iterable[Tuple[int, socket.socket]], + connected_ports_sockets: Iterable[Tuple[int, socket.socket]], +): + # Only call shutdown() on sockets we know to be connected + for port, s in connected_ports_sockets: + try: + s.shutdown(socket.SHUT_RDWR) + except socket.error as exc: + logger.warning(f"Error occurred while shutting down socket on port {port}: {exc}") + + # Call close() for all sockets + for port, s in possible_ports: + try: + s.close() + except socket.error as exc: + logger.warning(f"Error occurred while closing socket on port {port}: {exc}") From 21ede3e3410df50cd04766b5175df466d407805d Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 10 Feb 2022 10:17:29 -0500 Subject: [PATCH 08/10] Agent: Improve readability of _check_tcp_ports() --- .../infection_monkey/network/tcp_scanner.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 45de37d30..330b23d52 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -53,8 +53,8 @@ def _check_tcp_ports( for s in sockets: s.setblocking(False) - possible_ports = [] - connected_ports_sockets = [] + possible_ports = set() + connected_ports = set() open_ports = {} try: @@ -64,19 +64,17 @@ def _check_tcp_ports( for sock, port in zip(sockets, ports_to_scan): err = sock.connect_ex((ip, port)) if err == 0: # immediate connect - connected_ports_sockets.append((port, sock)) - possible_ports.append((port, sock)) - continue - if err == 10035: # WSAEWOULDBLOCK is valid. + connected_ports.add((port, sock)) + possible_ports.add((port, sock)) + elif err == 10035: # WSAEWOULDBLOCK is valid. # https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect # says, "Use the select function to determine the completion of the connection # request by checking to see if the socket is writable," which is being done below. - possible_ports.append((port, sock)) - continue - if err == 115: # EINPROGRESS 115 /* Operation now in progress */ - possible_ports.append((port, sock)) - continue - logger.warning("Failed to connect to port %s, error code is %d", port, err) + possible_ports.add((port, sock)) + elif err == 115: # EINPROGRESS 115 /* Operation now in progress */ + possible_ports.add((port, sock)) + else: + logger.warning("Failed to connect to port %s, error code is %d", port, err) if len(possible_ports) != 0: sockets_to_try = possible_ports.copy() @@ -90,24 +88,25 @@ def _check_tcp_ports( _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining) for s in writeable_sockets: try: # actual test - connected_ports_sockets.append((s.getpeername()[1], s)) + connected_ports.add((s.getpeername()[1], s)) except socket.error: # bad socket, select didn't filter it properly pass - sockets_to_try = [s for s in sockets_to_try if s not in connected_ports_sockets] + + sockets_to_try = sockets_to_try - connected_ports logger.debug( "On host %s discovered the following ports %s" - % (str(ip), ",".join([str(s[0]) for s in connected_ports_sockets])) + % (str(ip), ",".join([str(s[0]) for s in connected_ports])) ) - open_ports = {port: "" for port, _ in connected_ports_sockets} - if len(connected_ports_sockets) != 0: + open_ports = {port: "" for port, _ in connected_ports} + if len(connected_ports) != 0: readable_sockets, _, _ = select.select( - [s[1] for s in connected_ports_sockets], [], [], timer.time_remaining + [s[1] for s in connected_ports], [], [], timer.time_remaining ) # read first BANNER_READ bytes. We ignore errors because service might not send a # decodable byte string. - for port, sock in connected_ports_sockets: + for port, sock in connected_ports: if sock in readable_sockets: open_ports[port] = sock.recv(BANNER_READ).decode(errors="ignore") else: @@ -116,7 +115,7 @@ def _check_tcp_ports( except socket.error as exc: logger.warning("Exception when checking ports on host %s, Exception: %s", str(ip), exc) - _clean_up_sockets(possible_ports, connected_ports_sockets) + _clean_up_sockets(possible_ports, connected_ports) return open_ports From 36a2b3ff6bf1a851a51704c6857358390d128607 Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Thu, 10 Feb 2022 10:49:21 -0500 Subject: [PATCH 09/10] Agent: Add sleep back into _check_tcp_ports() --- monkey/infection_monkey/network/tcp_scanner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/monkey/infection_monkey/network/tcp_scanner.py b/monkey/infection_monkey/network/tcp_scanner.py index 330b23d52..6fdded293 100644 --- a/monkey/infection_monkey/network/tcp_scanner.py +++ b/monkey/infection_monkey/network/tcp_scanner.py @@ -1,6 +1,7 @@ import logging import select import socket +import time from typing import Iterable, Mapping, Tuple from infection_monkey.i_puppet import PortScanData, PortStatus @@ -9,6 +10,8 @@ from infection_monkey.utils.timer import Timer logger = logging.getLogger(__name__) +POLL_INTERVAL = 0.5 + def scan_tcp_ports( host: str, ports_to_scan: Iterable[int], timeout: float @@ -83,6 +86,10 @@ def _check_tcp_ports( timer.set(timeout) while (not timer.is_expired()) and sockets_to_try: + # The call to select() may return sockets that are writeable but not actually + # connected. Adding this sleep prevents excessive looping. + time.sleep(min(POLL_INTERVAL, timer.time_remaining)) + sock_objects = [s[1] for s in sockets_to_try] _, writeable_sockets, _ = select.select([], sock_objects, [], timer.time_remaining) From 543ff24ac38b94c703ca59ea952931aed890325c Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Thu, 10 Feb 2022 17:56:38 +0100 Subject: [PATCH 10/10] UT: Add tests for tcp scanning --- .../network/test_tcp_scanning.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py diff --git a/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py b/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py new file mode 100644 index 000000000..e383e1004 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/test_tcp_scanning.py @@ -0,0 +1,54 @@ +import pytest + +from infection_monkey.i_puppet import PortStatus +from infection_monkey.network import scan_tcp_ports + +PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] + +OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"} + + +@pytest.fixture +def patch_check_tcp_ports(monkeypatch, open_ports_data): + monkeypatch.setattr( + "infection_monkey.network.tcp_scanner._check_tcp_ports", + lambda *_: open_ports_data, + ) + + +@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) +def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): + closed_ports = [8080, 143, 445] + + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) + + assert len(port_scan_data) == 6 + for port in open_ports_data.keys(): + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.OPEN + assert port_scan_data[port].banner == open_ports_data.get(port) + + for port in closed_ports: + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.CLOSED + assert port_scan_data[port].banner is None + + +@pytest.mark.parametrize("open_ports_data", [{}]) +def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): + + port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) + + assert len(port_scan_data) == 6 + for port in open_ports_data: + assert port_scan_data[port].port == port + assert port_scan_data[port].status == PortStatus.CLOSED + assert port_scan_data[port].banner is None + + +@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) +def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): + + port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) + + assert len(port_scan_data) == 0