Agent: Add a Timer class

This commit is contained in:
Mike Salvatore 2021-12-02 10:45:58 -05:00
parent 73bf93050f
commit 9279d82adf
5 changed files with 106 additions and 2 deletions

View File

@ -32,6 +32,7 @@ class SSHExploiter(HostExploiter):
self.skip_exist = self._config.skip_exploit_if_file_exist self.skip_exist = self._config.skip_exploit_if_file_exist
def log_transfer(self, transferred, total): def log_transfer(self, transferred, total):
# TODO: Replace with infection_monkey.utils.timer.Timer
if time.time() - self._update_timestamp > TRANSFER_UPDATE_RATE: if time.time() - self._update_timestamp > TRANSFER_UPDATE_RATE:
logger.debug("SFTP transferred: %d bytes, total: %d bytes", transferred, total) logger.debug("SFTP transferred: %d bytes, total: %d bytes", transferred, total)
self._update_timestamp = time.time() self._update_timestamp = time.time()

View File

@ -40,8 +40,7 @@ class BatchingTelemetryMessenger(ITelemetryMessenger):
self._period = period self._period = period
self._should_run_batch_thread = True self._should_run_batch_thread = True
# TODO: Create a "timer" or "countdown" class and inject an object instead of # TODO: Replace with infection_monkey.utils.timer.Timer
# using time.time()
self._last_sent_time = time.time() self._last_sent_time = time.time()
self._telemetry_batches: Dict[str, IBatchableTelem] = {} self._telemetry_batches: Dict[str, IBatchableTelem] = {}

View File

@ -173,6 +173,7 @@ class MonkeyTunnel(Thread):
# wait till all of the tunnel clients has been disconnected, or no one used the tunnel in # wait till all of the tunnel clients has been disconnected, or no one used the tunnel in
# QUIT_TIMEOUT seconds # QUIT_TIMEOUT seconds
# TODO: Replace with infection_monkey.utils.timer.Timer
while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT): while self._clients and (time.time() - get_last_serve_time() < QUIT_TIMEOUT):
try: try:
search, address = self._broad_sock.recvfrom(BUFFER_READ) search, address = self._broad_sock.recvfrom(BUFFER_READ)

View File

@ -0,0 +1,34 @@
import time
class Timer:
"""
A class for checking whether or not a certain amount of time has elapsed.
"""
def __init__(self):
self._timeout_sec = 0
self._start_time = 0
def set(self, timeout_sec: float):
"""
Set a timer
:param float timeout_sec: A fractional number of seconds to set the timeout for.
"""
self._timeout_sec = timeout_sec
self._start_time = time.time()
def is_expired(self):
"""
Check whether or not the timer has expired
:return: True if the elapsed time since set(TIMEOUT_SEC) was called is greater than
TIMEOUT_SEC, False otherwise
:rtype: bool
"""
return (time.time() - self._start_time) >= self._timeout_sec
def reset(self):
"""
Reset the timer without changing the timeout
"""
self._start_time = time.time()

View File

@ -0,0 +1,69 @@
import time
import pytest
from infection_monkey.utils.timer import Timer
@pytest.fixture
def start_time(set_current_time):
start_time = 100
set_current_time(start_time)
return start_time
@pytest.fixture
def set_current_time(monkeypatch):
def inner(current_time):
monkeypatch.setattr(time, "time", lambda: current_time)
return inner
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_not_expired(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + (timeout - 0.001))
assert not t.is_expired()
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_expired(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + timeout)
assert t.is_expired()
set_current_time(start_time + timeout + 0.001)
assert t.is_expired()
def test_unset_timer_expired():
t = Timer()
assert t.is_expired()
@pytest.mark.parametrize(("timeout"), [5, 1.25])
def test_timer_reset(start_time, set_current_time, timeout):
t = Timer()
t.set(timeout)
assert not t.is_expired()
set_current_time(start_time + timeout)
assert t.is_expired()
t.reset()
assert not t.is_expired()
set_current_time(start_time + (2 * timeout))
assert t.is_expired()