diff --git a/monkey/infection_monkey/tcp_relay.py b/monkey/infection_monkey/tcp_relay.py index 235800d12..77c4808b3 100644 --- a/monkey/infection_monkey/tcp_relay.py +++ b/monkey/infection_monkey/tcp_relay.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from ipaddress import IPv4Address from threading import Event, Lock, Thread from time import sleep, time -from typing import List +from typing import Dict from infection_monkey.transport.tcp import TcpProxy @@ -35,8 +35,8 @@ class TCPRelay(Thread): self._new_client_timeout = new_client_timeout super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread") self.daemon = True - self._relay_users: List[RelayUser] = [] - self._potential_users: List[RelayUser] = [] + self._relay_users: Dict[IPv4Address, RelayUser] = {} + self._potential_users: Dict[IPv4Address, RelayUser] = {} self._lock = Lock() def run(self): @@ -67,14 +67,16 @@ class TCPRelay(Thread): :param user: A user which will be added to the relay """ with self._lock: - self._potential_users = [u for u in self._potential_users if u.address != user_address] - self._relay_users.append(RelayUser(user_address, time())) + if user_address in self._potential_users: + del self._potential_users[user_address] + + self._relay_users[user_address] = RelayUser(user_address, time()) def on_user_disconnected(self, user_address: IPv4Address): """Handle user disconnection.""" pass - def relay_users(self) -> List[RelayUser]: + def relay_users(self) -> Dict[IPv4Address, RelayUser]: """ Get the list of users connected to the relay. """ @@ -88,7 +90,7 @@ class TCPRelay(Thread): :param user: A potential user that tries to connect to the relay """ with self._lock: - self._potential_users.append(RelayUser(user_address, time())) + self._potential_users[user_address] = RelayUser(user_address, time()) def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool: """ @@ -104,7 +106,8 @@ class TCPRelay(Thread): def _disconnect_user(self, user_address: IPv4Address): with self._lock: - self._relay_users = [u for u in self._relay_users if u.address != user_address] + if user_address in self._relay_users: + del self._relay_users[user_address] def _wait_for_users_to_disconnect(self): stop = False @@ -112,8 +115,10 @@ class TCPRelay(Thread): sleep(0.01) current_time = time() most_recent_potential_time = max( - self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0) - ).time + self._potential_users.values(), + key=lambda ru: ru.last_update_time, + default=RelayUser(IPv4Address(""), 0.0), + ).last_update_time potential_elapsed = current_time - most_recent_potential_time stop = not self._potential_users or potential_elapsed > self._new_client_timeout diff --git a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py index 5caa70e0c..70c1c6bbb 100644 --- a/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py +++ b/monkey/tests/unit_tests/infection_monkey/test_tcp_relay.py @@ -39,7 +39,7 @@ def test_user_added(tcp_relay): users = tcp_relay.relay_users() assert len(users) == 1 - assert users[0].address == NEW_USER_ADDRESS + assert NEW_USER_ADDRESS in users def test_user_not_removed_on_disconnect(tcp_relay):