From 76b51d25b9631602600bc30375c5100ce9ffd99b Mon Sep 17 00:00:00 2001 From: Mike Salvatore Date: Tue, 13 Sep 2022 20:15:04 -0400 Subject: [PATCH] Island: Use frozenset instead of tuple for communication types --- monkey/monkey_island/cc/models/node.py | 4 ++-- .../monkey_island/cc/models/test_node.py | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/monkey/monkey_island/cc/models/node.py b/monkey/monkey_island/cc/models/node.py index 4b3ab5608..715e52bb3 100644 --- a/monkey/monkey_island/cc/models/node.py +++ b/monkey/monkey_island/cc/models/node.py @@ -1,4 +1,4 @@ -from typing import Mapping, Tuple +from typing import FrozenSet, Mapping from pydantic import Field from typing_extensions import TypeAlias @@ -7,7 +7,7 @@ from common.base_models import MutableInfectionMonkeyBaseModel from . import CommunicationType, MachineID -NodeConnections: TypeAlias = Mapping[MachineID, Tuple[CommunicationType, ...]] +NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]] class Node(MutableInfectionMonkeyBaseModel): diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py index e70df6975..603dee2e4 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_node.py @@ -8,8 +8,8 @@ from monkey_island.cc.models import CommunicationType, Node def test_constructor(): machine_id = 1 connections = { - 6: (CommunicationType.SCANNED,), - 7: (CommunicationType.SCANNED, CommunicationType.EXPLOITED), + 6: frozenset((CommunicationType.SCANNED,)), + 7: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)), } n = Node( machine_id=1, @@ -24,13 +24,27 @@ def test_serialization(): node_dict = { "machine_id": 1, "connections": { - "6": ["cc", "scanned"], - "7": ["exploited", "cc"], + "6": [CommunicationType.CC.value, CommunicationType.SCANNED.value], + "7": [CommunicationType.EXPLOITED.value, CommunicationType.CC.value], }, } + # "6": frozenset((CommunicationType.CC, CommunicationType.SCANNED)), + # "7": frozenset((CommunicationType.EXPLOITED, CommunicationType.CC)), n = Node(**node_dict) - assert n.dict(simplify=True) == node_dict + serialized_node = n.dict(simplify=True) + + # NOTE: Comparing these nodes is difficult because sets are not ordered + assert len(serialized_node) == len(node_dict) + for key in serialized_node.keys(): + assert key in node_dict + + assert len(serialized_node["connections"]) == len(node_dict["connections"]) + + for key, value in serialized_node["connections"].items(): + assert len(value) == len(node_dict["connections"][key]) + for comm_type in value: + assert comm_type in node_dict["connections"][key] def test_machine_id_immutable():