diff --git a/changelogs/fragments/orderedset.yml b/changelogs/fragments/orderedset.yml new file mode 100644 index 00000000000..52dd1941d9a --- /dev/null +++ b/changelogs/fragments/orderedset.yml @@ -0,0 +1,2 @@ +minor_changes: +- Add new OrderedSet class for situations a unique ordered list is needed diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py index 4fdc874269b..8ae08ffd910 100644 --- a/lib/ansible/module_utils/common/collections.py +++ b/lib/ansible/module_utils/common/collections.py @@ -5,31 +5,41 @@ from __future__ import annotations - -from collections.abc import Hashable, Mapping, MutableMapping, Sequence # pylint: disable=unused-import +import collections.abc as _c +import typing as _t from ansible.module_utils._internal import _no_six from ansible.module_utils.common import warnings as _warnings +_KT = _t.TypeVar('_KT', bound=_c.Hashable) +_VT = _t.TypeVar('_VT') +_T = _t.TypeVar('_T', bound=_c.Hashable) -class ImmutableDict(Hashable, Mapping): + +class ImmutableDict(_c.Hashable, _c.Mapping[_KT, _VT], _t.Generic[_KT, _VT]): """Dictionary that cannot be updated""" - def __init__(self, *args, **kwargs): - self._store = dict(*args, **kwargs) + def __init__( + self, + *args: _c.Mapping[_KT, _VT] | _c.Iterable[tuple[_KT, _VT]], + **kwargs: _VT, + ) -> None: + self._store: dict[_KT, _VT] = dict(*args) if args else {} + if kwargs: + self._store.update(_t.cast(_c.Mapping[_KT, _VT], kwargs)) - def __getitem__(self, key): + def __getitem__(self, key: _KT) -> _VT: return self._store[key] - def __iter__(self): + def __iter__(self) -> _c.Iterator[_KT]: return self._store.__iter__() - def __len__(self): + def __len__(self) -> int: return self._store.__len__() - def __hash__(self): + def __hash__(self) -> int: return hash(frozenset(self.items())) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: try: if self.__hash__() == hash(other): return True @@ -38,10 +48,10 @@ class ImmutableDict(Hashable, Mapping): return False - def __repr__(self): + def __repr__(self) -> str: return 'ImmutableDict({0})'.format(repr(self._store)) - def union(self, overriding_mapping): + def union(self, overriding_mapping: _c.Mapping[_KT, _VT]) -> ImmutableDict[_KT, _VT]: """ Create an ImmutableDict as a combination of the original and overriding_mapping @@ -51,9 +61,11 @@ class ImmutableDict(Hashable, Mapping): If any of the keys in overriding_mapping are already present in the original ImmutableDict, the overriding_mapping item replaces the one in the original ImmutableDict. """ - return ImmutableDict(self._store, **overriding_mapping) + result = dict(self._store) + result.update(overriding_mapping) + return ImmutableDict(result) - def difference(self, subtractive_iterable): + def difference(self, subtractive_iterable: _c.Iterable) -> ImmutableDict[_KT, _VT]: """ Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable @@ -66,12 +78,94 @@ class ImmutableDict(Hashable, Mapping): return ImmutableDict((k, self._store[k]) for k in keys) -def is_string(seq): +class OrderedSet(_c.MutableSet[_T], _t.Generic[_T]): + def __init__( + self, + iterable: _c.Iterable[_T] | None = None, + / + ) -> None: + + self._data: dict[_T, None] = dict.fromkeys(iterable or ()) + + def __repr__(self, /) -> str: + return f'OrderedSet({list(self)!r})' + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, OrderedSet): + return NotImplemented + return tuple(self._data) == tuple(other._data) + + def __contains__(self, x: object, /) -> bool: + return x in self._data + + def __iter__(self, /) -> _c.Iterator[_T]: + return self._data.__iter__() + + def __len__(self, /) -> int: + return self._data.__len__() + + def add(self, value: _T) -> None: + self._data[value] = None + + def discard(self, value: _T) -> None: + self._data.pop(value, None) + + def clear(self) -> None: + self._data.clear() + + def copy(self) -> OrderedSet[_T]: + result: OrderedSet[_T] = OrderedSet() + result._data = self._data.copy() + return result + + def __and__(self, other: _c.Container, /) -> OrderedSet[_T]: + # overridden, because the ABC produces an arguably unexpected sorting + return OrderedSet(value for value in self if value in other) + + def __sub__(self, other: _c.Container, /) -> OrderedSet[_T]: + return OrderedSet(value for value in self if value not in other) + + def __or__(self, other: _c.Set[_T], /) -> OrderedSet[_T]: # type: ignore[override] + result = self.copy() + for value in other: + result._data[value] = None + return result + + def __xor__(self, other: _c.Set[_T], /) -> OrderedSet[_T]: # type: ignore[override] + result = self.copy() + for value in other: + if value in result._data: + del result._data[value] + else: + result._data[value] = None + return result + + def __rsub__(self, other: _c.Iterable[_T], /) -> OrderedSet[_T]: + return OrderedSet(other).__sub__(self) + + def __rxor__(self, other: _c.Iterable[_T], /) -> OrderedSet[_T]: + return OrderedSet(other).__xor__(self) + + difference = __sub__ + difference_update = _c.MutableSet.__isub__ + intersection = __and__ + __rand__ = __and__ + __ror__ = __or__ + intersection_update = _c.MutableSet.__iand__ + issubset = _c.MutableSet.__le__ + issuperset = _c.MutableSet.__ge__ + symmetric_difference = __xor__ + symmetric_difference_update = _c.MutableSet.__ixor__ + union = __or__ + update = _c.MutableSet.__ior__ + + +def is_string(seq: _c.Iterable) -> bool: """Identify whether the input has a string-like type (including bytes).""" return isinstance(seq, (str, bytes)) -def is_iterable(seq, include_strings=False): +def is_iterable(seq: _c.Iterable, include_strings: bool = False) -> bool: """Identify whether the input is an iterable.""" if not include_strings and is_string(seq): return False @@ -83,7 +177,7 @@ def is_iterable(seq, include_strings=False): return False -def is_sequence(seq, include_strings=False): +def is_sequence(seq: _c.Iterable, include_strings: bool = False) -> bool: """Identify whether the input is a sequence. Strings and bytes are not sequences here, @@ -94,10 +188,10 @@ def is_sequence(seq, include_strings=False): if not include_strings and is_string(seq): return False - return isinstance(seq, Sequence) + return isinstance(seq, _c.Sequence) -def count(seq): +def count(seq: _c.Iterable) -> dict[_c.Hashable, int]: """Returns a dictionary with the number of appearances of each element of the iterable. Resembles the collections.Counter class functionality. It is meant to be used when the @@ -111,11 +205,17 @@ def count(seq): ) if not is_iterable(seq): raise Exception('Argument provided is not an iterable') - counters = dict() + counters: dict[_c.Hashable, int] = {} for elem in seq: counters[elem] = counters.get(elem, 0) + 1 return counters +Hashable = _c.Hashable +Mapping = _c.Mapping +MutableMapping = _c.MutableMapping +Sequence = _c.Sequence + + def __getattr__(importable_name): return _no_six.deprecate(importable_name, __name__, "binary_type", "text_type") diff --git a/lib/ansible/module_utils/urls.py b/lib/ansible/module_utils/urls.py index fc7f8f27a27..bc9cebd9ad9 100644 --- a/lib/ansible/module_utils/urls.py +++ b/lib/ansible/module_utils/urls.py @@ -63,7 +63,7 @@ else: GzipFile = gzip.GzipFile # type: ignore[assignment,misc] from ansible.module_utils.basic import missing_required_lib -from ansible.module_utils.common.collections import Mapping, is_sequence +from ansible.module_utils.common.collections import Mapping, OrderedSet, is_sequence from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.compat import typing as _t @@ -513,9 +513,8 @@ def get_ca_certs(cafile=None, capath=None): # tries to find a valid CA cert in one of the # standard locations for the current distribution - # Using a dict, instead of a set for order, the value is meaningless and will be None # Not directly using a bytearray to avoid duplicates with fast lookup - cadata = {} + cadata = OrderedSet() # If cafile is passed, we are only using that for verification, # don't add additional ca certs @@ -524,7 +523,7 @@ def get_ca_certs(cafile=None, capath=None): with open(to_bytes(cafile, errors='surrogate_or_strict'), 'r', errors='surrogateescape') as f: for pem in extract_pem_certs(f.read()): b_der = ssl.PEM_cert_to_DER_cert(pem) - cadata[b_der] = None + cadata.add(b_der) return bytearray().join(cadata), paths_checked default_verify_paths = ssl.get_default_verify_paths() @@ -575,7 +574,7 @@ def get_ca_certs(cafile=None, capath=None): try: for pem in extract_pem_certs(cert): b_der = ssl.PEM_cert_to_DER_cert(pem) - cadata[b_der] = None + cadata.add(b_der) except Exception: continue except OSError: diff --git a/test/units/module_utils/common/test_collections.py b/test/units/module_utils/common/test_collections.py index 2e31b6b4787..efbc8fe5044 100644 --- a/test/units/module_utils/common/test_collections.py +++ b/test/units/module_utils/common/test_collections.py @@ -8,7 +8,7 @@ from __future__ import annotations import pytest from collections.abc import Sequence -from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence +from ansible.module_utils.common.collections import ImmutableDict, OrderedSet, is_iterable, is_sequence class SeqStub: @@ -146,3 +146,323 @@ class TestImmutableDict: actual_repr = repr(imdict) expected_repr = "ImmutableDict({0})".format(initial_data_repr) assert actual_repr == expected_repr + + +class TestOrderedSet: + def test_init_empty(self): + o = OrderedSet() + assert len(o) == 0 + assert list(o) == [] + + def test_init_with_iterable(self): + o = OrderedSet(['foo', 'bar', 'baz']) + assert list(o) == ['foo', 'bar', 'baz'] + + def test_init_deduplication(self): + o = OrderedSet([1, 2, 1, 3, 2, 4]) + assert list(o) == [1, 2, 3, 4] + + def test_repr(self): + o = OrderedSet([1, 2, 3]) + assert repr(o) == "OrderedSet([1, 2, 3])" + + def test_repr_empty(self): + o = OrderedSet() + assert repr(o) == "OrderedSet([])" + + def test_len(self): + o = OrderedSet([1, 2, 3]) + assert len(o) == 3 + + def test_len_empty(self): + o = OrderedSet() + assert len(o) == 0 + + @pytest.mark.parametrize('value,expected', [ + ('foo', True), + ('missing', False), + (1, False), + ]) + def test_contains(self, value, expected): + o = OrderedSet(['foo', 'bar', 'baz']) + assert (value in o) == expected + + def test_iter_preserves_order(self): + expected = ['foo', 'bar', 'baz'] + o = OrderedSet(expected) + assert list(o) == expected + + def test_add(self): + o = OrderedSet() + o.add('foo') + assert 'foo' in o + assert list(o) == ['foo'] + + def test_add_duplicate(self): + o = OrderedSet(['foo', 'bar']) + o.add('foo') + assert list(o) == ['foo', 'bar'] + + def test_discard_existing(self): + o = OrderedSet(['foo', 'bar', 'baz']) + o.discard('bar') + assert list(o) == ['foo', 'baz'] + + def test_discard_missing(self): + o = OrderedSet(['foo', 'bar']) + o.discard('missing') + assert list(o) == ['foo', 'bar'] + + def test_clear(self): + o = OrderedSet(['foo', 'bar', 'baz']) + o.clear() + assert len(o) == 0 + assert list(o) == [] + + def test_copy(self): + o1 = OrderedSet(['foo', 'bar', 'baz']) + o2 = o1.copy() + assert o1 == o2 + assert o1 is not o2 + + def test_copy_independence(self): + o1 = OrderedSet(['foo', 'bar']) + o2 = o1.copy() + o2.add('baz') + assert list(o1) == ['foo', 'bar'] + assert list(o2) == ['foo', 'bar', 'baz'] + + def test_eq_same_order(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2, 3]) + assert o1 == o2 + + def test_eq_different_order(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([3, 2, 1]) + assert o1 != o2 + + def test_eq_different_elements(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2, 4]) + assert o1 != o2 + + def test_eq_different_length(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2]) + assert o1 != o2 + + @pytest.mark.parametrize('other', [ + set([1, 2, 3]), + [1, 2, 3], + {1: 2, 2: 3, 3: 4}, + 'abc', + ]) + def test_eq_with_non_orderedset(self, other): + o = OrderedSet([1, 2, 3]) + assert (o == other) is False + + def test_difference(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + result = o1 - o2 + assert list(result) == ['foo', 'baz'] + + def test_difference_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + result = o1.difference(o2) + assert list(result) == ['foo', 'baz'] + + def test_difference_update(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + o1 -= o2 + assert list(o1) == ['foo', 'baz'] + + def test_difference_update_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + o1.difference_update(o2) + assert list(o1) == ['foo', 'baz'] + + def test_intersection(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + result = o1 & o2 + assert list(result) == ['bar', 'qux'] + + def test_intersection_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + result = o1.intersection(o2) + assert list(result) == ['bar', 'qux'] + + def test_intersection_update(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + o1 &= o2 + assert list(o1) == ['bar', 'qux'] + + def test_intersection_update_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham']) + o1.intersection_update(o2) + assert list(o1) == ['bar', 'qux'] + + def test_union(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + result = o1 | o2 + assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich'] + + def test_union_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + result = o1.union(o2) + assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich'] + + def test_update(self): + o1 = OrderedSet(['foo', 'bar']) + o1 |= ['baz', 'qux'] + assert list(o1) == ['foo', 'bar', 'baz', 'qux'] + + def test_update_method(self): + o1 = OrderedSet(['foo', 'bar']) + o1.update(['baz', 'qux']) + assert list(o1) == ['foo', 'bar', 'baz', 'qux'] + + def test_symmetric_difference(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + result = o1 ^ o2 + assert list(result) == ['foo', 'baz', 'ham', 'sandwich'] + + def test_symmetric_difference_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + result = o1.symmetric_difference(o2) + assert list(result) == ['foo', 'baz', 'ham', 'sandwich'] + + def test_symmetric_difference_update(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + o1 ^= o2 + assert list(o1) == ['foo', 'baz', 'ham', 'sandwich'] + + def test_symmetric_difference_update_method(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + o1.symmetric_difference_update(o2) + assert list(o1) == ['foo', 'baz', 'ham', 'sandwich'] + + def test_issubset_true(self): + o1 = OrderedSet([1, 2]) + o2 = OrderedSet([1, 2, 3]) + assert o1.issubset(o2) + assert o1 <= o2 + + def test_issubset_different_order(self): + o1 = OrderedSet([2, 1]) + o2 = OrderedSet([1, 2, 3]) + assert o1.issubset(o2) + assert o1 <= o2 + + def test_issubset_false(self): + o1 = OrderedSet([1, 2, 4]) + o2 = OrderedSet([1, 2, 3]) + assert not o1.issubset(o2) + assert not (o1 <= o2) # pylint: disable=unnecessary-negation + + def test_issubset_equal(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2, 3]) + assert o1.issubset(o2) + assert o1 <= o2 + + def test_issuperset_true(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2]) + assert o1.issuperset(o2) + assert o1 >= o2 + + def test_issuperset_different_order(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([2, 1]) + assert o1.issuperset(o2) + assert o1 >= o2 + + def test_issuperset_false(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2, 4]) + assert not o1.issuperset(o2) + assert not (o1 >= o2) # pylint: disable=unnecessary-negation + + def test_issuperset_equal(self): + o1 = OrderedSet([1, 2, 3]) + o2 = OrderedSet([1, 2, 3]) + assert o1.issuperset(o2) + assert o1 >= o2 + + def test_rand_intersection(self): + o = OrderedSet(['bar', 'qux']) + s = {'foo', 'bar', 'baz', 'qux'} + result = s & o + assert isinstance(result, OrderedSet) + assert list(result) == ['bar', 'qux'] + + def test_ror_union(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = s | o + assert isinstance(result, OrderedSet) + assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham'] + + def test_rsub_difference(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = s - o + assert isinstance(result, OrderedSet) + assert list(result) == ['ham'] + + def test_rxor_symmetric_difference(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = s ^ o + assert isinstance(result, OrderedSet) + assert set(result) == {'foo', 'baz', 'ham'} + + def test_intersection_with_regular_set(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = o & s + assert list(result) == ['bar', 'qux'] + + def test_difference_with_regular_set(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = o - s + assert list(result) == ['foo', 'baz'] + + def test_union_with_regular_set(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = o | s + assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham'] + + def test_symmetric_difference_with_regular_set(self): + o = OrderedSet(['foo', 'bar', 'baz', 'qux']) + s = {'qux', 'bar', 'ham'} + result = o ^ s + assert set(result) == {'foo', 'baz', 'ham'} + + def test_union_preserves_left_order_for_duplicates(self): + o1 = OrderedSet([1, 2, 3, 4]) + o2 = OrderedSet([3, 5, 1, 6]) + result = o1 | o2 + assert list(result) == [1, 2, 3, 4, 5, 6] + + def test_update_with_duplicates(self): + o = OrderedSet([1, 2, 3]) + o.update([3, 4, 1, 5]) + assert list(o) == [1, 2, 3, 4, 5]