diff --git a/monkey/monkey_island/cc/repository/file_agent_log_repository.py b/monkey/monkey_island/cc/repository/file_agent_log_repository.py index 728d16da8..d92e07170 100644 --- a/monkey/monkey_island/cc/repository/file_agent_log_repository.py +++ b/monkey/monkey_island/cc/repository/file_agent_log_repository.py @@ -2,7 +2,12 @@ import io import re from monkey_island.cc.models import AgentID -from monkey_island.cc.repository import IAgentLogRepository, IFileRepository, RetrievalError +from monkey_island.cc.repository import ( + IAgentLogRepository, + IFileRepository, + RepositoryError, + RetrievalError, +) AGENT_LOG_FILE_NAME_PATTERN = "agent-*.log" AGENT_LOG_FILE_NAME_REGEX = re.compile(r"^agent-[\w-]+.log$") @@ -22,6 +27,8 @@ class FileAgentLogRepository(IAgentLogRepository): with self._file_repository.open_file(self._get_agent_log_file_name(agent_id)) as f: log_contents = f.read().decode() return log_contents + except RepositoryError as err: + raise err except Exception as err: raise RetrievalError(f"Error retrieving the agent logs: {err}") diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_log_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_log_repository.py index d49d0957d..ac841096b 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_log_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_log_repository.py @@ -1,9 +1,16 @@ +import io +from unittest.mock import MagicMock from uuid import UUID import pytest from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository -from monkey_island.cc.repository import FileAgentLogRepository, RetrievalError +from monkey_island.cc.repository import ( + FileAgentLogRepository, + IFileRepository, + RetrievalError, + UnknownRecordError, +) LOG_CONTENTS = "lots of useful information" AGENT_ID = UUID("6bfd8b64-43d8-4449-8c70-d898aca74ad8") @@ -22,7 +29,7 @@ def test_store_agent_log(repository): def test_get_agent_log__unknown_record_error(repository): - with pytest.raises(RetrievalError): + with pytest.raises(UnknownRecordError): repository.get_agent_log(AGENT_ID) @@ -32,8 +39,17 @@ def test_get_agent_log__retrieval_error(): repository.get_agent_log(AGENT_ID) +def test_get_agent_log__corrupt_data(): + file_repository = MagicMock(spec=IFileRepository) + file_repository.open_file = MagicMock(return_value=io.BytesIO(b"\xff\xfe")) + repository = FileAgentLogRepository(file_repository) + + with pytest.raises(RetrievalError): + repository.get_agent_log(AGENT_ID) + + def test_reset_agent_logs(repository): repository.upsert_agent_log(AGENT_ID, LOG_CONTENTS) repository.reset() - with pytest.raises(RetrievalError): + with pytest.raises(UnknownRecordError): repository.get_agent_log(AGENT_ID)