Agent, Common: Refactor pydantic credentials to use SecretStr

This commit is contained in:
vakaris_zilius 2022-09-06 14:29:42 +00:00
parent 17e3b3d205
commit ece4d9383e
16 changed files with 132 additions and 41 deletions

View File

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import Optional, Union from typing import Optional, Union
from pydantic import SecretBytes, SecretStr
from ..base_models import InfectionMonkeyBaseModel from ..base_models import InfectionMonkeyBaseModel
from . import LMHash, NTHash, Password, SSHKeypair, Username from . import LMHash, NTHash, Password, SSHKeypair, Username
@ -9,6 +11,25 @@ Secret = Union[Password, LMHash, NTHash, SSHKeypair]
Identity = Username Identity = Username
def get_plain_text(secret: Union[SecretStr, SecretBytes, None, str]) -> Optional[str]:
if secret:
return secret.get_secret_value()
else:
return secret
class Credentials(InfectionMonkeyBaseModel): class Credentials(InfectionMonkeyBaseModel):
"""Represents a credential pair (some form of identity and a secret)"""
identity: Optional[Identity] identity: Optional[Identity]
"""Identity part of credentials, like a username or an email"""
secret: Optional[Secret] secret: Optional[Secret]
"""Secret part of credentials, like a password or a hash"""
class Config:
json_encoders = {
# This makes secrets dumpable to json, but not loggable
SecretStr: lambda v: v.get_secret_value() if v else None,
SecretBytes: lambda v: v.get_secret_value() if v else None,
}

View File

@ -1,16 +1,16 @@
import re import re
from pydantic import validator from pydantic import SecretStr, validator
from pydantic.main import BaseModel
from ..base_models import InfectionMonkeyBaseModel
from .validators import ntlm_hash_regex from .validators import ntlm_hash_regex
class LMHash(BaseModel): class LMHash(InfectionMonkeyBaseModel):
lm_hash: str lm_hash: SecretStr
@validator("lm_hash") @validator("lm_hash")
def validate_hash_format(cls, nt_hash): def validate_hash_format(cls, lm_hash):
if not re.match(ntlm_hash_regex, nt_hash): if not re.match(ntlm_hash_regex, lm_hash.get_secret_value()):
raise ValueError(f"Invalid LM hash provided: {nt_hash}") raise ValueError("Invalid LM hash provided")
return nt_hash return lm_hash

View File

@ -1,16 +1,16 @@
import re import re
from pydantic import validator from pydantic import SecretStr, validator
from ..base_models import InfectionMonkeyBaseModel from ..base_models import InfectionMonkeyBaseModel
from .validators import ntlm_hash_regex from .validators import ntlm_hash_regex
class NTHash(InfectionMonkeyBaseModel): class NTHash(InfectionMonkeyBaseModel):
nt_hash: str nt_hash: SecretStr
@validator("nt_hash") @validator("nt_hash")
def validate_hash_format(cls, nt_hash): def validate_hash_format(cls, nt_hash):
if not re.match(ntlm_hash_regex, nt_hash): if not re.match(ntlm_hash_regex, nt_hash.get_secret_value()):
raise ValueError(f"Invalid NT hash provided: {nt_hash}") raise ValueError("Invalid NT hash provided")
return nt_hash return nt_hash

View File

@ -1,5 +1,7 @@
from pydantic import SecretStr
from ..base_models import InfectionMonkeyBaseModel from ..base_models import InfectionMonkeyBaseModel
class Password(InfectionMonkeyBaseModel): class Password(InfectionMonkeyBaseModel):
password: str password: SecretStr

View File

@ -1,6 +1,8 @@
from pydantic import SecretStr
from ..base_models import InfectionMonkeyBaseModel from ..base_models import InfectionMonkeyBaseModel
class SSHKeypair(InfectionMonkeyBaseModel): class SSHKeypair(InfectionMonkeyBaseModel):
private_key: str private_key: SecretStr
public_key: str public_key: str

View File

@ -6,6 +6,7 @@ from typing import Sequence, Tuple
import pymssql import pymssql
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.credentials.credentials import get_plain_text
from common.utils.exceptions import FailedExploitationError from common.utils.exceptions import FailedExploitationError
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.tools.helpers import get_agent_dst_path from infection_monkey.exploit.tools.helpers import get_agent_dst_path
@ -111,7 +112,7 @@ class MSSQLExploiter(HostExploiter):
conn = pymssql.connect( conn = pymssql.connect(
host, host,
user, user,
password, get_plain_text(password),
port=port, port=port,
login_timeout=self.LOGIN_TIMEOUT, login_timeout=self.LOGIN_TIMEOUT,
timeout=self.QUERY_TIMEOUT, timeout=self.QUERY_TIMEOUT,

View File

@ -11,6 +11,7 @@ from pypsrp.powershell import PowerShell, RunspacePool
from typing_extensions import Protocol from typing_extensions import Protocol
from urllib3 import connectionpool from urllib3 import connectionpool
from common.credentials.credentials import get_plain_text
from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions
from infection_monkey.exploit.powershell_utils.credentials import Credentials, SecretType from infection_monkey.exploit.powershell_utils.credentials import Credentials, SecretType
@ -42,14 +43,16 @@ def format_password(credentials: Credentials) -> Optional[str]:
if credentials.secret_type == SecretType.CACHED: if credentials.secret_type == SecretType.CACHED:
return None return None
plaintext_secret = get_plain_text(credentials.secret)
if credentials.secret_type == SecretType.PASSWORD: if credentials.secret_type == SecretType.PASSWORD:
return credentials.secret return plaintext_secret
if credentials.secret_type == SecretType.LM_HASH: if credentials.secret_type == SecretType.LM_HASH:
return f"{credentials.secret}:00000000000000000000000000000000" return f"{plaintext_secret}:00000000000000000000000000000000"
if credentials.secret_type == SecretType.NT_HASH: if credentials.secret_type == SecretType.NT_HASH:
return f"00000000000000000000000000000000:{credentials.secret}" return f"00000000000000000000000000000000:{plaintext_secret}"
raise ValueError(f"Unknown secret type {credentials.secret_type}") raise ValueError(f"Unknown secret type {credentials.secret_type}")

View File

@ -4,6 +4,7 @@ from impacket.dcerpc.v5 import scmr, transport
from impacket.dcerpc.v5.scmr import DCERPCSessionError from impacket.dcerpc.v5.scmr import DCERPCSessionError
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.credentials.credentials import get_plain_text
from common.utils.attack_utils import ScanStatus, UsageEnum from common.utils.attack_utils import ScanStatus, UsageEnum
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.tools.helpers import get_agent_dst_path from infection_monkey.exploit.tools.helpers import get_agent_dst_path
@ -107,7 +108,14 @@ class SMBExploiter(HostExploiter):
rpctransport.setRemoteHost(self.host.ip_addr) rpctransport.setRemoteHost(self.host.ip_addr)
if hasattr(rpctransport, "set_credentials"): if hasattr(rpctransport, "set_credentials"):
# This method exists only for selected protocol sequences. # This method exists only for selected protocol sequences.
rpctransport.set_credentials(user, password, "", lm_hash, ntlm_hash, None) rpctransport.set_credentials(
user,
get_plain_text(password),
"",
get_plain_text(lm_hash),
get_plain_text(ntlm_hash),
None,
)
rpctransport.set_kerberos(SMBExploiter.USE_KERBEROS) rpctransport.set_kerberos(SMBExploiter.USE_KERBEROS)
scmr_rpc = rpctransport.get_dce_rpc() scmr_rpc = rpctransport.get_dce_rpc()
@ -116,7 +124,8 @@ class SMBExploiter(HostExploiter):
scmr_rpc.connect() scmr_rpc.connect()
except Exception as exc: except Exception as exc:
logger.debug( logger.debug(
f"Can't connect to SCM on exploited machine {self.host}, port {port} : {exc}" f"Can't connect to SCM on exploited machine {self.host}, port {port} : "
f"{exc}"
) )
continue continue

View File

@ -5,6 +5,7 @@ from pathlib import PurePath
import paramiko import paramiko
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.credentials.credentials import get_plain_text
from common.utils import Timer from common.utils import Timer
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
from common.utils.exceptions import FailedExploitationError from common.utils.exceptions import FailedExploitationError
@ -59,7 +60,7 @@ class SSHExploiter(HostExploiter):
for user, ssh_key_pair in ssh_key_pairs_iterator: for user, ssh_key_pair in ssh_key_pairs_iterator:
# Creating file-like private key for paramiko # Creating file-like private key for paramiko
pkey = io.StringIO(ssh_key_pair["private_key"]) pkey = io.StringIO(get_plain_text(ssh_key_pair["private_key"]))
ssh_string = "%s@%s" % (user, self.host.ip_addr) ssh_string = "%s@%s" % (user, self.host.ip_addr)
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()

View File

@ -8,7 +8,9 @@ from typing import Optional
from impacket.dcerpc.v5 import srvs, transport from impacket.dcerpc.v5 import srvs, transport
from impacket.smb3structs import SMB2_DIALECT_002, SMB2_DIALECT_21 from impacket.smb3structs import SMB2_DIALECT_002, SMB2_DIALECT_21
from impacket.smbconnection import SMB_DIALECT, SMBConnection from impacket.smbconnection import SMB_DIALECT, SMBConnection
from pydantic import SecretStr
from common.credentials.credentials import get_plain_text
from common.utils.attack_utils import ScanStatus from common.utils.attack_utils import ScanStatus
from infection_monkey.network.tools import get_interface_to_target from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.telemetry.attack.t1105_telem import T1105Telem from infection_monkey.telemetry.attack.t1105_telem import T1105Telem
@ -26,8 +28,8 @@ class SmbTools(object):
host, host,
agent_file: BytesIO, agent_file: BytesIO,
dst_path: PurePath, dst_path: PurePath,
username, username: str,
password, password: SecretStr,
lm_hash="", lm_hash="",
ntlm_hash="", ntlm_hash="",
timeout=30, timeout=30,
@ -190,7 +192,9 @@ class SmbTools(object):
return remote_full_path return remote_full_path
@staticmethod @staticmethod
def new_smb_connection(host, username, password, lm_hash="", ntlm_hash="", timeout=30): def new_smb_connection(
host, username: str, password: SecretStr, lm_hash="", ntlm_hash="", timeout=30
):
try: try:
smb = SMBConnection(host.ip_addr, host.ip_addr, sess_port=445) smb = SMBConnection(host.ip_addr, host.ip_addr, sess_port=445)
except Exception as exc: except Exception as exc:
@ -212,7 +216,13 @@ class SmbTools(object):
# we know this should work because the WMI connection worked # we know this should work because the WMI connection worked
try: try:
smb.login(username, password, "", lm_hash, ntlm_hash) smb.login(
username,
get_plain_text(password),
"",
get_plain_text(lm_hash),
get_plain_text(ntlm_hash),
)
except Exception as exc: except Exception as exc:
logger.error(f'Error while logging into {host} using user "{username}": {exc}') logger.error(f'Error while logging into {host} using user "{username}": {exc}')
return None, dialect return None, dialect

View File

@ -5,6 +5,7 @@ import traceback
from impacket.dcerpc.v5.rpcrt import DCERPCException from impacket.dcerpc.v5.rpcrt import DCERPCException
from common.credentials.credentials import get_plain_text
from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.HostExploiter import HostExploiter
from infection_monkey.exploit.tools.helpers import get_agent_dst_path from infection_monkey.exploit.tools.helpers import get_agent_dst_path
from infection_monkey.exploit.tools.smb_tools import SmbTools from infection_monkey.exploit.tools.smb_tools import SmbTools
@ -44,7 +45,14 @@ class WmiExploiter(HostExploiter):
wmi_connection = WmiTools.WmiConnection() wmi_connection = WmiTools.WmiConnection()
try: try:
wmi_connection.connect(self.host, user, password, None, lm_hash, ntlm_hash) wmi_connection.connect(
self.host,
user,
get_plain_text(password),
None,
get_plain_text(lm_hash),
get_plain_text(ntlm_hash),
)
except AccessDeniedException: except AccessDeniedException:
self.report_login_attempt(False, user, password, lm_hash, ntlm_hash) self.report_login_attempt(False, user, password, lm_hash, ntlm_hash)
logger.debug(f"Failed connecting to {self.host} using WMI") logger.debug(f"Failed connecting to {self.host} using WMI")

View File

@ -299,8 +299,8 @@ class ZerologonExploiter(HostExploiter):
self, user: str, lmhash: str, nthash: str self, user: str, lmhash: str, nthash: str
) -> None: ) -> None:
extracted_credentials = [ extracted_credentials = [
Credentials(Username(user), LMHash(lmhash)), Credentials(identity=Username(username=user), secret=LMHash(lm_hash=lmhash)),
Credentials(Username(user), NTHash(nthash)), Credentials(identity=Username(username=user), secret=NTHash(nt_hash=nthash)),
] ]
self.telemetry_messenger.send_telemetry(CredentialsTelem(extracted_credentials)) self.telemetry_messenger.send_telemetry(CredentialsTelem(extracted_credentials))

View File

@ -1,16 +1,18 @@
from itertools import product from itertools import product
from pydantic import SecretStr
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
USERNAME = "m0nk3y_user" USERNAME = "m0nk3y_user"
SPECIAL_USERNAME = "m0nk3y.user" SPECIAL_USERNAME = "m0nk3y.user"
NT_HASH = "C1C58F96CDF212B50837BC11A00BE47C" NT_HASH = SecretStr("C1C58F96CDF212B50837BC11A00BE47C")
LM_HASH = "299BD128C1101FD6299BD128C1101FD6" LM_HASH = SecretStr("299BD128C1101FD6299BD128C1101FD6")
PASSWORD_1 = "trytostealthis" PASSWORD_1 = SecretStr("trytostealthis")
PASSWORD_2 = "password!" PASSWORD_2 = SecretStr("password!")
PASSWORD_3 = "rubberbabybuggybumpers" PASSWORD_3 = SecretStr("rubberbabybuggybumpers")
PUBLIC_KEY = "MY_PUBLIC_KEY" PUBLIC_KEY = "MY_PUBLIC_KEY"
PRIVATE_KEY = "MY_PRIVATE_KEY" PRIVATE_KEY = SecretStr("MY_PRIVATE_KEY")
IDENTITIES = [Username(username=USERNAME), None, Username(username=SPECIAL_USERNAME)] IDENTITIES = [Username(username=USERNAME), None, Username(username=SPECIAL_USERNAME)]
IDENTITY_DICTS = [{"username": USERNAME}, None] IDENTITY_DICTS = [{"username": USERNAME}, None]

