Agent: Use IPv4Address for RelayUser.address

This commit is contained in:
Ilija Lazoroski 2022-09-01 12:10:21 +02:00
parent 1d394bbd2e
commit 9a3afb051d
2 changed files with 38 additions and 30 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from ipaddress import IPv4Address
from threading import Event, Lock, Thread
from time import sleep, time
from typing import List
@ -11,7 +12,7 @@ RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
@dataclass
class RelayUser:
address: str
address: IPv4Address
last_update_time: float
@ -59,17 +60,17 @@ class TCPRelay(Thread):
def stop(self):
self._stopped.set()
def on_user_connected(self, user: str):
def on_user_connected(self, user_address: IPv4Address):
"""
Handle new user connection.
:param user: A user which will be added to the relay
"""
with self._lock:
self._potential_users = [u for u in self._potential_users if u.address != user]
self._relay_users.append(RelayUser(user, time()))
self._potential_users = [u for u in self._potential_users if u.address != user_address]
self._relay_users.append(RelayUser(user_address, time()))
def on_user_disconnected(self, user: str):
def on_user_disconnected(self, user_address: IPv4Address):
"""Handle user disconnection."""
pass
@ -80,16 +81,16 @@ class TCPRelay(Thread):
with self._lock:
return self._relay_users.copy()
def on_potential_new_user(self, user: str):
def on_potential_new_user(self, user_address: IPv4Address):
"""
Notify TCPRelay that a new user may try and connect.
:param user: A potential user that tries to connect to the relay
"""
with self._lock:
self._potential_users.append(RelayUser(user, time()))
self._potential_users.append(RelayUser(user_address, time()))
def on_user_data_received(self, data: bytes, user: str) -> bool:
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
Disconnect a user which a specific starting data.
@ -97,13 +98,13 @@ class TCPRelay(Thread):
:param user: User which send the data
"""
if data.startswith(RELAY_CONTROL_MESSAGE):
self._disconnect_user(user)
self._disconnect_user(user_address)
return False
return True
def _disconnect_user(self, user: str):
def _disconnect_user(self, user_address: IPv4Address):
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user]
self._relay_users = [u for u in self._relay_users if u.address != user_address]
def _wait_for_users_to_disconnect(self):
stop = False
@ -111,7 +112,7 @@ class TCPRelay(Thread):
sleep(0.01)
current_time = time()
most_recent_potential_time = max(
self._potential_users, key=lambda u: u.time, default=RelayUser("", 0)
self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0)
).time
potential_elapsed = current_time - most_recent_potential_time

View File

@ -1,7 +1,20 @@
from ipaddress import IPv4Address
from threading import Thread
import pytest
from monkey.infection_monkey.tcp_relay import RELAY_CONTROL_MESSAGE, TCPRelay
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
LOCAL_PORT = 9975
TARGET_ADDRESS = "0.0.0.0"
TARGET_PORT = 9976
@pytest.fixture
def tcp_relay():
return TCPRelay(LOCAL_PORT, TARGET_ADDRESS, TARGET_PORT)
def join_or_kill_thread(thread: Thread, timeout: float):
"""Whether or not the thread joined in the given timeout period."""
@ -21,34 +34,28 @@ def join_or_kill_thread(thread: Thread, timeout: float):
# assert join_or_kill_thread(relay, 0.2)
def test_user_added():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
def test_user_added(tcp_relay):
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
users = relay.relay_users()
users = tcp_relay.relay_users()
assert len(users) == 1
assert users[0].address == new_user
assert users[0].address == NEW_USER_ADDRESS
def test_user_not_removed_on_disconnect():
def test_user_not_removed_on_disconnect(tcp_relay):
# A user should only be disconnected when they send a disconnect request
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_disconnected(new_user)
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
tcp_relay.on_user_disconnected(NEW_USER_ADDRESS)
users = relay.relay_users()
users = tcp_relay.relay_users()
assert len(users) == 1
def test_user_removed_on_request():
relay = TCPRelay(9975, "0.0.0.0", 9976)
new_user = "0.0.0.1"
relay.on_user_connected(new_user)
relay.on_user_data_received(RELAY_CONTROL_MESSAGE, "0.0.0.1")
def test_user_removed_on_request(tcp_relay):
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
tcp_relay.on_user_data_received(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
users = relay.relay_users()
users = tcp_relay.relay_users()
assert len(users) == 0