forked from p15670423/monkey
Agent: Use IPv4Address for RelayUser.address
This commit is contained in:
parent
1d394bbd2e
commit
9a3afb051d
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address
|
||||
from threading import Event, Lock, Thread
|
||||
from time import sleep, time
|
||||
from typing import List
|
||||
|
@ -11,7 +12,7 @@ RELAY_CONTROL_MESSAGE = b"infection-monkey-relay-control-message: -"
|
|||
|
||||
@dataclass
|
||||
class RelayUser:
|
||||
address: str
|
||||
address: IPv4Address
|
||||
last_update_time: float
|
||||
|
||||
|
||||
|
@ -59,17 +60,17 @@ class TCPRelay(Thread):
|
|||
def stop(self):
|
||||
self._stopped.set()
|
||||
|
||||
def on_user_connected(self, user: str):
|
||||
def on_user_connected(self, user_address: IPv4Address):
|
||||
"""
|
||||
Handle new user connection.
|
||||
|
||||
:param user: A user which will be added to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
self._potential_users = [u for u in self._potential_users if u.address != user]
|
||||
self._relay_users.append(RelayUser(user, time()))
|
||||
self._potential_users = [u for u in self._potential_users if u.address != user_address]
|
||||
self._relay_users.append(RelayUser(user_address, time()))
|
||||
|
||||
def on_user_disconnected(self, user: str):
|
||||
def on_user_disconnected(self, user_address: IPv4Address):
|
||||
"""Handle user disconnection."""
|
||||
pass
|
||||
|
||||
|
@ -80,16 +81,16 @@ class TCPRelay(Thread):
|
|||
with self._lock:
|
||||
return self._relay_users.copy()
|
||||
|
||||
def on_potential_new_user(self, user: str):
|
||||
def on_potential_new_user(self, user_address: IPv4Address):
|
||||
"""
|
||||
Notify TCPRelay that a new user may try and connect.
|
||||
|
||||
:param user: A potential user that tries to connect to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
self._potential_users.append(RelayUser(user, time()))
|
||||
self._potential_users.append(RelayUser(user_address, time()))
|
||||
|
||||
def on_user_data_received(self, data: bytes, user: str) -> bool:
|
||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
||||
"""
|
||||
Disconnect a user which a specific starting data.
|
||||
|
||||
|
@ -97,13 +98,13 @@ class TCPRelay(Thread):
|
|||
:param user: User which send the data
|
||||
"""
|
||||
if data.startswith(RELAY_CONTROL_MESSAGE):
|
||||
self._disconnect_user(user)
|
||||
self._disconnect_user(user_address)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _disconnect_user(self, user: str):
|
||||
def _disconnect_user(self, user_address: IPv4Address):
|
||||
with self._lock:
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user]
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user_address]
|
||||
|
||||
def _wait_for_users_to_disconnect(self):
|
||||
stop = False
|
||||
|
@ -111,7 +112,7 @@ class TCPRelay(Thread):
|
|||
sleep(0.01)
|
||||
current_time = time()
|
||||
most_recent_potential_time = max(
|
||||
self._potential_users, key=lambda u: u.time, default=RelayUser("", 0)
|
||||
self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0)
|
||||
).time
|
||||
potential_elapsed = current_time - most_recent_potential_time
|
||||
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
from ipaddress import IPv4Address
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
|
||||
from monkey.infection_monkey.tcp_relay import RELAY_CONTROL_MESSAGE, TCPRelay
|
||||
|
||||
NEW_USER_ADDRESS = IPv4Address("0.0.0.1")
|
||||
LOCAL_PORT = 9975
|
||||
TARGET_ADDRESS = "0.0.0.0"
|
||||
TARGET_PORT = 9976
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tcp_relay():
|
||||
return TCPRelay(LOCAL_PORT, TARGET_ADDRESS, TARGET_PORT)
|
||||
|
||||
|
||||
def join_or_kill_thread(thread: Thread, timeout: float):
|
||||
"""Whether or not the thread joined in the given timeout period."""
|
||||
|
@ -21,34 +34,28 @@ def join_or_kill_thread(thread: Thread, timeout: float):
|
|||
# assert join_or_kill_thread(relay, 0.2)
|
||||
|
||||
|
||||
def test_user_added():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
def test_user_added(tcp_relay):
|
||||
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
|
||||
|
||||
users = relay.relay_users()
|
||||
users = tcp_relay.relay_users()
|
||||
assert len(users) == 1
|
||||
assert users[0].address == new_user
|
||||
assert users[0].address == NEW_USER_ADDRESS
|
||||
|
||||
|
||||
def test_user_not_removed_on_disconnect():
|
||||
def test_user_not_removed_on_disconnect(tcp_relay):
|
||||
# A user should only be disconnected when they send a disconnect request
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
relay.on_user_disconnected(new_user)
|
||||
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
|
||||
tcp_relay.on_user_disconnected(NEW_USER_ADDRESS)
|
||||
|
||||
users = relay.relay_users()
|
||||
users = tcp_relay.relay_users()
|
||||
assert len(users) == 1
|
||||
|
||||
|
||||
def test_user_removed_on_request():
|
||||
relay = TCPRelay(9975, "0.0.0.0", 9976)
|
||||
new_user = "0.0.0.1"
|
||||
relay.on_user_connected(new_user)
|
||||
relay.on_user_data_received(RELAY_CONTROL_MESSAGE, "0.0.0.1")
|
||||
def test_user_removed_on_request(tcp_relay):
|
||||
tcp_relay.on_user_connected(NEW_USER_ADDRESS)
|
||||
tcp_relay.on_user_data_received(RELAY_CONTROL_MESSAGE, NEW_USER_ADDRESS)
|
||||
|
||||
users = relay.relay_users()
|
||||
users = tcp_relay.relay_users()
|
||||
assert len(users) == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue