Merge pull request #2353 from guardicore/2323-SocketAddress-in-AgentRegistrationData

SocketAddress in AgentRegistrationData
This commit is contained in:
Shreya Malviya 2022-09-27 13:56:16 +05:30 committed by GitHub
commit 14999fba4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 40 additions and 40 deletions

View File

@ -7,7 +7,7 @@ from pydantic import validator
from .base_models import InfectionMonkeyBaseModel from .base_models import InfectionMonkeyBaseModel
from .transforms import make_immutable_sequence from .transforms import make_immutable_sequence
from .types import HardwareID from .types import HardwareID, SocketAddress
class AgentRegistrationData(InfectionMonkeyBaseModel): class AgentRegistrationData(InfectionMonkeyBaseModel):
@ -15,7 +15,7 @@ class AgentRegistrationData(InfectionMonkeyBaseModel):
machine_hardware_id: HardwareID machine_hardware_id: HardwareID
start_time: datetime start_time: datetime
parent_id: Optional[UUID] parent_id: Optional[UUID]
cc_server: str cc_server: SocketAddress
network_interfaces: Sequence[IPv4Interface] network_interfaces: Sequence[IPv4Interface]
_make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)( _make_immutable_sequence = validator("network_interfaces", pre=True, allow_reuse=True)(

View File

@ -3,7 +3,7 @@ import logging
import os import os
import subprocess import subprocess
import sys import sys
from ipaddress import IPv4Address, IPv4Interface from ipaddress import IPv4Interface
from pathlib import Path, WindowsPath from pathlib import Path, WindowsPath
from typing import List, Mapping, Optional, Tuple from typing import List, Mapping, Optional, Tuple
@ -119,19 +119,15 @@ class InfectionMonkey:
self._agent_event_serializer_registry = self._setup_agent_event_serializers() self._agent_event_serializer_registry = self._setup_agent_event_serializers()
server, self._island_api_client = self._connect_to_island_api() server, self._island_api_client = self._connect_to_island_api()
# TODO: `address_to_port()` should return the port as an integer.
self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server) self._cmd_island_ip, self._cmd_island_port = address_to_ip_port(server)
self._cmd_island_port = int(self._cmd_island_port)
self._island_address = SocketAddress( self._island_address = SocketAddress(self._cmd_island_ip, self._cmd_island_port)
IPv4Address(self._cmd_island_ip), self._cmd_island_port
)
self._control_client = ControlClient( self._control_client = ControlClient(
server_address=server, island_api_client=self._island_api_client server_address=server, island_api_client=self._island_api_client
) )
self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client) self._control_channel = ControlChannel(server, get_agent_id(), self._island_api_client)
self._register_agent(server) self._register_agent(self._island_address)
# TODO Refactor the telemetry messengers to accept control client # TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object # and remove control_client_object
@ -180,7 +176,7 @@ class InfectionMonkey:
return server, island_api_client return server, island_api_client
def _register_agent(self, server: str): def _register_agent(self, server: SocketAddress):
agent_registration_data = AgentRegistrationData( agent_registration_data = AgentRegistrationData(
id=get_agent_id(), id=get_agent_id(),
machine_hardware_id=get_machine_id(), machine_hardware_id=get_machine_id(),

View File

@ -1,11 +1,9 @@
import logging import logging
import socket import socket
from contextlib import suppress from contextlib import suppress
from ipaddress import IPv4Address
from typing import Dict, Iterable, Iterator, Optional from typing import Dict, Iterable, Iterator, Optional
from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT
from common.network.network_utils import address_to_ip_port
from common.types import SocketAddress from common.types import SocketAddress
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
AbstractIslandAPIClientFactory, AbstractIslandAPIClientFactory,
@ -81,20 +79,15 @@ def _check_if_island_server(
def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]): def send_remove_from_waitlist_control_message_to_relays(servers: Iterable[str]):
for i, server in enumerate(servers, start=1): for i, server in enumerate(servers, start=1):
server_address = SocketAddress.from_string(server)
t = create_daemon_thread( t = create_daemon_thread(
target=_send_remove_from_waitlist_control_message_to_relay, target=notify_disconnect,
name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}", name=f"SendRemoveFromWaitlistControlMessageToRelaysThread-{i:02d}",
args=(server,), args=(server_address,),
) )
t.start() t.start()
def _send_remove_from_waitlist_control_message_to_relay(server: str):
ip, port = address_to_ip_port(server)
server_address = SocketAddress(IPv4Address(ip), int(port))
notify_disconnect(server_address)
def notify_disconnect(server_address: SocketAddress): def notify_disconnect(server_address: SocketAddress):
""" """
Tell upstream relay that we no longer need the relay Tell upstream relay that we no longer need the relay

View File

@ -1,9 +1,9 @@
from contextlib import suppress from contextlib import suppress
from ipaddress import IPv4Address, IPv4Interface from ipaddress import IPv4Interface
from typing import List, Optional from typing import List, Optional
from common import AgentRegistrationData from common import AgentRegistrationData
from common.network.network_utils import address_to_ip_port from common.types import SocketAddress
from monkey_island.cc.models import Agent, CommunicationType, Machine from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
IAgentRepository, IAgentRepository,
@ -116,8 +116,8 @@ class handle_agent_registration:
src_machine.id, dst_machine.id, CommunicationType.CC src_machine.id, dst_machine.id, CommunicationType.CC
) )
def _get_or_create_cc_machine(self, cc_server: str) -> Machine: def _get_or_create_cc_machine(self, cc_server: SocketAddress) -> Machine:
dst_ip = IPv4Address(address_to_ip_port(cc_server)[0]) dst_ip = cc_server.ip
try: try:
return self._machine_repository.get_machines_by_ip(dst_ip)[0] return self._machine_repository.get_machines_by_ip(dst_ip)[0]

View File

@ -4,6 +4,7 @@ from typing import Optional
from pydantic import Field from pydantic import Field
from common.base_models import MutableInfectionMonkeyBaseModel from common.base_models import MutableInfectionMonkeyBaseModel
from common.types import SocketAddress
from . import AgentID, MachineID from . import AgentID, MachineID
@ -26,7 +27,7 @@ class Agent(MutableInfectionMonkeyBaseModel):
parent_id: Optional[AgentID] = Field(allow_mutation=False) parent_id: Optional[AgentID] = Field(allow_mutation=False)
"""The ID of the parent agent that spawned this agent""" """The ID of the parent agent that spawned this agent"""
cc_server: str = Field(default="") cc_server: Optional[SocketAddress]
"""The address that the agent used to communicate with the island""" """The address that the agent used to communicate with the island"""
log_contents: str = Field(default="") log_contents: str = Field(default="")

View File

@ -11,6 +11,7 @@ from common.agent_event_serializers import (
) )
from common.agent_events import AbstractAgentEvent from common.agent_events import AbstractAgentEvent
from common.agent_registration_data import AgentRegistrationData from common.agent_registration_data import AgentRegistrationData
from common.types import SocketAddress
from infection_monkey.island_api_client import ( from infection_monkey.island_api_client import (
HTTPIslandAPIClient, HTTPIslandAPIClient,
IslandAPIConnectionError, IslandAPIConnectionError,
@ -20,7 +21,7 @@ from infection_monkey.island_api_client import (
IslandAPITimeoutError, IslandAPITimeoutError,
) )
SERVER = "1.1.1.1:9999" SERVER = SocketAddress(ip="1.1.1.1", port="9999")
PBA_FILE = "dummy.pba" PBA_FILE = "dummy.pba"
WINDOWS = "windows" WINDOWS = "windows"
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")

View File

@ -7,6 +7,7 @@ from uuid import UUID
import pytest import pytest
from common import AgentRegistrationData from common import AgentRegistrationData
from common.types import SocketAddress
from monkey_island.cc.island_event_handlers import handle_agent_registration from monkey_island.cc.island_event_handlers import handle_agent_registration
from monkey_island.cc.models import Agent, CommunicationType, Machine from monkey_island.cc.models import Agent, CommunicationType, Machine
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
@ -26,12 +27,14 @@ MACHINE = Machine(
network_interfaces=[IPv4Interface("192.168.2.2/24")], network_interfaces=[IPv4Interface("192.168.2.2/24")],
) )
IP = "192.168.1.1:5000"
AGENT_REGISTRATION_DATA = AgentRegistrationData( AGENT_REGISTRATION_DATA = AgentRegistrationData(
id=AGENT_ID, id=AGENT_ID,
machine_hardware_id=MACHINE.hardware_id, machine_hardware_id=MACHINE.hardware_id,
start_time=0, start_time=0,
parent_id=None, parent_id=None,
cc_server="192.168.1.1:5000", cc_server=SocketAddress.from_string(IP),
network_interfaces=[IPv4Interface("192.168.1.2/24")], network_interfaces=[IPv4Interface("192.168.1.2/24")],
) )
@ -111,7 +114,7 @@ def test_existing_machine_updated__find_by_ip(handler, machine_repository):
machine_hardware_id=5, machine_hardware_id=5,
start_time=0, start_time=0,
parent_id=None, parent_id=None,
cc_server="192.168.1.1:5000", cc_server=SocketAddress(ip="192.168.1.1", port="5000"),
network_interfaces=[ network_interfaces=[
IPv4Interface("192.168.1.2/24"), IPv4Interface("192.168.1.2/24"),
IPv4Interface("192.168.1.4/24"), IPv4Interface("192.168.1.4/24"),
@ -215,7 +218,7 @@ def test_machine_interfaces_updated(handler, machine_repository):
machine_hardware_id=MACHINE.hardware_id, machine_hardware_id=MACHINE.hardware_id,
start_time=0, start_time=0,
parent_id=None, parent_id=None,
cc_server="192.168.1.1:5000", cc_server=SocketAddress(ip="192.168.1.1", port="5000"),
network_interfaces=[ network_interfaces=[
IPv4Interface("192.168.1.2/24"), IPv4Interface("192.168.1.2/24"),
IPv4Interface("192.168.1.3/16"), IPv4Interface("192.168.1.3/16"),

View File

@ -27,7 +27,7 @@ def test_constructor__defaults_from_objects():
a = Agent(**AGENT_OBJECT_DICT) a = Agent(**AGENT_OBJECT_DICT)
assert a.stop_time is None assert a.stop_time is None
assert a.cc_server == "" assert a.cc_server is None
def test_constructor__defaults_from_simple_dict(): def test_constructor__defaults_from_simple_dict():
@ -37,7 +37,7 @@ def test_constructor__defaults_from_simple_dict():
assert a.parent_id is None assert a.parent_id is None
assert a.stop_time is None assert a.stop_time is None
assert a.cc_server == "" assert a.cc_server is None
assert a.log_contents == "" assert a.log_contents == ""
@ -45,7 +45,7 @@ def test_to_dict():
a = Agent(**AGENT_OBJECT_DICT) a = Agent(**AGENT_OBJECT_DICT)
agent_simple_dict = AGENT_SIMPLE_DICT.copy() agent_simple_dict = AGENT_SIMPLE_DICT.copy()
agent_simple_dict["stop_time"] = None agent_simple_dict["stop_time"] = None
agent_simple_dict["cc_server"] = "" agent_simple_dict["cc_server"] = None
agent_simple_dict["log_contents"] = "" agent_simple_dict["log_contents"] = ""
assert a.dict(simplify=True) == agent_simple_dict assert a.dict(simplify=True) == agent_simple_dict
@ -59,7 +59,7 @@ def test_to_dict():
("start_time", None), ("start_time", None),
("stop_time", []), ("stop_time", []),
("parent_id", 2.1), ("parent_id", 2.1),
("cc_server", []), ("cc_server", [1]),
("log_contents", None), ("log_contents", None),
], ],
) )
@ -77,6 +77,7 @@ def test_construct_invalid_field__type_error(key, value):
("machine_id", -1), ("machine_id", -1),
("start_time", "not-a-datetime"), ("start_time", "not-a-datetime"),
("stop_time", "not-a-datetime"), ("stop_time", "not-a-datetime"),
("cc_server", []),
], ],
) )
def test_construct_invalid_field__value_error(key, value): def test_construct_invalid_field__value_error(key, value):
@ -126,7 +127,7 @@ def test_cc_server_set_validated():
a = Agent(**AGENT_SIMPLE_DICT) a = Agent(**AGENT_SIMPLE_DICT)
with pytest.raises(ValueError): with pytest.raises(ValueError):
a.cc_server = None a.cc_server = []
def test_log_contents_set_validated(): def test_log_contents_set_validated():

View File

