diff --git a/monkey/monkey_island/cc/repository/__init__.py b/monkey/monkey_island/cc/repository/__init__.py index 6f683cc1a..d90f6aac6 100644 --- a/monkey/monkey_island/cc/repository/__init__.py +++ b/monkey/monkey_island/cc/repository/__init__.py @@ -1,6 +1,7 @@ from .errors import RemovalError, RetrievalError, StorageError from .i_file_repository import FileNotFoundError, IFileRepository from .local_storage_file_repository import LocalStorageFileRepository +from .file_repository_caching_decorator import FileRepositoryCachingDecorator from .file_repository_locking_decorator import FileRepositoryLockingDecorator from .file_repository_logging_decorator import FileRepositoryLoggingDecorator from .i_agent_binary_repository import IAgentBinaryRepository diff --git a/monkey/monkey_island/cc/repository/file_repository_caching_decorator.py b/monkey/monkey_island/cc/repository/file_repository_caching_decorator.py new file mode 100644 index 000000000..76739716f --- /dev/null +++ b/monkey/monkey_island/cc/repository/file_repository_caching_decorator.py @@ -0,0 +1,37 @@ +import io +import shutil +from functools import lru_cache +from typing import BinaryIO + +from . import IFileRepository + + +class FileRepositoryCachingDecorator(IFileRepository): + def __init__(self, file_repository: IFileRepository): + self._file_repository = file_repository + + def save_file(self, unsafe_file_name: str, file_contents: BinaryIO): + self._open_file.cache_clear() + return self._file_repository.save_file(unsafe_file_name, file_contents) + + def open_file(self, unsafe_file_name: str) -> BinaryIO: + original_file = self._open_file(unsafe_file_name) + file_copy = io.BytesIO() + + shutil.copyfileobj(original_file, file_copy) + original_file.seek(0) + file_copy.seek(0) + + return file_copy + + @lru_cache(maxsize=16) + def _open_file(self, unsafe_file_name: str) -> BinaryIO: + return self._file_repository.open_file(unsafe_file_name) + + def delete_file(self, unsafe_file_name: str): + self._open_file.cache_clear() + return self._file_repository.delete_file(unsafe_file_name) + + def delete_all_files(self): + self._open_file.cache_clear() + return self._file_repository.delete_all_files() diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_repository_caching_decorator.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_repository_caching_decorator.py new file mode 100644 index 000000000..14fdf8c27 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_repository_caching_decorator.py @@ -0,0 +1,55 @@ +import io + +import pytest +from tests.monkey_island import SingleFileRepository + +from monkey_island.cc import repository +from monkey_island.cc.repository import FileRepositoryCachingDecorator + + +@pytest.fixture +def file_repository(): + return FileRepositoryCachingDecorator(SingleFileRepository()) + + +def test_open_cache_file(file_repository): + file_name = "test.txt" + file_contents = b"Hello World!" + + file_repository.save_file(file_name, io.BytesIO(file_contents)) + assert file_repository.open_file(file_name).read() == file_contents + assert file_repository.open_file(file_name).read() == file_contents + + +def test_overwrite_file(file_repository): + file_name = "test.txt" + file_contents_1 = b"Hello World!" + file_contents_2 = b"Goodbye World!" + + file_repository.save_file(file_name, io.BytesIO(file_contents_1)) + assert file_repository.open_file(file_name).read() == file_contents_1 + + file_repository.save_file(file_name, io.BytesIO(file_contents_2)) + assert file_repository.open_file(file_name).read() == file_contents_2 + + +def test_delete_file(file_repository): + file_name = "test.txt" + file_contents = b"Hello World!" + + file_repository.save_file(file_name, io.BytesIO(file_contents)) + file_repository.delete_file(file_name) + + with pytest.raises(repository.FileNotFoundError): + file_repository.open_file(file_name) + + +def test_delete_all_files(file_repository): + file_name = "test.txt" + file_contents = b"Hello World!" + + file_repository.save_file(file_name, io.BytesIO(file_contents)) + file_repository.delete_all_files() + + with pytest.raises(repository.FileNotFoundError): + file_repository.open_file(file_name)