Agent: Use dictionary for relay and potential users in TCPRelay
This commit is contained in:
parent
9a3afb051d
commit
9fae6cca20
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from ipaddress import IPv4Address
|
||||
from threading import Event, Lock, Thread
|
||||
from time import sleep, time
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
|
||||
from infection_monkey.transport.tcp import TcpProxy
|
||||
|
||||
|
@ -35,8 +35,8 @@ class TCPRelay(Thread):
|
|||
self._new_client_timeout = new_client_timeout
|
||||
super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
|
||||
self.daemon = True
|
||||
self._relay_users: List[RelayUser] = []
|
||||
self._potential_users: List[RelayUser] = []
|
||||
self._relay_users: Dict[IPv4Address, RelayUser] = {}
|
||||
self._potential_users: Dict[IPv4Address, RelayUser] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def run(self):
|
||||
|
@ -67,14 +67,16 @@ class TCPRelay(Thread):
|
|||
:param user: A user which will be added to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
self._potential_users = [u for u in self._potential_users if u.address != user_address]
|
||||
self._relay_users.append(RelayUser(user_address, time()))
|
||||
if user_address in self._potential_users:
|
||||
del self._potential_users[user_address]
|
||||
|
||||
self._relay_users[user_address] = RelayUser(user_address, time())
|
||||
|
||||
def on_user_disconnected(self, user_address: IPv4Address):
|
||||
"""Handle user disconnection."""
|
||||
pass
|
||||
|
||||
def relay_users(self) -> List[RelayUser]:
|
||||
def relay_users(self) -> Dict[IPv4Address, RelayUser]:
|
||||
"""
|
||||
Get the list of users connected to the relay.
|
||||
"""
|
||||
|
@ -88,7 +90,7 @@ class TCPRelay(Thread):
|
|||
:param user: A potential user that tries to connect to the relay
|
||||
"""
|
||||
with self._lock:
|
||||
self._potential_users.append(RelayUser(user_address, time()))
|
||||
self._potential_users[user_address] = RelayUser(user_address, time())
|
||||
|
||||
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
|
||||
"""
|
||||
|
@ -104,7 +106,8 @@ class TCPRelay(Thread):
|
|||
|
||||
def _disconnect_user(self, user_address: IPv4Address):
|
||||
with self._lock:
|
||||
self._relay_users = [u for u in self._relay_users if u.address != user_address]
|
||||
if user_address in self._relay_users:
|
||||
del self._relay_users[user_address]
|
||||
|
||||
def _wait_for_users_to_disconnect(self):
|
||||
stop = False
|
||||
|
@ -112,8 +115,10 @@ class TCPRelay(Thread):
|
|||
sleep(0.01)
|
||||
current_time = time()
|
||||
most_recent_potential_time = max(
|
||||
self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0)
|
||||
).time
|
||||
self._potential_users.values(),
|
||||
key=lambda ru: ru.last_update_time,
|
||||
default=RelayUser(IPv4Address(""), 0.0),
|
||||
).last_update_time
|
||||
potential_elapsed = current_time - most_recent_potential_time
|
||||
|
||||
stop = not self._potential_users or potential_elapsed > self._new_client_timeout
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_user_added(tcp_relay):
|
|||
|
||||
users = tcp_relay.relay_users()
|
||||
assert len(users) == 1
|
||||
assert users[0].address == NEW_USER_ADDRESS
|
||||
assert NEW_USER_ADDRESS in users
|
||||
|
||||
|
||||
def test_user_not_removed_on_disconnect(tcp_relay):
|
||||
|
|
Loading…
Reference in New Issue