1# mypy: allow-untyped-defs 2import cmath 3import math 4import warnings 5from collections import OrderedDict 6from typing import Dict, Optional 7 8import torch 9import torch.backends.cudnn as cudnn 10from torch.nn.modules.utils import ( 11 _list_with_default, 12 _pair, 13 _quadruple, 14 _single, 15 _triple, 16) 17 18 19_builtin_table: Optional[Dict[int, str]] = None 20 21_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950 22 23_builtin_ops = [ 24 # Pairs of (function, op_name) 25 (_pair, "aten::_pair"), 26 (_quadruple, "aten::_quadruple"), 27 (_single, "aten::_single"), 28 (_triple, "aten::_triple"), 29 (_list_with_default, "aten::list_with_default"), 30 (OrderedDict, "aten::dict"), 31 (dict, "aten::dict"), 32 (cudnn.is_acceptable, "aten::cudnn_is_acceptable"), 33 (math.ceil, "aten::ceil"), 34 (math.copysign, "aten::copysign"), 35 (math.erf, "aten::erf"), 36 (math.erfc, "aten::erfc"), 37 (math.exp, "aten::exp"), 38 (math.expm1, "aten::expm1"), 39 (math.fabs, "aten::fabs"), 40 (math.floor, "aten::floor"), 41 (math.gamma, "aten::gamma"), 42 (math.lgamma, "aten::lgamma"), 43 (math.log, "aten::log"), 44 (math.log10, "aten::log10"), 45 (math.log1p, "aten::log1p"), 46 (math.pow, "aten::pow"), 47 (math.sqrt, "aten::sqrt"), 48 (math.isnan, "aten::isnan"), 49 (math.asinh, "aten::asinh"), 50 (math.atanh, "aten::atanh"), 51 (math.cosh, "aten::cosh"), 52 (math.sinh, "aten::sinh"), 53 (math.tanh, "aten::tanh"), 54 (math.acos, "aten::acos"), 55 (math.asin, "aten::asin"), 56 (math.atan, "aten::atan"), 57 (math.atan2, "aten::atan2"), 58 (math.cos, "aten::cos"), 59 (math.sin, "aten::sin"), 60 (math.tan, "aten::tan"), 61 (math.asinh, "aten::asinh"), 62 (math.atanh, "aten::atanh"), 63 (math.acosh, "aten::acosh"), 64 (math.fmod, "aten::fmod"), 65 (math.modf, "aten::modf"), 66 (math.factorial, "aten::factorial"), 67 (math.frexp, "aten::frexp"), 68 (math.isinf, "aten::isinf"), 69 (math.degrees, "aten::degrees"), 70 (math.radians, "aten::radians"), 71 (cmath.isnan, "aten::isnan"), 72 (cmath.isfinite, "aten::isfinite"), 73 (cmath.isinf, "aten::isinf"), 74 (cmath.phase, "aten::angle"), 75 (cmath.rect, "aten::polar"), 76 (cmath.log, "aten::log"), 77 (cmath.log10, "aten::log10"), 78 (cmath.sqrt, "aten::sqrt"), 79 (cmath.exp, "aten::exp"), 80 (cmath.sin, "aten::sin"), 81 (cmath.tan, "aten::tan"), 82 (cmath.cos, "aten::cos"), 83 (cmath.asin, "aten::asin"), 84 (cmath.acos, "aten::acos"), 85 (cmath.atan, "aten::atan"), 86 (cmath.sinh, "aten::sinh"), 87 (cmath.cosh, "aten::cosh"), 88 (cmath.tanh, "aten::tanh"), 89 (cmath.asinh, "aten::asinh"), 90 (cmath.acosh, "aten::acosh"), 91 (cmath.atanh, "aten::atanh"), 92 (math.ldexp, "aten::ldexp"), 93 (torch._assert, "aten::_assert"), 94 (torch.autograd.grad, "aten::grad"), 95 (torch.autograd.backward, "aten::backward"), 96 (torch._C._infer_size, "aten::_infer_size"), 97 (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined] 98 (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), 99 (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), 100 (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), 101 (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"), 102 (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"), 103 (torch._C._get_tracing_state, "aten::_get_tracing_state"), 104 (torch._C._get_cpu_capability, "aten::_get_cpu_capability"), 105 (warnings.warn, "aten::warn"), 106 (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined] 107 (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined] 108 (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined] 109 (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined] 110 (torch._VF.unique_dim, "aten::unique_dim"), 111 (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined] 112 (torch._VF.nuclear_norm, "aten::nuclear_norm"), 113 (torch._VF.frobenius_norm, "aten::frobenius_norm"), 114 (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined] 115] 116 117# ops in torch.functional are bound to torch 118# in these cases, we want to resolve the function to their python implementation 119# instead looking up a builtin "aten::" schema 120 121 122def _gen_torch_functional_registered_ops(): 123 # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__) 124 # but we are currently only able to compile some of the functions. additionally, 125 # some functions directly map to their aten:: implementations. 126 # TODO: add support for more ops 127 ops = [ 128 "stft", 129 "istft", 130 "lu", 131 "cdist", 132 "norm", 133 "unique", 134 "unique_consecutive", 135 "tensordot", 136 ] 137 return {getattr(torch.functional, name) for name in ops} 138 139 140_functional_registered_ops = _gen_torch_functional_registered_ops() 141 142 143def _is_special_functional_bound_op(fn): 144 return fn in _functional_registered_ops 145 146 147# lazily built to ensure the correct initialization order 148def _get_builtin_table(): 149 global _builtin_table 150 if _builtin_table is not None: 151 return _builtin_table 152 _builtin_table = {} 153 154 def register_all(mod): 155 for name in dir(mod): 156 v = getattr(mod, name) 157 if ( 158 callable(v) 159 and not _is_special_functional_bound_op(v) 160 and v is not torch.no_grad 161 and v is not torch.autocast 162 ): 163 # Fixup inconsistency in segment_reduce 164 if name == "_segment_reduce": 165 name = name[1:] 166 _builtin_ops.append((v, "aten::" + name)) 167 168 for mod in _modules_containing_builtins: 169 register_all(mod) 170 171 _builtin_ops.append((math.gcd, "aten::gcd")) 172 _builtin_ops.append((math.isfinite, "aten::isfinite")) 173 _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined] 174 175 import torch.distributed.autograd as dist_autograd 176 177 if dist_autograd.is_available(): 178 _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients")) 179 _builtin_ops.append((dist_autograd.backward, "aten::dist_backward")) 180 181 # populate the _builtin_table from _builtin_ops 182 for builtin, aten_op in _builtin_ops: 183 _builtin_table[id(builtin)] = aten_op 184 185 return _builtin_table 186 187 188def _register_builtin(fn, op): 189 _get_builtin_table()[id(fn)] = op 190 191 192def _find_builtin(fn): 193 return _get_builtin_table().get(id(fn)) 194