diff --git a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py index 0af258d19..6b927e20b 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/data_store_encryptor.py @@ -46,10 +46,10 @@ class DataStoreEncryptor(IEncryptor): return KeyBasedEncryptor(plaintext_key) - def encrypt(self, plaintext: str) -> str: + def encrypt(self, plaintext: bytes) -> bytes: return self._key_based_encryptor.encrypt(plaintext) - def decrypt(self, ciphertext: str): + def decrypt(self, ciphertext: bytes) -> bytes: return self._key_based_encryptor.decrypt(ciphertext) diff --git a/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_encryptor.py index 28f0f2c93..6de6cb385 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_encryptor.py @@ -5,8 +5,8 @@ from . import IFieldEncryptor class StringEncryptor(IFieldEncryptor): @staticmethod def encrypt(value: str): - return get_datastore_encryptor().encrypt(value) + return get_datastore_encryptor().encrypt(value.encode()) @staticmethod def decrypt(value: str): - return get_datastore_encryptor().decrypt(value) + return get_datastore_encryptor().decrypt(value).decode() diff --git a/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_list_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_list_encryptor.py index ce0ceb8dd..9adf733a4 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_list_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/field_encryptors/string_list_encryptor.py @@ -7,8 +7,8 @@ from . import IFieldEncryptor class StringListEncryptor(IFieldEncryptor): @staticmethod def encrypt(value: List[str]): - return [get_datastore_encryptor().encrypt(string) for string in value] + return [get_datastore_encryptor().encrypt(string.encode()) for string in value] @staticmethod - def decrypt(value: List[str]): - return [get_datastore_encryptor().decrypt(string) for string in value] + def decrypt(value: List[bytes]): + return [get_datastore_encryptor().decrypt(bytes_).decode() for bytes_ in value] diff --git a/monkey/monkey_island/cc/server_utils/encryption/i_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/i_encryptor.py index d46e51c1a..8bc701bf9 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/i_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/i_encryptor.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any class IEncryptor(ABC): @abstractmethod - def encrypt(self, plaintext: Any) -> Any: + def encrypt(self, plaintext: bytes) -> bytes: """ Encrypts data and returns the ciphertext. @@ -13,7 +12,7 @@ class IEncryptor(ABC): """ @abstractmethod - def decrypt(self, ciphertext: Any): + def decrypt(self, ciphertext: bytes) -> bytes: """ Decrypts data and returns the plaintext. diff --git a/monkey/monkey_island/cc/server_utils/encryption/key_based_encryptor.py b/monkey/monkey_island/cc/server_utils/encryption/key_based_encryptor.py index 630094989..67e59f94d 100644 --- a/monkey/monkey_island/cc/server_utils/encryption/key_based_encryptor.py +++ b/monkey/monkey_island/cc/server_utils/encryption/key_based_encryptor.py @@ -31,15 +31,15 @@ class KeyBasedEncryptor(IEncryptor): # something up. The main drawback to fernet is that it uses AES-128, which is not # quantum-safe. At the present time, human error is probably a greater risk than quantum # computers. - def encrypt(self, plaintext: str) -> str: + def encrypt(self, plaintext: bytes) -> bytes: cipher_iv = Random.new().read(AES.block_size) cipher = AES.new(self._key, AES.MODE_CBC, cipher_iv) - padded_plaintext = Padding.pad(plaintext.encode(), self._BLOCK_SIZE) - return base64.b64encode(cipher_iv + cipher.encrypt(padded_plaintext)).decode() + padded_plaintext = Padding.pad(plaintext, self._BLOCK_SIZE) + return base64.b64encode(cipher_iv + cipher.encrypt(padded_plaintext)) - def decrypt(self, ciphertext: str): + def decrypt(self, ciphertext: bytes) -> bytes: enc_message = base64.b64decode(ciphertext) cipher_iv = enc_message[0 : AES.block_size] cipher = AES.new(self._key, AES.MODE_CBC, cipher_iv) padded_plaintext = cipher.decrypt(enc_message[AES.block_size :]) - return Padding.unpad(padded_plaintext, self._BLOCK_SIZE).decode() + return Padding.unpad(padded_plaintext, self._BLOCK_SIZE) diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py index da4a9ec09..1887185d4 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_data_store_encryptor.py @@ -11,7 +11,7 @@ from monkey_island.cc.server_utils.encryption import ( # Mark all tests in this module as slow pytestmark = pytest.mark.slow -PLAINTEXT = "Hello, Monkey!" +PLAINTEXT = b"Hello, Monkey!" MOCK_SECRET = "53CR31" diff --git a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_key_based_encryptor.py b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_key_based_encryptor.py index d41f866a7..d17a5cd1a 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_key_based_encryptor.py +++ b/monkey/tests/unit_tests/monkey_island/cc/server_utils/encryption/test_key_based_encryptor.py @@ -14,19 +14,19 @@ kb_encryptor = KeyBasedEncryptor(KEY) def test_encrypt_decrypt_string_with_key(): - encrypted = kb_encryptor.encrypt(PLAINTEXT) - decrypted = kb_encryptor.decrypt(encrypted) + encrypted = kb_encryptor.encrypt(PLAINTEXT.encode()) + decrypted = kb_encryptor.decrypt(encrypted).decode() assert decrypted == PLAINTEXT @pytest.mark.parametrize("plaintext", [PLAINTEXT_UTF8_1, PLAINTEXT_UTF8_2, PLAINTEXT_UTF8_3]) def test_encrypt_decrypt_string_utf8_with_key(plaintext): - encrypted = kb_encryptor.encrypt(plaintext) - decrypted = kb_encryptor.decrypt(encrypted) + encrypted = kb_encryptor.encrypt(plaintext.encode()) + decrypted = kb_encryptor.decrypt(encrypted).decode() assert decrypted == plaintext def test_encrypt_decrypt_string_multiple_block_size_with_key(): - encrypted = kb_encryptor.encrypt(PLAINTEXT_MULTIPLE_BLOCK_SIZE) - decrypted = kb_encryptor.decrypt(encrypted) + encrypted = kb_encryptor.encrypt(PLAINTEXT_MULTIPLE_BLOCK_SIZE.encode()) + decrypted = kb_encryptor.decrypt(encrypted).decode() assert decrypted == PLAINTEXT_MULTIPLE_BLOCK_SIZE