Island: Use frozenset instead of tuple for communication types

This commit is contained in:
Mike Salvatore 2022-09-13 20:15:04 -04:00
parent 6cc8948ebf
commit 76b51d25b9
2 changed files with 21 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from typing import Mapping, Tuple from typing import FrozenSet, Mapping
from pydantic import Field from pydantic import Field
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -7,7 +7,7 @@ from common.base_models import MutableInfectionMonkeyBaseModel
from . import CommunicationType, MachineID from . import CommunicationType, MachineID
NodeConnections: TypeAlias = Mapping[MachineID, Tuple[CommunicationType, ...]] NodeConnections: TypeAlias = Mapping[MachineID, FrozenSet[CommunicationType]]
class Node(MutableInfectionMonkeyBaseModel): class Node(MutableInfectionMonkeyBaseModel):

View File

@ -8,8 +8,8 @@ from monkey_island.cc.models import CommunicationType, Node
def test_constructor(): def test_constructor():
machine_id = 1 machine_id = 1
connections = { connections = {
6: (CommunicationType.SCANNED,), 6: frozenset((CommunicationType.SCANNED,)),
7: (CommunicationType.SCANNED, CommunicationType.EXPLOITED), 7: frozenset((CommunicationType.SCANNED, CommunicationType.EXPLOITED)),
} }
n = Node( n = Node(
machine_id=1, machine_id=1,
@ -24,13 +24,27 @@ def test_serialization():
node_dict = { node_dict = {
"machine_id": 1, "machine_id": 1,
"connections": { "connections": {
"6": ["cc", "scanned"], "6": [CommunicationType.CC.value, CommunicationType.SCANNED.value],
"7": ["exploited", "cc"], "7": [CommunicationType.EXPLOITED.value, CommunicationType.CC.value],
}, },
} }
# "6": frozenset((CommunicationType.CC, CommunicationType.SCANNED)),
# "7": frozenset((CommunicationType.EXPLOITED, CommunicationType.CC)),
n = Node(**node_dict) n = Node(**node_dict)
assert n.dict(simplify=True) == node_dict serialized_node = n.dict(simplify=True)
# NOTE: Comparing these nodes is difficult because sets are not ordered
assert len(serialized_node) == len(node_dict)
for key in serialized_node.keys():
assert key in node_dict
assert len(serialized_node["connections"]) == len(node_dict["connections"])
for key, value in serialized_node["connections"].items():
assert len(value) == len(node_dict["connections"][key])
for comm_type in value:
assert comm_type in node_dict["connections"][key]
def test_machine_id_immutable(): def test_machine_id_immutable():