Merge pull request #2375 from guardicore/2267-tcp-scanner-accept-iagenteventqueue

2267 tcp scanner accept iagenteventqueue
This commit is contained in:
Mike Salvatore 2022-09-29 15:41:48 -04:00 committed by GitHub
commit 31c97faf98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 16 deletions

View File

@ -5,6 +5,7 @@ import time
from pprint import pformat from pprint import pformat
from typing import Collection, Dict, Iterable, Mapping, Tuple from typing import Collection, Dict, Iterable, Mapping, Tuple
from common.event_queue import IAgentEventQueue
from common.types import PortStatus from common.types import PortStatus
from common.utils import Timer from common.utils import Timer
from infection_monkey.i_puppet import PortScanData from infection_monkey.i_puppet import PortScanData
@ -17,7 +18,7 @@ EMPTY_PORT_SCAN = {-1: PortScanData(-1, PortStatus.CLOSED, None, None)}
def scan_tcp_ports( def scan_tcp_ports(
host: str, ports_to_scan: Collection[int], timeout: float host: str, ports_to_scan: Collection[int], timeout: float, agent_event_queue: IAgentEventQueue
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
try: try:
return _scan_tcp_ports(host, ports_to_scan, timeout) return _scan_tcp_ports(host, ports_to_scan, timeout)

View File

@ -48,7 +48,7 @@ class Puppet(IPuppet):
def scan_tcp_ports( def scan_tcp_ports(
self, host: str, ports: Sequence[int], timeout: float = CONNECTION_TIMEOUT self, host: str, ports: Sequence[int], timeout: float = CONNECTION_TIMEOUT
) -> Dict[int, PortScanData]: ) -> Dict[int, PortScanData]:
return network_scanning.scan_tcp_ports(host, ports, timeout) return network_scanning.scan_tcp_ports(host, ports, timeout, self._agent_event_queue)
def fingerprint( def fingerprint(
self, self,

View File

@ -0,0 +1,10 @@
from unittest.mock import MagicMock
import pytest
from common.event_queue import IAgentEventQueue
@pytest.fixture
def mock_agent_event_queue() -> IAgentEventQueue:
return MagicMock(spec=IAgentEventQueue)

View File

@ -7,7 +7,6 @@ import pytest
import infection_monkey.network_scanning.ping_scanner # noqa: F401 import infection_monkey.network_scanning.ping_scanner # noqa: F401
from common import OperatingSystem from common import OperatingSystem
from common.agent_events import PingScanEvent from common.agent_events import PingScanEvent
from common.event_queue import IAgentEventQueue
from common.types import PingScanData from common.types import PingScanData
from infection_monkey.network_scanning import ping from infection_monkey.network_scanning import ping
from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN from infection_monkey.network_scanning.ping_scanner import EMPTY_PING_SCAN
@ -99,11 +98,6 @@ def set_os_windows(monkeypatch):
monkeypatch.setattr("sys.platform", "win32") monkeypatch.setattr("sys.platform", "win32")
@pytest.fixture
def mock_agent_event_queue():
return MagicMock(spec=IAgentEventQueue)
HOST_IP = "192.168.1.1" HOST_IP = "192.168.1.1"
TIMEOUT = 1.0 TIMEOUT = 1.0

View File

@ -20,10 +20,12 @@ def patch_check_tcp_ports(monkeypatch, open_ports_data):
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_successful(
monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
):
closed_ports = [8080, 143, 445] closed_ports = [8080, 143, 445]
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data.keys(): for port in open_ports_data.keys():
@ -38,8 +40,10 @@ def test_tcp_successful(monkeypatch, patch_check_tcp_ports, open_ports_data):
@pytest.mark.parametrize("open_ports_data", [{}]) @pytest.mark.parametrize("open_ports_data", [{}])
def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_empty_response(
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0) monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
):
port_scan_data = scan_tcp_ports("127.0.0.1", PORTS_TO_SCAN, 0, mock_agent_event_queue)
assert len(port_scan_data) == 6 assert len(port_scan_data) == 6
for port in open_ports_data: for port in open_ports_data:
@ -49,15 +53,17 @@ def test_tcp_empty_response(monkeypatch, patch_check_tcp_ports, open_ports_data)
@pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA]) @pytest.mark.parametrize("open_ports_data", [OPEN_PORTS_DATA])
def test_tcp_no_ports_to_scan(monkeypatch, patch_check_tcp_ports, open_ports_data): def test_tcp_no_ports_to_scan(
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0) monkeypatch, patch_check_tcp_ports, open_ports_data, mock_agent_event_queue
):
port_scan_data = scan_tcp_ports("127.0.0.1", [], 0, mock_agent_event_queue)
assert len(port_scan_data) == 0 assert len(port_scan_data) == 0
def test_exception_handling(monkeypatch): def test_exception_handling(monkeypatch, mock_agent_event_queue):
monkeypatch.setattr( monkeypatch.setattr(
"infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports", "infection_monkey.network_scanning.tcp_scanner._scan_tcp_ports",
MagicMock(side_effect=Exception), MagicMock(side_effect=Exception),
) )
assert scan_tcp_ports("abc", [123], 123) == EMPTY_PORT_SCAN assert scan_tcp_ports("abc", [123], 123, mock_agent_event_queue) == EMPTY_PORT_SCAN