diff --git a/monkey/infection_monkey/network/relay/__init__.py b/monkey/infection_monkey/network/relay/__init__.py index 41bc81e80..5817c0a15 100644 --- a/monkey/infection_monkey/network/relay/__init__.py +++ b/monkey/infection_monkey/network/relay/__init__.py @@ -1,3 +1,4 @@ +from .relay_connection_handler import RelayConnectionHandler from .relay_user_handler import RelayUser, RelayUserHandler from .tcp_connection_handler import TCPConnectionHandler from .tcp import SocketsPipe diff --git a/monkey/infection_monkey/network/relay/relay_connection_handler.py b/monkey/infection_monkey/network/relay/relay_connection_handler.py new file mode 100644 index 000000000..8f039488b --- /dev/null +++ b/monkey/infection_monkey/network/relay/relay_connection_handler.py @@ -0,0 +1,23 @@ +import socket +from ipaddress import IPv4Address + +from .relay_user_handler import RelayUserHandler +from .tcp_pipe_spawner import TCPPipeSpawner + +RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -" + + +class RelayConnectionHandler: + def __init__(self, pipe_spawner: TCPPipeSpawner, relay_user_handler: RelayUserHandler): + self._pipe_spawner = pipe_spawner + self._relay_user_handler = relay_user_handler + + def handle_new_connection(self, sock: socket.socket): + control_message = sock.recv(socket.MSG_PEEK) + addr, _ = sock.getpeername() # TODO check the type of the addr object + if control_message.startswith(RELAY_CONTROL_MESSAGE): + + self._relay_user_handler.disconnect_user(IPv4Address(addr)) + else: + self._relay_user_handler.add_relay_user(IPv4Address(addr)) + self._pipe_spawner.spawn_pipe(sock) diff --git a/monkey/infection_monkey/network/relay/relay_user_handler.py b/monkey/infection_monkey/network/relay/relay_user_handler.py index 364f5985e..0a32988b5 100644 --- a/monkey/infection_monkey/network/relay/relay_user_handler.py +++ b/monkey/infection_monkey/network/relay/relay_user_handler.py @@ -52,11 +52,16 @@ class RelayUserHandler: :param user_address: An address defining RelayUser which received the data """ if data.startswith(RELAY_CONTROL_MESSAGE): - self._disconnect_user(user_address) + self.disconnect_user(user_address) return False return True - def _disconnect_user(self, user_address: IPv4Address): + def disconnect_user(self, user_address: IPv4Address): + """ + Handle when a user disconnects. + + :param user_address: The address of the disconnecting user. + """ with self._lock: if user_address in self._relay_users: del self._relay_users[user_address] diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py new file mode 100644 index 000000000..c871681b6 --- /dev/null +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_relay_connection_handler.py @@ -0,0 +1,64 @@ +import socket +from ipaddress import IPv4Address +from unittest.mock import MagicMock + +import pytest + +from monkey.infection_monkey.network.relay import ( + RelayConnectionHandler, + RelayUserHandler, + TCPPipeSpawner, +) +from monkey.infection_monkey.network.relay.relay_connection_handler import RELAY_CONTROL_MESSAGE + +USER_ADDRESS = "0.0.0.1" + + +@pytest.fixture +def pipe_spawner(): + return MagicMock(spec=TCPPipeSpawner) + + +@pytest.fixture +def relay_user_handler(): + return MagicMock(spec=RelayUserHandler) + + +@pytest.fixture +def close_socket(): + sock = MagicMock(spec=socket.socket) + sock.recv.return_value = RELAY_CONTROL_MESSAGE + sock.getpeername.return_value = (USER_ADDRESS, 12345) + return sock + + +@pytest.fixture +def data_socket(): + sock = MagicMock(spec=socket.socket) + sock.recv.return_value = b"some data" + sock.getpeername.return_value = (USER_ADDRESS, 12345) + return sock + + +def test_control_message_disconnects_user(pipe_spawner, relay_user_handler, close_socket): + connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler) + + connection_handler.handle_new_connection(close_socket) + + relay_user_handler.disconnect_user.assert_called_once_with(IPv4Address(USER_ADDRESS)) + + +def test_connection_spawns_pipe(pipe_spawner, relay_user_handler, data_socket): + connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler) + + connection_handler.handle_new_connection(data_socket) + + assert pipe_spawner.spawn_pipe.called + + +def test_connection_adds_user(pipe_spawner, relay_user_handler, data_socket): + connection_handler = RelayConnectionHandler(pipe_spawner, relay_user_handler) + + connection_handler.handle_new_connection(data_socket) + + relay_user_handler.add_relay_user.assert_called_once_with(IPv4Address(USER_ADDRESS))