diff --git a/monkey/monkey_island/cc/repository/mongo_agent_repository.py b/monkey/monkey_island/cc/repository/mongo_agent_repository.py index dfad2bbf7..021cf34bc 100644 --- a/monkey/monkey_island/cc/repository/mongo_agent_repository.py +++ b/monkey/monkey_island/cc/repository/mongo_agent_repository.py @@ -58,6 +58,14 @@ class MongoAgentRepository(IAgentRepository): except Exception as err: raise RetrievalError(f"Error retrieving running agents: {err}") + def get_progenitor(self, agent: Agent) -> Agent: + if agent.parent_id is None: + return agent + + parent = self.get_agent_by_id(agent.parent_id) + + return self.get_progenitor(parent) + def reset(self): try: self._agents_collection.drop() diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py index a1221d32d..938ea7412 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_mongo_agent_repository.py @@ -17,14 +17,29 @@ from monkey_island.cc.repository import ( ) VICTIM_ZERO_ID = uuid4() +VICTIM_TWO_ID = uuid4() +VICTIM_THREE_ID = uuid4() + +PROGENITOR_AGENT = Agent( + id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718) +) + +DESCENDANT_AGENT = Agent( + id=VICTIM_THREE_ID, + machine_id=4, + start_time=datetime.fromtimestamp(1661856868), + parent_id=VICTIM_TWO_ID, +) + RUNNING_AGENTS = ( - Agent(id=VICTIM_ZERO_ID, machine_id=1, start_time=datetime.fromtimestamp(1661856718)), + PROGENITOR_AGENT, Agent( - id=uuid4(), + id=VICTIM_TWO_ID, machine_id=2, start_time=datetime.fromtimestamp(1661856818), parent_id=VICTIM_ZERO_ID, ), + DESCENDANT_AGENT, ) STOPPED_AGENTS = ( Agent( @@ -172,6 +187,24 @@ def test_get_running_agents__retrieval_error(error_raising_agent_repository): error_raising_agent_repository.get_running_agents() +@pytest.mark.parametrize("agent", [DESCENDANT_AGENT, PROGENITOR_AGENT]) +def test_get_progenitor(agent_repository, agent): + actual_progenitor = agent_repository.get_progenitor(agent) + + assert actual_progenitor == PROGENITOR_AGENT + + +def test_get_progenitor__id_not_found(agent_repository): + dummy_agent = Agent(id=uuid4(), machine_id=10, start_time=datetime.now(), parent_id=uuid4()) + with pytest.raises(UnknownRecordError): + agent_repository.get_progenitor(dummy_agent) + + +def test_get_progenitor__retrieval_error(error_raising_agent_repository): + with pytest.raises(RetrievalError): + error_raising_agent_repository.get_progenitor(AGENTS[1]) + + def test_reset(agent_repository): # Ensure the repository is not empty for agent in AGENTS: