Island: Allow RepositoryErrors to be reraised

Previously, FileAgentLogRepository.get_agent_log() wrapped all errors as
RetrievalError, which is not necessarily correct. This commit allows all
repository errors raised by IFileRepository to be reraised, and all
other, unexpected errors to be reraised as RetrievalError.
This commit is contained in:
Mike Salvatore 2022-09-28 12:38:56 -04:00
parent 3c2ee32bdf
commit d49d16bc37
2 changed files with 27 additions and 4 deletions

View File

@ -2,7 +2,12 @@ import io
import re import re
from monkey_island.cc.models import AgentID from monkey_island.cc.models import AgentID
from monkey_island.cc.repository import IAgentLogRepository, IFileRepository, RetrievalError from monkey_island.cc.repository import (
IAgentLogRepository,
IFileRepository,
RepositoryError,
RetrievalError,
)
AGENT_LOG_FILE_NAME_PATTERN = "agent-*.log" AGENT_LOG_FILE_NAME_PATTERN = "agent-*.log"
AGENT_LOG_FILE_NAME_REGEX = re.compile(r"^agent-[\w-]+.log$") AGENT_LOG_FILE_NAME_REGEX = re.compile(r"^agent-[\w-]+.log$")
@ -22,6 +27,8 @@ class FileAgentLogRepository(IAgentLogRepository):
with self._file_repository.open_file(self._get_agent_log_file_name(agent_id)) as f: with self._file_repository.open_file(self._get_agent_log_file_name(agent_id)) as f:
log_contents = f.read().decode() log_contents = f.read().decode()
return log_contents return log_contents
except RepositoryError as err:
raise err
except Exception as err: except Exception as err:
raise RetrievalError(f"Error retrieving the agent logs: {err}") raise RetrievalError(f"Error retrieving the agent logs: {err}")

View File

@ -1,9 +1,16 @@
import io
from unittest.mock import MagicMock
from uuid import UUID from uuid import UUID
import pytest import pytest
from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository
from monkey_island.cc.repository import FileAgentLogRepository, RetrievalError from monkey_island.cc.repository import (
FileAgentLogRepository,
IFileRepository,
RetrievalError,
UnknownRecordError,
)
LOG_CONTENTS = "lots of useful information" LOG_CONTENTS = "lots of useful information"
AGENT_ID = UUID("6bfd8b64-43d8-4449-8c70-d898aca74ad8") AGENT_ID = UUID("6bfd8b64-43d8-4449-8c70-d898aca74ad8")
@ -22,7 +29,7 @@ def test_store_agent_log(repository):
def test_get_agent_log__unknown_record_error(repository): def test_get_agent_log__unknown_record_error(repository):
with pytest.raises(RetrievalError): with pytest.raises(UnknownRecordError):
repository.get_agent_log(AGENT_ID) repository.get_agent_log(AGENT_ID)
@ -32,8 +39,17 @@ def test_get_agent_log__retrieval_error():
repository.get_agent_log(AGENT_ID) repository.get_agent_log(AGENT_ID)
def test_get_agent_log__corrupt_data():
file_repository = MagicMock(spec=IFileRepository)
file_repository.open_file = MagicMock(return_value=io.BytesIO(b"\xff\xfe"))
repository = FileAgentLogRepository(file_repository)
with pytest.raises(RetrievalError):
repository.get_agent_log(AGENT_ID)
def test_reset_agent_logs(repository): def test_reset_agent_logs(repository):
repository.upsert_agent_log(AGENT_ID, LOG_CONTENTS) repository.upsert_agent_log(AGENT_ID, LOG_CONTENTS)
repository.reset() repository.reset()
with pytest.raises(RetrievalError): with pytest.raises(UnknownRecordError):
repository.get_agent_log(AGENT_ID) repository.get_agent_log(AGENT_ID)