Merge pull request #2377 from guardicore/2267-publish-tcp-scan-event

Publish TCPScanEvent
This commit is contained in:
Mike Salvatore 2022-09-30 11:59:51 -04:00 committed by GitHub
commit 9a6300481c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 13 deletions

View File

@ -1,15 +1,18 @@
import logging import logging
import select import select
import socket import socket
import time from ipaddress import IPv4Address
from pprint import pformat from pprint import pformat
from time import sleep, time
from typing import Collection, Dict, Iterable, Mapping, Tuple from typing import Collection, Dict, Iterable, Mapping, Tuple
from common.agent_events import TCPScanEvent
from common.event_queue import IAgentEventQueue from common.event_queue import IAgentEventQueue
from common.types import PortStatus from common.types import PortStatus
from common.utils import Timer from common.utils import Timer
from infection_monkey.i_puppet import PortScanData from infection_monkey.i_puppet import PortScanData
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
from infection_monkey.utils.ids import get_agent_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,21 +24,41 @@ def scan_tcp_ports(
host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
try: try:
return _scan_tcp_ports(host, ports_to_scan, timeout) return _scan_tcp_ports(host, ports_to_scan, timeout, agent_event_queue)
except Exception: except Exception:
logger.exception("Unhandled exception occurred while trying to scan tcp ports") logger.exception("Unhandled exception occurred while trying to scan tcp ports")
return EMPTY_PORT_SCAN return EMPTY_PORT_SCAN
def _scan_tcp_ports(host: str, ports_to_scan: Collection[int], timeout: float): def _scan_tcp_ports(
open_ports = _check_tcp_ports(host, ports_to_scan, timeout) host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue
) -> Dict[int, PortScanData]:
event_timestamp, open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
return _build_port_scan_data(ports_to_scan, open_ports) port_scan_data = _build_port_scan_data(ports_to_scan, open_ports)
tcp_scan_event = _generate_tcp_scan_event(host, port_scan_data, event_timestamp)
agent_event_queue.publish(tcp_scan_event)
return port_scan_data
def _generate_tcp_scan_event(
host: str, port_scan_data: Dict[int, PortScanData], event_timestamp: float
):
port_statuses = {port: psd.status for port, psd in port_scan_data.items()}
return TCPScanEvent(
source=get_agent_id(),
target=IPv4Address(host),
timestamp=event_timestamp,
ports=port_statuses,
)
def _build_port_scan_data( def _build_port_scan_data(
ports_to_scan: Iterable[int], open_ports: Mapping[int, str] ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
) -> Mapping[int, PortScanData]: ) -> Dict[int, PortScanData]:
port_scan_data = {} port_scan_data = {}
for port in ports_to_scan: for port in ports_to_scan:
if port in open_ports: if port in open_ports:
@ -55,7 +78,7 @@ def _get_closed_port_data(port: int) -> PortScanData:
def _check_tcp_ports( def _check_tcp_ports(
ip: str, ports_to_scan: Collection[int], timeout: float = DEFAULT_TIMEOUT ip: str, ports_to_scan: Collection[int], timeout: float = DEFAULT_TIMEOUT
) -> Mapping[int, str]: ) -> Tuple[float, Dict[int, str]]:
""" """
Checks whether any of the given ports are open on a target IP. Checks whether any of the given ports are open on a target IP.
:param ip: IP of host to attack :param ip: IP of host to attack
@ -72,6 +95,7 @@ def _check_tcp_ports(
connected_ports = set() connected_ports = set()
open_ports = {} open_ports = {}
event_timestamp = time()
try: try:
logger.debug( logger.debug(
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan)) "Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
@ -100,7 +124,7 @@ def _check_tcp_ports(
while (not timer.is_expired()) and sockets_to_try: while (not timer.is_expired()) and sockets_to_try:
# The call to select() may return sockets that are writeable but not actually # The call to select() may return sockets that are writeable but not actually
# connected. Adding this sleep prevents excessive looping. # connected. Adding this sleep prevents excessive looping.
time.sleep(min(POLL_INTERVAL, timer.time_remaining)) sleep(min(POLL_INTERVAL, timer.time_remaining))
sock_objects = [s[1] for s in sockets_to_try] sock_objects = [s[1] for s in sockets_to_try]
@ -136,7 +160,7 @@ def _check_tcp_ports(
_clean_up_sockets(possible_ports, connected_ports) _clean_up_sockets(possible_ports, connected_ports)
return open_ports return event_timestamp, open_ports
def _clean_up_sockets( def _clean_up_sockets(

View File

@ -2,20 +2,46 @@ from unittest.mock import MagicMock
import pytest import pytest
from common.agent_events import TCPScanEvent
from common.types import PortStatus from common.types import PortStatus
from infection_monkey.i_puppet import PortScanData
from infection_monkey.network_scanning import scan_tcp_ports from infection_monkey.network_scanning import scan_tcp_ports
from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN
from infection_monkey.utils.ids import get_agent_id
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222] PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"} OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"}
TIMESTAMP = 123.321
HOST_IP = "127.0.0.1"
@pytest.fixture(autouse=True)
def patch_timestamp(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner.time",
lambda: TIMESTAMP,
)
@pytest.fixture @pytest.fixture
def patch_check_tcp_ports(monkeypatch, open_ports_data): def patch_check_tcp_ports(monkeypatch, open_ports_data):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner._check_tcp_ports", "infection_monkey.network_scanning.tcp_scanner._check_tcp_ports",
lambda *_: open_ports_data, lambda *_: (TIMESTAMP, open_ports_data),
)
def _get_tcp_scan_event(port_scan_data: PortScanData):
port_statuses = {port: psd.status for port, psd in port_scan_data.items()}
return TCPScanEvent(
source=get_agent_id(),
target=HOST_IP,
timestamp=TIMESTAMP,
ports=port_statuses,
) )
@ -25,7 +51,7 @@ def test_tcp_successful(
): ):
closed_ports = [8080, 143, 445] closed_ports = [8080, 143, 445]
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data.keys(): for port in open_ports_data.keys():
@ -38,12 +64,17 @@ def test_tcp_successful(
assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response( def test_tcp_empty_response(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
): ):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data: for port in open_ports_data:
@ -51,15 +82,25 @@ def test_tcp_empty_response(
assert port_scan_data[port].status == PortStatus.CLOSED assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan( def test_tcp_no_ports_to_scan(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
): ):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue) port_scan_data = scan_tcp_ports(HOST_IP, [], 0, mock_agent_event_queue)
assert len(port_scan_data) == 0 assert len(port_scan_data) == 0
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
def test_exception_handling(monkeypatch, mock_agent_event_queue): def test_exception_handling(monkeypatch, mock_agent_event_queue):
monkeypatch.setattr( monkeypatch.setattr(
@ -67,3 +108,4 @@ def test_exception_handling(monkeypatch, mock_agent_event_queue):
MagicMock(side_effect=Exception), MagicMock(side_effect=Exception),
) )
assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN
assert mock_agent_event_queue.publish.call_count == 0