xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/_weakrefset.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1# Access WeakSet through the weakref module.
2# This code is separated-out because it is needed
3# by abc.py to load everything else at startup.
4
5from _weakref import ref
6from types import GenericAlias
7
8__all__ = ['WeakSet']
9
10
11class _IterationGuard:
12    # This context manager registers itself in the current iterators of the
13    # weak container, such as to delay all removals until the context manager
14    # exits.
15    # This technique should be relatively thread-safe (since sets are).
16
17    def __init__(self, weakcontainer):
18        # Don't create cycles
19        self.weakcontainer = ref(weakcontainer)
20
21    def __enter__(self):
22        w = self.weakcontainer()
23        if w is not None:
24            w._iterating.add(self)
25        return self
26
27    def __exit__(self, e, t, b):
28        w = self.weakcontainer()
29        if w is not None:
30            s = w._iterating
31            s.remove(self)
32            if not s:
33                w._commit_removals()
34
35
36class WeakSet:
37    def __init__(self, data=None):
38        self.data = set()
39        def _remove(item, selfref=ref(self)):
40            self = selfref()
41            if self is not None:
42                if self._iterating:
43                    self._pending_removals.append(item)
44                else:
45                    self.data.discard(item)
46        self._remove = _remove
47        # A list of keys to be removed
48        self._pending_removals = []
49        self._iterating = set()
50        if data is not None:
51            self.update(data)
52
53    def _commit_removals(self):
54        pop = self._pending_removals.pop
55        discard = self.data.discard
56        while True:
57            try:
58                item = pop()
59            except IndexError:
60                return
61            discard(item)
62
63    def __iter__(self):
64        with _IterationGuard(self):
65            for itemref in self.data:
66                item = itemref()
67                if item is not None:
68                    # Caveat: the iterator will keep a strong reference to
69                    # `item` until it is resumed or closed.
70                    yield item
71
72    def __len__(self):
73        return len(self.data) - len(self._pending_removals)
74
75    def __contains__(self, item):
76        try:
77            wr = ref(item)
78        except TypeError:
79            return False
80        return wr in self.data
81
82    def __reduce__(self):
83        return self.__class__, (list(self),), self.__getstate__()
84
85    def add(self, item):
86        if self._pending_removals:
87            self._commit_removals()
88        self.data.add(ref(item, self._remove))
89
90    def clear(self):
91        if self._pending_removals:
92            self._commit_removals()
93        self.data.clear()
94
95    def copy(self):
96        return self.__class__(self)
97
98    def pop(self):
99        if self._pending_removals:
100            self._commit_removals()
101        while True:
102            try:
103                itemref = self.data.pop()
104            except KeyError:
105                raise KeyError('pop from empty WeakSet') from None
106            item = itemref()
107            if item is not None:
108                return item
109
110    def remove(self, item):
111        if self._pending_removals:
112            self._commit_removals()
113        self.data.remove(ref(item))
114
115    def discard(self, item):
116        if self._pending_removals:
117            self._commit_removals()
118        self.data.discard(ref(item))
119
120    def update(self, other):
121        if self._pending_removals:
122            self._commit_removals()
123        for element in other:
124            self.add(element)
125
126    def __ior__(self, other):
127        self.update(other)
128        return self
129
130    def difference(self, other):
131        newset = self.copy()
132        newset.difference_update(other)
133        return newset
134    __sub__ = difference
135
136    def difference_update(self, other):
137        self.__isub__(other)
138    def __isub__(self, other):
139        if self._pending_removals:
140            self._commit_removals()
141        if self is other:
142            self.data.clear()
143        else:
144            self.data.difference_update(ref(item) for item in other)
145        return self
146
147    def intersection(self, other):
148        return self.__class__(item for item in other if item in self)
149    __and__ = intersection
150
151    def intersection_update(self, other):
152        self.__iand__(other)
153    def __iand__(self, other):
154        if self._pending_removals:
155            self._commit_removals()
156        self.data.intersection_update(ref(item) for item in other)
157        return self
158
159    def issubset(self, other):
160        return self.data.issubset(ref(item) for item in other)
161    __le__ = issubset
162
163    def __lt__(self, other):
164        return self.data < set(map(ref, other))
165
166    def issuperset(self, other):
167        return self.data.issuperset(ref(item) for item in other)
168    __ge__ = issuperset
169
170    def __gt__(self, other):
171        return self.data > set(map(ref, other))
172
173    def __eq__(self, other):
174        if not isinstance(other, self.__class__):
175            return NotImplemented
176        return self.data == set(map(ref, other))
177
178    def symmetric_difference(self, other):
179        newset = self.copy()
180        newset.symmetric_difference_update(other)
181        return newset
182    __xor__ = symmetric_difference
183
184    def symmetric_difference_update(self, other):
185        self.__ixor__(other)
186    def __ixor__(self, other):
187        if self._pending_removals:
188            self._commit_removals()
189        if self is other:
190            self.data.clear()
191        else:
192            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
193        return self
194
195    def union(self, other):
196        return self.__class__(e for s in (self, other) for e in s)
197    __or__ = union
198
199    def isdisjoint(self, other):
200        return len(self.intersection(other)) == 0
201
202    def __repr__(self):
203        return repr(self.data)
204
205    __class_getitem__ = classmethod(GenericAlias)
206