diff --git a/monkey/monkey_island/cc/models/transforms.py b/monkey/monkey_island/cc/models/transforms.py new file mode 100644 index 000000000..c8437c038 --- /dev/null +++ b/monkey/monkey_island/cc/models/transforms.py @@ -0,0 +1,36 @@ +from typing import Any, MutableSequence, Sequence, Union + + +def make_immutable_nested_sequence(sequence_or_element: Union[Sequence, Any]) -> Sequence: + """ + Take a Sequence of Sequences (or other types) and return an immutable copy + + Takes a Sequence of Sequences, for example `List[List[int, float]]]` and returns an immutable + copy. Note that if the Sequence does not contain other sequences, `make_sequence_immutable()` + will be more performant. + + :param sequence_or_element: A nested sequence or an element from within a nested sequence + :return: An immutable copy of the sequence if `sequence_or_element` is a Sequence, otherwise + just return `sequence_or_element` + """ + if isinstance(sequence_or_element, str): + return sequence_or_element + + if isinstance(sequence_or_element, Sequence): + return tuple(map(make_immutable_nested_sequence, sequence_or_element)) + + return sequence_or_element + + +def make_immutable_sequence(sequence: Sequence): + """ + Take a Sequence and return an immutable copy + + :param sequence: A Sequence to create an immutable copy from + :return: An immutable copy of `sequence` + """ + + if isinstance(sequence, MutableSequence): + return tuple(sequence) + + return sequence diff --git a/monkey/tests/unit_tests/monkey_island/cc/models/test_transforms.py b/monkey/tests/unit_tests/monkey_island/cc/models/test_transforms.py new file mode 100644 index 000000000..9d8f125e2 --- /dev/null +++ b/monkey/tests/unit_tests/monkey_island/cc/models/test_transforms.py @@ -0,0 +1,53 @@ +from itertools import zip_longest +import pytest +from typing import MutableSequence, Sequence + +from monkey_island.cc.models.transforms import ( + make_immutable_nested_sequence, + make_immutable_sequence, +) + + +def test_make_immutable_sequence__list(): + mutable_sequence = [1, 2, 3] + immutable_sequence = make_immutable_sequence(mutable_sequence) + + assert isinstance(immutable_sequence, Sequence) + assert not isinstance(immutable_sequence, MutableSequence) + assert_sequences_equal(mutable_sequence, immutable_sequence) + + +@pytest.mark.parametrize( + "mutable_sequence", [ + [1, 2, 3], + [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3, [4, 5, 6]], [4, 5, 6]], + [8, [5.3, "invalid_comm_type"]]] +) +def test_make_immutable_nested_sequence(mutable_sequence): + immutable_sequence = make_immutable_nested_sequence(mutable_sequence) + + assert isinstance(immutable_sequence, Sequence) + assert not isinstance(immutable_sequence, MutableSequence) + assert_sequences_equal(mutable_sequence, immutable_sequence) + + +def assert_sequence_immutable_recursive(sequence: Sequence): + assert not isinstance(sequence, MutableSequence) + + for s in sequence: + if isinstance(s, str): + continue + + if isinstance(s, Sequence): + assert_sequence_immutable_recursive(s) + assert not isinstance(s, MutableSequence) + + +def assert_sequences_equal(a: Sequence, b: Sequence): + assert len(a) == len(b) + for i, j in zip_longest(a, b): + if isinstance(i, str) or not isinstance(i, Sequence): + assert i == j + else: + assert_sequences_equal(i, j)