Island: Simplify agent encryption calls

This commit is contained in:
Kekoa Kaaikala 2022-09-20 13:59:54 +00:00
parent 644f3628a5
commit c0869aebba
3 changed files with 10 additions and 23 deletions

View File

@ -1,23 +1,17 @@
import json import json
from typing import Callable, Iterable from typing import Callable
from common.agent_event_serializers import JSONSerializable from common.agent_event_serializers import JSONSerializable
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
ENCRYPTED_PREFIX = "encrypted_" ENCRYPTED_PREFIX = "encrypted_"
ABSTRACT_AGENT_EVENT_FIELDS = vars(AbstractAgentEvent)["__fields__"].keys()
SERIALIZED_EVENT_FIELDS = set(ABSTRACT_AGENT_EVENT_FIELDS) | set(["type"])
def get_fields_to_encrypt(event: AbstractAgentEvent):
"""
Get the fields of the event that are not part of the base AbstractAgentEvent.
"""
return set(vars(AbstractAgentEvent)["__fields__"].keys()) ^ set(event.dict().keys())
def encrypt_event( def encrypt_event(
encrypt: Callable[[bytes], bytes], encrypt: Callable[[bytes], bytes],
event_data: JSONSerializable, event_data: JSONSerializable,
fields: Iterable[str] = [],
) -> JSONSerializable: ) -> JSONSerializable:
""" """
Encrypt a serialized AbstractAgentEvent Encrypt a serialized AbstractAgentEvent
@ -27,7 +21,6 @@ def encrypt_event(
:param encrypt: Callable used to encrypt data :param encrypt: Callable used to encrypt data
:param event_data: Serialized event to encrypt :param event_data: Serialized event to encrypt
:param fields: Fields to encrypt
:return: Serialized event with the fields encrypted :return: Serialized event with the fields encrypted
:raises TypeError: If the serialized data is not a dict :raises TypeError: If the serialized data is not a dict
""" """
@ -35,7 +28,8 @@ def encrypt_event(
raise TypeError("Event encryption only supported for dict") raise TypeError("Event encryption only supported for dict")
data = event_data.copy() data = event_data.copy()
for field in fields: fields_to_encrypt = SERIALIZED_EVENT_FIELDS ^ set(event_data.keys())
for field in fields_to_encrypt:
data[ENCRYPTED_PREFIX + field] = str( data[ENCRYPTED_PREFIX + field] = str(
encrypt(json.dumps(event_data[field]).encode()), "utf-8" encrypt(json.dumps(event_data[field]).encode()), "utf-8"
) )

View File

@ -9,7 +9,7 @@ from monkey_island.cc.repository import IAgentEventRepository
from monkey_island.cc.server_utils.encryption import ILockableEncryptor from monkey_island.cc.server_utils.encryption import ILockableEncryptor
from . import RemovalError, RetrievalError, StorageError from . import RemovalError, RetrievalError, StorageError
from .agent_event_encryption import decrypt_event, encrypt_event, get_fields_to_encrypt from .agent_event_encryption import decrypt_event, encrypt_event
from .consts import MONGO_OBJECT_ID_KEY from .consts import MONGO_OBJECT_ID_KEY
@ -30,8 +30,7 @@ class MongoAgentEventRepository(IAgentEventRepository):
try: try:
serializer = self._serializers[type(event)] serializer = self._serializers[type(event)]
serialized_event = serializer.serialize(event) serialized_event = serializer.serialize(event)
fields = get_fields_to_encrypt(event) encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event)
encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event, fields)
self._events_collection.insert_one(encrypted_event) self._events_collection.insert_one(encrypted_event)
except Exception as err: except Exception as err:
raise StorageError(f"Error saving event: {err}") raise StorageError(f"Error saving event: {err}")

View File

@ -5,11 +5,7 @@ import pytest
from common.agent_event_serializers import PydanticAgentEventSerializer from common.agent_event_serializers import PydanticAgentEventSerializer
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from monkey_island.cc.repository.agent_event_encryption import ( from monkey_island.cc.repository.agent_event_encryption import decrypt_event, encrypt_event
decrypt_event,
encrypt_event,
get_fields_to_encrypt,
)
from monkey_island.cc.server_utils.encryption import RepositoryEncryptor from monkey_island.cc.server_utils.encryption import RepositoryEncryptor
@ -43,8 +39,7 @@ def serializer():
def test_agent_event_encryption__encrypts(encryptor, serializer): def test_agent_event_encryption__encrypts(encryptor, serializer):
data = serializer.serialize(EVENT) data = serializer.serialize(EVENT)
fields = get_fields_to_encrypt(EVENT) encrypted_data = encrypt_event(encryptor.encrypt, data)
encrypted_data = encrypt_event(encryptor.encrypt, data, fields)
# Encrypted fields have the "encrypted_" prefix # Encrypted fields have the "encrypted_" prefix
assert "encrypted_data" in encrypted_data assert "encrypted_data" in encrypted_data
@ -55,8 +50,7 @@ def test_agent_event_encryption__encrypts(encryptor, serializer):
def test_agent_event_encryption__decrypts(encryptor, serializer): def test_agent_event_encryption__decrypts(encryptor, serializer):
data = serializer.serialize(EVENT) data = serializer.serialize(EVENT)
fields = get_fields_to_encrypt(EVENT) encrypted_data = encrypt_event(encryptor.encrypt, data)
encrypted_data = encrypt_event(encryptor.encrypt, data, fields)
decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data) decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data)
deserialized_event = serializer.deserialize(decrypted_data) deserialized_event = serializer.deserialize(decrypted_data)