xref: /aosp_15_r20/external/pytorch/torch/jit/_builtins.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import cmath
3import math
4import warnings
5from collections import OrderedDict
6from typing import Dict, Optional
7
8import torch
9import torch.backends.cudnn as cudnn
10from torch.nn.modules.utils import (
11    _list_with_default,
12    _pair,
13    _quadruple,
14    _single,
15    _triple,
16)
17
18
19_builtin_table: Optional[Dict[int, str]] = None
20
21_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special)  # type: ignore[attr-defined] # noqa: B950
22
23_builtin_ops = [
24    # Pairs of (function, op_name)
25    (_pair, "aten::_pair"),
26    (_quadruple, "aten::_quadruple"),
27    (_single, "aten::_single"),
28    (_triple, "aten::_triple"),
29    (_list_with_default, "aten::list_with_default"),
30    (OrderedDict, "aten::dict"),
31    (dict, "aten::dict"),
32    (cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
33    (math.ceil, "aten::ceil"),
34    (math.copysign, "aten::copysign"),
35    (math.erf, "aten::erf"),
36    (math.erfc, "aten::erfc"),
37    (math.exp, "aten::exp"),
38    (math.expm1, "aten::expm1"),
39    (math.fabs, "aten::fabs"),
40    (math.floor, "aten::floor"),
41    (math.gamma, "aten::gamma"),
42    (math.lgamma, "aten::lgamma"),
43    (math.log, "aten::log"),
44    (math.log10, "aten::log10"),
45    (math.log1p, "aten::log1p"),
46    (math.pow, "aten::pow"),
47    (math.sqrt, "aten::sqrt"),
48    (math.isnan, "aten::isnan"),
49    (math.asinh, "aten::asinh"),
50    (math.atanh, "aten::atanh"),
51    (math.cosh, "aten::cosh"),
52    (math.sinh, "aten::sinh"),
53    (math.tanh, "aten::tanh"),
54    (math.acos, "aten::acos"),
55    (math.asin, "aten::asin"),
56    (math.atan, "aten::atan"),
57    (math.atan2, "aten::atan2"),
58    (math.cos, "aten::cos"),
59    (math.sin, "aten::sin"),
60    (math.tan, "aten::tan"),
61    (math.asinh, "aten::asinh"),
62    (math.atanh, "aten::atanh"),
63    (math.acosh, "aten::acosh"),
64    (math.fmod, "aten::fmod"),
65    (math.modf, "aten::modf"),
66    (math.factorial, "aten::factorial"),
67    (math.frexp, "aten::frexp"),
68    (math.isinf, "aten::isinf"),
69    (math.degrees, "aten::degrees"),
70    (math.radians, "aten::radians"),
71    (cmath.isnan, "aten::isnan"),
72    (cmath.isfinite, "aten::isfinite"),
73    (cmath.isinf, "aten::isinf"),
74    (cmath.phase, "aten::angle"),
75    (cmath.rect, "aten::polar"),
76    (cmath.log, "aten::log"),
77    (cmath.log10, "aten::log10"),
78    (cmath.sqrt, "aten::sqrt"),
79    (cmath.exp, "aten::exp"),
80    (cmath.sin, "aten::sin"),
81    (cmath.tan, "aten::tan"),
82    (cmath.cos, "aten::cos"),
83    (cmath.asin, "aten::asin"),
84    (cmath.acos, "aten::acos"),
85    (cmath.atan, "aten::atan"),
86    (cmath.sinh, "aten::sinh"),
87    (cmath.cosh, "aten::cosh"),
88    (cmath.tanh, "aten::tanh"),
89    (cmath.asinh, "aten::asinh"),
90    (cmath.acosh, "aten::acosh"),
91    (cmath.atanh, "aten::atanh"),
92    (math.ldexp, "aten::ldexp"),
93    (torch._assert, "aten::_assert"),
94    (torch.autograd.grad, "aten::grad"),
95    (torch.autograd.backward, "aten::backward"),
96    (torch._C._infer_size, "aten::_infer_size"),
97    (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"),  # type: ignore[attr-defined]
98    (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
99    (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
100    (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
101    (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
102    (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
103    (torch._C._get_tracing_state, "aten::_get_tracing_state"),
104    (torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
105    (warnings.warn, "aten::warn"),
106    (torch._VF.stft, "aten::stft"),  # type: ignore[attr-defined]
107    (torch._VF.istft, "aten::istft"),  # type: ignore[attr-defined]
108    (torch._VF.cdist, "aten::cdist"),  # type: ignore[attr-defined]
109    (torch._VF.norm, "aten::norm"),  # type: ignore[attr-defined]
110    (torch._VF.unique_dim, "aten::unique_dim"),
111    (torch._VF.unique_consecutive, "aten::unique_consecutive"),  # type: ignore[attr-defined]
112    (torch._VF.nuclear_norm, "aten::nuclear_norm"),
113    (torch._VF.frobenius_norm, "aten::frobenius_norm"),
114    (torch._VF.tensordot, "aten::tensordot"),  # type: ignore[attr-defined]
115]
116
117# ops in torch.functional are bound to torch
118# in these cases, we want to resolve the function to their python implementation
119# instead looking up a builtin "aten::" schema
120
121
122def _gen_torch_functional_registered_ops():
123    # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
124    # but we are currently only able to compile some of the functions. additionally,
125    # some functions directly map to their aten:: implementations.
126    # TODO: add support for more ops
127    ops = [
128        "stft",
129        "istft",
130        "lu",
131        "cdist",
132        "norm",
133        "unique",
134        "unique_consecutive",
135        "tensordot",
136    ]
137    return {getattr(torch.functional, name) for name in ops}
138
139
140_functional_registered_ops = _gen_torch_functional_registered_ops()
141
142
143def _is_special_functional_bound_op(fn):
144    return fn in _functional_registered_ops
145
146
147# lazily built to ensure the correct initialization order
148def _get_builtin_table():
149    global _builtin_table
150    if _builtin_table is not None:
151        return _builtin_table
152    _builtin_table = {}
153
154    def register_all(mod):
155        for name in dir(mod):
156            v = getattr(mod, name)
157            if (
158                callable(v)
159                and not _is_special_functional_bound_op(v)
160                and v is not torch.no_grad
161                and v is not torch.autocast
162            ):
163                # Fixup inconsistency in segment_reduce
164                if name == "_segment_reduce":
165                    name = name[1:]
166                _builtin_ops.append((v, "aten::" + name))
167
168    for mod in _modules_containing_builtins:
169        register_all(mod)
170
171    _builtin_ops.append((math.gcd, "aten::gcd"))
172    _builtin_ops.append((math.isfinite, "aten::isfinite"))
173    _builtin_ops.append((math.remainder, "aten::mathremainder"))  # type: ignore[attr-defined]
174
175    import torch.distributed.autograd as dist_autograd
176
177    if dist_autograd.is_available():
178        _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
179        _builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
180
181    # populate the _builtin_table from _builtin_ops
182    for builtin, aten_op in _builtin_ops:
183        _builtin_table[id(builtin)] = aten_op
184
185    return _builtin_table
186
187
188def _register_builtin(fn, op):
189    _get_builtin_table()[id(fn)] = op
190
191
192def _find_builtin(fn):
193    return _get_builtin_table().get(id(fn))
194