Agent: Separate responsibilites in network.relay.utils.find_server

This commit is contained in:
Ilija Lazoroski 2022-09-07 16:45:47 +02:00 committed by Mike Salvatore
parent c6c6cf1e79
commit bb2b4aaf6c
2 changed files with 45 additions and 9 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
import socket import socket
from typing import Optional, Sequence from typing import Iterable, Optional
import requests import requests
@ -12,13 +12,12 @@ from infection_monkey.utils.threading import create_daemon_thread
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_server(self, servers: Sequence[str]) -> Optional[str]: def find_server(servers: Iterable[str]) -> Optional[str]:
server_found = None server_found = None
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}") logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
server_iterator = (s for s in servers) for server in servers:
for server in server_iterator:
try: try:
debug_message = f"Trying to connect to server: {server}" debug_message = f"Trying to connect to server: {server}"
logger.debug(debug_message) logger.debug(debug_message)
@ -40,18 +39,20 @@ def find_server(self, servers: Sequence[str]) -> Optional[str]:
f"Exception encountered when trying to connect to server/relay {server}: {err}" f"Exception encountered when trying to connect to server/relay {server}: {err}"
) )
for server in server_iterator: return server_found
def send_relay_control_message(servers: Iterable[str]):
for server in servers:
t = create_daemon_thread( t = create_daemon_thread(
target=_send_relay_control_message, target=_open_socket_to_server,
name="SendControlRelayMessageThread", name="SendControlRelayMessageThread",
args=(server,), args=(server,),
) )
t.start() t.start()
return server_found
def _open_socket_to_server(server: str):
def _send_relay_control_message(server: str):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT) d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)

View File

@ -0,0 +1,35 @@
import pytest
import requests
from infection_monkey.network.relay.utils import find_server
SERVER_1 = "1.1.1.1:12312"
SERVER_2 = "2.2.2.2:4321"
SERVER_3 = "3.3.3.3:3142"
SERVER_4 = "4.4.4.4:5000"
class MockConnectionError:
def __init__(self, *args, **kwargs):
raise requests.exceptions.ConnectionError
class MockRequestsGetResponsePerServerArgument:
def __init__(self, *args, **kwargs):
if SERVER_1 in args[0]:
MockConnectionError()
@pytest.fixture
def servers():
return [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
@pytest.mark.parametrize(
"mock_requests_get, expected",
[(MockConnectionError, None), (MockRequestsGetResponsePerServerArgument, SERVER_2)],
)
def test_find_server__no_available_relays(monkeypatch, servers, mock_requests_get, expected):
monkeypatch.setattr("infection_monkey.control.requests.get", mock_requests_get)
assert find_server(servers) is expected