forked from p15670423/monkey
Merge pull request #2200 from guardicore/2191-fix-credentials-repository-get
2191 fix credentials repository get
This commit is contained in:
commit
d09c1a689e
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable, Sequence
|
from typing import Any, Iterable
|
||||||
|
|
||||||
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
|
from common.credentials import CredentialComponentType, Credentials, ICredentialComponent
|
||||||
from infection_monkey.custom_types import PropagationCredentials
|
from infection_monkey.custom_types import PropagationCredentials
|
||||||
|
@ -29,6 +29,11 @@ class AggregatingPropagationCredentialsRepository(IPropagationCredentialsReposit
|
||||||
}
|
}
|
||||||
self._control_channel = control_channel
|
self._control_channel = control_channel
|
||||||
|
|
||||||
|
# Ensure caching happens per-instance instead of being shared across instances
|
||||||
|
self._get_credentials_from_control_channel = request_cache(CREDENTIALS_POLL_PERIOD_SEC)(
|
||||||
|
self._control_channel.get_credentials_for_propagation
|
||||||
|
)
|
||||||
|
|
||||||
def add_credentials(self, credentials_to_add: Iterable[Credentials]):
|
def add_credentials(self, credentials_to_add: Iterable[Credentials]):
|
||||||
for credentials in credentials_to_add:
|
for credentials in credentials_to_add:
|
||||||
if credentials.identity:
|
if credentials.identity:
|
||||||
|
@ -58,15 +63,10 @@ class AggregatingPropagationCredentialsRepository(IPropagationCredentialsReposit
|
||||||
try:
|
try:
|
||||||
propagation_credentials = self._get_credentials_from_control_channel()
|
propagation_credentials = self._get_credentials_from_control_channel()
|
||||||
self.add_credentials(propagation_credentials)
|
self.add_credentials(propagation_credentials)
|
||||||
|
|
||||||
return self._stored_credentials
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
self._stored_credentials = {}
|
|
||||||
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
|
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")
|
||||||
|
|
||||||
@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
|
return self._stored_credentials
|
||||||
def _get_credentials_from_control_channel(self) -> Sequence[Credentials]:
|
|
||||||
return self._control_channel.get_credentials_for_propagation()
|
|
||||||
|
|
||||||
def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]):
|
def _set_attribute(self, attribute_to_be_set: str, credentials_values: Iterable[Any]):
|
||||||
if not credentials_values:
|
if not credentials_values:
|
||||||
|
|
|
@ -122,3 +122,16 @@ def test_all_keys_if_credentials_empty():
|
||||||
assert "exploit_password_list" in actual_stored_credentials
|
assert "exploit_password_list" in actual_stored_credentials
|
||||||
assert "exploit_ntlm_hash_list" in actual_stored_credentials
|
assert "exploit_ntlm_hash_list" in actual_stored_credentials
|
||||||
assert "exploit_ssh_keys" in actual_stored_credentials
|
assert "exploit_ssh_keys" in actual_stored_credentials
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_obtained_if_propagation_credentials_fails():
|
||||||
|
control_channel = MagicMock()
|
||||||
|
control_channel.get_credentials_for_propagation.return_value = EMPTY_CHANNEL_CREDENTIALS
|
||||||
|
control_channel.get_credentials_for_propagation.side_effect = Exception(
|
||||||
|
"No credentials for you!"
|
||||||
|
)
|
||||||
|
credentials_repository = AggregatingPropagationCredentialsRepository(control_channel)
|
||||||
|
|
||||||
|
credentials = credentials_repository.get_credentials()
|
||||||
|
|
||||||
|
assert credentials is not None
|
||||||
|
|
Loading…
Reference in New Issue