xref: /aosp_15_r20/external/pytorch/torch/_dynamo/polyfills/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Python polyfills for common builtins.
3"""
4
5# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
6#       2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
7#          Add it in the TYPE_CHECKING block below as well.
8
9# mypy: allow-untyped-defs
10
11from typing import Any, Callable, Sequence, TYPE_CHECKING
12
13import torch
14
15
16if TYPE_CHECKING:
17    # Load by torch._dynamo.polyfills.loader
18    # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
19    # Put the submodules here to avoid circular imports
20    from . import (
21        builtins as builtins,
22        functools as functools,
23        itertools as itertools,
24        os as os,
25        sys as sys,
26    )
27
28
29def index(iterator, item, start=0, end=None):
30    from itertools import islice
31
32    for i, elem in islice(enumerate(iterator), start, end):
33        if item == elem:
34            return i
35    # This will not run in dynamo
36    raise ValueError(f"{item} is not in {type(iterator)}")
37
38
39def repeat(item, count):
40    for i in range(count):
41        yield item
42
43
44def radians(x):
45    import math
46
47    return math.pi / 180.0 * x
48
49
50def accumulate_grad(x, new_grad):
51    new_grad = torch.clone(new_grad)
52    if x.grad is None:
53        x.grad = new_grad
54    else:
55        x.grad.add_(new_grad)
56
57
58def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
59    """emulate `(1,2,3) > (1,2)` etc"""
60    for a, b in zip(left, right):
61        if a != b:
62            return op(a, b)
63    return op(len(left), len(right))
64
65
66def set_isdisjoint(set1, set2):
67    for x in set1:
68        if x in set2:
69            return False
70    return True
71
72
73def set_intersection(set1, set2):
74    intersection_set = set()
75    for x in set1:
76        if x in set2:
77            intersection_set.add(x)
78    return intersection_set
79
80
81def set_union(set1, set2):
82    union_set = set1.copy()
83    for x in set2:
84        if x not in union_set:
85            union_set.add(x)
86    return union_set
87
88
89def set_difference(set1, set2):
90    difference_set = set()
91    for x in set1:
92        if x not in set2:
93            difference_set.add(x)
94    return difference_set
95
96
97def dropwhile(predicate, iterable):
98    # dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
99    iterable = iter(iterable)
100    for x in iterable:
101        if not predicate(x):
102            yield x
103            break
104    yield from iterable
105
106
107def zip_longest(*iterables, fillvalue=None):
108    # Create a list of iterators from the input iterables
109    iterators = [iter(it) for it in iterables]
110    result = []
111    while True:
112        row = []
113        active = False
114        for it in iterators:
115            try:
116                # Try to get the next item from the iterator
117                value = next(it)
118                row.append(value)
119                active = True
120            except StopIteration:
121                # If the iterator is exhausted, use the fillvalue
122                row.append(fillvalue)
123        if not active:
124            break
125        result.append(tuple(row))
126    return result
127
128
129def getattr_and_trace(*args, **kwargs):
130    wrapper_obj = args[0]
131    attr_name = args[1]
132    fn = getattr(wrapper_obj, attr_name)
133    return fn(*args[2:], **kwargs)
134
135
136def mapping_get(obj, key, value=None):
137    try:
138        return obj.__getitem__(key)
139    except KeyError:
140        return value
141
142
143def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
144    obj = cls.__new__(cls, *args, **kwargs)
145
146    # Only call __init__ if the object is an instance of the class
147    # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673
148    if isinstance(obj, cls):
149        obj.__init__(*args, **kwargs)
150    return obj
151
152
153def foreach_lerp_inplace(self, end, weight):
154    # decompose foreach lerp into constituent ops, prevents a graph break due to
155    # converting a value to a scalar when arg[2] is a single tensor
156    result = torch._foreach_sub(end, self)
157    result = torch._foreach_mul(result, weight)
158    return torch._foreach_add_(self, result)
159
160
161def foreach_pow_scalar(scalar, exps):
162    return torch._foreach_pow([scalar for _ in exps], exps)
163
164
165def addcmul_inplace(self, tensor1, tensor2, value):
166    return self.add_(tensor1 * tensor2 * value)
167