Agent: Add caching to ControlChannel.get_credentials_for_propagation()

This commit is contained in:
Mike Salvatore 2022-02-17 14:29:07 -05:00
parent c3e9690280
commit 4005ea2924
1 changed files with 9 additions and 3 deletions

View File

@ -7,11 +7,14 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
from infection_monkey.config import WormConfiguration from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient from infection_monkey.control import ControlClient
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.utils.decorators import request_cache
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CREDENTIALS_POLL_PERIOD_SEC = 30
class ControlChannel(IControlChannel): class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str): def __init__(self, server: str, agent_id: str):
@ -66,18 +69,21 @@ class ControlChannel(IControlChannel):
) as e: ) as e:
raise IslandCommunicationError(e) raise IslandCommunicationError(e)
@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def get_credentials_for_propagation(self) -> dict: def get_credentials_for_propagation(self) -> dict:
propagation_credentials_url = (
f"https://{self._control_channel_server}/api/propagation-credentials/{self._agent_id}"
)
try: try:
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
f"{self._control_channel_server}/api/propagation-credentials/{self._agent_id}", propagation_credentials_url,
verify=False, verify=False,
proxies=ControlClient.proxies, proxies=ControlClient.proxies,
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
response = json.loads(response.content.decode())["propagation_credentials"] return json.loads(response.content.decode())["propagation_credentials"]
return response
except ( except (
json.JSONDecodeError, json.JSONDecodeError,
requests.exceptions.ConnectionError, requests.exceptions.ConnectionError,