forked from p15670423/monkey
Agent: Add register_agent to IslandAPIClient
This commit is contained in:
parent
dfa1709064
commit
54ef77698c
|
@ -4,10 +4,14 @@ from typing import List, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from common import OperatingSystem
|
from common import AgentRegistrationData, OperatingSystem
|
||||||
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
from common.agent_event_serializers import AgentEventSerializerRegistry, JSONSerializable
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
|
from common.common_consts.timeouts import (
|
||||||
|
LONG_REQUEST_TIMEOUT,
|
||||||
|
MEDIUM_REQUEST_TIMEOUT,
|
||||||
|
SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AbstractIslandAPIClientFactory,
|
AbstractIslandAPIClientFactory,
|
||||||
|
@ -116,6 +120,25 @@ class HTTPIslandAPIClient(IIslandAPIClient):
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
try:
|
||||||
|
url = f"https://{agent_registration_data.cc_server}/api/agents"
|
||||||
|
response = requests.post( # noqa: DUO123
|
||||||
|
url,
|
||||||
|
json=agent_registration_data.dict(simplify=True),
|
||||||
|
verify=False,
|
||||||
|
timeout=SHORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except (
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.TooManyRedirects,
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
|
) as e:
|
||||||
|
raise IslandAPIConnectionError(e)
|
||||||
|
except requests.exceptions.Timeout as e:
|
||||||
|
raise IslandAPITimeoutError(e)
|
||||||
|
|
||||||
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
|
||||||
serialized_events: List[JSONSerializable] = []
|
serialized_events: List[JSONSerializable] = []
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ from typing import Optional, Sequence
|
||||||
from common import OperatingSystem
|
from common import OperatingSystem
|
||||||
from common.agent_events import AbstractAgentEvent
|
from common.agent_events import AbstractAgentEvent
|
||||||
|
|
||||||
|
from common import AgentRegistrationData
|
||||||
|
|
||||||
|
|
||||||
class IIslandAPIClient(ABC):
|
class IIslandAPIClient(ABC):
|
||||||
"""
|
"""
|
||||||
|
@ -74,7 +76,6 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
:raises IslandAPITimeoutError: If a timeout occurs while attempting to connect to the island
|
||||||
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
:raises IslandAPIError: If an unexpected error occurs while attempting to retrieve the
|
||||||
agent binary
|
agent binary
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -92,3 +93,14 @@ class IIslandAPIClient(ABC):
|
||||||
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
:raises IslandAPIError: If an unexpected error occurs while attempting to send events to
|
||||||
the island
|
the island
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_agent(self, agent_registration_data: AgentRegistrationData):
|
||||||
|
"""
|
||||||
|
Register an agent with the Island
|
||||||
|
|
||||||
|
:param agent_registration_data: Information about the agent to register
|
||||||
|
with the island
|
||||||
|
:raises IslandAPIConnectionError: If the client could not connect to the island
|
||||||
|
:raises IslandAPITimeoutError: If the command timed out
|
||||||
|
"""
|
||||||
|
|
|
@ -13,6 +13,11 @@ from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT
|
||||||
from common.credentials import Credentials
|
from common.credentials import Credentials
|
||||||
from common.network.network_utils import get_network_interfaces
|
from common.network.network_utils import get_network_interfaces
|
||||||
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
|
||||||
|
from infection_monkey.island_api_client import (
|
||||||
|
IIslandAPIClient,
|
||||||
|
IslandAPIConnectionError,
|
||||||
|
IslandAPITimeoutError,
|
||||||
|
)
|
||||||
from infection_monkey.utils import agent_process
|
from infection_monkey.utils import agent_process
|
||||||
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
from infection_monkey.utils.ids import get_agent_id, get_machine_id
|
||||||
|
|
||||||
|
@ -22,9 +27,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlChannel(IControlChannel):
|
class ControlChannel(IControlChannel):
|
||||||
def __init__(self, server: str, agent_id: str):
|
def __init__(self, server: str, agent_id: str, api_client: IIslandAPIClient):
|
||||||
self._agent_id = agent_id
|
self._agent_id = agent_id
|
||||||
self._control_channel_server = server
|
self._control_channel_server = server
|
||||||
|
self._island_api_client = api_client
|
||||||
|
|
||||||
def register_agent(self, parent: Optional[UUID] = None):
|
def register_agent(self, parent: Optional[UUID] = None):
|
||||||
agent_registration_data = AgentRegistrationData(
|
agent_registration_data = AgentRegistrationData(
|
||||||
|
@ -38,20 +44,8 @@ class ControlChannel(IControlChannel):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = f"https://{self._control_channel_server}/api/agents"
|
self._island_api_client.register_agent(agent_registration_data)
|
||||||
response = requests.post( # noqa: DUO123
|
except (IslandAPIConnectionError, IslandAPITimeoutError) as e:
|
||||||
url,
|
|
||||||
json=agent_registration_data.dict(simplify=True),
|
|
||||||
verify=False,
|
|
||||||
timeout=SHORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
except (
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
requests.exceptions.Timeout,
|
|
||||||
requests.exceptions.TooManyRedirects,
|
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as e:
|
|
||||||
raise IslandCommunicationError(e)
|
raise IslandCommunicationError(e)
|
||||||
|
|
||||||
def should_agent_stop(self) -> bool:
|
def should_agent_stop(self) -> bool:
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from infection_monkey.i_control_channel import IslandCommunicationError
|
||||||
|
from infection_monkey.island_api_client import (
|
||||||
|
IIslandAPIClient,
|
||||||
|
IslandAPIConnectionError,
|
||||||
|
IslandAPITimeoutError,
|
||||||
|
)
|
||||||
|
from infection_monkey.master.control_channel import ControlChannel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def island_api_client() -> IIslandAPIClient:
|
||||||
|
client = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def control_channel(island_api_client) -> ControlChannel:
|
||||||
|
return ControlChannel("server", "agent-id", island_api_client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__register_agent(control_channel, island_api_client):
|
||||||
|
control_channel.register_agent()
|
||||||
|
assert island_api_client.register_agent.called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__register_agent_raises_on_connection_error(
|
||||||
|
control_channel, island_api_client
|
||||||
|
):
|
||||||
|
island_api_client.register_agent.side_effect = IslandAPIConnectionError()
|
||||||
|
|
||||||
|
with pytest.raises(IslandCommunicationError):
|
||||||
|
control_channel.register_agent()
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_channel__register_agent_raises_on_timeout_error(
|
||||||
|
control_channel, island_api_client
|
||||||
|
):
|
||||||
|
island_api_client.register_agent.side_effect = IslandAPITimeoutError()
|
||||||
|
|
||||||
|
with pytest.raises(IslandCommunicationError):
|
||||||
|
control_channel.register_agent()
|
Loading…
Reference in New Issue