Merge pull request #2347 from guardicore/2261-move-agent-signals-to-common
2261 move agent signals to common
This commit is contained in:
commit
a49ddf7a4a
|
@ -7,3 +7,4 @@ from .operating_system import OperatingSystem
|
|||
from . import types
|
||||
from . import base_models
|
||||
from .agent_registration_data import AgentRegistrationData
|
||||
from .agent_signals import AgentSignals
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from common.base_models import InfectionMonkeyBaseModel
|
||||
from .base_models import InfectionMonkeyBaseModel
|
||||
|
||||
|
||||
class AgentSignals(InfectionMonkeyBaseModel):
|
|
@ -1,13 +1,12 @@
|
|||
import functools
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pprint import pformat
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from common import AgentRegistrationData, OperatingSystem
|
||||
from common import AgentRegistrationData, AgentSignals, OperatingSystem
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
|
@ -189,7 +188,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
|
||||
@handle_island_errors
|
||||
@convert_json_error_to_island_api_error
|
||||
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
|
||||
def get_agent_signals(self, agent_id: str) -> AgentSignals:
|
||||
url = f"{self._api_url}/agent-signals/{agent_id}"
|
||||
response = requests.get( # noqa: DUO123
|
||||
url,
|
||||
|
@ -197,7 +196,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
|||
timeout=SHORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["terminate"]
|
||||
return AgentSignals(**response.json())
|
||||
|
||||
|
||||
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from common import AgentRegistrationData, OperatingSystem
|
||||
from common import AgentRegistrationData, AgentSignals, OperatingSystem
|
||||
from common.agent_configuration import AgentConfiguration
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.credentials import Credentials
|
||||
|
@ -133,7 +132,7 @@ class IIslandAPIClient(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_signals(self, agent_id: str) -> Optional[datetime]:
|
||||
def get_agent_signals(self, agent_id: str) -> AgentSignals:
|
||||
"""
|
||||
Gets an agent's signals from the island
|
||||
|
||||
|
@ -142,5 +141,5 @@ class IIslandAPIClient(ABC):
|
|||
:raises IslandAPIRequestError: If there was a problem with the client request
|
||||
:raises IslandAPIRequestFailedError: If the server experienced an error
|
||||
:raises IslandAPITimeoutError: If the command timed out
|
||||
:return: The relevant agent's terminate signal's timestamp
|
||||
:return: The relevant agent's signals
|
||||
"""
|
||||
|
|
|
@ -36,7 +36,8 @@ class ControlChannel(IControlChannel):
|
|||
if not self._control_channel_server:
|
||||
logger.error("Agent should stop because it can't connect to the C&C server.")
|
||||
return True
|
||||
return self._island_api_client.get_agent_signals(self._agent_id) is not None
|
||||
agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
|
||||
return agent_signals.terminate is not None
|
||||
|
||||
@handle_island_api_errors
|
||||
def get_config(self) -> AgentConfiguration:
|
||||
|
|
|
@ -15,4 +15,3 @@ from .communication_type import CommunicationType
|
|||
from .node import Node
|
||||
from common.types import AgentID
|
||||
from .agent import Agent
|
||||
from .agent_signals import AgentSignals
|
||||
|
|
|
@ -2,8 +2,9 @@ import logging
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from common.agent_signals import AgentSignals
|
||||
from common.types import AgentID
|
||||
from monkey_island.cc.models import AgentSignals, Simulation
|
||||
from monkey_island.cc.models import Simulation
|
||||
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
import requests
|
||||
import requests_mock
|
||||
|
||||
from common import OperatingSystem
|
||||
from common import AgentSignals, OperatingSystem
|
||||
from common.agent_event_serializers import (
|
||||
AgentEventSerializerRegistry,
|
||||
PydanticAgentEventSerializer,
|
||||
|
@ -456,16 +456,17 @@ def test_island_api_client_get_agent_signals__status_code(
|
|||
island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None])
|
||||
def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp):
|
||||
@pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
|
||||
def test_island_api_client_get_agent_signals(island_api_client, timestamp):
|
||||
expected_agent_signals = AgentSignals(terminate=timestamp)
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(ISLAND_URI)
|
||||
island_api_client.connect(SERVER)
|
||||
|
||||
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp})
|
||||
actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
|
||||
actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)
|
||||
|
||||
assert actual_terminate_timestamp == expected_timestamp
|
||||
assert actual_agent_signals == expected_agent_signals
|
||||
|
||||
|
||||
def test_island_api_client_get_agent_signals__bad_json(island_api_client):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from infection_monkey.i_control_channel import IslandCommunicationError
|
||||
from common import AgentSignals
|
||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||
from infection_monkey.island_api_client import (
|
||||
IIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
|
@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel:
|
|||
return ControlChannel(SERVER, AGENT_ID, island_api_client)
|
||||
|
||||
|
||||
def test_control_channel__should_agent_stop(control_channel, island_api_client):
|
||||
control_channel.should_agent_stop()
|
||||
assert island_api_client.get_agent_signals.called_once()
|
||||
@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
|
||||
def test_control_channel__should_agent_stop(
|
||||
control_channel: IControlChannel,
|
||||
island_api_client: IIslandAPIClient,
|
||||
signal_time: Optional[int],
|
||||
expected_should_stop: bool,
|
||||
):
|
||||
island_api_client.get_agent_signals = MagicMock(
|
||||
return_value=AgentSignals(terminate=signal_time)
|
||||
)
|
||||
assert control_channel.should_agent_stop() is expected_should_stop
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)
|
||||
|
|
|
@ -5,7 +5,7 @@ from uuid import UUID
|
|||
import pytest
|
||||
from tests.common import StubDIContainer
|
||||
|
||||
from monkey_island.cc.models import AgentSignals as Signals
|
||||
from common.agent_signals import AgentSignals as Signals
|
||||
from monkey_island.cc.repository import RetrievalError, StorageError
|
||||
from monkey_island.cc.services import AgentSignalsService
|
||||
|
||||
|
|
Loading…
Reference in New Issue