Merge pull request #2200 from guardicore/2191-fix-credentials-repository-get

2191 fix credentials repository get
This commit is contained in:
Mike Salvatore 2022-08-15 15:45:03 -04:00 committed by GitHub
commit d09c1a689e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 7 deletions

View File

@ -1,5 +1,5 @@
import logging
from typing import Any, Iterable, Sequence
from typing import Any, Iterable
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
from infection_monkey.custom_types import PropagationCredentials
@ -29,6 +29,11 @@ class AggregatingPropagationCredentialsRepository(IPropagationCredentialsReposit
}
self._control_channel = control_channel
# Ensure caching happens per-instance instead of being shared across instances
self._get_credentials_from_control_channel = request_cache(CREDENTIALS_POLL_PERIOD_SEC)(
self._control_channel.get_credentials_for_propagation
)
def add_credentials(self, credentials_to_add: Iterable[Credentials]):
for credentials in credentials_to_add:
if credentials.identity:
@ -58,15 +63,10 @@ class AggregatingPropagationCredentialsRepository(IPropagationCredentialsReposit
try:
propagation_credentials = self._get_credentials_from_control_channel()
self.add_credentials(propagation_credentials)
return self._stored_credentials
except Exception as ex:
self._stored_credentials = {}
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def _get_credentials_from_control_channel(self) -> Sequence[Credentials]:
return self._control_channel.get_credentials_for_propagation()
return self._stored_credentials
def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]):
if not credentials_values:

View File

@ -122,3 +122,16 @@ def test_all_keys_if_credentials_empty():
assert "exploit_password_list" in actual_stored_credentials
assert "exploit_ntlm_hash_list" in actual_stored_credentials
assert "exploit_ssh_keys" in actual_stored_credentials
def test_credentials_obtained_if_propagation_credentials_fails():
control_channel = MagicMock()
control_channel.get_credentials_for_propagation.return_value = EMPTY_CHANNEL_CREDENTIALS
control_channel.get_credentials_for_propagation.side_effect = Exception(
"No credentials for you!"
)
credentials_repository = AggregatingPropagationCredentialsRepository(control_channel)
credentials = credentials_repository.get_credentials()
assert credentials is not None