Agent: Change ICredentialCollector interface to return Sequence

Being able to check if the ICredentialCollector returned an empty
Sequence is useful and easier than checking for an "empty" Iterable.
This commit is contained in:
Mike Salvatore 2022-02-16 15:10:38 -05:00
parent 3a3a5f0c9c
commit 0880e16c54
5 changed files with 12 additions and 12 deletions

View File

@ -1,4 +1,4 @@
from typing import Iterable from typing import Sequence
from infection_monkey.credential_collectors import LMHash, NTHash, Password, Username from infection_monkey.credential_collectors import LMHash, NTHash, Password, Username
from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector from infection_monkey.i_puppet.credential_collection import Credentials, ICredentialCollector
@ -8,12 +8,12 @@ from .windows_credentials import WindowsCredentials
class MimikatzCredentialCollector(ICredentialCollector): class MimikatzCredentialCollector(ICredentialCollector):
def collect_credentials(self, options=None) -> Iterable[Credentials]: def collect_credentials(self, options=None) -> Sequence[Credentials]:
creds = pypykatz_handler.get_windows_creds() creds = pypykatz_handler.get_windows_creds()
return MimikatzCredentialCollector._to_credentials(creds) return MimikatzCredentialCollector._to_credentials(creds)
@staticmethod @staticmethod
def _to_credentials(win_creds: Iterable[WindowsCredentials]) -> [Credentials]: def _to_credentials(win_creds: Sequence[WindowsCredentials]) -> [Credentials]:
all_creds = [] all_creds = []
for win_cred in win_creds: for win_cred in win_creds:
identities = [] identities = []

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, Iterable, List from typing import Dict, Iterable, Sequence
from infection_monkey.credential_collectors import SSHKeypair, Username from infection_monkey.credential_collectors import SSHKeypair, Username
from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.credential_collectors.ssh_collector import ssh_handler
@ -17,7 +17,7 @@ class SSHCredentialCollector(ICredentialCollector):
def __init__(self, telemetry_messenger: ITelemetryMessenger): def __init__(self, telemetry_messenger: ITelemetryMessenger):
self._telemetry_messenger = telemetry_messenger self._telemetry_messenger = telemetry_messenger
def collect_credentials(self, _options=None) -> List[Credentials]: def collect_credentials(self, _options=None) -> Sequence[Credentials]:
logger.info("Started scanning for SSH credentials") logger.info("Started scanning for SSH credentials")
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger) ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger)
logger.info("Finished scanning for SSH credentials") logger.info("Finished scanning for SSH credentials")
@ -25,7 +25,7 @@ class SSHCredentialCollector(ICredentialCollector):
return SSHCredentialCollector._to_credentials(ssh_info) return SSHCredentialCollector._to_credentials(ssh_info)
@staticmethod @staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]: def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = [] ssh_credentials = []
for info in ssh_info: for info in ssh_info:

View File

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterable, Mapping, Optional from typing import Mapping, Optional, Sequence
from .credentials import Credentials from .credentials import Credentials
class ICredentialCollector(ABC): class ICredentialCollector(ABC):
@abstractmethod @abstractmethod
def collect_credentials(self, options: Optional[Mapping]) -> Iterable[Credentials]: def collect_credentials(self, options: Optional[Mapping]) -> Sequence[Credentials]:
pass pass

View File

@ -32,7 +32,7 @@ class Puppet(IPuppet):
credential_collector = self._plugin_registry.get_plugin( credential_collector = self._plugin_registry.get_plugin(
name, PluginType.CREDENTIAL_COLLECTOR name, PluginType.CREDENTIAL_COLLECTOR
) )
return list(credential_collector.collect_credentials(options)) return credential_collector.collect_credentials(options)
def run_pba(self, name: str, options: Dict) -> PostBreachData: def run_pba(self, name: str, options: Dict) -> PostBreachData:
return self._mock_puppet.run_pba(name, options) return self._mock_puppet.run_pba(name, options)

View File

@ -1,4 +1,4 @@
from typing import List from typing import Sequence
import pytest import pytest
@ -23,8 +23,8 @@ def patch_pypykatz(win_creds: [WindowsCredentials], monkeypatch):
) )
def collect_credentials() -> List[Credentials]: def collect_credentials() -> Sequence[Credentials]:
return list(MimikatzCredentialCollector().collect_credentials()) return MimikatzCredentialCollector().collect_credentials()
@pytest.mark.parametrize( @pytest.mark.parametrize(