Agent, Island: Use pydantic credentials and methods

Since the interface of credential serialization changed, code was modified to use the new interface
This commit is contained in:
vakarisz 2022-09-01 14:10:54 +03:00 committed by vakaris_zilius
parent f018b85f56
commit d73cbee591
13 changed files with 92 additions and 82 deletions

View File

@ -30,7 +30,7 @@ class MonkeyIslandClient(object):
def get_propagation_credentials(self) -> Sequence[Credentials]:
response = self.requests.get("api/propagation-credentials")
return [Credentials.from_mapping(credentials) for credentials in response.json()]
return [Credentials(**credentials) for credentials in response.json()]
@avoid_race_condition
def import_config(self, test_configuration: TestConfiguration):
@ -61,7 +61,7 @@ class MonkeyIslandClient(object):
@avoid_race_condition
def _import_credentials(self, propagation_credentials: Credentials):
serialized_propagation_credentials = [
Credentials.to_mapping(credentials) for credentials in propagation_credentials
Credentials.dict(credentials) for credentials in propagation_credentials
]
response = self.requests.put_json(
"/api/propagation-credentials/configured-credentials",

View File

@ -55,22 +55,22 @@ class MimikatzCredentialCollector(ICredentialCollector):
identity = None
if wc.username:
identity = Username(wc.username)
identity = Username(username=wc.username)
if wc.password:
password = Password(wc.password)
credentials.append(Credentials(identity, password))
password = Password(password=wc.password)
credentials.append(Credentials(identity=identity, secret=password))
if wc.lm_hash:
lm_hash = LMHash(lm_hash=wc.lm_hash)
credentials.append(Credentials(identity, lm_hash))
credentials.append(Credentials(identity=identity, secret=lm_hash))
if wc.ntlm_hash:
ntlm_hash = NTHash(nt_hash=wc.ntlm_hash)
credentials.append(Credentials(identity, ntlm_hash))
credentials.append(Credentials(identity=identity, secret=ntlm_hash))
if len(credentials) == 0 and identity is not None:
credentials.append(Credentials(identity, None))
credentials.append(Credentials(identity=identity, secret=None))
return credentials

View File

@ -149,7 +149,7 @@ def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
secret = None
if info.get("name", ""):
identity = Username(info["name"])
identity = Username(username=info["name"])
ssh_keypair = {}
for key in ["public_key", "private_key"]:
@ -158,11 +158,12 @@ def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:
if len(ssh_keypair):
secret = SSHKeypair(
ssh_keypair.get("private_key", ""), ssh_keypair.get("public_key", "")
private_key=ssh_keypair.get("private_key", ""),
public_key=ssh_keypair.get("public_key", ""),
)
if any([identity, secret]):
ssh_credentials.append(Credentials(identity, secret))
ssh_credentials.append(Credentials(identity=identity, secret=secret))
return ssh_credentials

View File

@ -1,4 +1,3 @@
import json
from typing import Iterable
from common.common_consts.telem_categories import TelemCategoryEnum
@ -24,4 +23,4 @@ class CredentialsTelem(BaseTelem):
super().send(log_data=False)
def get_data(self):
return [json.loads(Credentials.to_json(c)) for c in self._credentials]
return [c.dict(simplify=True) for c in self._credentials]

View File

@ -59,7 +59,7 @@ class MongoCredentialsRepository(ICredentialsRepository):
for encrypted_credentials in list_collection_result:
del encrypted_credentials[MONGO_OBJECT_ID_KEY]
plaintext_credentials = self._decrypt_credentials_mapping(encrypted_credentials)
collection_result.append(Credentials.from_mapping(plaintext_credentials))
collection_result.append(Credentials(**plaintext_credentials))
return collection_result
except Exception as err:
@ -68,7 +68,7 @@ class MongoCredentialsRepository(ICredentialsRepository):
def _save_credentials_to_collection(self, credentials: Sequence[Credentials], collection):
try:
for c in credentials:
encrypted_credentials = self._encrypt_credentials_mapping(Credentials.to_mapping(c))
encrypted_credentials = self._encrypt_credentials_mapping(c.dict())
collection.insert_one(encrypted_credentials)
except Exception as err:
raise StorageError(err)

View File

@ -29,7 +29,7 @@ class PropagationCredentials(AbstractResource):
return propagation_credentials, HTTPStatus.OK
def put(self, collection=None):
credentials = [Credentials.from_mapping(c) for c in request.json]
credentials = [Credentials.parse_raw(c) for c in request.json]
if collection == _configured_collection:
self._credentials_repository.remove_configured_credentials()
self._credentials_repository.save_configured_credentials(credentials)

View File

@ -41,9 +41,9 @@ def test_pypykatz_result_parsing(monkeypatch):
win_creds = [WindowsCredentials(username="user", password="secret", ntlm_hash="", lm_hash="")]
patch_pypykatz(win_creds, monkeypatch)
username = Username("user")
password = Password("secret")
expected_credentials = Credentials(username, password)
username = Username(username="user")
password = Password(password="secret")
expected_credentials = Credentials(identity=username, secret=password)
collected_credentials = collect_credentials()
assert len(collected_credentials) == 1
@ -70,10 +70,13 @@ def test_pypykatz_result_parsing_defaults(monkeypatch):
patch_pypykatz(win_creds, monkeypatch)
# Expected credentials
username = Username("user2")
password = Password("secret2")
lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
expected_credentials = [Credentials(username, password), Credentials(username, lm_hash)]
username = Username(username="user2")
password = Password(password="secret2")
lm_hash = LMHash(lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E")
expected_credentials = [
Credentials(identity=username, secret=password),
Credentials(identity=username, secret=lm_hash),
]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 2
@ -91,9 +94,12 @@ def test_pypykatz_result_parsing_no_identities(monkeypatch):
]
patch_pypykatz(win_creds, monkeypatch)
lm_hash = LMHash("0182BD0BD4444BF8FC83B5D9042EED2E")
nt_hash = NTHash("E9F85516721DDC218359AD5280DB4450")
expected_credentials = [Credentials(None, lm_hash), Credentials(None, nt_hash)]
lm_hash = LMHash(lm_hash="0182BD0BD4444BF8FC83B5D9042EED2E")
nt_hash = NTHash(nt_hash="E9F85516721DDC218359AD5280DB4450")
expected_credentials = [
Credentials(identity=None, secret=lm_hash),
Credentials(identity=None, secret=nt_hash),
]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 2
@ -112,7 +118,7 @@ def test_pypykatz_result_parsing_no_secrets(monkeypatch):
]
patch_pypykatz(win_creds, monkeypatch)
expected_credentials = [Credentials(Username(username), None)]
expected_credentials = [Credentials(identity=Username(username=username), secret=None)]
collected_credentials = collect_credentials()
assert len(collected_credentials) == 1

View File

@ -31,7 +31,6 @@ def test_ssh_credentials_empty_results(monkeypatch, ssh_creds, patch_telemetry_m
def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
ssh_creds = [
{
"name": "ubuntu",
@ -56,13 +55,15 @@ def test_ssh_info_result_parsing(monkeypatch, patch_telemetry_messenger):
patch_ssh_handler(ssh_creds, monkeypatch)
# Expected credentials
username = Username("ubuntu")
username2 = Username("mcus")
username3 = Username("guest")
username = Username(username="ubuntu")
username2 = Username(username="mcus")
username3 = Username(username="guest")
ssh_keypair1 = SSHKeypair("ExtremelyGoodPrivateKey", "SomePublicKeyUbuntu")
ssh_keypair2 = SSHKeypair("", "AnotherPublicKey")
ssh_keypair3 = SSHKeypair("PrivKey", "PubKey")
ssh_keypair1 = SSHKeypair(
private_key="ExtremelyGoodPrivateKey", public_key="SomePublicKeyUbuntu"
)
ssh_keypair2 = SSHKeypair(private_key="", public_key="AnotherPublicKey")
ssh_keypair3 = SSHKeypair(private_key="PrivKey", public_key="PubKey")
expected = [
Credentials(identity=username, secret=ssh_keypair1),

View File

@ -8,7 +8,11 @@ from infection_monkey.credential_repository import (
add_credentials_from_event_to_propagation_credentials_repository,
)
credentials = [Credentials(identity=Username("test_username"), secret=Password("some_password"))]
credentials = [
Credentials(
identity=Username(username="test_username"), secret=Password(password="some_password")
)
]
credentials_stolen_event = CredentialsStolenEvent(
source=UUID("f811ad00-5a68-4437-bd51-7b5cc1768ad5"),

View File

@ -2,13 +2,13 @@ from unittest.mock import MagicMock
import pytest
from tests.data_for_tests.propagation_credentials import (
CREDENTIALS,
LM_HASH,
NT_HASH,
PASSWORD_1,
PASSWORD_2,
PASSWORD_3,
PRIVATE_KEY,
PROPAGATION_CREDENTIALS,
PUBLIC_KEY,
SPECIAL_USERNAME,
USERNAME,
@ -17,7 +17,6 @@ from tests.data_for_tests.propagation_credentials import (
from common.credentials import Credentials, LMHash, NTHash, Password, SSHKeypair, Username
from infection_monkey.credential_repository import AggregatingPropagationCredentialsRepository
CONTROL_CHANNEL_CREDENTIALS = PROPAGATION_CREDENTIALS
TRANSFORMED_CONTROL_CHANNEL_CREDENTIALS = {
"exploit_user_list": {USERNAME, SPECIAL_USERNAME},
"exploit_password_list": {PASSWORD_1, PASSWORD_2, PASSWORD_3},
@ -41,26 +40,32 @@ STOLEN_PRIVATE_KEY_1 = "some_private_key_1"
STOLEN_PRIVATE_KEY_2 = "some_private_key_2"
STOLEN_CREDENTIALS = [
Credentials(
identity=Username(STOLEN_USERNAME_1),
secret=Password(PASSWORD_1),
identity=Username(username=STOLEN_USERNAME_1),
secret=Password(password=PASSWORD_1),
),
Credentials(identity=Username(STOLEN_USERNAME_1), secret=Password(STOLEN_PASSWORD_1)),
Credentials(
identity=Username(STOLEN_USERNAME_2),
identity=Username(username=STOLEN_USERNAME_1), secret=Password(password=STOLEN_PASSWORD_1)
),
Credentials(
identity=Username(username=STOLEN_USERNAME_2),
secret=SSHKeypair(public_key=STOLEN_PUBLIC_KEY_1, private_key=STOLEN_PRIVATE_KEY_1),
),
Credentials(
identity=None,
secret=Password(STOLEN_PASSWORD_2),
secret=Password(password=STOLEN_PASSWORD_2),
),
Credentials(identity=Username(STOLEN_USERNAME_2), secret=LMHash(STOLEN_LM_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_2), secret=NTHash(STOLEN_NT_HASH)),
Credentials(identity=Username(STOLEN_USERNAME_3), secret=None),
Credentials(
identity=Username(username=STOLEN_USERNAME_2), secret=LMHash(lm_hash=STOLEN_LM_HASH)
),
Credentials(
identity=Username(username=STOLEN_USERNAME_2), secret=NTHash(nt_hash=STOLEN_NT_HASH)
),
Credentials(identity=Username(username=STOLEN_USERNAME_3), secret=None),
]
STOLEN_SSH_KEYS_CREDENTIALS = [
Credentials(
identity=Username(USERNAME),
identity=Username(username=USERNAME),
secret=SSHKeypair(public_key=STOLEN_PUBLIC_KEY_2, private_key=STOLEN_PRIVATE_KEY_2),
)
]
@ -69,7 +74,7 @@ STOLEN_SSH_KEYS_CREDENTIALS = [
@pytest.fixture
def aggregating_credentials_repository() -> AggregatingPropagationCredentialsRepository:
control_channel = MagicMock()
control_channel.get_credentials_for_propagation.return_value = CONTROL_CHANNEL_CREDENTIALS
control_channel.get_credentials_for_propagation.return_value = CREDENTIALS
return AggregatingPropagationCredentialsRepository(control_channel)

View File

@ -13,13 +13,12 @@ PRIVATE_KEY = "priv_key"
@pytest.fixture
def credentials_for_test():
return Credentials(Username(USERNAME), Password(PASSWORD))
return Credentials(identity=Username(username=USERNAME), secret=Password(password=PASSWORD))
def test_credential_telem_send(spy_send_telemetry, credentials_for_test):
expected_data = [Credentials.to_mapping(credentials_for_test)]
expected_data = [credentials_for_test.dict(simplify=True)]
telem = CredentialsTelem([credentials_for_test])
telem.send()

View File

@ -6,14 +6,14 @@ import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from tests.data_for_tests.propagation_credentials import PROPAGATION_CREDENTIALS
from tests.data_for_tests.propagation_credentials import CREDENTIALS
from common.credentials import Credentials
from monkey_island.cc.repository import MongoCredentialsRepository
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
CONFIGURED_CREDENTIALS = PROPAGATION_CREDENTIALS[0:3]
STOLEN_CREDENTIALS = PROPAGATION_CREDENTIALS[3:]
CONFIGURED_CREDENTIALS = CREDENTIALS[0:3]
STOLEN_CREDENTIALS = CREDENTIALS[3:]
def reverse(data: bytes) -> bytes:
@ -59,9 +59,9 @@ def test_mongo_repository_get_all(mongo_repository):
def test_mongo_repository_configured(mongo_repository):
mongo_repository.save_configured_credentials(PROPAGATION_CREDENTIALS)
mongo_repository.save_configured_credentials(CREDENTIALS)
actual_configured_credentials = mongo_repository.get_configured_credentials()
assert actual_configured_credentials == PROPAGATION_CREDENTIALS
assert actual_configured_credentials == CREDENTIALS
mongo_repository.remove_configured_credentials()
actual_configured_credentials = mongo_repository.get_configured_credentials()
@ -82,7 +82,7 @@ def test_mongo_repository_all(mongo_repository):
mongo_repository.save_configured_credentials(CONFIGURED_CREDENTIALS)
mongo_repository.save_stolen_credentials(STOLEN_CREDENTIALS)
actual_credentials = mongo_repository.get_all_credentials()
assert actual_credentials == PROPAGATION_CREDENTIALS
assert actual_credentials == CREDENTIALS
mongo_repository.remove_all_credentials()
@ -91,7 +91,7 @@ def test_mongo_repository_all(mongo_repository):
assert mongo_repository.get_configured_credentials() == []
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
@pytest.mark.parametrize("credentials", CREDENTIALS)
def test_configured_secrets_encrypted(
mongo_repository: MongoCredentialsRepository,
mongo_client: MongoClient,
@ -101,14 +101,14 @@ def test_configured_secrets_encrypted(
check_if_stored_credentials_encrypted(mongo_client, credentials)
@pytest.mark.parametrize("credentials", PROPAGATION_CREDENTIALS)
@pytest.mark.parametrize("credentials", CREDENTIALS)
def test_stolen_secrets_encrypted(mongo_repository, mongo_client, credentials: Credentials):
mongo_repository.save_stolen_credentials([credentials])
check_if_stored_credentials_encrypted(mongo_client, credentials)
def check_if_stored_credentials_encrypted(mongo_client: MongoClient, original_credentials):
original_credentials_mapping = Credentials.to_mapping(original_credentials)
original_credentials_mapping = original_credentials.dict()
raw_credentials = get_all_credentials_in_mongo(mongo_client)
for rc in raw_credentials:

View File

@ -5,15 +5,10 @@ from urllib.parse import urljoin
import pytest
from tests.common import StubDIContainer
from tests.data_for_tests.propagation_credentials import (
LM_HASH_CREDENTIALS,
NT_HASH_CREDENTIALS,
PASSWORD_CREDENTIALS_1,
PASSWORD_CREDENTIALS_2,
)
from tests.data_for_tests.propagation_credentials import LM_HASH, NT_HASH, PASSWORD_1, PASSWORD_2
from tests.monkey_island import InMemoryCredentialsRepository
from common.credentials import Credentials
from common.credentials import Credentials, LMHash, NTHash, Password
from monkey_island.cc.repository import ICredentialsRepository
from monkey_island.cc.resources import PropagationCredentials
from monkey_island.cc.resources.propagation_credentials import (
@ -24,6 +19,10 @@ from monkey_island.cc.resources.propagation_credentials import (
ALL_CREDENTIALS_URL = PropagationCredentials.urls[0]
CONFIGURED_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL + "/", _configured_collection)
STOLEN_CREDENTIALS_URL = urljoin(ALL_CREDENTIALS_URL + "/", _stolen_collection)
CREDENTIALS_1 = Credentials(identity=None, secret=Password(password=PASSWORD_1))
CREDENTIALS_2 = Credentials(identity=None, secret=LMHash(lm_hash=LM_HASH))
CREDENTIALS_3 = Credentials(identity=None, secret=NTHash(nt_hash=NT_HASH))
CREDENTIALS_4 = Credentials(identity=None, secret=Password(password=PASSWORD_2))
@pytest.fixture
@ -42,20 +41,18 @@ def flask_client(build_flask_client, credentials_repository):
def test_propagation_credentials_endpoint_get(flask_client, credentials_repository):
credentials_repository.save_configured_credentials(
[PASSWORD_CREDENTIALS_1, NT_HASH_CREDENTIALS]
)
credentials_repository.save_stolen_credentials([LM_HASH_CREDENTIALS, PASSWORD_CREDENTIALS_2])
credentials_repository.save_configured_credentials([CREDENTIALS_1, CREDENTIALS_2])
credentials_repository.save_stolen_credentials([CREDENTIALS_3, CREDENTIALS_4])
resp = flask_client.get(ALL_CREDENTIALS_URL)
actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
actual_propagation_credentials = [Credentials(**creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 4
assert PASSWORD_CREDENTIALS_1 in actual_propagation_credentials
assert LM_HASH_CREDENTIALS in actual_propagation_credentials
assert NT_HASH_CREDENTIALS in actual_propagation_credentials
assert PASSWORD_CREDENTIALS_2 in actual_propagation_credentials
assert CREDENTIALS_1 in actual_propagation_credentials
assert CREDENTIALS_2 in actual_propagation_credentials
assert CREDENTIALS_3 in actual_propagation_credentials
assert CREDENTIALS_4 in actual_propagation_credentials
def pre_populate_repository(
@ -69,24 +66,22 @@ def pre_populate_repository(
@pytest.mark.parametrize("url", [CONFIGURED_CREDENTIALS_URL, STOLEN_CREDENTIALS_URL])
def test_propagation_credentials_endpoint__get_stolen(flask_client, credentials_repository, url):
pre_populate_repository(
url, credentials_repository, [PASSWORD_CREDENTIALS_1, LM_HASH_CREDENTIALS]
)
pre_populate_repository(url, credentials_repository, [CREDENTIALS_1, CREDENTIALS_2])
resp = flask_client.get(url)
actual_propagation_credentials = [Credentials.from_mapping(creds) for creds in resp.json]
actual_propagation_credentials = [Credentials(**creds) for creds in resp.json]
assert resp.status_code == HTTPStatus.OK
assert len(actual_propagation_credentials) == 2
assert actual_propagation_credentials[0] == PASSWORD_CREDENTIALS_1
assert actual_propagation_credentials[1] == LM_HASH_CREDENTIALS
assert actual_propagation_credentials[0].secret.password == PASSWORD_1
assert actual_propagation_credentials[1].secret.lm_hash == LM_HASH
def test_configured_propagation_credentials_endpoint_put(flask_client, credentials_repository):
pre_populate_repository(
CONFIGURED_CREDENTIALS_URL,
credentials_repository,
[PASSWORD_CREDENTIALS_1, LM_HASH_CREDENTIALS],
[CREDENTIALS_1, CREDENTIALS_2],
)
resp = flask_client.put(CONFIGURED_CREDENTIALS_URL, json=[])
assert resp.status_code == HTTPStatus.NO_CONTENT