Agent: Move ControlClient.find_server in network/relay/utls.py

This commit is contained in:
Ilija Lazoroski 2022-09-07 15:06:21 +02:00 committed by Mike Salvatore
parent 178b296f75
commit 789d6b8441
3 changed files with 61 additions and 57 deletions

View File

@ -1,24 +1,19 @@
import json
import logging
import platform
import socket
from socket import gethostname
from typing import Mapping, Optional, Sequence
from typing import Mapping, Optional
import requests
from requests.exceptions import ConnectionError
import infection_monkey.tunnel as tunnel
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.config import GUID
from infection_monkey.network.info import get_host_subnets, local_ips
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE
from infection_monkey.transport.http import HTTPConnectProxy
from infection_monkey.transport.tcp import TcpProxy
from infection_monkey.utils import agent_process
from infection_monkey.utils.environment import is_windows_os
from infection_monkey.utils.threading import create_daemon_thread
requests.packages.urllib3.disable_warnings()
@ -67,56 +62,6 @@ class ControlClient:
timeout=MEDIUM_REQUEST_TIMEOUT,
)
def find_server(self, servers: Sequence[str]):
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
server_iterator = (s for s in servers)
for server in server_iterator:
try:
debug_message = f"Trying to connect to server: {server}"
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
break
# TODO: Check how we are going to set the server address that the ControlCLient
# is going to use
# self.server_address = server
except ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
for server in server_iterator:
t = create_daemon_thread(
target=ControlClient._send_relay_control_message,
name="SendControlRelayMessageThread",
args=(server,),
)
t.start()
@staticmethod
def _send_relay_control_message(server: str):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
try:
address, port = address_to_ip_port(server)
d_socket.connect((address, int(port)))
d_socket.send(RELAY_CONTROL_MESSAGE)
logger.info(f"Control message was sent to the server/relay {server}")
except OSError as err:
logger.error(f"Error connecting to socket {server}: {err}")
def set_proxies(self, proxy_find):
"""
Note: The proxy schema changes between different versions of requests and urllib3,

View File

@ -42,6 +42,7 @@ from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_network_interfaces
from infection_monkey.network.relay.utils import find_server
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
@ -162,7 +163,7 @@ class InfectionMonkey:
self._control_client.wakeup(parent=self._opts.parent)
def _current_server_is_set(self) -> bool:
if self._control_client.find_server(default_tunnel=self._opts.servers):
if find_server(servers=self._opts.servers):
return True
return False

View File

@ -0,0 +1,58 @@
import logging
import socket
from typing import Sequence
import requests
from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE
from infection_monkey.utils.threading import create_daemon_thread
logger = logging.getLogger(__name__)
def find_server(self, servers: Sequence[str]):
logger.debug(f"Trying to wake up with servers: {', '.join(servers)}")
server_iterator = (s for s in servers)
for server in server_iterator:
try:
debug_message = f"Trying to connect to server: {server}"
logger.debug(debug_message)
requests.get( # noqa: DUO123
f"https://{server}/api?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
break
except requests.exceptions.ConnectionError as err:
logger.error(f"Unable to connect to server/relay {server}: {err}")
except TimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except Exception as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
)
for server in server_iterator:
t = create_daemon_thread(
target=_send_relay_control_message,
name="SendControlRelayMessageThread",
args=(server,),
)
t.start()
def _send_relay_control_message(server: str):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as d_socket:
d_socket.settimeout(MEDIUM_REQUEST_TIMEOUT)
try:
address, port = address_to_ip_port(server)
d_socket.connect((address, int(port)))
d_socket.send(RELAY_CONTROL_MESSAGE)
logger.info(f"Control message was sent to the server/relay {server}")
except OSError as err:
logger.error(f"Error connecting to socket {server}: {err}")