xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/unification_tools.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import operator
4from functools import reduce
5from collections.abc import Mapping
6
7__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
8           'valfilter', 'keyfilter', 'itemfilter',
9           'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
10
11
12def _get_factory(f, kwargs):
13    factory = kwargs.pop('factory', dict)
14    if kwargs:
15        raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
16    return factory
17
18
19def merge(*dicts, **kwargs):
20    """ Merge a collection of dictionaries
21
22    >>> merge({1: 'one'}, {2: 'two'})
23    {1: 'one', 2: 'two'}
24
25    Later dictionaries have precedence
26
27    >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
28    {1: 2, 3: 3, 4: 4}
29
30    See Also:
31        merge_with
32    """
33    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
34        dicts = dicts[0]
35    factory = _get_factory(merge, kwargs)
36
37    rv = factory()
38    for d in dicts:
39        rv.update(d)
40    return rv
41
42
43def merge_with(func, *dicts, **kwargs):
44    """ Merge dictionaries and apply function to combined values
45
46    A key may occur in more than one dict, and all values mapped from the key
47    will be passed to the function as a list, such as func([val1, val2, ...]).
48
49    >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
50    {1: 11, 2: 22}
51
52    >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30})  # doctest: +SKIP
53    {1: 1, 2: 2, 3: 30}
54
55    See Also:
56        merge
57    """
58    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
59        dicts = dicts[0]
60    factory = _get_factory(merge_with, kwargs)
61
62    result = factory()
63    for d in dicts:
64        for k, v in d.items():
65            if k not in result:
66                result[k] = [v]
67            else:
68                result[k].append(v)
69    return valmap(func, result, factory)
70
71
72def valmap(func, d, factory=dict):
73    """ Apply function to values of dictionary
74
75    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
76    >>> valmap(sum, bills)  # doctest: +SKIP
77    {'Alice': 65, 'Bob': 45}
78
79    See Also:
80        keymap
81        itemmap
82    """
83    rv = factory()
84    rv.update(zip(d.keys(), map(func, d.values())))
85    return rv
86
87
88def keymap(func, d, factory=dict):
89    """ Apply function to keys of dictionary
90
91    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
92    >>> keymap(str.lower, bills)  # doctest: +SKIP
93    {'alice': [20, 15, 30], 'bob': [10, 35]}
94
95    See Also:
96        valmap
97        itemmap
98    """
99    rv = factory()
100    rv.update(zip(map(func, d.keys()), d.values()))
101    return rv
102
103
104def itemmap(func, d, factory=dict):
105    """ Apply function to items of dictionary
106
107    >>> accountids = {"Alice": 10, "Bob": 20}
108    >>> itemmap(reversed, accountids)  # doctest: +SKIP
109    {10: "Alice", 20: "Bob"}
110
111    See Also:
112        keymap
113        valmap
114    """
115    rv = factory()
116    rv.update(map(func, d.items()))
117    return rv
118
119
120def valfilter(predicate, d, factory=dict):
121    """ Filter items in dictionary by value
122
123    >>> iseven = lambda x: x % 2 == 0
124    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
125    >>> valfilter(iseven, d)
126    {1: 2, 3: 4}
127
128    See Also:
129        keyfilter
130        itemfilter
131        valmap
132    """
133    rv = factory()
134    for k, v in d.items():
135        if predicate(v):
136            rv[k] = v
137    return rv
138
139
140def keyfilter(predicate, d, factory=dict):
141    """ Filter items in dictionary by key
142
143    >>> iseven = lambda x: x % 2 == 0
144    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
145    >>> keyfilter(iseven, d)
146    {2: 3, 4: 5}
147
148    See Also:
149        valfilter
150        itemfilter
151        keymap
152    """
153    rv = factory()
154    for k, v in d.items():
155        if predicate(k):
156            rv[k] = v
157    return rv
158
159
160def itemfilter(predicate, d, factory=dict):
161    """ Filter items in dictionary by item
162
163    >>> def isvalid(item):
164    ...     k, v = item
165    ...     return k % 2 == 0 and v < 4
166
167    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
168    >>> itemfilter(isvalid, d)
169    {2: 3}
170
171    See Also:
172        keyfilter
173        valfilter
174        itemmap
175    """
176    rv = factory()
177    for item in d.items():
178        if predicate(item):
179            k, v = item
180            rv[k] = v
181    return rv
182
183
184def assoc(d, key, value, factory=dict):
185    """ Return a new dict with new key value pair
186
187    New dict has d[key] set to value. Does not modify the initial dictionary.
188
189    >>> assoc({'x': 1}, 'x', 2)
190    {'x': 2}
191    >>> assoc({'x': 1}, 'y', 3)   # doctest: +SKIP
192    {'x': 1, 'y': 3}
193    """
194    d2 = factory()
195    d2.update(d)
196    d2[key] = value
197    return d2
198
199
200def dissoc(d, *keys, **kwargs):
201    """ Return a new dict with the given key(s) removed.
202
203    New dict has d[key] deleted for each supplied key.
204    Does not modify the initial dictionary.
205
206    >>> dissoc({'x': 1, 'y': 2}, 'y')
207    {'x': 1}
208    >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
209    {}
210    >>> dissoc({'x': 1}, 'y') # Ignores missing keys
211    {'x': 1}
212    """
213    factory = _get_factory(dissoc, kwargs)
214    d2 = factory()
215
216    if len(keys) < len(d) * .6:
217        d2.update(d)
218        for key in keys:
219            if key in d2:
220                del d2[key]
221    else:
222        remaining = set(d)
223        remaining.difference_update(keys)
224        for k in remaining:
225            d2[k] = d[k]
226    return d2
227
228
229def assoc_in(d, keys, value, factory=dict):
230    """ Return a new dict with new, potentially nested, key value pair
231
232    >>> purchase = {'name': 'Alice',
233    ...             'order': {'items': ['Apple', 'Orange'],
234    ...                       'costs': [0.50, 1.25]},
235    ...             'credit card': '5555-1234-1234-1234'}
236    >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
237    {'credit card': '5555-1234-1234-1234',
238     'name': 'Alice',
239     'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
240    """
241    return update_in(d, keys, lambda x: value, value, factory)
242
243
244def update_in(d, keys, func, default=None, factory=dict):
245    """ Update value in a (potentially) nested dictionary
246
247    inputs:
248    d - dictionary on which to operate
249    keys - list or tuple giving the location of the value to be changed in d
250    func - function to operate on that value
251
252    If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
253    original dictionary with v replaced by func(v), but does not mutate the
254    original dictionary.
255
256    If k0 is not a key in d, update_in creates nested dictionaries to the depth
257    specified by the keys, with the innermost value set to func(default).
258
259    >>> inc = lambda x: x + 1
260    >>> update_in({'a': 0}, ['a'], inc)
261    {'a': 1}
262
263    >>> transaction = {'name': 'Alice',
264    ...                'purchase': {'items': ['Apple', 'Orange'],
265    ...                             'costs': [0.50, 1.25]},
266    ...                'credit card': '5555-1234-1234-1234'}
267    >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
268    {'credit card': '5555-1234-1234-1234',
269     'name': 'Alice',
270     'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
271
272    >>> # updating a value when k0 is not in d
273    >>> update_in({}, [1, 2, 3], str, default="bar")
274    {1: {2: {3: 'bar'}}}
275    >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
276    {1: 'foo', 2: {3: {4: 1}}}
277    """
278    ks = iter(keys)
279    k = next(ks)
280
281    rv = inner = factory()
282    rv.update(d)
283
284    for key in ks:
285        if k in d:
286            d = d[k]
287            dtemp = factory()
288            dtemp.update(d)
289        else:
290            d = dtemp = factory()
291
292        inner[k] = inner = dtemp
293        k = key
294
295    if k in d:
296        inner[k] = func(d[k])
297    else:
298        inner[k] = func(default)
299    return rv
300
301
302def get_in(keys, coll, default=None, no_default=False):
303    """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
304
305    If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
306    ``no_default`` is specified, then it raises KeyError or IndexError.
307
308    ``get_in`` is a generalization of ``operator.getitem`` for nested data
309    structures such as dictionaries and lists.
310
311    >>> transaction = {'name': 'Alice',
312    ...                'purchase': {'items': ['Apple', 'Orange'],
313    ...                             'costs': [0.50, 1.25]},
314    ...                'credit card': '5555-1234-1234-1234'}
315    >>> get_in(['purchase', 'items', 0], transaction)
316    'Apple'
317    >>> get_in(['name'], transaction)
318    'Alice'
319    >>> get_in(['purchase', 'total'], transaction)
320    >>> get_in(['purchase', 'items', 'apple'], transaction)
321    >>> get_in(['purchase', 'items', 10], transaction)
322    >>> get_in(['purchase', 'total'], transaction, 0)
323    0
324    >>> get_in(['y'], {}, no_default=True)
325    Traceback (most recent call last):
326        ...
327    KeyError: 'y'
328
329    See Also:
330        itertoolz.get
331        operator.getitem
332    """
333    try:
334        return reduce(operator.getitem, keys, coll)
335    except (KeyError, IndexError, TypeError):
336        if no_default:
337            raise
338        return default
339
340
341def getter(index):
342    if isinstance(index, list):
343        if len(index) == 1:
344            index = index[0]
345            return lambda x: (x[index],)
346        elif index:
347            return operator.itemgetter(*index)
348        else:
349            return lambda x: ()
350    else:
351        return operator.itemgetter(index)
352
353
354def groupby(key, seq):
355    """ Group a collection by a key function
356
357    >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
358    >>> groupby(len, names)  # doctest: +SKIP
359    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
360
361    >>> iseven = lambda x: x % 2 == 0
362    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
363    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
364
365    Non-callable keys imply grouping on a member.
366
367    >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
368    ...                    {'name': 'Bob', 'gender': 'M'},
369    ...                    {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
370    {'F': [{'gender': 'F', 'name': 'Alice'}],
371     'M': [{'gender': 'M', 'name': 'Bob'},
372           {'gender': 'M', 'name': 'Charlie'}]}
373
374    Not to be confused with ``itertools.groupby``
375
376    See Also:
377        countby
378    """
379    if not callable(key):
380        key = getter(key)
381    d = collections.defaultdict(lambda: [].append)  # type: ignore[var-annotated]
382    for item in seq:
383        d[key(item)](item)
384    rv = {}
385    for k, v in d.items():
386        rv[k] = v.__self__  # type: ignore[var-annotated, attr-defined]
387    return rv
388
389
390def first(seq):
391    """ The first element in a sequence
392
393    >>> first('ABC')
394    'A'
395    """
396    return next(iter(seq))
397