Island: Add transform functions to make immutable copies of sequences

This commit is contained in:
Mike Salvatore 2022-08-18 10:33:45 -04:00
parent 3fd7051869
commit b3bfc598a3
2 changed files with 89 additions and 0 deletions

View File

@ -0,0 +1,36 @@
from typing import Any, MutableSequence, Sequence, Union
def make_immutable_nested_sequence(sequence_or_element: Union[Sequence, Any]) -> Sequence:
"""
Take a Sequence of Sequences (or other types) and return an immutable copy
Takes a Sequence of Sequences, for example `List[List[int, float]]]` and returns an immutable
copy. Note that if the Sequence does not contain other sequences, `make_sequence_immutable()`
will be more performant.
:param sequence_or_element: A nested sequence or an element from within a nested sequence
:return: An immutable copy of the sequence if `sequence_or_element` is a Sequence, otherwise
just return `sequence_or_element`
"""
if isinstance(sequence_or_element, str):
return sequence_or_element
if isinstance(sequence_or_element, Sequence):
return tuple(map(make_immutable_nested_sequence, sequence_or_element))
return sequence_or_element
def make_immutable_sequence(sequence: Sequence):
"""
Take a Sequence and return an immutable copy
:param sequence: A Sequence to create an immutable copy from
:return: An immutable copy of `sequence`
"""
if isinstance(sequence, MutableSequence):
return tuple(sequence)
return sequence

View File

@ -0,0 +1,53 @@
from itertools import zip_longest
import pytest
from typing import MutableSequence, Sequence
from monkey_island.cc.models.transforms import (
make_immutable_nested_sequence,
make_immutable_sequence,
)
def test_make_immutable_sequence__list():
mutable_sequence = [1, 2, 3]
immutable_sequence = make_immutable_sequence(mutable_sequence)
assert isinstance(immutable_sequence, Sequence)
assert not isinstance(immutable_sequence, MutableSequence)
assert_sequences_equal(mutable_sequence, immutable_sequence)
@pytest.mark.parametrize(
"mutable_sequence", [
[1, 2, 3],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3, [4, 5, 6]], [4, 5, 6]],
[8, [5.3, "invalid_comm_type"]]]
)
def test_make_immutable_nested_sequence(mutable_sequence):
immutable_sequence = make_immutable_nested_sequence(mutable_sequence)
assert isinstance(immutable_sequence, Sequence)
assert not isinstance(immutable_sequence, MutableSequence)
assert_sequences_equal(mutable_sequence, immutable_sequence)
def assert_sequence_immutable_recursive(sequence: Sequence):
assert not isinstance(sequence, MutableSequence)
for s in sequence:
if isinstance(s, str):
continue
if isinstance(s, Sequence):
assert_sequence_immutable_recursive(s)
assert not isinstance(s, MutableSequence)
def assert_sequences_equal(a: Sequence, b: Sequence):
assert len(a) == len(b)
for i, j in zip_longest(a, b):
if isinstance(i, str) or not isinstance(i, Sequence):
assert i == j
else:
assert_sequences_equal(i, j)