Agent: Remove duplication in SSHCredentialCollector

This commit is contained in:
Ilija Lazoroski 2022-08-16 17:14:37 +02:00
parent d38a386f67
commit 142136dd41
3 changed files with 39 additions and 45 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Dict, Iterable, Sequence from typing import Sequence
from common.credentials import Credentials, SSHKeypair, Username from common.credentials import Credentials
from common.event_queue import IEventQueue from common.event_queue import IEventQueue
from infection_monkey.credential_collectors.ssh_collector import ssh_handler from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet import ICredentialCollector from infection_monkey.i_puppet import ICredentialCollector
@ -24,30 +24,4 @@ class SSHCredentialCollector(ICredentialCollector):
ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue) ssh_info = ssh_handler.get_ssh_info(self._telemetry_messenger, self._event_queue)
logger.info("Finished scanning for SSH credentials") logger.info("Finished scanning for SSH credentials")
return SSHCredentialCollector._to_credentials(ssh_info) return ssh_handler.to_credentials(ssh_info)
@staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials

View File

@ -1,7 +1,7 @@
import glob import glob
import logging import logging
import os import os
from typing import Dict, Iterable from typing import Dict, Iterable, Sequence
from common.credentials import Credentials, SSHKeypair, Username from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue from common.event_queue import IEventQueue
@ -118,12 +118,7 @@ def _get_ssh_files(
) )
) )
collected_credentials = Credentials( collected_credentials = to_credentials([info])
identity=Username(info["name"]),
secret=SSHKeypair(
info["private_key"], info["public_key"]
),
)
_publish_credentials_stolen_event( _publish_credentials_stolen_event(
collected_credentials, event_queue collected_credentials, event_queue
) )
@ -142,6 +137,32 @@ def _get_ssh_files(
return user_info return user_info
def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
ssh_credentials = []
for info in ssh_info:
identity = None
secret = None
if info.get("name", ""):
identity = Username(info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
if info.get(key) is not None:
ssh_keypair[key] = info[key]
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
return ssh_credentials
def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue): def _publish_credentials_stolen_event(collected_credentials: Credentials, event_queue: IEventQueue):
credentials_stolen_event = CredentialsStolenEvent( credentials_stolen_event = CredentialsStolenEvent(
target=None, target=None,

View File

@ -1,9 +1,10 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pubsub.core import Publisher
from common.credentials import Credentials, SSHKeypair, Username from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IEventQueue from common.event_queue import IEventQueue, PyPubSubEventQueue
from infection_monkey.credential_collectors import SSHCredentialCollector from infection_monkey.credential_collectors import SSHCredentialCollector
@ -13,8 +14,8 @@ def patch_telemetry_messenger():
@pytest.fixture @pytest.fixture
def mock_event_queue(): def event_queue() -> IEventQueue:
return MagicMock(spec=IEventQueue) return PyPubSubEventQueue(Publisher())
def patch_ssh_handler(ssh_creds, monkeypatch): def patch_ssh_handler(ssh_creds, monkeypatch):
@ -27,17 +28,15 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])] "ssh_creds", [([{"name": "", "home_dir": "", "public_key": None, "private_key": None}]), ([])]
) )
def test_ssh_credentials_empty_results( def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_messenger):
monkeypatch, ssh_creds, patch_telemetry_messenger, mock_event_queue
):
patch_ssh_handler(ssh_creds, monkeypatch) patch_ssh_handler(ssh_creds, monkeypatch)
collected = SSHCredentialCollector( collected = SSHCredentialCollector(
patch_telemetry_messenger, mock_event_queue patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials() ).collect_credentials()
assert not collected assert not collected
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_event_queue): def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
ssh_creds = [ ssh_creds = [
{ {
@ -78,6 +77,6 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger, mock_ev
Credentials(identity=None, secret=ssh_keypair3), Credentials(identity=None, secret=ssh_keypair3),
] ]
collected = SSHCredentialCollector( collected = SSHCredentialCollector(
patch_telemetry_messenger, mock_event_queue patch_telemetry_messenger, MagicMock(spec=IEventQueue)
).collect_credentials() ).collect_credentials()
assert expected == collected assert expected == collected