Merge pull request #2347 from guardicore/2261-move-agent-signals-to-common

2261 move agent signals to common
This commit is contained in:
Mike Salvatore 2022-09-23 12:26:06 -04:00
commit a49ddf7a4a
10 changed files with 35 additions and 24 deletions

View File

@ -7,3 +7,4 @@ from .operating_system import OperatingSystem
from . import types from . import types
from . import base_models from . import base_models
from .agent_registration_data import AgentRegistrationData from .agent_registration_data import AgentRegistrationData
from .agent_signals import AgentSignals

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from common.base_models import InfectionMonkeyBaseModel from .base_models import InfectionMonkeyBaseModel
class AgentSignals(InfectionMonkeyBaseModel): class AgentSignals(InfectionMonkeyBaseModel):

View File

@ -1,13 +1,12 @@
import functools import functools
import json import json
import logging import logging
from datetime import datetime
from pprint import pformat from pprint import pformat
from typing import List, Optional, Sequence from typing import List, Sequence
import requests import requests
from common import AgentRegistrationData, OperatingSystem from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
@ -189,7 +188,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
@handle_island_errors @handle_island_errors
@convert_json_error_to_island_api_error @convert_json_error_to_island_api_error
def get_agent_signals(self, agent_id: str) -> Optional[datetime]: def get_agent_signals(self, agent_id: str) -> AgentSignals:
url = f"{self._api_url}/agent-signals/{agent_id}" url = f"{self._api_url}/agent-signals/{agent_id}"
response = requests.get( # noqa: DUO123 response = requests.get( # noqa: DUO123
url, url,
@ -197,7 +196,7 @@ class HTTPIslandAPIClient(IIslandAPIClient):
timeout=SHORT_REQUEST_TIMEOUT, timeout=SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
return response.json()["terminate"] return AgentSignals(**response.json())
class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Sequence from typing import Optional, Sequence
from common import AgentRegistrationData, OperatingSystem from common import AgentRegistrationData, AgentSignals, OperatingSystem
from common.agent_configuration import AgentConfiguration from common.agent_configuration import AgentConfiguration
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.credentials import Credentials from common.credentials import Credentials
@ -133,7 +132,7 @@ class IIslandAPIClient(ABC):
""" """
@abstractmethod @abstractmethod
def get_agent_signals(self, agent_id: str) -> Optional[datetime]: def get_agent_signals(self, agent_id: str) -> AgentSignals:
""" """
Gets an agent's signals from the island Gets an agent's signals from the island
@ -142,5 +141,5 @@ class IIslandAPIClient(ABC):
:raises IslandAPIRequestError: If there was a problem with the client request :raises IslandAPIRequestError: If there was a problem with the client request
:raises IslandAPIRequestFailedError: If the server experienced an error :raises IslandAPIRequestFailedError: If the server experienced an error
:raises IslandAPITimeoutError: If the command timed out :raises IslandAPITimeoutError: If the command timed out
:return: The relevant agent's terminate signal's timestamp :return: The relevant agent's signals
""" """

View File

@ -36,7 +36,8 @@ class ControlChannel(IControlChannel):
if not self._control_channel_server: if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.") logger.error("Agent should stop because it can't connect to the C&C server.")
return True return True
return self._island_api_client.get_agent_signals(self._agent_id) is not None agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
return agent_signals.terminate is not None
@handle_island_api_errors @handle_island_api_errors
def get_config(self) -> AgentConfiguration: def get_config(self) -> AgentConfiguration:

View File

@ -15,4 +15,3 @@ from .communication_type import CommunicationType
from .node import Node from .node import Node
from common.types import AgentID from common.types import AgentID
from .agent import Agent from .agent import Agent
from .agent_signals import AgentSignals

View File

@ -2,8 +2,9 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from common.agent_signals import AgentSignals
from common.types import AgentID from common.types import AgentID
from monkey_island.cc.models import AgentSignals, Simulation from monkey_island.cc.models import Simulation
from monkey_island.cc.repository import IAgentRepository, ISimulationRepository from monkey_island.cc.repository import IAgentRepository, ISimulationRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,7 +4,7 @@ import pytest
import requests import requests
import requests_mock import requests_mock
from common import OperatingSystem from common import AgentSignals, OperatingSystem
from common.agent_event_serializers import ( from common.agent_event_serializers import (
AgentEventSerializerRegistry, AgentEventSerializerRegistry,
PydanticAgentEventSerializer, PydanticAgentEventSerializer,
@ -456,16 +456,17 @@ def test_island_api_client_get_agent_signals__status_code(
island_api_client.get_agent_signals(agent_id=AGENT_ID) island_api_client.get_agent_signals(agent_id=AGENT_ID)
@pytest.mark.parametrize("expected_timestamp", [TIMESTAMP, None]) @pytest.mark.parametrize("timestamp", [TIMESTAMP, None])
def test_island_api_client_get_agent_signals(island_api_client, expected_timestamp): def test_island_api_client_get_agent_signals(island_api_client, timestamp):
expected_agent_signals = AgentSignals(terminate=timestamp)
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get(ISLAND_URI) m.get(ISLAND_URI)
island_api_client.connect(SERVER) island_api_client.connect(SERVER)
m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": expected_timestamp}) m.get(ISLAND_GET_AGENT_SIGNALS, json={"terminate": timestamp})
actual_terminate_timestamp = island_api_client.get_agent_signals(agent_id=AGENT_ID) actual_agent_signals = island_api_client.get_agent_signals(agent_id=AGENT_ID)
assert actual_terminate_timestamp == expected_timestamp assert actual_agent_signals == expected_agent_signals
def test_island_api_client_get_agent_signals__bad_json(island_api_client): def test_island_api_client_get_agent_signals__bad_json(island_api_client):

View File

@ -1,8 +1,10 @@
from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from infection_monkey.i_control_channel import IslandCommunicationError from common import AgentSignals
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
IIslandAPIClient, IIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
@ -33,9 +35,17 @@ def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client) return ControlChannel(SERVER, AGENT_ID, island_api_client)
def test_control_channel__should_agent_stop(control_channel, island_api_client): @pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
control_channel.should_agent_stop() def test_control_channel__should_agent_stop(
assert island_api_client.get_agent_signals.called_once() control_channel: IControlChannel,
island_api_client: IIslandAPIClient,
signal_time: Optional[int],
expected_should_stop: bool,
):
island_api_client.get_agent_signals = MagicMock(
return_value=AgentSignals(terminate=signal_time)
)
assert control_channel.should_agent_stop() is expected_should_stop
@pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS) @pytest.mark.parametrize("api_error", CONTROL_CHANNEL_API_ERRORS)

View File

@ -5,7 +5,7 @@ from uuid import UUID
import pytest import pytest
from tests.common import StubDIContainer from tests.common import StubDIContainer
from monkey_island.cc.models import AgentSignals as Signals from common.agent_signals import AgentSignals as Signals
from monkey_island.cc.repository import RetrievalError, StorageError from monkey_island.cc.repository import RetrievalError, StorageError
from monkey_island.cc.services import AgentSignalsService from monkey_island.cc.services import AgentSignalsService