@ -6,16 +6,18 @@ from uuid import UUID
import pytest import pytest
from common import AgentRegistrationData from common import AgentRegistrationData
from common.types import SocketAddress
AGENT_ID = UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10") AGENT_ID = UUID("012e7238-7b81-4108-8c7f-0787bc3f3c10")
PARENT_ID = UUID("0fc9afcb-1902-436b-bd5c-1ad194252484") PARENT_ID = UUID("0fc9afcb-1902-436b-bd5c-1ad194252484")
SOCKET_ADDRESS = SocketAddress(ip="192.168.1.1", port="5000")
AGENT_REGISTRATION_MESSAGE_OBJECT_DICT = { AGENT_REGISTRATION_MESSAGE_OBJECT_DICT = {
"id": AGENT_ID, "id": AGENT_ID,
"machine_hardware_id": 2, "machine_hardware_id": 2,
"start_time": datetime.fromtimestamp(1660848408, tz=timezone.utc), "start_time": datetime.fromtimestamp(1660848408, tz=timezone.utc),
"parent_id": PARENT_ID, "parent_id": PARENT_ID,
"cc_server": "192.168.1.1:5000", "cc_server": SOCKET_ADDRESS,
"network_interfaces": [IPv4Interface("10.0.0.1/24"), IPv4Interface("192.168.5.32/16")], "network_interfaces": [IPv4Interface("10.0.0.1/24"), IPv4Interface("192.168.5.32/16")],
} }
@ -24,7 +26,7 @@ AGENT_REGISTRATION_MESSAGE_SIMPLE_DICT = {
"machine_hardware_id": 2, "machine_hardware_id": 2,
"start_time": "2022-08-18T18:46:48+00:00", "start_time": "2022-08-18T18:46:48+00:00",
"parent_id": str(PARENT_ID), "parent_id": str(PARENT_ID),
"cc_server": "192.168.1.1:5000", "cc_server": SOCKET_ADDRESS.dict(simplify=True),
"network_interfaces": ["10.0.0.1/24", "192.168.5.32/16"], "network_interfaces": ["10.0.0.1/24", "192.168.5.32/16"],
} }
@ -50,7 +52,7 @@ def test_from_serialized():
("machine_hardware_id", "not-an-int"), ("machine_hardware_id", "not-an-int"),
("start_time", None), ("start_time", None),
("parent_id", 2.1), ("parent_id", 2.1),
("cc_server", []), ("cc_server", [1]),
("network_interfaces", "not-a-list"), ("network_interfaces", "not-a-list"),
], ],
) )
@ -68,6 +70,7 @@ def test_construct_invalid_field__type_error(key, value):
("machine_hardware_id", -1), ("machine_hardware_id", -1),
("start_time", "not-a-date-time"), ("start_time", "not-a-date-time"),
("network_interfaces", [1, "stuff", 3]), ("network_interfaces", [1, "stuff", 3]),
("cc_server", []),
], ],
) )
def test_construct_invalid_field__value_error(key, value): def test_construct_invalid_field__value_error(key, value):
@ -85,7 +88,7 @@ def test_construct_invalid_field__value_error(key, value):
("machine_hardware_id", 99), ("machine_hardware_id", 99),
("start_time", 0), ("start_time", 0),
("parent_id", AGENT_ID), ("parent_id", AGENT_ID),
("cc_server", "10.0.0.1:4999"), ("cc_server", SOCKET_ADDRESS),
("network_interfaces", ["10.0.0.1/24"]), ("network_interfaces", ["10.0.0.1/24"]),
], ],
) )

View File

@ -6,6 +6,7 @@ from uuid import uuid4
import mongomock import mongomock
import pytest import pytest
from common.types import SocketAddress
from monkey_island.cc.models import Agent from monkey_island.cc.models import Agent
from monkey_island.cc.repository import ( from monkey_island.cc.repository import (
IAgentRepository, IAgentRepository,
@ -54,6 +55,7 @@ AGENTS = (
*RUNNING_AGENTS, *RUNNING_AGENTS,
*STOPPED_AGENTS, *STOPPED_AGENTS,
) )
CC_SERVER = SocketAddress(ip="127.0.0.1", port="1984")
@pytest.fixture @pytest.fixture
@ -116,7 +118,7 @@ def test_upsert_agent__insert_empty_repository(empty_agent_repository):
def test_upsert_agent__update(agent_repository): def test_upsert_agent__update(agent_repository):
agents = deepcopy(AGENTS) agents = deepcopy(AGENTS)
agents[0].stop_time = datetime.now() agents[0].stop_time = datetime.now()
agents[0].cc_server = "127.0.0.1:1984" agents[0].cc_server = CC_SERVER
agent_repository.upsert_agent(agents[0]) agent_repository.upsert_agent(agents[0])

View File

@ -16,7 +16,7 @@ AGENT_REGISTRATION_DICT = {
"machine_hardware_id": 1, "machine_hardware_id": 1,
"start_time": 0, "start_time": 0,
"parent_id": UUID("9d55ba33-95c2-417d-bd86-d3d11e47daeb"), "parent_id": UUID("9d55ba33-95c2-417d-bd86-d3d11e47daeb"),
"cc_server": "10.0.0.1:5000", "cc_server": {"ip": "10.0.0.1", "port": "5000"},
"network_interfaces": ["10.1.1.2/24"], "network_interfaces": ["10.1.1.2/24"],
} }