forked from p15670423/monkey
Island: Simplify agent encryption calls
This commit is contained in:
parent
644f3628a5
commit
c0869aebba
|
@ -1,23 +1,17 @@
|
||||||
import json
|
import json
|
||||||
from typing import Callable, Iterable
|
from typing import Callable
|
||||||
|
|
||||||
from common.agent_event_serializers import JSONSerializable
|
from common.agent_event_serializers import JSONSerializable
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
|
||||||
ENCRYPTED_PREFIX = "encrypted_"
|
ENCRYPTED_PREFIX = "encrypted_"
|
||||||
|
ABSTRACT_AGENT_EVENT_FIELDS = vars(AbstractAgentEvent)["__fields__"].keys()
|
||||||
|
SERIALIZED_EVENT_FIELDS = set(ABSTRACT_AGENT_EVENT_FIELDS) | set(["type"])
|
||||||
def get_fields_to_encrypt(event: AbstractAgentEvent):
|
|
||||||
"""
|
|
||||||
Get the fields of the event that are not part of the base AbstractAgentEvent.
|
|
||||||
"""
|
|
||||||
return set(vars(AbstractAgentEvent)["__fields__"].keys()) ^ set(event.dict().keys())
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_event(
|
def encrypt_event(
|
||||||
encrypt: Callable[[bytes], bytes],
|
encrypt: Callable[[bytes], bytes],
|
||||||
event_data: JSONSerializable,
|
event_data: JSONSerializable,
|
||||||
fields: Iterable[str] = [],
|
|
||||||
) -> JSONSerializable:
|
) -> JSONSerializable:
|
||||||
"""
|
"""
|
||||||
Encrypt a serialized AbstractAgentEvent
|
Encrypt a serialized AbstractAgentEvent
|
||||||
|
@ -27,7 +21,6 @@ def encrypt_event(
|
||||||
|
|
||||||
:param encrypt: Callable used to encrypt data
|
:param encrypt: Callable used to encrypt data
|
||||||
:param event_data: Serialized event to encrypt
|
:param event_data: Serialized event to encrypt
|
||||||
:param fields: Fields to encrypt
|
|
||||||
:return: Serialized event with the fields encrypted
|
:return: Serialized event with the fields encrypted
|
||||||
:raises TypeError: If the serialized data is not a dict
|
:raises TypeError: If the serialized data is not a dict
|
||||||
"""
|
"""
|
||||||
|
@ -35,7 +28,8 @@ def encrypt_event(
|
||||||
raise TypeError("Event encryption only supported for dict")
|
raise TypeError("Event encryption only supported for dict")
|
||||||
|
|
||||||
data = event_data.copy()
|
data = event_data.copy()
|
||||||
for field in fields:
|
fields_to_encrypt = SERIALIZED_EVENT_FIELDS ^ set(event_data.keys())
|
||||||
|
for field in fields_to_encrypt:
|
||||||
data[ENCRYPTED_PREFIX + field] = str(
|
data[ENCRYPTED_PREFIX + field] = str(
|
||||||
encrypt(json.dumps(event_data[field]).encode()), "utf-8"
|
encrypt(json.dumps(event_data[field]).encode()), "utf-8"
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from monkey_island.cc.repository import IAgentEventRepository
|
||||||
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
|
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
|
||||||
|
|
||||||
from . import RemovalError, RetrievalError, StorageError
|
from . import RemovalError, RetrievalError, StorageError
|
||||||
from .agent_event_encryption import decrypt_event, encrypt_event, get_fields_to_encrypt
|
from .agent_event_encryption import decrypt_event, encrypt_event
|
||||||
from .consts import MONGO_OBJECT_ID_KEY
|
from .consts import MONGO_OBJECT_ID_KEY
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,8 +30,7 @@ class MongoAgentEventRepository(IAgentEventRepository):
|
||||||
try:
|
try:
|
||||||
serializer = self._serializers[type(event)]
|
serializer = self._serializers[type(event)]
|
||||||
serialized_event = serializer.serialize(event)
|
serialized_event = serializer.serialize(event)
|
||||||
fields = get_fields_to_encrypt(event)
|
encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event)
|
||||||
encrypted_event = encrypt_event(self._encryptor.encrypt, serialized_event, fields)
|
|
||||||
self._events_collection.insert_one(encrypted_event)
|
self._events_collection.insert_one(encrypted_event)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise StorageError(f"Error saving event: {err}")
|
raise StorageError(f"Error saving event: {err}")
|
||||||
|
|
|
@ -5,11 +5,7 @@ import pytest
|
||||||
|
|
||||||
from common.agent_event_serializers import PydanticAgentEventSerializer
|
from common.agent_event_serializers import PydanticAgentEventSerializer
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from monkey_island.cc.repository.agent_event_encryption import (
|
from monkey_island.cc.repository.agent_event_encryption import decrypt_event, encrypt_event
|
||||||
decrypt_event,
|
|
||||||
encrypt_event,
|
|
||||||
get_fields_to_encrypt,
|
|
||||||
)
|
|
||||||
from monkey_island.cc.server_utils.encryption import RepositoryEncryptor
|
from monkey_island.cc.server_utils.encryption import RepositoryEncryptor
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,8 +39,7 @@ def serializer():
|
||||||
|
|
||||||
def test_agent_event_encryption__encrypts(encryptor, serializer):
|
def test_agent_event_encryption__encrypts(encryptor, serializer):
|
||||||
data = serializer.serialize(EVENT)
|
data = serializer.serialize(EVENT)
|
||||||
fields = get_fields_to_encrypt(EVENT)
|
encrypted_data = encrypt_event(encryptor.encrypt, data)
|
||||||
encrypted_data = encrypt_event(encryptor.encrypt, data, fields)
|
|
||||||
|
|
||||||
# Encrypted fields have the "encrypted_" prefix
|
# Encrypted fields have the "encrypted_" prefix
|
||||||
assert "encrypted_data" in encrypted_data
|
assert "encrypted_data" in encrypted_data
|
||||||
|
@ -55,8 +50,7 @@ def test_agent_event_encryption__encrypts(encryptor, serializer):
|
||||||
|
|
||||||
def test_agent_event_encryption__decrypts(encryptor, serializer):
|
def test_agent_event_encryption__decrypts(encryptor, serializer):
|
||||||
data = serializer.serialize(EVENT)
|
data = serializer.serialize(EVENT)
|
||||||
fields = get_fields_to_encrypt(EVENT)
|
encrypted_data = encrypt_event(encryptor.encrypt, data)
|
||||||
encrypted_data = encrypt_event(encryptor.encrypt, data, fields)
|
|
||||||
|
|
||||||
decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data)
|
decrypted_data = decrypt_event(encryptor.decrypt, encrypted_data)
|
||||||
deserialized_event = serializer.deserialize(decrypted_data)
|
deserialized_event = serializer.deserialize(decrypted_data)
|
||||||
|
|
Loading…
Reference in New Issue