View File

@ -1,6 +1,11 @@
import logging
import pytest import pytest
from pydantic import SecretBytes
from pydantic.types import SecretStr
from tests.data_for_tests.propagation_credentials import CREDENTIALS, CREDENTIALS_DICTS from tests.data_for_tests.propagation_credentials import CREDENTIALS, CREDENTIALS_DICTS
from common.base_models import InfectionMonkeyBaseModel
from common.credentials import Credentials from common.credentials import Credentials
@ -12,3 +17,28 @@ def test_credentials_serialization_json(credentials, expected_credentials_dict):
deserialized_credentials = Credentials.parse_raw(serialized_credentials) deserialized_credentials = Credentials.parse_raw(serialized_credentials)
assert credentials == deserialized_credentials assert credentials == deserialized_credentials
logger = logging.getLogger()
logger.level = logging.DEBUG
def test_credentials_secrets_not_logged(caplog):
class TestSecret(InfectionMonkeyBaseModel):
some_secret: SecretStr
some_secret_in_bytes: SecretBytes
class TestCredentials(Credentials):
secret: TestSecret
sensitive = "super_secret"
creds = TestCredentials(
identity=None,
secret=TestSecret(some_secret=sensitive, some_secret_in_bytes=sensitive.encode()),
)
logging.getLogger().info(
f"{creds.secret.some_secret} and" f" {creds.secret.some_secret_in_bytes}"
)
assert sensitive not in caplog.text

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pydantic import SecretStr
from tests.data_for_tests.propagation_credentials import ( from tests.data_for_tests.propagation_credentials import (
CREDENTIALS, CREDENTIALS,
LM_HASH, LM_HASH,
@ -30,14 +31,14 @@ EMPTY_CHANNEL_CREDENTIALS = []
STOLEN_USERNAME_1 = "user1" STOLEN_USERNAME_1 = "user1"
STOLEN_USERNAME_2 = "user2" STOLEN_USERNAME_2 = "user2"
STOLEN_USERNAME_3 = "user3" STOLEN_USERNAME_3 = "user3"
STOLEN_PASSWORD_1 = "abcdefg" STOLEN_PASSWORD_1 = SecretStr("abcdefg")
STOLEN_PASSWORD_2 = "super_secret" STOLEN_PASSWORD_2 = SecretStr("super_secret")
STOLEN_PUBLIC_KEY_1 = "some_public_key_1" STOLEN_PUBLIC_KEY_1 = "some_public_key_1"
STOLEN_PUBLIC_KEY_2 = "some_public_key_2" STOLEN_PUBLIC_KEY_2 = "some_public_key_2"
STOLEN_LM_HASH = "AAD3B435B51404EEAAD3B435B51404EE" STOLEN_LM_HASH = SecretStr("AAD3B435B51404EEAAD3B435B51404EE")
STOLEN_NT_HASH = "C0172DFF622FE29B5327CB79DC12D24C" STOLEN_NT_HASH = SecretStr("C0172DFF622FE29B5327CB79DC12D24C")
STOLEN_PRIVATE_KEY_1 = "some_private_key_1" STOLEN_PRIVATE_KEY_1 = SecretStr("some_private_key_1")
STOLEN_PRIVATE_KEY_2 = "some_private_key_2" STOLEN_PRIVATE_KEY_2 = SecretStr("some_private_key_2")
STOLEN_CREDENTIALS = [ STOLEN_CREDENTIALS = [
Credentials( Credentials(
identity=Username(username=STOLEN_USERNAME_1), identity=Username(username=STOLEN_USERNAME_1),

View File

@ -7,7 +7,7 @@ from common.agent_configuration.agent_sub_configurations import (
CustomPBAConfiguration, CustomPBAConfiguration,
ScanTargetConfiguration, ScanTargetConfiguration,
) )
from common.credentials import LMHash, NTHash from common.credentials import Credentials, LMHash, NTHash
from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory from infection_monkey.exploit.log4shell_utils.ldap_server import LDAPServerFactory
from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue from monkey_island.cc.event_queue import IslandEventTopic, PyPubSubIslandEventQueue
from monkey_island.cc.models import Report from monkey_island.cc.models import Report
@ -167,6 +167,7 @@ strict_slashes # unused attribute (monkey/monkey_island/cc/app.py:96)
post_breach_actions # unused variable (monkey\infection_monkey\config.py:95) post_breach_actions # unused variable (monkey\infection_monkey\config.py:95)
LMHash.validate_hash_format LMHash.validate_hash_format
NTHash.validate_hash_format NTHash.validate_hash_format
Credentials.Config.json_encoders
# Deployments # Deployments
DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5) DEVELOP # unused variable (monkey/monkey/monkey_island/cc/deployment.py:5)