diff --git a/monkey/monkey_island/cc/repository/agent_event_encryption.py b/monkey/monkey_island/cc/repository/agent_event_encryption.py index 8904773ed..3e16855fe 100644 --- a/monkey/monkey_island/cc/repository/agent_event_encryption.py +++ b/monkey/monkey_island/cc/repository/agent_event_encryption.py @@ -1,23 +1,17 @@ import json -from typing import Callable, Iterable +from typing import Callable from common.agent_event_serializers import JSONSerializable from common.agent_events import AbstractAgentEvent ENCRYPTED_PREFIX = "encrypted_" - - -def get_fields_to_encrypt(event: AbstractAgentEvent): - """ - Get the fields of the event that are not part of the base AbstractAgentEvent. - """ - return set(vars(AbstractAgentEvent)["__fields__"].keys()) ^ set(event.dict().keys()) +ABSTRACT_AGENT_EVENT_FIELDS = vars(AbstractAgentEvent)["__fields__"].keys() +SERIALIZED_EVENT_FIELDS = set(ABSTRACT_AGENT_EVENT_FIELDS) | set(["type"]) def encrypt_event( encrypt: Callable[[bytes], bytes], event_data: JSONSerializable, - fields: Iterable[str] = [], ) -> JSONSerializable: """ Encrypt a serialized AbstractAgentEvent @@ -27,7 +21,6 @@ def encrypt_event( :param encrypt: Callable used to encrypt data :param event_data: Serialized event to encrypt - :param fields: Fields to encrypt :return: Serialized event with the fields encrypted :raises TypeError: If the serialized data is not a dict """ @@ -35,7 +28,8 @@ def encrypt_event( raise TypeError("Event encryption only supported for dict") data = event_data.copy() - for field in fields: + fields_to_encrypt = SERIALIZED_EVENT_FIELDS ^ set(event_data.keys()) + for field in fields_to_encrypt: data[ENCRYPTED_PREFIX + field] = str( encrypt(json.dumps(event_data[field]).encode()), "utf-8" ) diff --git a/monkey/monkey_island/cc/repository/mongo_agent_event_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_event_repository.py index 0194f3d3d..f483ac3d2 100644 --- a/monkey/monkey_island/cc/repository/mongo_agent_event_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_agent_event_repository.py @@ -9,7 +9,7 @@ from monkey_island.cc.repository import IAgentEventRepository from monkey_island.cc.server_utils.encryption import ILockableEncryptor from . import RemovalError, RetrievalError, StorageError -from .agent_event_encryption import decrypt_event, encrypt_event, get_fields_to_encrypt +from .agent_event_encryption import decrypt_event, encrypt_event from .consts import MONGO_OBJECT_ID_KEY @@ -30,8 +30,7 @@ class MongoAgentEventRepository(IAgentEventRepository): try: serializer = self._serializers[type(event)] serialized_event = serializer.serialize(event) - fields = get_fields_to_encrypt(event) - encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event, fields) + encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event) self._events_collection.insert_one(encrypted_event) except Exception as err: raise StorageError(f"Error saving event: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py index 947c50f80..35873a6cb 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py @@ -5,11 +5,7 @@ import pytest from common.agent_event_serializers import PydanticAgentEventSerializer from common.agent_events import AbstractAgentEvent -from monkey_island.cc.repository.agent_event_encryption import ( - decrypt_event, - encrypt_event, - get_fields_to_encrypt, -) +from monkey_island.cc.repository.agent_event_encryption import decrypt_event, encrypt_event from monkey_island.cc.server_utils.encryption import RepositoryEncryptor @@ -43,8 +39,7 @@ def serializer(): def test_agent_event_encryption__encrypts(encryptor, serializer): data = serializer.serialize(EVENT) - fields = get_fields_to_encrypt(EVENT) - encrypted_data = encrypt_event(encryptor.encrypt, data, fields) + encrypted_data = encrypt_event(encryptor.encrypt, data) # Encrypted fields have the "encrypted_" prefix assert "encrypted_data" in encrypted_data @@ -55,8 +50,7 @@ def test_agent_event_encryption__encrypts(encryptor, serializer): def test_agent_event_encryption__decrypts(encryptor, serializer): data = serializer.serialize(EVENT) - fields = get_fields_to_encrypt(EVENT) - encrypted_data = encrypt_event(encryptor.encrypt, data, fields) + encrypted_data = encrypt_event(encryptor.encrypt, data) decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data) deserialized_event = serializer.deserialize(decrypted_data)