Agent: Remove duplication in SSHCredentialCollector

This commit is contained in:
Ilija Lazoroski 2022-08-16 17:14:37 +02:00
parent d38a386f67
commit 142136dd41
3 changed files with 39 additions and 45 deletions

View File

@ -1,7 +1,7 @@
import logging
from typing import Dict, Iterable, Sequence
from typing import Sequence
from common.credentials import Credentials, SSHKeypair, Username
from common.credentials import Credentials
from common.event_queue import IEventQueue
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet import ICredentialCollector
@ -24,30 +24,4 @@ class SSHCredentialCollector(ICredentialCollector):
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
logger.info("Finished scanning for SSH credentials")
return SSHCredentialCollector._to_credentials(ssh_info)
@staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials
return ssh_handler.to_credentials(ssh_info)

View File

@ -1,7 +1,7 @@
import glob
import logging
import os
from typing import Dict, Iterable
from typing import Dict, Iterable, Sequence
from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue
@ -118,12 +118,7 @@ def _get_ssh_files(
)
)
collected_credentials = Credentials(
identity=Username(info["name"]),
secret=SSHKeypair(
info["private_key"], info["public_key"]
),
)
collected_credentials = to_credentials([info])
_publish_credentials_stolen_event(
collected_credentials, event_queue
)
@ -142,6 +137,32 @@ def _get_ssh_files(
return user_info
def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
credentials_stolen_event = CredentialsStolenEvent(
target=None,

View File

@ -1,9 +1,10 @@
from unittest.mock import MagicMock
import pytest
from pubsub.core import Publisher
from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue
from common.event_queue import IEventQueue, PyPubSubEventQueue
from infection_monkey.credential_collectors import SSHCredentialCollector
@ -13,8 +14,8 @@ def patch_telemetry_messenger():
@pytest.fixture
def mock_event_queue():
return MagicMock(spec=IEventQueue)
def event_queue() -> IEventQueue:
return PyPubSubEventQueue(Publisher())
def patch_ssh_handler(ssh_creds, monkeypatch):
@ -27,17 +28,15 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
@pytest.mark.parametrize(
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
)
def test_ssh_credentials_empty_results(
monkeypatch, ssh_creds, patch_telemetry_messenger, mock_event_queue
):
def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger):
patch_ssh_handler(ssh_creds, monkeypatch)
collected = SSHCredentialCollector(
patch_telemetry_messenger, mock_event_queue
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials()
assert not collected
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue):
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
ssh_creds = [
{
@ -78,6 +77,6 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_ev
Credentials(identity=None, secret=ssh_keypair3),
]
collected = SSHCredentialCollector(
patch_telemetry_messenger, mock_event_queue
patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials()
assert expected == collected