Merge pull request #2347 from guardicore/2261-move-agent-signals-to-common

2261 move agent signals to common
This commit is contained in:
Mike Salvatore 2022-09-23 12:26:06 -04:00
commit a49ddf7a4a
10 changed files with 35 additions and 24 deletions

View File

@ -7,3 +7,4 @@ from .operating_system import OperatingSystem
from . import types
from . import base_models
from .agent_registration_data import AgentRegistrationData
from .agent_signals import AgentSignals

View File

@ -1,7 +1,7 @@
from datetime import datetime
from typing import Optional
from common.base_models import InfectionMonkeyBaseModel
from .base_models import InfectionMonkeyBaseModel
class AgentSignals(InfectionMonkeyBaseModel):

View File

@ -1,13 +1,12 @@
import functools
import json
import logging
from datetime import datetime
from pprint import pformat
from typing import List, Optional, Sequence
from typing import List, Sequence
import requests
from common import AgentRegistrationData, OperatingSystem
from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_events import AbstractAgentEvent
@ -189,7 +188,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
@handle_island_errors
@convert_json_error_to_island_api_error
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
def get_agent_signals(self, agent_id: str) -> AgentSignals:
url = f"{self._api_url}/agent-signals/{agent_id}"
response = requests.get( # noqa: DUO123
url,
@ -197,7 +196,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()
return response.json()["terminate"]
return AgentSignals(**response.json())
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Sequence
from common import AgentRegistrationData, OperatingSystem
from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials
@ -133,7 +132,7 @@ class IIslandAPIClient(ABC):
"""
@abstractmethod
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
def get_agent_signals(self, agent_id: str) -> AgentSignals:
"""
Gets an agent's signals from the island
@ -142,5 +141,5 @@ class IIslandAPIClient(ABC):
:raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out
:return: The relevant agent's terminate signal's timestamp
:return: The relevant agent's signals
"""

View File

@ -36,7 +36,8 @@ class ControlChannel(IControlChannel):
if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.")
return True
return self._island_api_client.get_agent_signals(self._agent_id) is not None
agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
return agent_signals.terminate is not None
@handle_island_api_errors
def get_config(self) -> AgentConfiguration:

View File

@ -15,4 +15,3 @@ from .communication_type import CommunicationType
from .node import Node
from common.types import AgentID
from .agent import Agent
from .agent_signals import AgentSignals

View File

@ -2,8 +2,9 @@ import logging
from datetime import datetime
from typing import Optional
from common.agent_signals import AgentSignals
from common.types import AgentID
from monkey_island.cc.models import AgentSignals, Simulation
from monkey_island.cc.models import Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
logger = logging.getLogger(__name__)

View File

@ -4,7 +4,7 @@ import pytest
import requests
import requests_mock
from common import OperatingSystem
from common import AgentSignals, OperatingSystem
from common.agent_event_serializers import (
AgentEventSerializerRegistry,
PydanticAgentEventSerializer,
@ -456,16 +456,17 @@ def test_island_api_client_get_agent_signals__status_code(
island_api_client.get_agent_signals(agent_id=AGENT_ID)
@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp):
@pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, timestamp):
expected_agent_signals = AgentSignals(terminate=timestamp)
with requests_mock.Mocker() as m:
m.get(ISLAND_URI)
island_api_client.connect(SERVER)
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp})
actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID)
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)
assert actual_terminate_timestamp == expected_timestamp
assert actual_agent_signals == expected_agent_signals
def test_island_api_client_get_agent_signals__bad_json(island_api_client):

View File

@ -1,8 +1,10 @@
from typing import Optional
from unittest.mock import MagicMock
import pytest
from infection_monkey.i_control_channel import IslandCommunicationError
from common import AgentSignals
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.island_api_client import (
IIslandAPIClient,
IslandAPIConnectionError,
@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client)
def test_control_channel__should_agent_stop(control_channel, island_api_client):
control_channel.should_agent_stop()
assert island_api_client.get_agent_signals.called_once()
@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
def test_control_channel__should_agent_stop(
control_channel: IControlChannel,
island_api_client: IIslandAPIClient,
signal_time: Optional[int],
expected_should_stop: bool,
):
island_api_client.get_agent_signals = MagicMock(
return_value=AgentSignals(terminate=signal_time)
)
assert control_channel.should_agent_stop() is expected_should_stop
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)

View File

@ -5,7 +5,7 @@ from uuid import UUID
import pytest
from tests.common import StubDIContainer
from monkey_island.cc.models import AgentSignals as Signals
from common.agent_signals import AgentSignals as Signals
from monkey_island.cc.repository import RetrievalError, StorageError
from monkey_island.cc.services import AgentSignalsService