Agent: Use dictionary for relay and potential users in TCPRelay

This commit is contained in:
Ilija Lazoroski 2022-09-01 12:56:26 +02:00
parent 9a3afb051d
commit 9fae6cca20
2 changed files with 16 additions and 11 deletions

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
from ipaddress import IPv4Address
from threading import Event, Lock, Thread
from time import sleep, time
from typing import List
from typing import Dict
from infection_monkey.transport.tcp import TcpProxy
@ -35,8 +35,8 @@ class TCPRelay(Thread):
self._new_client_timeout = new_client_timeout
super(TCPRelay, self).__init__(name="MonkeyTcpRelayThread")
self.daemon = True
self._relay_users: List[RelayUser] = []
self._potential_users: List[RelayUser] = []
self._relay_users: Dict[IPv4Address, RelayUser] = {}
self._potential_users: Dict[IPv4Address, RelayUser] = {}
self._lock = Lock()
def run(self):
@ -67,14 +67,16 @@ class TCPRelay(Thread):
:param user: A user which will be added to the relay
"""
with self._lock:
self._potential_users = [u for u in self._potential_users if u.address != user_address]
self._relay_users.append(RelayUser(user_address, time()))
if user_address in self._potential_users:
del self._potential_users[user_address]
self._relay_users[user_address] = RelayUser(user_address, time())
def on_user_disconnected(self, user_address: IPv4Address):
"""Handle user disconnection."""
pass
def relay_users(self) -> List[RelayUser]:
def relay_users(self) -> Dict[IPv4Address, RelayUser]:
"""
Get the list of users connected to the relay.
"""
@ -88,7 +90,7 @@ class TCPRelay(Thread):
:param user: A potential user that tries to connect to the relay
"""
with self._lock:
self._potential_users.append(RelayUser(user_address, time()))
self._potential_users[user_address] = RelayUser(user_address, time())
def on_user_data_received(self, data: bytes, user_address: IPv4Address) -> bool:
"""
@ -104,7 +106,8 @@ class TCPRelay(Thread):
def _disconnect_user(self, user_address: IPv4Address):
with self._lock:
self._relay_users = [u for u in self._relay_users if u.address != user_address]
if user_address in self._relay_users:
del self._relay_users[user_address]
def _wait_for_users_to_disconnect(self):
stop = False
@ -112,8 +115,10 @@ class TCPRelay(Thread):
sleep(0.01)
current_time = time()
most_recent_potential_time = max(
self._potential_users, key=lambda u: u.time, default=RelayUser(IPv4Address(""), 0.0)
).time
self._potential_users.values(),
key=lambda ru: ru.last_update_time,
default=RelayUser(IPv4Address(""), 0.0),
).last_update_time
potential_elapsed = current_time - most_recent_potential_time
stop = not self._potential_users or potential_elapsed > self._new_client_timeout

View File

@ -39,7 +39,7 @@ def test_user_added(tcp_relay):
users = tcp_relay.relay_users()
assert len(users) == 1
assert users[0].address == NEW_USER_ADDRESS
assert NEW_USER_ADDRESS in users
def test_user_not_removed_on_disconnect(tcp_relay):