forked from p15670423/monkey
Agent: Modify find_server to accept list of servers
This commit is contained in:
parent
ac058c7788
commit
804bd4eadb
|
@ -1,20 +1,24 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
|
import socket
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
from typing import Mapping, Optional
|
from typing import Mapping, Optional, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import ConnectionError
|
from requests.exceptions import ConnectionError
|
||||||
|
|
||||||
import infection_monkey.tunnel as tunnel
|
import infection_monkey.tunnel as tunnel
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
||||||
|
from common.network.network_utils import address_to_ip_port
|
||||||
from infection_monkey.config import GUID
|
from infection_monkey.config import GUID
|
||||||
from infection_monkey.network.info import get_host_subnets, local_ips
|
from infection_monkey.network.info import get_host_subnets, local_ips
|
||||||
|
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE
|
||||||
from infection_monkey.transport.http import HTTPConnectProxy
|
from infection_monkey.transport.http import HTTPConnectProxy
|
||||||
from infection_monkey.transport.tcp import TcpProxy
|
from infection_monkey.transport.tcp import TcpProxy
|
||||||
from infection_monkey.utils import agent_process
|
from infection_monkey.utils import agent_process
|
||||||
from infection_monkey.utils.environment import is_windows_os
|
from infection_monkey.utils.environment import is_windows_os
|
||||||
|
from infection_monkey.utils.threading import create_daemon_thread
|
||||||
|
|
||||||
requests.packages.urllib3.disable_warnings()
|
requests.packages.urllib3.disable_warnings()
|
||||||
|
|
||||||
|
@ -63,37 +67,66 @@ class ControlClient:
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_server(self, default_tunnel=None):
|
def find_server(self, servers: Sequence[str]):
|
||||||
logger.debug(f"Trying to wake up with Monkey Island server: {self.server_address}")
|
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
|
||||||
if default_tunnel:
|
|
||||||
logger.debug("default_tunnel: %s" % (default_tunnel,))
|
|
||||||
|
|
||||||
try:
|
while servers:
|
||||||
debug_message = "Trying to connect to server: %s" % self.server_address
|
server = servers[0]
|
||||||
if self.proxies:
|
|
||||||
debug_message += " through proxies: %s" % self.proxies
|
|
||||||
logger.debug(debug_message)
|
|
||||||
requests.get( # noqa: DUO123
|
|
||||||
f"https://{self.server_address}/api?action=is-up",
|
|
||||||
verify=False,
|
|
||||||
proxies=self.proxies,
|
|
||||||
timeout=MEDIUM_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except ConnectionError as exc:
|
|
||||||
logger.warning("Error connecting to control server %s: %s", self.server_address, exc)
|
|
||||||
|
|
||||||
if self.proxies:
|
try:
|
||||||
return False
|
debug_message = f"Trying to connect to server: {server}"
|
||||||
else:
|
logger.debug(debug_message)
|
||||||
logger.info("Starting tunnel lookup...")
|
requests.get( # noqa: DUO123
|
||||||
proxy_find = tunnel.find_tunnel(default=default_tunnel)
|
f"https://{server}/api?action=is-up",
|
||||||
if proxy_find:
|
verify=False,
|
||||||
self.set_proxies(proxy_find)
|
timeout=MEDIUM_REQUEST_TIMEOUT,
|
||||||
return self.find_server()
|
)
|
||||||
else:
|
|
||||||
logger.info("No tunnel found")
|
# We remove the server that has been succesfull
|
||||||
return False
|
servers.remove(server)
|
||||||
|
|
||||||
|
# TODO: Check how we are going to set the server address that the ControlCLient
|
||||||
|
# is going to use
|
||||||
|
# self.server_address = server
|
||||||
|
|
||||||
|
# If we have any other server we send them RELAY_CONTROL_MESSAGE
|
||||||
|
if servers:
|
||||||
|
for ss in servers:
|
||||||
|
t = create_daemon_thread(
|
||||||
|
target=ControlClient._send_relay_control_message,
|
||||||
|
name="SendControlRelayMessageThread",
|
||||||
|
args=(ss,),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ConnectionError as err:
|
||||||
|
logger.error(f"Unable to connect to server/relay {server}: {err}")
|
||||||
|
servers.remove(server)
|
||||||
|
except TimeoutError as err:
|
||||||
|
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
|
||||||
|
servers.remove(server)
|
||||||
|
except Exception as err:
|
||||||
|
logger.error(
|
||||||
|
f"Exception encountered when trying to connect to server/relay {server}: {err}"
|
||||||
|
)
|
||||||
|
servers.remove(server)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _send_relay_control_message(server: str):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
|
||||||
|
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
|
||||||
|
|
||||||
|
try:
|
||||||
|
address, port = address_to_ip_port(server)
|
||||||
|
d_socket.connect((address, int(port)))
|
||||||
|
d_socket.send(RELAY_CONTROL_MESSAGE)
|
||||||
|
logger.info(f"Control message was sent to the server/relay {server}")
|
||||||
|
except OSError as err:
|
||||||
|
logger.error(f"Error connecting to socket {server}: {err}")
|
||||||
|
|
||||||
def set_proxies(self, proxy_find):
|
def set_proxies(self, proxy_find):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,7 +1,31 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
from monkey.infection_monkey.control import ControlClient
|
from monkey.infection_monkey.control import ControlClient
|
||||||
|
|
||||||
|
SERVER_1 = "1.1.1.1:12312"
|
||||||
|
SERVER_2 = "2.2.2.2:4321"
|
||||||
|
SERVER_3 = "3.3.3.3:3142"
|
||||||
|
SERVER_4 = "4.4.4.4:5000"
|
||||||
|
|
||||||
|
|
||||||
|
class MockConnectionError:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise requests.exceptions.ConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
class RequestsGetArgument:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if SERVER_1 in args[0]:
|
||||||
|
MockConnectionError()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def servers():
|
||||||
|
return [SERVER_1, SERVER_2, SERVER_3, SERVER_4]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"is_windows_os,expected_proxy_string",
|
"is_windows_os,expected_proxy_string",
|
||||||
|
@ -14,3 +38,33 @@ def test_control_set_proxies(monkeypatch, is_windows_os, expected_proxy_string):
|
||||||
control_client.set_proxies(("8.8.8.8", "45455"))
|
control_client.set_proxies(("8.8.8.8", "45455"))
|
||||||
|
|
||||||
assert control_client.proxies["https"] == expected_proxy_string
|
assert control_client.proxies["https"] == expected_proxy_string
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_find_server_any_exception(monkeypatch, servers):
|
||||||
|
monkeypatch.setattr("infection_monkey.control.requests.get", MockConnectionError)
|
||||||
|
|
||||||
|
cc = ControlClient(servers)
|
||||||
|
|
||||||
|
return_value = cc.find_server(servers)
|
||||||
|
|
||||||
|
assert return_value is False
|
||||||
|
assert servers == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_find_server_socket(monkeypatch, servers):
|
||||||
|
mock_connect = MagicMock()
|
||||||
|
mock_send = MagicMock()
|
||||||
|
monkeypatch.setattr("infection_monkey.control.requests.get", RequestsGetArgument)
|
||||||
|
monkeypatch.setattr("infection_monkey.control.socket.socket.connect", mock_connect)
|
||||||
|
monkeypatch.setattr("infection_monkey.control.socket.socket.send", mock_send)
|
||||||
|
|
||||||
|
cc = ControlClient(servers)
|
||||||
|
|
||||||
|
return_value = cc.find_server(servers)
|
||||||
|
|
||||||
|
assert len(servers) == 2
|
||||||
|
assert return_value is True
|
||||||
|
assert mock_connect.call_count == 2
|
||||||
|
assert mock_send.call_count == 2
|
||||||
|
# TODO: be sure that connect is called with SERVER_3 and SERVER_4
|
||||||
|
# assert mock_connect.call_args == SERVER_3
|
||||||
|
|
Loading…
Reference in New Issue