From bb2b4aaf6c2b58330e1ef44254e4e99e742ecc6d Mon Sep 17 00:00:00 2001 From: Ilija Lazoroski Date: Wed, 7 Sep 2022 16:45:47 +0200 Subject: [PATCH] Agent: Separate responsibilites in network.relay.utils.find_server --- .../infection_monkey/network/relay/utils.py | 19 +++++----- .../network/relay/test_utils.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 66f0b1d1e..8d3a76277 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -1,6 +1,6 @@ import logging import socket -from typing import Optional, Sequence +from typing import Iterable, Optional import requests @@ -12,13 +12,12 @@ from infection_monkey.utils.threading import create_daemon_thread logger = logging.getLogger(__name__) -def find_server(self, servers: Sequence[str]) -> Optional[str]: +def find_server(servers: Iterable[str]) -> Optional[str]: server_found = None logger.debug(f"Trying to wake up with servers: {', '.join(servers)}") - server_iterator = (s for s in servers) - for server in server_iterator: + for server in servers: try: debug_message = f"Trying to connect to server: {server}" logger.debug(debug_message) @@ -40,18 +39,20 @@ def find_server(self, servers: Sequence[str]) -> Optional[str]: f"Exception encountered when trying to connect to server/relay {server}: {err}" ) - for server in server_iterator: + return server_found + + +def send_relay_control_message(servers: Iterable[str]): + for server in servers: t = create_daemon_thread( - target=_send_relay_control_message, + target=_open_socket_to_server, name="SendControlRelayMessageThread", args=(server,), ) t.start() - return server_found - -def _send_relay_control_message(server: str): +def _open_socket_to_server(server: str): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT) diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py new file mode 100644 index 000000000..15c4f46da --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -0,0 +1,35 @@ +import pytest +import requests + +from infection_monkey.network.relay.utils import find_server + +SERVER_1 = "1.1.1.1:12312" +SERVER_2 = "2.2.2.2:4321" +SERVER_3 = "3.3.3.3:3142" +SERVER_4 = "4.4.4.4:5000" + + +class MockConnectionError: + def __init__(self, *args, **kwargs): + raise requests.exceptions.ConnectionError + + +class MockRequestsGetResponsePerServerArgument: + def __init__(self, *args, **kwargs): + if SERVER_1 in args[0]: + MockConnectionError() + + +@pytest.fixture +def servers(): + return [SERVER_1, SERVER_2, SERVER_3, SERVER_4] + + +@pytest.mark.parametrize( + "mock_requests_get, expected", + [(MockConnectionError, None), (MockRequestsGetResponsePerServerArgument, SERVER_2)], +) +def test_find_server__no_available_relays(monkeypatch, servers, mock_requests_get, expected): + monkeypatch.setattr("infection_monkey.control.requests.get", mock_requests_get) + + assert find_server(servers) is expected