Add OrderedSet class (#84134)

This commit is contained in:
sivel / Matt Martz
2026-04-30 14:56:26 -05:00
committed by GitHub
parent f074c20929
commit ff6b7e404a
4 changed files with 447 additions and 26 deletions
+2
View File
@@ -0,0 +1,2 @@
minor_changes:
- Add new OrderedSet class for situations a unique ordered list is needed
+120 -20
View File
@@ -5,31 +5,41 @@
from __future__ import annotations
from collections.abc import Hashable, Mapping, MutableMapping, Sequence # pylint: disable=unused-import
import collections.abc as _c
import typing as _t
from ansible.module_utils._internal import _no_six
from ansible.module_utils.common import warnings as _warnings
_KT = _t.TypeVar('_KT', bound=_c.Hashable)
_VT = _t.TypeVar('_VT')
_T = _t.TypeVar('_T', bound=_c.Hashable)
class ImmutableDict(Hashable, Mapping):
class ImmutableDict(_c.Hashable, _c.Mapping[_KT, _VT], _t.Generic[_KT, _VT]):
"""Dictionary that cannot be updated"""
def __init__(self, *args, **kwargs):
self._store = dict(*args, **kwargs)
def __init__(
self,
*args: _c.Mapping[_KT, _VT] | _c.Iterable[tuple[_KT, _VT]],
**kwargs: _VT,
) -> None:
self._store: dict[_KT, _VT] = dict(*args) if args else {}
if kwargs:
self._store.update(_t.cast(_c.Mapping[_KT, _VT], kwargs))
def __getitem__(self, key):
def __getitem__(self, key: _KT) -> _VT:
return self._store[key]
def __iter__(self):
def __iter__(self) -> _c.Iterator[_KT]:
return self._store.__iter__()
def __len__(self):
def __len__(self) -> int:
return self._store.__len__()
def __hash__(self):
def __hash__(self) -> int:
return hash(frozenset(self.items()))
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
try:
if self.__hash__() == hash(other):
return True
@@ -38,10 +48,10 @@ class ImmutableDict(Hashable, Mapping):
return False
def __repr__(self):
def __repr__(self) -> str:
return 'ImmutableDict({0})'.format(repr(self._store))
def union(self, overriding_mapping):
def union(self, overriding_mapping: _c.Mapping[_KT, _VT]) -> ImmutableDict[_KT, _VT]:
"""
Create an ImmutableDict as a combination of the original and overriding_mapping
@@ -51,9 +61,11 @@ class ImmutableDict(Hashable, Mapping):
If any of the keys in overriding_mapping are already present in the original ImmutableDict,
the overriding_mapping item replaces the one in the original ImmutableDict.
"""
return ImmutableDict(self._store, **overriding_mapping)
result = dict(self._store)
result.update(overriding_mapping)
return ImmutableDict(result)
def difference(self, subtractive_iterable):
def difference(self, subtractive_iterable: _c.Iterable) -> ImmutableDict[_KT, _VT]:
"""
Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable
@@ -66,12 +78,94 @@ class ImmutableDict(Hashable, Mapping):
return ImmutableDict((k, self._store[k]) for k in keys)
def is_string(seq):
class OrderedSet(_c.MutableSet[_T], _t.Generic[_T]):
def __init__(
self,
iterable: _c.Iterable[_T] | None = None,
/
) -> None:
self._data: dict[_T, None] = dict.fromkeys(iterable or ())
def __repr__(self, /) -> str:
return f'OrderedSet({list(self)!r})'
def __eq__(self, other: object, /) -> bool:
if not isinstance(other, OrderedSet):
return NotImplemented
return tuple(self._data) == tuple(other._data)
def __contains__(self, x: object, /) -> bool:
return x in self._data
def __iter__(self, /) -> _c.Iterator[_T]:
return self._data.__iter__()
def __len__(self, /) -> int:
return self._data.__len__()
def add(self, value: _T) -> None:
self._data[value] = None
def discard(self, value: _T) -> None:
self._data.pop(value, None)
def clear(self) -> None:
self._data.clear()
def copy(self) -> OrderedSet[_T]:
result: OrderedSet[_T] = OrderedSet()
result._data = self._data.copy()
return result
def __and__(self, other: _c.Container, /) -> OrderedSet[_T]:
# overridden, because the ABC produces an arguably unexpected sorting
return OrderedSet(value for value in self if value in other)
def __sub__(self, other: _c.Container, /) -> OrderedSet[_T]:
return OrderedSet(value for value in self if value not in other)
def __or__(self, other: _c.Set[_T], /) -> OrderedSet[_T]: # type: ignore[override]
result = self.copy()
for value in other:
result._data[value] = None
return result
def __xor__(self, other: _c.Set[_T], /) -> OrderedSet[_T]: # type: ignore[override]
result = self.copy()
for value in other:
if value in result._data:
del result._data[value]
else:
result._data[value] = None
return result
def __rsub__(self, other: _c.Iterable[_T], /) -> OrderedSet[_T]:
return OrderedSet(other).__sub__(self)
def __rxor__(self, other: _c.Iterable[_T], /) -> OrderedSet[_T]:
return OrderedSet(other).__xor__(self)
difference = __sub__
difference_update = _c.MutableSet.__isub__
intersection = __and__
__rand__ = __and__
__ror__ = __or__
intersection_update = _c.MutableSet.__iand__
issubset = _c.MutableSet.__le__
issuperset = _c.MutableSet.__ge__
symmetric_difference = __xor__
symmetric_difference_update = _c.MutableSet.__ixor__
union = __or__
update = _c.MutableSet.__ior__
def is_string(seq: _c.Iterable) -> bool:
"""Identify whether the input has a string-like type (including bytes)."""
return isinstance(seq, (str, bytes))
def is_iterable(seq, include_strings=False):
def is_iterable(seq: _c.Iterable, include_strings: bool = False) -> bool:
"""Identify whether the input is an iterable."""
if not include_strings and is_string(seq):
return False
@@ -83,7 +177,7 @@ def is_iterable(seq, include_strings=False):
return False
def is_sequence(seq, include_strings=False):
def is_sequence(seq: _c.Iterable, include_strings: bool = False) -> bool:
"""Identify whether the input is a sequence.
Strings and bytes are not sequences here,
@@ -94,10 +188,10 @@ def is_sequence(seq, include_strings=False):
if not include_strings and is_string(seq):
return False
return isinstance(seq, Sequence)
return isinstance(seq, _c.Sequence)
def count(seq):
def count(seq: _c.Iterable) -> dict[_c.Hashable, int]:
"""Returns a dictionary with the number of appearances of each element of the iterable.
Resembles the collections.Counter class functionality. It is meant to be used when the
@@ -111,11 +205,17 @@ def count(seq):
)
if not is_iterable(seq):
raise Exception('Argument provided is not an iterable')
counters = dict()
counters: dict[_c.Hashable, int] = {}
for elem in seq:
counters[elem] = counters.get(elem, 0) + 1
return counters
Hashable = _c.Hashable
Mapping = _c.Mapping
MutableMapping = _c.MutableMapping
Sequence = _c.Sequence
def __getattr__(importable_name):
return _no_six.deprecate(importable_name, __name__, "binary_type", "text_type")
+4 -5
View File
@@ -63,7 +63,7 @@ else:
GzipFile = gzip.GzipFile # type: ignore[assignment,misc]
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.collections import Mapping, is_sequence
from ansible.module_utils.common.collections import Mapping, OrderedSet, is_sequence
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.compat import typing as _t
@@ -513,9 +513,8 @@ def get_ca_certs(cafile=None, capath=None):
# tries to find a valid CA cert in one of the
# standard locations for the current distribution
# Using a dict, instead of a set for order, the value is meaningless and will be None
# Not directly using a bytearray to avoid duplicates with fast lookup
cadata = {}
cadata = OrderedSet()
# If cafile is passed, we are only using that for verification,
# don't add additional ca certs
@@ -524,7 +523,7 @@ def get_ca_certs(cafile=None, capath=None):
with open(to_bytes(cafile, errors='surrogate_or_strict'), 'r', errors='surrogateescape') as f:
for pem in extract_pem_certs(f.read()):
b_der = ssl.PEM_cert_to_DER_cert(pem)
cadata[b_der] = None
cadata.add(b_der)
return bytearray().join(cadata), paths_checked
default_verify_paths = ssl.get_default_verify_paths()
@@ -575,7 +574,7 @@ def get_ca_certs(cafile=None, capath=None):
try:
for pem in extract_pem_certs(cert):
b_der = ssl.PEM_cert_to_DER_cert(pem)
cadata[b_der] = None
cadata.add(b_der)
except Exception:
continue
except OSError:
@@ -8,7 +8,7 @@ from __future__ import annotations
import pytest
from collections.abc import Sequence
from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence
from ansible.module_utils.common.collections import ImmutableDict, OrderedSet, is_iterable, is_sequence
class SeqStub:
@@ -146,3 +146,323 @@ class TestImmutableDict:
actual_repr = repr(imdict)
expected_repr = "ImmutableDict({0})".format(initial_data_repr)
assert actual_repr == expected_repr
class TestOrderedSet:
def test_init_empty(self):
o = OrderedSet()
assert len(o) == 0
assert list(o) == []
def test_init_with_iterable(self):
o = OrderedSet(['foo', 'bar', 'baz'])
assert list(o) == ['foo', 'bar', 'baz']
def test_init_deduplication(self):
o = OrderedSet([1, 2, 1, 3, 2, 4])
assert list(o) == [1, 2, 3, 4]
def test_repr(self):
o = OrderedSet([1, 2, 3])
assert repr(o) == "OrderedSet([1, 2, 3])"
def test_repr_empty(self):
o = OrderedSet()
assert repr(o) == "OrderedSet([])"
def test_len(self):
o = OrderedSet([1, 2, 3])
assert len(o) == 3
def test_len_empty(self):
o = OrderedSet()
assert len(o) == 0
@pytest.mark.parametrize('value,expected', [
('foo', True),
('missing', False),
(1, False),
])
def test_contains(self, value, expected):
o = OrderedSet(['foo', 'bar', 'baz'])
assert (value in o) == expected
def test_iter_preserves_order(self):
expected = ['foo', 'bar', 'baz']
o = OrderedSet(expected)
assert list(o) == expected
def test_add(self):
o = OrderedSet()
o.add('foo')
assert 'foo' in o
assert list(o) == ['foo']
def test_add_duplicate(self):
o = OrderedSet(['foo', 'bar'])
o.add('foo')
assert list(o) == ['foo', 'bar']
def test_discard_existing(self):
o = OrderedSet(['foo', 'bar', 'baz'])
o.discard('bar')
assert list(o) == ['foo', 'baz']
def test_discard_missing(self):
o = OrderedSet(['foo', 'bar'])
o.discard('missing')
assert list(o) == ['foo', 'bar']
def test_clear(self):
o = OrderedSet(['foo', 'bar', 'baz'])
o.clear()
assert len(o) == 0
assert list(o) == []
def test_copy(self):
o1 = OrderedSet(['foo', 'bar', 'baz'])
o2 = o1.copy()
assert o1 == o2
assert o1 is not o2
def test_copy_independence(self):
o1 = OrderedSet(['foo', 'bar'])
o2 = o1.copy()
o2.add('baz')
assert list(o1) == ['foo', 'bar']
assert list(o2) == ['foo', 'bar', 'baz']
def test_eq_same_order(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2, 3])
assert o1 == o2
def test_eq_different_order(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([3, 2, 1])
assert o1 != o2
def test_eq_different_elements(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2, 4])
assert o1 != o2
def test_eq_different_length(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2])
assert o1 != o2
@pytest.mark.parametrize('other', [
set([1, 2, 3]),
[1, 2, 3],
{1: 2, 2: 3, 3: 4},
'abc',
])
def test_eq_with_non_orderedset(self, other):
o = OrderedSet([1, 2, 3])
assert (o == other) is False
def test_difference(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
result = o1 - o2
assert list(result) == ['foo', 'baz']
def test_difference_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
result = o1.difference(o2)
assert list(result) == ['foo', 'baz']
def test_difference_update(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
o1 -= o2
assert list(o1) == ['foo', 'baz']
def test_difference_update_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
o1.difference_update(o2)
assert list(o1) == ['foo', 'baz']
def test_intersection(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
result = o1 & o2
assert list(result) == ['bar', 'qux']
def test_intersection_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
result = o1.intersection(o2)
assert list(result) == ['bar', 'qux']
def test_intersection_update(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
o1 &= o2
assert list(o1) == ['bar', 'qux']
def test_intersection_update_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham'])
o1.intersection_update(o2)
assert list(o1) == ['bar', 'qux']
def test_union(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
result = o1 | o2
assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich']
def test_union_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
result = o1.union(o2)
assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich']
def test_update(self):
o1 = OrderedSet(['foo', 'bar'])
o1 |= ['baz', 'qux']
assert list(o1) == ['foo', 'bar', 'baz', 'qux']
def test_update_method(self):
o1 = OrderedSet(['foo', 'bar'])
o1.update(['baz', 'qux'])
assert list(o1) == ['foo', 'bar', 'baz', 'qux']
def test_symmetric_difference(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
result = o1 ^ o2
assert list(result) == ['foo', 'baz', 'ham', 'sandwich']
def test_symmetric_difference_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
result = o1.symmetric_difference(o2)
assert list(result) == ['foo', 'baz', 'ham', 'sandwich']
def test_symmetric_difference_update(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
o1 ^= o2
assert list(o1) == ['foo', 'baz', 'ham', 'sandwich']
def test_symmetric_difference_update_method(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
o1.symmetric_difference_update(o2)
assert list(o1) == ['foo', 'baz', 'ham', 'sandwich']
def test_issubset_true(self):
o1 = OrderedSet([1, 2])
o2 = OrderedSet([1, 2, 3])
assert o1.issubset(o2)
assert o1 <= o2
def test_issubset_different_order(self):
o1 = OrderedSet([2, 1])
o2 = OrderedSet([1, 2, 3])
assert o1.issubset(o2)
assert o1 <= o2
def test_issubset_false(self):
o1 = OrderedSet([1, 2, 4])
o2 = OrderedSet([1, 2, 3])
assert not o1.issubset(o2)
assert not (o1 <= o2) # pylint: disable=unnecessary-negation
def test_issubset_equal(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2, 3])
assert o1.issubset(o2)
assert o1 <= o2
def test_issuperset_true(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2])
assert o1.issuperset(o2)
assert o1 >= o2
def test_issuperset_different_order(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([2, 1])
assert o1.issuperset(o2)
assert o1 >= o2
def test_issuperset_false(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2, 4])
assert not o1.issuperset(o2)
assert not (o1 >= o2) # pylint: disable=unnecessary-negation
def test_issuperset_equal(self):
o1 = OrderedSet([1, 2, 3])
o2 = OrderedSet([1, 2, 3])
assert o1.issuperset(o2)
assert o1 >= o2
def test_rand_intersection(self):
o = OrderedSet(['bar', 'qux'])
s = {'foo', 'bar', 'baz', 'qux'}
result = s & o
assert isinstance(result, OrderedSet)
assert list(result) == ['bar', 'qux']
def test_ror_union(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = s | o
assert isinstance(result, OrderedSet)
assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham']
def test_rsub_difference(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = s - o
assert isinstance(result, OrderedSet)
assert list(result) == ['ham']
def test_rxor_symmetric_difference(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = s ^ o
assert isinstance(result, OrderedSet)
assert set(result) == {'foo', 'baz', 'ham'}
def test_intersection_with_regular_set(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = o & s
assert list(result) == ['bar', 'qux']
def test_difference_with_regular_set(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = o - s
assert list(result) == ['foo', 'baz']
def test_union_with_regular_set(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = o | s
assert list(result) == ['foo', 'bar', 'baz', 'qux', 'ham']
def test_symmetric_difference_with_regular_set(self):
o = OrderedSet(['foo', 'bar', 'baz', 'qux'])
s = {'qux', 'bar', 'ham'}
result = o ^ s
assert set(result) == {'foo', 'baz', 'ham'}
def test_union_preserves_left_order_for_duplicates(self):
o1 = OrderedSet([1, 2, 3, 4])
o2 = OrderedSet([3, 5, 1, 6])
result = o1 | o2
assert list(result) == [1, 2, 3, 4, 5, 6]
def test_update_with_duplicates(self):
o = OrderedSet([1, 2, 3])
o.update([3, 4, 1, 5])
assert list(o) == [1, 2, 3, 4, 5]