Agent: Remove duplication in SSHCredentialCollector
This commit is contained in:
parent
d38a386f67
commit
142136dd41
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from common.credentials import Credentials, SSHKeypair, Username
|
from common.credentials import Credentials
|
||||||
from common.event_queue import IEventQueue
|
from common.event_queue import IEventQueue
|
||||||
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
|
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
|
||||||
from infection_monkey.i_puppet import ICredentialCollector
|
from infection_monkey.i_puppet import ICredentialCollector
|
||||||
|
@ -24,30 +24,4 @@ class SSHCredentialCollector(ICredentialCollector):
|
||||||
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
|
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
|
||||||
logger.info("Finished scanning for SSH credentials")
|
logger.info("Finished scanning for SSH credentials")
|
||||||
|
|
||||||
return SSHCredentialCollector._to_credentials(ssh_info)
|
return ssh_handler.to_credentials(ssh_info)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
|
|
||||||
ssh_credentials = []
|
|
||||||
|
|
||||||
for info in ssh_info:
|
|
||||||
identity = None
|
|
||||||
secret = None
|
|
||||||
|
|
||||||
if info.get("name", ""):
|
|
||||||
identity = Username(info["name"])
|
|
||||||
|
|
||||||
ssh_keypair = {}
|
|
||||||
for key in ["public_key", "private_key"]:
|
|
||||||
if info.get(key) is not None:
|
|
||||||
ssh_keypair[key] = info[key]
|
|
||||||
|
|
||||||
if len(ssh_keypair):
|
|
||||||
secret = SSHKeypair(
|
|
||||||
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
if any([identity, secret]):
|
|
||||||
ssh_credentials.append(Credentials(identity, secret))
|
|
||||||
|
|
||||||
return ssh_credentials
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Iterable
|
from typing import Dict, Iterable, Sequence
|
||||||
|
|
||||||
from common.credentials import Credentials, SSHKeypair, Username
|
from common.credentials import Credentials, SSHKeypair, Username
|
||||||
from common.event_queue import IEventQueue
|
from common.event_queue import IEventQueue
|
||||||
|
@ -118,12 +118,7 @@ def _get_ssh_files(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
collected_credentials = Credentials(
|
collected_credentials = to_credentials([info])
|
||||||
identity=Username(info["name"]),
|
|
||||||
secret=SSHKeypair(
|
|
||||||
info["private_key"], info["public_key"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
_publish_credentials_stolen_event(
|
_publish_credentials_stolen_event(
|
||||||
collected_credentials, event_queue
|
collected_credentials, event_queue
|
||||||
)
|
)
|
||||||
|
@ -142,6 +137,32 @@ def _get_ssh_files(
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
|
||||||
|
def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
|
||||||
|
ssh_credentials = []
|
||||||
|
|
||||||
|
for info in ssh_info:
|
||||||
|
identity = None
|
||||||
|
secret = None
|
||||||
|
|
||||||
|
if info.get("name", ""):
|
||||||
|
identity = Username(info["name"])
|
||||||
|
|
||||||
|
ssh_keypair = {}
|
||||||
|
for key in ["public_key", "private_key"]:
|
||||||
|
if info.get(key) is not None:
|
||||||
|
ssh_keypair[key] = info[key]
|
||||||
|
|
||||||
|
if len(ssh_keypair):
|
||||||
|
secret = SSHKeypair(
|
||||||
|
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if any([identity, secret]):
|
||||||
|
ssh_credentials.append(Credentials(identity, secret))
|
||||||
|
|
||||||
|
return ssh_credentials
|
||||||
|
|
||||||
|
|
||||||
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
|
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
|
||||||
credentials_stolen_event = CredentialsStolenEvent(
|
credentials_stolen_event = CredentialsStolenEvent(
|
||||||
target=None,
|
target=None,
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pubsub.core import Publisher
|
||||||
|
|
||||||
from common.credentials import Credentials, SSHKeypair, Username
|
from common.credentials import Credentials, SSHKeypair, Username
|
||||||
from common.event_queue import IEventQueue
|
from common.event_queue import IEventQueue, PyPubSubEventQueue
|
||||||
from infection_monkey.credential_collectors import SSHCredentialCollector
|
from infection_monkey.credential_collectors import SSHCredentialCollector
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,8 +14,8 @@ def patch_telemetry_messenger():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_event_queue():
|
def event_queue() -> IEventQueue:
|
||||||
return MagicMock(spec=IEventQueue)
|
return PyPubSubEventQueue(Publisher())
|
||||||
|
|
||||||
|
|
||||||
def patch_ssh_handler(ssh_creds, monkeypatch):
|
def patch_ssh_handler(ssh_creds, monkeypatch):
|
||||||
|
@ -27,17 +28,15 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
|
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
|
||||||
)
|
)
|
||||||
def test_ssh_credentials_empty_results(
|
def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger):
|
||||||
monkeypatch, ssh_creds, patch_telemetry_messenger, mock_event_queue
|
|
||||||
):
|
|
||||||
patch_ssh_handler(ssh_creds, monkeypatch)
|
patch_ssh_handler(ssh_creds, monkeypatch)
|
||||||
collected = SSHCredentialCollector(
|
collected = SSHCredentialCollector(
|
||||||
patch_telemetry_messenger, mock_event_queue
|
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
|
||||||
).collect_credentials()
|
).collect_credentials()
|
||||||
assert not collected
|
assert not collected
|
||||||
|
|
||||||
|
|
||||||
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue):
|
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
|
||||||
|
|
||||||
ssh_creds = [
|
ssh_creds = [
|
||||||
{
|
{
|
||||||
|
@ -78,6 +77,6 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_ev
|
||||||
Credentials(identity=None, secret=ssh_keypair3),
|
Credentials(identity=None, secret=ssh_keypair3),
|
||||||
]
|
]
|
||||||
collected = SSHCredentialCollector(
|
collected = SSHCredentialCollector(
|
||||||
patch_telemetry_messenger, mock_event_queue
|
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
|
||||||
).collect_credentials()
|
).collect_credentials()
|
||||||
assert expected == collected
|
assert expected == collected
|
||||||
|
|
Loading…
Reference in New Issue