diff --git a/monkey/monkey_island/cc/services/repository_service.py b/monkey/monkey_island/cc/services/repository_service.py index af0d13fa1..9910469b5 100644 --- a/monkey/monkey_island/cc/services/repository_service.py +++ b/monkey/monkey_island/cc/services/repository_service.py @@ -1,4 +1,9 @@ -from monkey_island.cc.repository import IAgentConfigurationRepository, IFileRepository +from monkey_island.cc.repository import ( + IAgentConfigurationRepository, + ICredentialsRepository, + IFileRepository, +) +from monkey_island.cc.services.database import Database class RepositoryService: @@ -6,9 +11,11 @@ class RepositoryService: self, agent_configuration_repository: IAgentConfigurationRepository, file_repository: IFileRepository, + credentials_repository: ICredentialsRepository, ): self._agent_configuration_repository = agent_configuration_repository self._file_repository = file_repository + self._credentials_repository = credentials_repository def reset_agent_configuration(self): # NOTE: This method will be replaced by an event when we implement pub/sub in the island. @@ -37,4 +44,5 @@ class RepositoryService: # NOTE: This method will be replaced by an event when we implement pub/sub in the island. # Different plugins and components will be able to register for the event and clear # any configuration data they've collected. - raise NotImplementedError + Database.reset_db(reset_config=False) + self._credentials_repository.remove_stolen_credentials() diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_repository_service.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_repository_service.py index 17dd6ffb5..a5d443456 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_repository_service.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_repository_service.py @@ -5,7 +5,11 @@ import pytest from tests.monkey_island import InMemoryAgentConfigurationRepository from common.configuration import AgentConfiguration -from monkey_island.cc.repository import IAgentConfigurationRepository, IFileRepository +from monkey_island.cc.repository import ( + IAgentConfigurationRepository, + ICredentialsRepository, + IFileRepository, +) from monkey_island.cc.services import RepositoryService LINUX_FILENAME = "linux_pba_file.sh" @@ -37,11 +41,21 @@ def mock_file_repository() -> IFileRepository: return MagicMock(spec=IFileRepository) -def test_reset_configuration__remove_pba_files( - agent_configuration_repository, mock_file_repository -): - repository_service = RepositoryService(agent_configuration_repository, mock_file_repository) +@pytest.fixture +def mock_credentials_repository() -> ICredentialsRepository: + return MagicMock(spec=ICredentialsRepository) + +@pytest.fixture +def repository_service( + agent_configuration_repository, mock_file_repository, mock_credentials_repository +) -> RepositoryService: + return RepositoryService( + agent_configuration_repository, mock_file_repository, mock_credentials_repository + ) + + +def test_reset_configuration__remove_pba_files(repository_service, mock_file_repository): repository_service.reset_agent_configuration() assert mock_file_repository.delete_file.called_with(LINUX_FILENAME) @@ -49,11 +63,20 @@ def test_reset_configuration__remove_pba_files( def test_reset_configuration__agent_configuration_changed( - agent_configuration_repository, agent_configuration, mock_file_repository + repository_service, agent_configuration_repository, agent_configuration ): - mock_file_repository = MagicMock(spec=IFileRepository) - repository_service = RepositoryService(agent_configuration_repository, mock_file_repository) - repository_service.reset_agent_configuration() assert agent_configuration_repository.get_configuration() != agent_configuration + + +@pytest.mark.usefixtures("uses_database") +def test_clear_simulation_data( + repository_service: RepositoryService, + mock_credentials_repository: ICredentialsRepository, + monkeypatch, +): + monkeypatch.setattr("monkey_island.cc.services.repository_service.Database", MagicMock()) + repository_service.clear_simulation_data() + + mock_credentials_repository.remove_stolen_credentials.assert_called_once()