Merge pull request #1721 from guardicore/1605-get-updated-credentials

1605 get updated credentials
This commit is contained in:
Mike Salvatore 2022-02-18 06:01:25 -05:00 committed by GitHub
commit 0bfa0cd1ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 188 additions and 14 deletions

View File

@ -41,7 +41,9 @@ class AutomatedMaster(IMaster):
self._control_channel = control_channel
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
exploiter = Exploiter(
self._puppet, NUM_EXPLOIT_THREADS, self._control_channel.get_credentials_for_propagation
)
self._propagator = Propagator(
self._telemetry_messenger,
ip_scanner,

View File

@ -7,11 +7,14 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.utils.decorators import request_cache
requests.packages.urllib3.disable_warnings()
logger = logging.getLogger(__name__)
CREDENTIALS_POLL_PERIOD_SEC = 30
class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str):
@ -66,18 +69,21 @@ class ControlChannel(IControlChannel):
) as e:
raise IslandCommunicationError(e)
@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def get_credentials_for_propagation(self) -> dict:
propagation_credentials_url = (
f"https://{self._control_channel_server}/api/propagation-credentials/{self._agent_id}"
)
try:
response = requests.get( # noqa: DUO123
f"{self._control_channel_server}/api/propagation-credentials/{self._agent_id}",
propagation_credentials_url,
verify=False,
proxies=ControlClient.proxies,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
response = json.loads(response.content.decode())["propagation_credentials"]
return response
return json.loads(response.content.decode())["propagation_credentials"]
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,

View File

@ -3,7 +3,7 @@ import queue
import threading
from queue import Queue
from threading import Event
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Mapping
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost
@ -18,9 +18,15 @@ Callback = Callable[[ExploiterName, VictimHost, ExploiterResultData], None]
class Exploiter:
def __init__(self, puppet: IPuppet, num_workers: int):
def __init__(
self,
puppet: IPuppet,
num_workers: int,
get_updated_credentials_for_propagation: Callable[[], Mapping],
):
self._puppet = puppet
self._num_workers = num_workers
self._get_updated_credentials_for_propagation = get_updated_credentials_for_propagation
def exploit_hosts(
self,
@ -74,6 +80,7 @@ class Exploiter:
results_callback: Callback,
stop: Event,
):
for exploiter in interruptable_iter(exploiters_to_run, stop):
exploiter_name = exploiter["name"]
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
@ -86,7 +93,19 @@ class Exploiter:
self, exploiter_name: str, victim_host: VictimHost, stop: Event
) -> ExploiterResultData:
logger.debug(f"Attempting to use {exploiter_name} on {victim_host}")
return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, {}, stop)
credentials = self._get_credentials_for_propagation()
options = {"credentials": credentials}
return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, options, stop)
def _get_credentials_for_propagation(self) -> Mapping:
try:
return self._get_updated_credentials_for_propagation()
except Exception as ex:
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
return {}
def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue):

View File

@ -0,0 +1,46 @@
import threading
from functools import wraps
from .timer import Timer
def request_cache(ttl: float):
"""
This is a decorator that allows a single response of a function to be cached with an expiration
time (TTL). The first call to the function is executed and the response is cached. Subsequent
calls to the function result in the cached value being returned until the TTL elapses. Once the
TTL elapses, the cache is considered stale and the decorated function will be called, its
response cached, and the TTL reset.
An example usage of this decorator is to wrap a function that makes frequent slow calls to an
external resource, such as an HTTP request to a remote endpoint. If the most up-to-date
information is not need, this decorator provides a simple way to cache the response for a
certain amount of time.
Example:
@request_cache(600)
def raining_outside():
return requests.get(f"https://weather.service.api/check_for_rain/{MY_ZIP_CODE}")
:param ttl: The time-to-live in seconds for the cached return value
:return: The return value of the decorated function, or the cached return value if the TTL has
not elapsed.
"""
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
with wrapper.lock:
if wrapper.timer.is_expired():
wrapper.cached_value = fn(*args, **kwargs)
wrapper.timer.set(ttl)
return wrapper.cached_value
wrapper.cached_value = None
wrapper.timer = Timer()
wrapper.lock = threading.Lock()
return wrapper
return decorator

