Island: Use Mapping for tracking node connections

This commit is contained in:
Mike Salvatore 2022-09-07 00:38:10 -04:00
parent e7aca8326e
commit 094a0b1a8d
3 changed files with 29 additions and 44 deletions

View File

@ -11,5 +11,5 @@ from .simulation import Simulation, SimulationSchema, IslandMode
from .user_credentials import UserCredentials from .user_credentials import UserCredentials
from .machine import Machine, MachineID from .machine import Machine, MachineID
from .communication_type import CommunicationType from .communication_type import CommunicationType
from .node import Node, ConnectionTarget from .node import Node
from .agent import Agent, AgentID from .agent import Agent, AgentID

View File

@ -1,19 +1,14 @@
from typing import Collection, Sequence, Tuple, TypeAlias from typing import Mapping, Tuple, TypeAlias
from pydantic import Field, validator from pydantic import Field, validator
from common.base_models import MutableInfectionMonkeyBaseModel from common.base_models import MutableInfectionMonkeyBaseModel
from common.transforms import make_immutable_nested_sequence
from . import CommunicationType, MachineID from . import CommunicationType, MachineID
ConnectionTarget: TypeAlias = Tuple[MachineID, Sequence[CommunicationType]] NodeConnections: TypeAlias = Mapping[MachineID, Tuple[CommunicationType, ...]]
class Node(MutableInfectionMonkeyBaseModel): class Node(MutableInfectionMonkeyBaseModel):
machine_id: MachineID = Field(..., allow_mutation=False) machine_id: MachineID = Field(..., allow_mutation=False)
connections: Collection[ConnectionTarget] connections: NodeConnections
_make_immutable_nested_sequence = validator("connections", pre=True, allow_reuse=True)(
make_immutable_nested_sequence
)

View File

@ -1,16 +1,16 @@
from typing import MutableSequence from typing import MutableMapping, MutableSequence
import pytest import pytest
from monkey_island.cc.models import CommunicationType, Node from monkey_island.cc.models import CommunicationType, MachineID, Node
def test_constructor(): def test_constructor():
machine_id = 1 machine_id = 1
connections = ( connections = {
(6, (CommunicationType.SCANNED,)), 6: (CommunicationType.SCANNED,),
(7, (CommunicationType.SCANNED, CommunicationType.EXPLOITED)), 7: (CommunicationType.SCANNED, CommunicationType.EXPLOITED),
) }
n = Node( n = Node(
machine_id=1, machine_id=1,
connections=connections, connections=connections,
@ -23,13 +23,10 @@ def test_constructor():
def test_serialization(): def test_serialization():
node_dict = { node_dict = {
"machine_id": 1, "machine_id": 1,
"connections": [ "connections": {
[ "6": ["cc", "scanned"],
6, "7": ["exploited", "cc"],
["cc", "scanned"], },
],
[7, ["exploited", "cc_tunnel"]],
],
} }
n = Node(**node_dict) n = Node(**node_dict)
@ -37,7 +34,7 @@ def test_serialization():
def test_machine_id_immutable(): def test_machine_id_immutable():
n = Node(machine_id=1, connections=[]) n = Node(machine_id=1, connections={})
with pytest.raises(TypeError): with pytest.raises(TypeError):
n.machine_id = 2 n.machine_id = 2
@ -45,52 +42,45 @@ def test_machine_id_immutable():
def test_machine_id__invalid_type(): def test_machine_id__invalid_type():
with pytest.raises(TypeError): with pytest.raises(TypeError):
Node(machine_id=None, connections=[]) Node(machine_id=None, connections={})
def test_machine_id__invalid_value(): def test_machine_id__invalid_value():
with pytest.raises(ValueError): with pytest.raises(ValueError):
Node(machine_id=-5, connections=[]) Node(machine_id=-5, connections={})
def test_connections__mutable(): def test_connections__mutable():
n = Node(machine_id=1, connections=[]) n = Node(machine_id=1, connections={})
# Raises exception on failure # Raises exception on failure
n.connections = [(5, []), (7, [])] n.connections = {5: [], 7: []}
def test_connections__invalid_machine_id(): def test_connections__invalid_machine_id():
n = Node(machine_id=1, connections=[]) n = Node(machine_id=1, connections={})
with pytest.raises(ValueError): with pytest.raises(ValueError):
n.connections = [(5, []), (-5, [])] n.connections = {5: [], -5: []}
def test_connections__recursively_immutable(): def test_connections__recursively_immutable():
n = Node( n = Node(
machine_id=1, machine_id=1,
connections=[ connections={
[6, [CommunicationType.SCANNED]], 6: [CommunicationType.SCANNED],
[7, [CommunicationType.SCANNED, CommunicationType.EXPLOITED]], 7: [CommunicationType.SCANNED, CommunicationType.EXPLOITED],
], },
) )
assert not isinstance(n.connections, MutableSequence) for connections in n.connections.values():
assert not isinstance(n.connections[0], MutableSequence) assert not isinstance(connections, MutableSequence)
assert not isinstance(n.connections[1], MutableSequence)
assert not isinstance(n.connections[0][1], MutableSequence)
assert not isinstance(n.connections[1][1], MutableSequence)
def test_connections__set_invalid_communications_type(): def test_connections__set_invalid_communications_type():
connections = ( connections = {8: [CommunicationType.SCANNED, "invalid_comm_type"]}
[
[8, [CommunicationType.SCANNED, "invalid_comm_type"]],
],
)
n = Node(machine_id=1, connections=[]) n = Node(machine_id=1, connections={})
with pytest.raises(ValueError): with pytest.raises(ValueError):
n.connections = connections n.connections = connections