Merge pull request #2377 from guardicore/2267-publish-tcp-scan-event

Publish TCPScanEvent
This commit is contained in:
Mike Salvatore 2022-09-30 11:59:51 -04:00 committed by GitHub
commit 9a6300481c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 13 deletions

View File

@ -1,15 +1,18 @@
import logging
import select
import socket
import time
from ipaddress import IPv4Address
from pprint import pformat
from time import sleep, time
from typing import Collection, Dict, Iterable, Mapping, Tuple
from common.agent_events import TCPScanEvent
from common.event_queue import IAgentEventQueue
from common.types import PortStatus
from common.utils import Timer
from infection_monkey.i_puppet import PortScanData
from infection_monkey.network.tools import BANNER_READ, DEFAULT_TIMEOUT, tcp_port_to_service
from infection_monkey.utils.ids import get_agent_id
logger = logging.getLogger(__name__)
@ -21,21 +24,41 @@ def scan_tcp_ports(
host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue
) -> Dict[int, PortScanData]:
try:
return _scan_tcp_ports(host, ports_to_scan, timeout)
return _scan_tcp_ports(host, ports_to_scan, timeout, agent_event_queue)
except Exception:
logger.exception("Unhandled exception occurred while trying to scan tcp ports")
return EMPTY_PORT_SCAN
def _scan_tcp_ports(host: str, ports_to_scan: Collection[int], timeout: float):
open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
def _scan_tcp_ports(
host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue
) -> Dict[int, PortScanData]:
event_timestamp, open_ports = _check_tcp_ports(host, ports_to_scan, timeout)
return _build_port_scan_data(ports_to_scan, open_ports)
port_scan_data = _build_port_scan_data(ports_to_scan, open_ports)
tcp_scan_event = _generate_tcp_scan_event(host, port_scan_data, event_timestamp)
agent_event_queue.publish(tcp_scan_event)
return port_scan_data
def _generate_tcp_scan_event(
host: str, port_scan_data: Dict[int, PortScanData], event_timestamp: float
):
port_statuses = {port: psd.status for port, psd in port_scan_data.items()}
return TCPScanEvent(
source=get_agent_id(),
target=IPv4Address(host),
timestamp=event_timestamp,
ports=port_statuses,
)
def _build_port_scan_data(
ports_to_scan: Iterable[int], open_ports: Mapping[int, str]
) -> Mapping[int, PortScanData]:
) -> Dict[int, PortScanData]:
port_scan_data = {}
for port in ports_to_scan:
if port in open_ports:
@ -55,7 +78,7 @@ def _get_closed_port_data(port: int) -> PortScanData:
def _check_tcp_ports(
ip: str, ports_to_scan: Collection[int], timeout: float = DEFAULT_TIMEOUT
) -> Mapping[int, str]:
) -> Tuple[float, Dict[int, str]]:
"""
Checks whether any of the given ports are open on a target IP.
:param ip: IP of host to attack
@ -72,6 +95,7 @@ def _check_tcp_ports(
connected_ports = set()
open_ports = {}
event_timestamp = time()
try:
logger.debug(
"Connecting to the following ports %s" % ",".join((str(x) for x in ports_to_scan))
@ -100,7 +124,7 @@ def _check_tcp_ports(
while (not timer.is_expired()) and sockets_to_try:
# The call to select() may return sockets that are writeable but not actually
# connected. Adding this sleep prevents excessive looping.
time.sleep(min(POLL_INTERVAL, timer.time_remaining))
sleep(min(POLL_INTERVAL, timer.time_remaining))
sock_objects = [s[1] for s in sockets_to_try]
@ -136,7 +160,7 @@ def _check_tcp_ports(
_clean_up_sockets(possible_ports, connected_ports)
return open_ports
return event_timestamp, open_ports
def _clean_up_sockets(

View File

@ -2,20 +2,46 @@ from unittest.mock import MagicMock
import pytest
from common.agent_events import TCPScanEvent
from common.types import PortStatus
from infection_monkey.i_puppet import PortScanData
from infection_monkey.network_scanning import scan_tcp_ports
from infection_monkey.network_scanning.tcp_scanner import EMPTY_PORT_SCAN
from infection_monkey.utils.ids import get_agent_id
PORTS_TO_SCAN = [22, 80, 8080, 143, 445, 2222]
OPEN_PORTS_DATA = {22: "SSH-banner", 80: "", 2222: "SSH2-banner"}
TIMESTAMP = 123.321
HOST_IP = "127.0.0.1"
@pytest.fixture(autouse=True)
def patch_timestamp(monkeypatch):
monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner.time",
lambda: TIMESTAMP,
)
@pytest.fixture
def patch_check_tcp_ports(monkeypatch, open_ports_data):
monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner._check_tcp_ports",
lambda *_: open_ports_data,
lambda *_: (TIMESTAMP, open_ports_data),
)
def _get_tcp_scan_event(port_scan_data: PortScanData):
port_statuses = {port: psd.status for port, psd in port_scan_data.items()}
return TCPScanEvent(
source=get_agent_id(),
target=HOST_IP,
timestamp=TIMESTAMP,
ports=port_statuses,
)
@ -25,7 +51,7 @@ def test_tcp_successful(
):
closed_ports = [8080, 143, 445]
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue)
port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6
for port in open_ports_data.keys():
@ -38,12 +64,17 @@ def test_tcp_successful(
assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue)
port_scan_data = scan_tcp_ports(HOST_IP, PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6
for port in open_ports_data:
@ -51,15 +82,25 @@ def test_tcp_empty_response(
assert port_scan_data[port].status == PortStatus.CLOSED
assert port_scan_data[port].banner is None
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue)
port_scan_data = scan_tcp_ports(HOST_IP, [], 0, mock_agent_event_queue)
assert len(port_scan_data) == 0
event = _get_tcp_scan_event(port_scan_data)
assert mock_agent_event_queue.publish.call_count == 1
mock_agent_event_queue.publish.assert_called_with(event)
def test_exception_handling(monkeypatch, mock_agent_event_queue):
monkeypatch.setattr(
@ -67,3 +108,4 @@ def test_exception_handling(monkeypatch, mock_agent_event_queue):
MagicMock(side_effect=Exception),
)
assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN
assert mock_agent_event_queue.publish.call_count == 0