Merge pull request #2353 from guardicore/2323-SocketAddress-in-AgentRegistrationData
SocketAddress in AgentRegistrationData
This commit is contained in:
commit
14999fba4e
|
@ -7,7 +7,7 @@ from pydantic import validator
|
|||
|
||||
from .base_models import InfectionMonkeyBaseModel
|
||||
from .transforms import make_immutable_sequence
|
||||
from .types import HardwareID
|
||||
from .types import HardwareID, SocketAddress
|
||||
|
||||
|
||||
class AgentRegistrationData(InfectionMonkeyBaseModel):
|
||||
|
@ -15,7 +15,7 @@ class AgentRegistrationData(InfectionMonkeyBaseModel):
|
|||
machine_hardware_id: HardwareID
|
||||
start_time: datetime
|
||||
parent_id: Optional[UUID]
|
||||
cc_server: str
|
||||
cc_server: SocketAddress
|
||||
network_interfaces: Sequence[IPv4Interface]
|
||||
|
||||
_make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)(
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from ipaddress import IPv4Address, IPv4Interface
|
||||
from ipaddress import IPv4Interface
|
||||
from pathlib import Path, WindowsPath
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
|
||||
|
@ -119,19 +119,15 @@ class InfectionMonkey:
|
|||
self._agent_event_serializer_registry = self._setup_agent_event_serializers()
|
||||
|
||||
server, self._island_api_client = self._connect_to_island_api()
|
||||
# TODO: `address_to_port()` should return the port as an integer.
|
||||
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
|
||||
self._cmd_island_port = int(self._cmd_island_port)
|
||||
|
||||
self._island_address = SocketAddress(
|
||||
IPv4Address(self._cmd_island_ip), self._cmd_island_port
|
||||
)
|
||||
self._island_address = SocketAddress(self._cmd_island_ip, self._cmd_island_port)
|
||||
|
||||
self._control_client = ControlClient(
|
||||
server_address=server, island_api_client=self._island_api_client
|
||||
)
|
||||
self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client)
|
||||
self._register_agent(server)
|
||||
self._register_agent(self._island_address)
|
||||
|
||||
# TODO Refactor the telemetry messengers to accept control client
|
||||
# and remove control_client_object
|
||||
|
@ -180,7 +176,7 @@ class InfectionMonkey:
|
|||
|
||||
return server, island_api_client
|
||||
|
||||
def _register_agent(self, server: str):
|
||||
def _register_agent(self, server: SocketAddress):
|
||||
agent_registration_data = AgentRegistrationData(
|
||||
id=get_agent_id(),
|
||||
machine_hardware_id=get_machine_id(),
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import logging
|
||||
import socket
|
||||
from contextlib import suppress
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Iterable, Iterator, Optional
|
||||
|
||||
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from common.types import SocketAddress
|
||||
from infection_monkey.island_api_client import (
|
||||
AbstractIslandAPIClientFactory,
|
||||
|
@ -81,20 +79,15 @@ def _check_if_island_server(
|
|||
|
||||
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
|
||||
for i, server in enumerate(servers, start=1):
|
||||
server_address = SocketAddress.from_string(server)
|
||||
t = create_daemon_thread(
|
||||
target=_send_remove_from_waitlist_control_message_to_relay,
|
||||
target=notify_disconnect,
|
||||
name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}",
|
||||
args=(server,),
|
||||
args=(server_address,),
|
||||
)
|
||||
t.start()
|
||||
|
||||
|
||||
def _send_remove_from_waitlist_control_message_to_relay(server: str):
|
||||
ip, port = address_to_ip_port(server)
|
||||
server_address = SocketAddress(IPv4Address(ip), int(port))
|
||||
notify_disconnect(server_address)
|
||||
|
||||
|
||||
def notify_disconnect(server_address: SocketAddress):
|
||||
"""
|
||||
Tell upstream relay that we no longer need the relay
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from contextlib import suppress
|
||||
from ipaddress import IPv4Address, IPv4Interface
|
||||
from ipaddress import IPv4Interface
|
||||
from typing import List, Optional
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from common.network.network_utils import address_to_ip_port
|
||||
from common.types import SocketAddress
|
||||
from monkey_island.cc.models import Agent, CommunicationType, Machine
|
||||
from monkey_island.cc.repository import (
|
||||
IAgentRepository,
|
||||
|
@ -116,8 +116,8 @@ class handle_agent_registration:
|
|||
src_machine.id, dst_machine.id, CommunicationType.CC
|
||||
)
|
||||
|
||||
def _get_or_create_cc_machine(self, cc_server: str) -> Machine:
|
||||
dst_ip = IPv4Address(address_to_ip_port(cc_server)[0])
|
||||
def _get_or_create_cc_machine(self, cc_server: SocketAddress) -> Machine:
|
||||
dst_ip = cc_server.ip
|
||||
|
||||
try:
|
||||
return self._machine_repository.get_machines_by_ip(dst_ip)[0]
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
from pydantic import Field
|
||||
|
||||
from common.base_models import MutableInfectionMonkeyBaseModel
|
||||
from common.types import SocketAddress
|
||||
|
||||
from . import AgentID, MachineID
|
||||
|
||||
|
@ -26,7 +27,7 @@ class Agent(MutableInfectionMonkeyBaseModel):
|
|||
parent_id: Optional[AgentID] = Field(allow_mutation=False)
|
||||
"""The ID of the parent agent that spawned this agent"""
|
||||
|
||||
cc_server: str = Field(default="")
|
||||
cc_server: Optional[SocketAddress]
|
||||
"""The address that the agent used to communicate with the island"""
|
||||
|
||||
log_contents: str = Field(default="")
|
||||
|
|
|
@ -11,6 +11,7 @@ from common.agent_event_serializers import (
|
|||
)
|
||||
from common.agent_events import AbstractAgentEvent
|
||||
from common.agent_registration_data import AgentRegistrationData
|
||||
from common.types import SocketAddress
|
||||
from infection_monkey.island_api_client import (
|
||||
HTTPIslandAPIClient,
|
||||
IslandAPIConnectionError,
|
||||
|
@ -20,7 +21,7 @@ from infection_monkey.island_api_client import (
|
|||
IslandAPITimeoutError,
|
||||
)
|
||||
|
||||
SERVER = "1.1.1.1:9999"
|
||||
SERVER = SocketAddress(ip="1.1.1.1", port="9999")
|
||||
PBA_FILE = "dummy.pba"
|
||||
WINDOWS = "windows"
|
||||
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")
|
||||
|
|
|
@ -7,6 +7,7 @@ from uuid import UUID
|
|||
import pytest
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from common.types import SocketAddress
|
||||
from monkey_island.cc.island_event_handlers import handle_agent_registration
|
||||
from monkey_island.cc.models import Agent, CommunicationType, Machine
|
||||
from monkey_island.cc.repository import (
|
||||
|
@ -26,12 +27,14 @@ MACHINE = Machine(
|
|||
network_interfaces=[IPv4Interface("192.168.2.2/24")],
|
||||
)
|
||||
|
||||
IP = "192.168.1.1:5000"
|
||||
|
||||
AGENT_REGISTRATION_DATA = AgentRegistrationData(
|
||||
id=AGENT_ID,
|
||||
machine_hardware_id=MACHINE.hardware_id,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server="192.168.1.1:5000",
|
||||
cc_server=SocketAddress.from_string(IP),
|
||||
network_interfaces=[IPv4Interface("192.168.1.2/24")],
|
||||
)
|
||||
|
||||
|
@ -111,7 +114,7 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
|
|||
machine_hardware_id=5,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server="192.168.1.1:5000",
|
||||
cc_server=SocketAddress(ip="192.168.1.1", port="5000"),
|
||||
network_interfaces=[
|
||||
IPv4Interface("192.168.1.2/24"),
|
||||
IPv4Interface("192.168.1.4/24"),
|
||||
|
@ -215,7 +218,7 @@ def test_machine_interfaces_updated(handler, machine_repository):
|
|||
machine_hardware_id=MACHINE.hardware_id,
|
||||
start_time=0,
|
||||
parent_id=None,
|
||||
cc_server="192.168.1.1:5000",
|
||||
cc_server=SocketAddress(ip="192.168.1.1", port="5000"),
|
||||
network_interfaces=[
|
||||
IPv4Interface("192.168.1.2/24"),
|
||||
IPv4Interface("192.168.1.3/16"),
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_constructor__defaults_from_objects():
|
|||
a = Agent(**AGENT_OBJECT_DICT)
|
||||
|
||||
assert a.stop_time is None
|
||||
assert a.cc_server == ""
|
||||
assert a.cc_server is None
|
||||
|
||||
|
||||
def test_constructor__defaults_from_simple_dict():
|
||||
|
@ -37,7 +37,7 @@ def test_constructor__defaults_from_simple_dict():
|
|||
|
||||
assert a.parent_id is None
|
||||
assert a.stop_time is None
|
||||
assert a.cc_server == ""
|
||||
assert a.cc_server is None
|
||||
assert a.log_contents == ""
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ def test_to_dict():
|
|||
a = Agent(**AGENT_OBJECT_DICT)
|
||||
agent_simple_dict = AGENT_SIMPLE_DICT.copy()
|
||||
agent_simple_dict["stop_time"] = None
|
||||
agent_simple_dict["cc_server"] = ""
|
||||
agent_simple_dict["cc_server"] = None
|
||||
agent_simple_dict["log_contents"] = ""
|
||||
|
||||
assert a.dict(simplify=True) == agent_simple_dict
|
||||
|
@ -59,7 +59,7 @@ def test_to_dict():
|
|||
("start_time", None),
|
||||
("stop_time", []),
|
||||
("parent_id", 2.1),
|
||||
("cc_server", []),
|
||||
("cc_server", [1]),
|
||||
("log_contents", None),
|
||||
],
|
||||
)
|
||||
|
@ -77,6 +77,7 @@ def test_construct_invalid_field__type_error(key, value):
|
|||
("machine_id", -1),
|
||||
("start_time", "not-a-datetime"),
|
||||
("stop_time", "not-a-datetime"),
|
||||
("cc_server", []),
|
||||
],
|
||||
)
|
||||
def test_construct_invalid_field__value_error(key, value):
|
||||
|
@ -126,7 +127,7 @@ def test_cc_server_set_validated():
|
|||
a = Agent(**AGENT_SIMPLE_DICT)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.cc_server = None
|
||||
a.cc_server = []
|
||||
|
||||
|
||||
def test_log_contents_set_validated():
|
||||
|
|
|
@ -6,16 +6,18 @@ from uuid import UUID
|
|||
import pytest
|
||||
|
||||
from common import AgentRegistrationData
|
||||
from common.types import SocketAddress
|
||||
|
||||
AGENT_ID = UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10")
|
||||
PARENT_ID = UUID("0fc9afcb-1902-436b-bd5c-1ad194252484")
|
||||
SOCKET_ADDRESS = SocketAddress(ip="192.168.1.1", port="5000")
|
||||
|
||||
AGENT_REGISTRATION_MESSAGE_OBJECT_DICT = {
|
||||
"id": AGENT_ID,
|
||||
"machine_hardware_id": 2,
|
||||
"start_time": datetime.fromtimestamp(1660848408, tz=timezone.utc),
|
||||
"parent_id": PARENT_ID,
|
||||
"cc_server": "192.168.1.1:5000",
|
||||
"cc_server": SOCKET_ADDRESS,
|
||||
"network_interfaces": [IPv4Interface("10.0.0.1/24"), IPv4Interface("192.168.5.32/16")],
|
||||
}
|
||||
|
||||
|
@ -24,7 +26,7 @@ AGENT_REGISTRATION_MESSAGE_SIMPLE_DICT = {
|
|||
"machine_hardware_id": 2,
|
||||
"start_time": "2022-08-18T18:46:48+00:00",
|
||||
"parent_id": str(PARENT_ID),
|
||||
"cc_server": "192.168.1.1:5000",
|
||||
"cc_server": SOCKET_ADDRESS.dict(simplify=True),
|
||||
"network_interfaces": ["10.0.0.1/24", "192.168.5.32/16"],
|
||||
}
|
||||
|
||||
|
@ -50,7 +52,7 @@ def test_from_serialized():
|
|||
("machine_hardware_id", "not-an-int"),
|
||||
("start_time", None),
|
||||
("parent_id", 2.1),
|
||||
("cc_server", []),
|
||||
("cc_server", [1]),
|
||||
("network_interfaces", "not-a-list"),
|
||||
],
|
||||
)
|
||||
|
@ -68,6 +70,7 @@ def test_construct_invalid_field__type_error(key, value):
|
|||
("machine_hardware_id", -1),
|
||||
("start_time", "not-a-date-time"),
|
||||
("network_interfaces", [1, "stuff", 3]),
|
||||
("cc_server", []),
|
||||
],
|
||||
)
|
||||
def test_construct_invalid_field__value_error(key, value):
|
||||
|
@ -85,7 +88,7 @@ def test_construct_invalid_field__value_error(key, value):
|
|||
("machine_hardware_id", 99),
|
||||
("start_time", 0),
|
||||
("parent_id", AGENT_ID),
|
||||
("cc_server", "10.0.0.1:4999"),
|
||||
("cc_server", SOCKET_ADDRESS),
|
||||
("network_interfaces", ["10.0.0.1/24"]),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -6,6 +6,7 @@ from uuid import uuid4
|
|||
import mongomock
|
||||
import pytest
|
||||
|
||||
from common.types import SocketAddress
|
||||
from monkey_island.cc.models import Agent
|
||||
from monkey_island.cc.repository import (
|
||||
IAgentRepository,
|
||||
|
@ -54,6 +55,7 @@ AGENTS = (
|
|||
*RUNNING_AGENTS,
|
||||
*STOPPED_AGENTS,
|
||||
)
|
||||
CC_SERVER = SocketAddress(ip="127.0.0.1", port="1984")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -116,7 +118,7 @@ def test_upsert_agent__insert_empty_repository(empty_agent_repository):
|
|||
def test_upsert_agent__update(agent_repository):
|
||||
agents = deepcopy(AGENTS)
|
||||
agents[0].stop_time = datetime.now()
|
||||
agents[0].cc_server = "127.0.0.1:1984"
|
||||
agents[0].cc_server = CC_SERVER
|
||||
|
||||
agent_repository.upsert_agent(agents[0])
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ AGENT_REGISTRATION_DICT = {
|
|||
"machine_hardware_id": 1,
|
||||
"start_time": 0,
|
||||
"parent_id": UUID("9d55ba33-95c2-417d-bd86-d3d11e47daeb"),
|
||||
"cc_server": "10.0.0.1:5000",
|
||||
"cc_server": {"ip": "10.0.0.1", "port": "5000"},
|
||||
"network_interfaces": ["10.1.1.2/24"],
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue