diff --git a/monkey/monkey_island/cc/models/__init__.py b/monkey/monkey_island/cc/models/__init__.py index c8e1a0fad..31f2700d0 100644 --- a/monkey/monkey_island/cc/models/__init__.py +++ b/monkey/monkey_island/cc/models/__init__.py @@ -11,3 +11,4 @@ from .simulation import Simulation, SimulationSchema, IslandMode from .user_credentials import UserCredentials from .machine import Machine, MachineID from .communication_type import CommunicationType +from .node import Node diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py new file mode 100644 index 000000000..6fa307734 --- /dev/null +++ b/monkey/monkey_island/cc/models/node.py @@ -0,0 +1,18 @@ +from typing import Sequence, Tuple + +from pydantic import Field, validator + +from . import CommunicationType, MachineID +from .base_models import MutableBaseModel +from .transforms import make_immutable_nested_sequence + +ConnectionsSequence = Sequence[Tuple[MachineID, Sequence[CommunicationType]]] + + +class Node(MutableBaseModel): + machine_id: MachineID = Field(..., allow_mutation=False) + connections: ConnectionsSequence + + _make_immutable_nested_sequence = validator("connections", pre=True, allow_reuse=True)( + make_immutable_nested_sequence + ) diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py new file mode 100644 index 000000000..e980ad5d6 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py @@ -0,0 +1,96 @@ +from typing import MutableSequence + +import pytest + +from monkey_island.cc.models import CommunicationType, Node + + +def test_constructor(): + machine_id = 1 + connections = ( + (6, (CommunicationType.SCANNED,)), + (7, (CommunicationType.SCANNED, CommunicationType.EXPLOITED)), + ) + n = Node( + machine_id=1, + connections=connections, + ) + + assert n.machine_id == machine_id + assert n.connections == connections + + +def test_serialization(): + node_dict = { + "machine_id": 1, + "connections": [ + [ + 6, + ["cc", "scanned"], + ], + [7, ["exploited", "cc_tunnel"]], + ], + } + n = Node(**node_dict) + + assert n.dict(simplify=True) == node_dict + + +def test_machine_id_immutable(): + n = Node(machine_id=1, connections=[]) + + with pytest.raises(TypeError): + n.machine_id = 2 + + +def test_machine_id__invalid_type(): + with pytest.raises(TypeError): + Node(machine_id=None, connections=[]) + + +def test_machine_id__invalid_value(): + with pytest.raises(ValueError): + Node(machine_id=-5, connections=[]) + + +def test_connections__mutable(): + n = Node(machine_id=1, connections=[]) + + # Raises exception on failure + n.connections = [(5, []), (7, [])] + + +def test_connections__invalid_machine_id(): + n = Node(machine_id=1, connections=[]) + + with pytest.raises(ValueError): + n.connections = [(5, []), (-5, [])] + + +def test_connections__recursively_immutable(): + n = Node( + machine_id=1, + connections=[ + [6, [CommunicationType.SCANNED]], + [7, [CommunicationType.SCANNED, CommunicationType.EXPLOITED]], + ], + ) + + assert not isinstance(n.connections, MutableSequence) + assert not isinstance(n.connections[0], MutableSequence) + assert not isinstance(n.connections[1], MutableSequence) + assert not isinstance(n.connections[0][1], MutableSequence) + assert not isinstance(n.connections[1][1], MutableSequence) + + +def test_connections__set_invalid_communications_type(): + connections = ( + [ + [8, [CommunicationType.SCANNED, "invalid_comm_type"]], + ], + ) + + n = Node(machine_id=1, connections=[]) + + with pytest.raises(ValueError): + n.connections = connections