forked from p15670423/monkey
Agent: Separate responsibilites in network.relay.utils.find_server
This commit is contained in:
parent
c6c6cf1e79
commit
bb2b4aaf6c
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import Optional, Sequence
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -12,13 +12,12 @@ from infection_monkey.utils.threading import create_daemon_thread
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_server(self, servers: Sequence[str]) -> Optional[str]:
|
def find_server(servers: Iterable[str]) -> Optional[str]:
|
||||||
server_found = None
|
server_found = None
|
||||||
|
|
||||||
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
|
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
|
||||||
|
|
||||||
server_iterator = (s for s in servers)
|
for server in servers:
|
||||||
for server in server_iterator:
|
|
||||||
try:
|
try:
|
||||||
debug_message = f"Trying to connect to server: {server}"
|
debug_message = f"Trying to connect to server: {server}"
|
||||||
logger.debug(debug_message)
|
logger.debug(debug_message)
|
||||||
|
@ -40,18 +39,20 @@ def find_server(self, servers: Sequence[str]) -> Optional[str]:
|
||||||
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for server in server_iterator:
|
return server_found
|
||||||
|
|
||||||
|
|
||||||
|
def send_relay_control_message(servers: Iterable[str]):
|
||||||
|
for server in servers:
|
||||||
t = create_daemon_thread(
|
t = create_daemon_thread(
|
||||||
target=_send_relay_control_message,
|
target=_open_socket_to_server,
|
||||||
name="SendControlRelayMessageThread",
|
name="SendControlRelayMessageThread",
|
||||||
args=(server,),
|
args=(server,),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
return server_found
|
|
||||||
|
|
||||||
|
def _open_socket_to_server(server: str):
|
||||||
def _send_relay_control_message(server: str):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
|
||||||
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
|
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from infection_monkey.network.relay.utils import find_server
|
||||||
|
|
||||||
|
SERVER_1 = "1.1.1.1:12312"
|
||||||
|
SERVER_2 = "2.2.2.2:4321"
|
||||||
|
SERVER_3 = "3.3.3.3:3142"
|
||||||
|
SERVER_4 = "4.4.4.4:5000"
|
||||||
|
|
||||||
|
|
||||||
|
class MockConnectionError:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise requests.exceptions.ConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
class MockRequestsGetResponsePerServerArgument:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if SERVER_1 in args[0]:
|
||||||
|
MockConnectionError()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def servers():
|
||||||
|
return [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mock_requests_get, expected",
|
||||||
|
[(MockConnectionError, None), (MockRequestsGetResponsePerServerArgument, SERVER_2)],
|
||||||
|
)
|
||||||
|
def test_find_server__no_available_relays(monkeypatch, servers, mock_requests_get, expected):
|
||||||
|
monkeypatch.setattr("infection_monkey.control.requests.get", mock_requests_get)
|
||||||
|
|
||||||
|
assert find_server(servers) is expected
|
Loading…
Reference in New Issue