diff --git a/monkey/monkey_island/cc/repository/agent_event_encryption.py b/monkey/monkey_island/cc/repository/agent_event_encryption.py new file mode 100644 index 000000000..83469613a --- /dev/null +++ b/monkey/monkey_island/cc/repository/agent_event_encryption.py @@ -0,0 +1,44 @@ +import json +from typing import Callable, Iterable + +from common.agent_event_serializers import JSONSerializable +from common.agent_events import AbstractAgentEvent + +ENCRYPTED_PREFIX = "encrypted_" + + +def get_fields_to_encrypt(event: AbstractAgentEvent): + return set(vars(AbstractAgentEvent)["__fields__"].keys()) ^ set(event.dict().keys()) + + +def encrypt_event( + encrypt: Callable[[bytes], bytes], + event_data: JSONSerializable, + fields: Iterable[str] = [], +) -> JSONSerializable: + if not isinstance(event_data, dict): + raise TypeError("Event encryption only supported for dict") + + for field in fields: + event_data[ENCRYPTED_PREFIX + field] = str( + encrypt(json.dumps(event_data[field]).encode()), "utf-8" + ) + del event_data[field] + + return event_data + + +def decrypt_event( + decrypt: Callable[[bytes], bytes], event_data: JSONSerializable +) -> JSONSerializable: + if not isinstance(event_data, dict): + raise TypeError("Event decryption only supported for dict") + + for field in event_data.keys(): + if field.startswith("encrypted_"): + event_data[field[len(ENCRYPTED_PREFIX) :]] = json.loads( + str(decrypt(event_data[field].encode()), "utf-8") + ) + del event_data[field] + + return event_data diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py new file mode 100644 index 000000000..0ebd0a5c6 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_agent_event_encryption.py @@ -0,0 +1,57 @@ +import uuid + +import pytest + +from common.agent_event_serializers import PydanticAgentEventSerializer +from common.agent_events import AbstractAgentEvent +from monkey_island.cc.repository.agent_event_encryption import ( + decrypt_event, + encrypt_event, + get_fields_to_encrypt, +) +from monkey_island.cc.server_utils.encryption import RepositoryEncryptor + + +class FakeAgentEvent(AbstractAgentEvent): + data: str + + +@pytest.fixture +def key_file(tmp_path): + return tmp_path / "test_key.bin" + + +@pytest.fixture +def encryptor(key_file): + encryptor = RepositoryEncryptor(key_file) + encryptor.unlock(b"password") + return encryptor + + +@pytest.fixture +def serializer(): + return PydanticAgentEventSerializer() + + +def test_agent_event_encryption__encrypts(encryptor, serializer): + event = FakeAgentEvent(source=uuid.uuid4(), data="foo") + data = serializer.serialize(event) + fields = get_fields_to_encrypt(event) + encrypted_data = encrypt_event(encryptor.encrypt, data, fields) + + # Encrypted fields have the "encrypted_" prefix + assert "encrypted_data" in encrypted_data + assert encrypted_data["encrypted_data"] is not event.data + + +def test_agent_event_encryption__decrypts(encryptor, serializer): + event = FakeAgentEvent(source=uuid.uuid4(), data="foo") + + data = serializer.serialize(event) + fields = get_fields_to_encrypt(event) + encrypted_data = encrypt_event(encryptor.encrypt, data, fields) + + decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data) + deserialized_event = serializer.deserialize(decrypted_data) + + assert deserialized_event == event