Agent: Use IPv4Address for RelayUser.address

This commit is contained in:
Ilija Lazoroski 2022-09-01 12:10:21 +02:00
parent 1d394bbd2e
commit 9a3afb051d
2 changed files with 38 additions and 30 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from time import sleep, time from time import sleep, time
from typing import List from typing import List
@ -11,7 +12,7 @@ RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
@dataclass @dataclass
class RelayUser: class RelayUser:
address: str address: IPv4Address
last_update_time: float last_update_time: float
@ -59,17 +60,17 @@ class TCPRelay(Thread):
def stop(self): def stop(self):
self._stopped.set() self._stopped.set()
def on_user_connected(self, user: str): def on_user_connected(self, user_address: IPv4Address):
""" """
Handle new user connection. Handle new user connection.
:param user: A user which will be added to the relay :param user: A user which will be added to the relay
""" """
with self._lock: with self._lock:
self._potential_users = [u for u in self._potential_users if u.address != user] self._potential_users = [u for u in self._potential_users if u.address != user_address]
self._relay_users.append(RelayUser(user, time())) self._relay_users.append(RelayUser(user_address, time()))
def on_user_disconnected(self, user: str): def on_user_disconnected(self, user_address: IPv4Address):
"""Handle user disconnection.""" """Handle user disconnection."""
pass pass
@ -80,16 +81,16 @@ class TCPRelay(Thread):
with self._lock: with self._lock:
return self._relay_users.copy() return self._relay_users.copy()
def on_potential_new_user(self, user: str): def on_potential_new_user(self, user_address: IPv4Address):
""" """
Notify TCPRelay that a new user may try and connect. Notify TCPRelay that a new user may try and connect.
:param user: A potential user that tries to connect to the relay :param user: A potential user that tries to connect to the relay
""" """
with self._lock: with self._lock:
self._potential_users.append(RelayUser(user, time())) self._potential_users.append(RelayUser(user_address, time()))
def on_user_data_received(self, data: bytes, user: str) -> bool: def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
""" """
Disconnect a user which a specific starting data. Disconnect a user which a specific starting data.
@ -97,13 +98,13 @@ class TCPRelay(Thread):
:param user: User which send the data :param user: User which send the data
""" """
if data.startswith(RELAY_CONTROL_MESSAGE): if data.startswith(RELAY_CONTROL_MESSAGE):
self._disconnect_user(user) self._disconnect_user(user_address)
return False return False
return True return True
def _disconnect_user(self, user: str): def _disconnect_user(self, user_address: IPv4Address):
with self._lock: with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user] self._relay_users = [u for u in self._relay_users if u.address != user_address]
def _wait_for_users_to_disconnect(self): def _wait_for_users_to_disconnect(self):
stop = False stop = False
@ -111,7 +112,7 @@ class TCPRelay(Thread):
sleep(0.01) sleep(0.01)
current_time = time() current_time = time()
most_recent_potential_time = max( most_recent_potential_time = max(
self._potential_users, key=lambda u: u.time, default=RelayUser("", 0) self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0)
).time ).time
potential_elapsed = current_time - most_recent_potential_time potential_elapsed = current_time - most_recent_potential_time

View File

@ -1,7 +1,20 @@
from ipaddress import IPv4Address
from threading import Thread from threading import Thread
import pytest
from monkey.infection_monkey.tcp_relay import RELAY_CONTROL_MESSAGE, TCPRelay from monkey.infection_monkey.tcp_relay import RELAY_CONTROL_MESSAGE, TCPRelay
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
LOCAL_PORT = 9975
TARGET_ADDRESS = "0.0.0.0"
TARGET_PORT = 9976
@pytest.fixture
def tcp_relay():
return TCPRelay(LOCAL_PORT, TARGET_ADDRESS, TARGET_PORT)
def join_or_kill_thread(thread: Thread, timeout: float): def join_or_kill_thread(thread: Thread, timeout: float):
"""Whether or not the thread joined in the given timeout period.""" """Whether or not the thread joined in the given timeout period."""
@ -21,34 +34,28 @@ def join_or_kill_thread(thread: Thread, timeout: float):
# assert join_or_kill_thread(relay, 0.2) # assert join_or_kill_thread(relay, 0.2)
def test_user_added(): def test_user_added(tcp_relay):
relay = TCPRelay(9975, "0.0.0.0", 9976) tcp_relay.on_user_connected(NEW_USER_ADDRESS)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
users = relay.relay_users() users = tcp_relay.relay_users()
assert len(users) == 1 assert len(users) == 1
assert users[0].address == new_user assert users[0].address == NEW_USER_ADDRESS
def test_user_not_removed_on_disconnect(): def test_user_not_removed_on_disconnect(tcp_relay):
# A user should only be disconnected when they send a disconnect request # A user should only be disconnected when they send a disconnect request
relay = TCPRelay(9975, "0.0.0.0", 9976) tcp_relay.on_user_connected(NEW_USER_ADDRESS)
new_user = "0.0.0.1" tcp_relay.on_user_disconnected(NEW_USER_ADDRESS)
relay.on_user_connected(new_user)
relay.on_user_disconnected(new_user)
users = relay.relay_users() users = tcp_relay.relay_users()
assert len(users) == 1 assert len(users) == 1
def test_user_removed_on_request(): def test_user_removed_on_request(tcp_relay):
relay = TCPRelay(9975, "0.0.0.0", 9976) tcp_relay.on_user_connected(NEW_USER_ADDRESS)
new_user = "0.0.0.1" tcp_relay.on_user_data_received(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
relay.on_user_connected(new_user)
relay.on_user_data_received(RELAY_CONTROL_MESSAGE, "0.0.0.1")
users = relay.relay_users() users = tcp_relay.relay_users()
assert len(users) == 0 assert len(users) == 0