forked from p15670423/monkey
Merge pull request #1721 from guardicore/1605-get-updated-credentials
1605 get updated credentials
This commit is contained in:
commit
0bfa0cd1ca
|
@ -41,7 +41,9 @@ class AutomatedMaster(IMaster):
|
||||||
self._control_channel = control_channel
|
self._control_channel = control_channel
|
||||||
|
|
||||||
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
|
ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
|
||||||
exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
|
exploiter = Exploiter(
|
||||||
|
self._puppet, NUM_EXPLOIT_THREADS, self._control_channel.get_credentials_for_propagation
|
||||||
|
)
|
||||||
self._propagator = Propagator(
|
self._propagator = Propagator(
|
||||||
self._telemetry_messenger,
|
self._telemetry_messenger,
|
||||||
ip_scanner,
|
ip_scanner,
|
||||||
|
|
|
@ -7,11 +7,14 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
||||||
from infection_monkey.config import WormConfiguration
|
from infection_monkey.config import WormConfiguration
|
||||||
from infection_monkey.control import ControlClient
|
from infection_monkey.control import ControlClient
|
||||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||||
|
from infection_monkey.utils.decorators import request_cache
|
||||||
|
|
||||||
requests.packages.urllib3.disable_warnings()
|
requests.packages.urllib3.disable_warnings()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CREDENTIALS_POLL_PERIOD_SEC = 30
|
||||||
|
|
||||||
|
|
||||||
class ControlChannel(IControlChannel):
|
class ControlChannel(IControlChannel):
|
||||||
def __init__(self, server: str, agent_id: str):
|
def __init__(self, server: str, agent_id: str):
|
||||||
|
@ -66,18 +69,21 @@ class ControlChannel(IControlChannel):
|
||||||
) as e:
|
) as e:
|
||||||
raise IslandCommunicationError(e)
|
raise IslandCommunicationError(e)
|
||||||
|
|
||||||
|
@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
|
||||||
def get_credentials_for_propagation(self) -> dict:
|
def get_credentials_for_propagation(self) -> dict:
|
||||||
|
propagation_credentials_url = (
|
||||||
|
f"https://{self._control_channel_server}/api/propagation-credentials/{self._agent_id}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = requests.get( # noqa: DUO123
|
response = requests.get( # noqa: DUO123
|
||||||
f"{self._control_channel_server}/api/propagation-credentials/{self._agent_id}",
|
propagation_credentials_url,
|
||||||
verify=False,
|
verify=False,
|
||||||
proxies=ControlClient.proxies,
|
proxies=ControlClient.proxies,
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
response = json.loads(response.content.decode())["propagation_credentials"]
|
return json.loads(response.content.decode())["propagation_credentials"]
|
||||||
return response
|
|
||||||
except (
|
except (
|
||||||
json.JSONDecodeError,
|
json.JSONDecodeError,
|
||||||
requests.exceptions.ConnectionError,
|
requests.exceptions.ConnectionError,
|
||||||
|
|
|
@ -3,7 +3,7 @@ import queue
|
||||||
import threading
|
import threading
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List, Mapping
|
||||||
|
|
||||||
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
|
from infection_monkey.i_puppet import ExploiterResultData, IPuppet
|
||||||
from infection_monkey.model import VictimHost
|
from infection_monkey.model import VictimHost
|
||||||
|
@ -18,9 +18,15 @@ Callback = Callable[[ExploiterName, VictimHost, ExploiterResultData], None]
|
||||||
|
|
||||||
|
|
||||||
class Exploiter:
|
class Exploiter:
|
||||||
def __init__(self, puppet: IPuppet, num_workers: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
puppet: IPuppet,
|
||||||
|
num_workers: int,
|
||||||
|
get_updated_credentials_for_propagation: Callable[[], Mapping],
|
||||||
|
):
|
||||||
self._puppet = puppet
|
self._puppet = puppet
|
||||||
self._num_workers = num_workers
|
self._num_workers = num_workers
|
||||||
|
self._get_updated_credentials_for_propagation = get_updated_credentials_for_propagation
|
||||||
|
|
||||||
def exploit_hosts(
|
def exploit_hosts(
|
||||||
self,
|
self,
|
||||||
|
@ -74,6 +80,7 @@ class Exploiter:
|
||||||
results_callback: Callback,
|
results_callback: Callback,
|
||||||
stop: Event,
|
stop: Event,
|
||||||
):
|
):
|
||||||
|
|
||||||
for exploiter in interruptable_iter(exploiters_to_run, stop):
|
for exploiter in interruptable_iter(exploiters_to_run, stop):
|
||||||
exploiter_name = exploiter["name"]
|
exploiter_name = exploiter["name"]
|
||||||
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
|
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
|
||||||
|
@ -86,7 +93,19 @@ class Exploiter:
|
||||||
self, exploiter_name: str, victim_host: VictimHost, stop: Event
|
self, exploiter_name: str, victim_host: VictimHost, stop: Event
|
||||||
) -> ExploiterResultData:
|
) -> ExploiterResultData:
|
||||||
logger.debug(f"Attempting to use {exploiter_name} on {victim_host}")
|
logger.debug(f"Attempting to use {exploiter_name} on {victim_host}")
|
||||||
return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, {}, stop)
|
|
||||||
|
credentials = self._get_credentials_for_propagation()
|
||||||
|
options = {"credentials": credentials}
|
||||||
|
|
||||||
|
return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, options, stop)
|
||||||
|
|
||||||
|
def _get_credentials_for_propagation(self) -> Mapping:
|
||||||
|
try:
|
||||||
|
return self._get_updated_credentials_for_propagation()
|
||||||
|
except Exception as ex:
|
||||||
|
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue):
|
def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue):
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
import threading
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from .timer import Timer
|
||||||
|
|
||||||
|
|
||||||
|
def request_cache(ttl: float):
|
||||||
|
"""
|
||||||
|
This is a decorator that allows a single response of a function to be cached with an expiration
|
||||||
|
time (TTL). The first call to the function is executed and the response is cached. Subsequent
|
||||||
|
calls to the function result in the cached value being returned until the TTL elapses. Once the
|
||||||
|
TTL elapses, the cache is considered stale and the decorated function will be called, its
|
||||||
|
response cached, and the TTL reset.
|
||||||
|
|
||||||
|
An example usage of this decorator is to wrap a function that makes frequent slow calls to an
|
||||||
|
external resource, such as an HTTP request to a remote endpoint. If the most up-to-date
|
||||||
|
information is not need, this decorator provides a simple way to cache the response for a
|
||||||
|
certain amount of time.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@request_cache(600)
|
||||||
|
def raining_outside():
|
||||||
|
return requests.get(f"https://weather.service.api/check_for_rain/{MY_ZIP_CODE}")
|
||||||
|
|
||||||
|
:param ttl: The time-to-live in seconds for the cached return value
|
||||||
|
:return: The return value of the decorated function, or the cached return value if the TTL has
|
||||||
|
not elapsed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with wrapper.lock:
|
||||||
|
if wrapper.timer.is_expired():
|
||||||
|
wrapper.cached_value = fn(*args, **kwargs)
|
||||||
|
wrapper.timer.set(ttl)
|
||||||
|
|
||||||
|
return wrapper.cached_value
|
||||||
|
|
||||||
|
wrapper.cached_value = None
|
||||||
|
wrapper.timer = Timer()
|
||||||
|
wrapper.lock = threading.Lock()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
|
@ -14,7 +14,7 @@ INTERVAL = 0.001
|
||||||
|
|
||||||
|
|
||||||
def test_terminate_without_start():
|
def test_terminate_without_start():
|
||||||
m = AutomatedMaster(None, None, None, None, [])
|
m = AutomatedMaster(None, None, None, MagicMock(), [])
|
||||||
|
|
||||||
# Test that call to terminate does not raise exception
|
# Test that call to terminate does not raise exception
|
||||||
m.terminate()
|
m.terminate()
|
||||||
|
|
|
@ -59,12 +59,27 @@ def hosts_to_exploit(hosts):
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit):
|
CREDENTIALS_FOR_PROPAGATION = {"usernames": ["m0nk3y", "user"], "passwords": ["1234", "pword"]}
|
||||||
# Set this so that Exploiter() exits once it has processed all victims
|
|
||||||
scan_completed.set()
|
|
||||||
|
|
||||||
e = Exploiter(MockPuppet(), 2)
|
|
||||||
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)
|
def get_credentials_for_propagation():
|
||||||
|
return CREDENTIALS_FOR_PROPAGATION
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def run_exploiters(exploiter_config, hosts_to_exploit, callback, scan_completed, stop):
|
||||||
|
def inner(puppet, num_workers):
|
||||||
|
# Set this so that Exploiter() exits once it has processed all victims
|
||||||
|
scan_completed.set()
|
||||||
|
|
||||||
|
e = Exploiter(puppet, num_workers, get_credentials_for_propagation)
|
||||||
|
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def test_exploiter(callback, hosts, hosts_to_exploit, run_exploiters):
|
||||||
|
run_exploiters(MockPuppet(), 2)
|
||||||
|
|
||||||
assert callback.call_count == 5
|
assert callback.call_count == 5
|
||||||
host_exploit_combos = set()
|
host_exploit_combos = set()
|
||||||
|
@ -81,6 +96,14 @@ def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, host
|
||||||
assert ("SSHExploiter", hosts[1]) in host_exploit_combos
|
assert ("SSHExploiter", hosts[1]) in host_exploit_combos
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_passed_to_exploiter(run_exploiters):
|
||||||
|
mock_puppet = MagicMock()
|
||||||
|
run_exploiters(mock_puppet, 1)
|
||||||
|
|
||||||
|
for call_args in mock_puppet.exploit_host.call_args_list:
|
||||||
|
assert call_args[0][2].get("credentials") == CREDENTIALS_FOR_PROPAGATION
|
||||||
|
|
||||||
|
|
||||||
def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit):
|
def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit):
|
||||||
callback_barrier_count = 2
|
callback_barrier_count = 2
|
||||||
|
|
||||||
|
@ -96,7 +119,7 @@ def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, h
|
||||||
|
|
||||||
# Intentionally NOT setting scan_completed.set(); _callback() will set stop
|
# Intentionally NOT setting scan_completed.set(); _callback() will set stop
|
||||||
|
|
||||||
e = Exploiter(MockPuppet(), callback_barrier_count + 2)
|
e = Exploiter(MockPuppet(), callback_barrier_count + 2, get_credentials_for_propagation)
|
||||||
e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop)
|
e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop)
|
||||||
|
|
||||||
assert stoppable_callback.call_count == 2
|
assert stoppable_callback.call_count == 2
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.utils.decorators import request_cache
|
||||||
|
from infection_monkey.utils.timer import Timer
|
||||||
|
|
||||||
|
|
||||||
|
class MockTimer(Timer):
|
||||||
|
def __init__(self):
|
||||||
|
self._time_remaining = 0
|
||||||
|
self._set_time = 0
|
||||||
|
|
||||||
|
def set(self, timeout_sec: float):
|
||||||
|
self._time_remaining = timeout_sec
|
||||||
|
self._set_time = timeout_sec
|
||||||
|
|
||||||
|
def set_expired(self):
|
||||||
|
self._time_remaining = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_remaining(self) -> float:
|
||||||
|
return self._time_remaining
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the timer without changing the timeout
|
||||||
|
"""
|
||||||
|
self._time_remaining = self._set_time
|
||||||
|
|
||||||
|
|
||||||
|
class MockTimerFactory:
|
||||||
|
def __init__(self):
|
||||||
|
self._instance = None
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
if self._instance is None:
|
||||||
|
mt = MockTimer()
|
||||||
|
self._instance = mt
|
||||||
|
|
||||||
|
return self._instance
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
mock_timer_factory = MockTimerFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_timer(monkeypatch):
|
||||||
|
mock_timer_factory.reset
|
||||||
|
|
||||||
|
monkeypatch.setattr("infection_monkey.utils.decorators.Timer", mock_timer_factory)
|
||||||
|
|
||||||
|
return mock_timer_factory()
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_cache(mock_timer):
|
||||||
|
mock_request = MagicMock(side_effect=lambda: time.time())
|
||||||
|
|
||||||
|
@request_cache(10)
|
||||||
|
def make_request():
|
||||||
|
return mock_request()
|
||||||
|
|
||||||
|
t1 = make_request()
|
||||||
|
t2 = make_request()
|
||||||
|
|
||||||
|
assert t1 == t2
|
||||||
|
|
||||||
|
mock_timer.set_expired()
|
||||||
|
|
||||||
|
t3 = make_request()
|
||||||
|
t4 = make_request()
|
||||||
|
|
||||||
|
assert t3 != t1
|
||||||
|
assert t3 == t4
|
Loading…
Reference in New Issue