View File

@ -14,7 +14,7 @@ INTERVAL = 0.001
def test_terminate_without_start():
m = AutomatedMaster(None, None, None, None, [])
m = AutomatedMaster(None, None, None, MagicMock(), [])
# Test that call to terminate does not raise exception
m.terminate()

View File

@ -59,12 +59,27 @@ def hosts_to_exploit(hosts):
return q
def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit):
# Set this so that Exploiter() exits once it has processed all victims
scan_completed.set()
CREDENTIALS_FOR_PROPAGATION = {"usernames": ["m0nk3y", "user"], "passwords": ["1234", "pword"]}
e = Exploiter(MockPuppet(), 2)
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)
def get_credentials_for_propagation():
return CREDENTIALS_FOR_PROPAGATION
@pytest.fixture
def run_exploiters(exploiter_config, hosts_to_exploit, callback, scan_completed, stop):
def inner(puppet, num_workers):
# Set this so that Exploiter() exits once it has processed all victims
scan_completed.set()
e = Exploiter(puppet, num_workers, get_credentials_for_propagation)
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)
return inner
def test_exploiter(callback, hosts, hosts_to_exploit, run_exploiters):
run_exploiters(MockPuppet(), 2)
assert callback.call_count == 5
host_exploit_combos = set()
@ -81,6 +96,14 @@ def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, host
assert ("SSHExploiter", hosts[1]) in host_exploit_combos
def test_credentials_passed_to_exploiter(run_exploiters):
mock_puppet = MagicMock()
run_exploiters(mock_puppet, 1)
for call_args in mock_puppet.exploit_host.call_args_list:
assert call_args[0][2].get("credentials") == CREDENTIALS_FOR_PROPAGATION
def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit):
callback_barrier_count = 2
@ -96,7 +119,7 @@ def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, h
# Intentionally NOT setting scan_completed.set(); _callback() will set stop
e = Exploiter(MockPuppet(), callback_barrier_count + 2)
e = Exploiter(MockPuppet(), callback_barrier_count + 2, get_credentials_for_propagation)
e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop)
assert stoppable_callback.call_count == 2

View File

@ -0,0 +1,78 @@
import time
from unittest.mock import MagicMock
import pytest
from infection_monkey.utils.decorators import request_cache
from infection_monkey.utils.timer import Timer
class MockTimer(Timer):
def __init__(self):
self._time_remaining = 0
self._set_time = 0
def set(self, timeout_sec: float):
self._time_remaining = timeout_sec
self._set_time = timeout_sec
def set_expired(self):
self._time_remaining = 0
@property
def time_remaining(self) -> float:
return self._time_remaining
def reset(self):
"""
Reset the timer without changing the timeout
"""
self._time_remaining = self._set_time
class MockTimerFactory:
def __init__(self):
self._instance = None
def __call__(self):
if self._instance is None:
mt = MockTimer()
self._instance = mt
return self._instance
def reset(self):
self._instance = None
mock_timer_factory = MockTimerFactory()
@pytest.fixture
def mock_timer(monkeypatch):
mock_timer_factory.reset
monkeypatch.setattr("infection_monkey.utils.decorators.Timer", mock_timer_factory)
return mock_timer_factory()
def test_request_cache(mock_timer):
mock_request = MagicMock(side_effect=lambda: time.time())
@request_cache(10)
def make_request():
return mock_request()
t1 = make_request()
t2 = make_request()
assert t1 == t2
mock_timer.set_expired()
t3 = make_request()
t4 = make_request()
assert t3 != t1
assert t3 == t4