From 1e8a60c8902bcb1f4cb85413625717856f2e7d88 Mon Sep 17 00:00:00 2001
From: Mike Salvatore <mike.s.salvatore@gmail.com>
Date: Wed, 21 Sep 2022 10:29:55 -0400
Subject: [PATCH] Island: Add new agent to repository on agent registration

---
 .../handle_agent_registration.py              | 24 +++++++++++----
 .../test_handle_agent_registration.py         | 30 ++++++++++++++++---
 2 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py
index 9e3ead2b6..eaea53ffb 100644
--- a/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py
+++ b/monkey/monkey_island/cc/island_event_handlers/handle_agent_registration.py
@@ -2,8 +2,8 @@ from contextlib import suppress
 from typing import Optional
 
 from common import AgentRegistrationData
-from monkey_island.cc.models import Machine
-from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
+from monkey_island.cc.models import Agent, Machine
+from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
 
 
 class handle_agent_registration:
@@ -11,13 +11,15 @@ class handle_agent_registration:
     Update repositories when a new agent registers
     """
 
-    def __init__(self, machine_repository: IMachineRepository):
+    def __init__(self, machine_repository: IMachineRepository, agent_repository: IAgentRepository):
         self._machine_repository = machine_repository
+        self._agent_repository = agent_repository
 
     def __call__(self, agent_registration_data: AgentRegistrationData):
-        self._update_machine_repository(agent_registration_data)
+        machine = self._update_machine_repository(agent_registration_data)
+        self._add_agent(agent_registration_data, machine)
 
-    def _update_machine_repository(self, agent_registration_data: AgentRegistrationData):
+    def _update_machine_repository(self, agent_registration_data: AgentRegistrationData) -> Machine:
         machine = self._find_existing_machine_to_update(agent_registration_data)
 
         if machine is None:
@@ -25,6 +27,8 @@ class handle_agent_registration:
 
         self._upsert_machine(machine, agent_registration_data)
 
+        return machine
+
     def _find_existing_machine_to_update(
         self, agent_registration_data: AgentRegistrationData
     ) -> Optional[Machine]:
@@ -72,3 +76,13 @@ class handle_agent_registration:
         )
 
         machine.network_interfaces = sorted(updated_network_interfaces)
+
+    def _add_agent(self, agent_registration_data: AgentRegistrationData, machine: Machine):
+        new_agent = Agent(
+            id=agent_registration_data.id,
+            machine_id=machine.id,
+            start_time=agent_registration_data.start_time,
+            parent_id=agent_registration_data.parent_id,
+            cc_server=agent_registration_data.cc_server,
+        )
+        self._agent_repository.upsert_agent(new_agent)
diff --git a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py
index 37e95b14f..91d392368 100644
--- a/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py
+++ b/monkey/tests/unit_tests/monkey_island/cc/island_event_handlers/test_handle_agent_registration.py
@@ -8,8 +8,8 @@ import pytest
 
 from common import AgentRegistrationData
 from monkey_island.cc.island_event_handlers import handle_agent_registration
-from monkey_island.cc.models import Machine
-from monkey_island.cc.repository import IMachineRepository, UnknownRecordError
+from monkey_island.cc.models import Agent, Machine
+from monkey_island.cc.repository import IAgentRepository, IMachineRepository, UnknownRecordError
 
 AGENT_ID = UUID("860aff5b-d2af-43ea-afb5-62bac3d30b7e")
 
@@ -36,12 +36,21 @@ def machine_repository() -> IMachineRepository:
     machine_repository = MagicMock(spec=IMachineRepository)
     machine_repository.get_new_id = MagicMock(side_effect=count(SEED_ID))
     machine_repository.upsert_machine = MagicMock()
+    machine_repository.get_machine_by_hardware_id = MagicMock(side_effect=UnknownRecordError)
+    machine_repository.get_machines_by_ip = MagicMock(side_effect=UnknownRecordError)
     return machine_repository
 
 
 @pytest.fixture
-def handler(machine_repository) -> handle_agent_registration:
-    return handle_agent_registration(machine_repository)
+def agent_repository() -> IAgentRepository:
+    agent_repository = MagicMock(spec=IAgentRepository)
+    agent_repository.upsert_agent = MagicMock()
+    return agent_repository
+
+
+@pytest.fixture
+def handler(machine_repository, agent_repository) -> handle_agent_registration:
+    return handle_agent_registration(machine_repository, agent_repository)
 
 
 def test_new_machine_added(handler, machine_repository):
@@ -126,3 +135,16 @@ def test_hardware_id_mismatch(handler, machine_repository):
 
     with pytest.raises(Exception):
         handler(AGENT_REGISTRATION_DATA)
+
+
+def test_add_agent(handler, agent_repository):
+    expected_agent = Agent(
+        id=AGENT_REGISTRATION_DATA.id,
+        machine_id=SEED_ID,
+        start_time=AGENT_REGISTRATION_DATA.start_time,
+        parent_id=AGENT_REGISTRATION_DATA.parent_id,
+        cc_server=AGENT_REGISTRATION_DATA.cc_server,
+    )
+    handler(AGENT_REGISTRATION_DATA)
+
+    agent_repository.upsert_agent.assert_called_with(expected_agent)