diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py index d992d1836..ada8aac19 100644 --- a/monkey/monkey_island/cc/models/node.py +++ b/monkey/monkey_island/cc/models/node.py @@ -1,4 +1,4 @@ -from typing import FrozenSet, Mapping +from typing import FrozenSet, Mapping, Tuple from pydantic import Field from typing_extensions import TypeAlias @@ -26,5 +26,5 @@ class Node(MutableInfectionMonkeyBaseModel): connections: NodeConnections """All outbound connections from this node to other machines""" - tcp_connections: Mapping[MachineID, FrozenSet[SocketAddress]] = {} + tcp_connections: Mapping[MachineID, Tuple[SocketAddress, ...]] = {} """All successfull outbound TCP connections""" diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py index 74a83860c..e50c493ef 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py @@ -2,6 +2,7 @@ from typing import MutableSequence import pytest +from common.types import SocketAddress from monkey_island.cc.models import CommunicationType, Node @@ -11,13 +12,21 @@ def test_constructor(): 6: frozenset((CommunicationType.SCANNED,)), 7: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), } + tcp_connections = { + 6: tuple( + (SocketAddress(ip="192.168.1.1", port=80), SocketAddress(ip="192.168.1.1", port=443)) + ), + 7: tuple((SocketAddress(ip="192.168.1.2", port=22),)), + } n = Node( - machine_id=1, + machine_id=machine_id, connections=connections, + tcp_connections=tcp_connections, ) assert n.machine_id == machine_id assert n.connections == connections + assert n.tcp_connections == tcp_connections def test_serialization(): @@ -27,9 +36,12 @@ def test_serialization(): "6": [CommunicationType.CC.value, CommunicationType.SCANNED.value], "7": [CommunicationType.EXPLOITED.value, CommunicationType.CC.value], }, + "tcp_connections": { + "6": [{"ip": "192.168.1.1", "port": 80}, {"ip": "192.168.1.1", "port": 443}], + "7": [{"ip": "192.168.1.2", "port": 22}], + }, } - # "6": frozenset((CommunicationType.CC, CommunicationType.SCANNED)), - # "7": frozenset((CommunicationType.EXPLOITED, CommunicationType.CC)), + n = Node(**node_dict) serialized_node = n.dict(simplify=True) @@ -44,6 +56,8 @@ def test_serialization(): for key, value in serialized_node["connections"].items(): assert set(value) == set(node_dict["connections"][key]) + assert serialized_node["tcp_connections"] == node_dict["tcp_connections"] + def test_machine_id_immutable(): n = Node(machine_id=1, connections={})