Agent: Add RelayConnectionHandler

This commit is contained in:
Kekoa Kaaikala 2022-09-02 19:17:43 +00:00
parent b179f602c4
commit bbc9cf16e6
4 changed files with 95 additions and 2 deletions

View File

@ -1,3 +1,4 @@
from .relay_connection_handler import RelayConnectionHandler
from .relay_user_handler import RelayUser, RelayUserHandler from .relay_user_handler import RelayUser, RelayUserHandler
from .tcp_connection_handler import TCPConnectionHandler from .tcp_connection_handler import TCPConnectionHandler
from .tcp import SocketsPipe from .tcp import SocketsPipe

View File

@ -0,0 +1,23 @@
import socket
from ipaddress import IPv4Address
from .relay_user_handler import RelayUserHandler
from .tcp_pipe_spawner import TCPPipeSpawner
RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
class RelayConnectionHandler:
def __init__(self, pipe_spawner: TCPPipeSpawner, relay_user_handler: RelayUserHandler):
self._pipe_spawner = pipe_spawner
self._relay_user_handler = relay_user_handler
def handle_new_connection(self, sock: socket.socket):
control_message = sock.recv(socket.MSG_PEEK)
addr, _ = sock.getpeername() # TODO check the type of the addr object
if control_message.startswith(RELAY_CONTROL_MESSAGE):
self._relay_user_handler.disconnect_user(IPv4Address(addr))
else:
self._relay_user_handler.add_relay_user(IPv4Address(addr))
self._pipe_spawner.spawn_pipe(sock)

View File

@ -52,11 +52,16 @@ class RelayUserHandler:
:param user_address: An address defining RelayUser which received the data :param user_address: An address defining RelayUser which received the data
""" """
if data.startswith(RELAY_CONTROL_MESSAGE): if data.startswith(RELAY_CONTROL_MESSAGE):
self._disconnect_user(user_address) self.disconnect_user(user_address)
return False return False
return True return True
def _disconnect_user(self, user_address: IPv4Address): def disconnect_user(self, user_address: IPv4Address):
"""
Handle when a user disconnects.
:param user_address: The address of the disconnecting user.
"""
with self._lock: with self._lock:
if user_address in self._relay_users: if user_address in self._relay_users:
del self._relay_users[user_address] del self._relay_users[user_address]

View File

@ -0,0 +1,64 @@
import socket
from ipaddress import IPv4Address
from unittest.mock import MagicMock
import pytest
from monkey.infection_monkey.network.relay import (
RelayConnectionHandler,
RelayUserHandler,
TCPPipeSpawner,
)
from monkey.infection_monkey.network.relay.relay_connection_handler import RELAY_CONTROL_MESSAGE
USER_ADDRESS = "0.0.0.1"
@pytest.fixture
def pipe_spawner():
return MagicMock(spec=TCPPipeSpawner)
@pytest.fixture
def relay_user_handler():
return MagicMock(spec=RelayUserHandler)
@pytest.fixture
def close_socket():
sock = MagicMock(spec=socket.socket)
sock.recv.return_value = RELAY_CONTROL_MESSAGE
sock.getpeername.return_value = (USER_ADDRESS, 12345)
return sock
@pytest.fixture
def data_socket():
sock = MagicMock(spec=socket.socket)
sock.recv.return_value = b"some data"
sock.getpeername.return_value = (USER_ADDRESS, 12345)
return sock
def test_control_message_disconnects_user(pipe_spawner, relay_user_handler, close_socket):
connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler)
connection_handler.handle_new_connection(close_socket)
relay_user_handler.disconnect_user.assert_called_once_with(IPv4Address(USER_ADDRESS))
def test_connection_spawns_pipe(pipe_spawner, relay_user_handler, data_socket):
connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler)
connection_handler.handle_new_connection(data_socket)
assert pipe_spawner.spawn_pipe.called
def test_connection_adds_user(pipe_spawner, relay_user_handler, data_socket):
connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler)
connection_handler.handle_new_connection(data_socket)
relay_user_handler.add_relay_user.assert_called_once_with(IPv4Address(USER_ADDRESS))