xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/inliner.h>
4 #include <torch/csrc/jit/runtime/operator.h>
5 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
6 #include <unordered_map>
7 
8 namespace torch::jit {
9 
get_tensorexpr_elementwise_set()10 const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
11   // clang-format off
12  static const OperatorMap<std::string> tensorexpr_elementwise_set{
13       {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
14       {"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"},
15       {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
16       {"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
17       {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
18       {"aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
19       {"aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
20       {"aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
21       {"aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
22       {"aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
23       {"aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
24       {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "unary"},
25       {"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "unary"},
26       {"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
27       {"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
28       {"aten::to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
29       {"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"},
30       {"aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"},
31       {"aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "unary"},
32       {"aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "unary"},
33       {"aten::isnan(Tensor self) -> Tensor", "unary"},
34       {"aten::lgamma(Tensor self) -> Tensor", "unary"},
35       {"aten::log10(Tensor self) -> Tensor", "unary"},
36       {"aten::log(Tensor self) -> Tensor", "unary"},
37       {"aten::log2(Tensor self) -> Tensor", "unary"},
38       {"aten::log1p(Tensor self) -> Tensor", "unary"},
39       {"aten::exp(Tensor self) -> Tensor", "unary"},
40       {"aten::erf(Tensor self) -> Tensor", "unary"},
41       {"aten::erfc(Tensor self) -> Tensor", "unary"},
42       // TODO: uncomment when we properly support pow
43       // "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor",
44       // "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor",
45       // TODO: support clamp_min, clamp_max
46       // "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor",
47       // "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", TODO: requires 0-dim Tensor
48       // "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
49       // TODO: uncomment once we can handle rand+broadcasts
50       // "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
51       {"aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
52       {"aten::cos(Tensor self) -> Tensor", "unary"},
53       {"aten::sin(Tensor self) -> Tensor", "unary"},
54       {"aten::tan(Tensor self) -> Tensor", "unary"},
55       {"aten::acos(Tensor self) -> Tensor", "unary"},
56       {"aten::asin(Tensor self) -> Tensor", "unary"},
57       {"aten::atan(Tensor self) -> Tensor", "unary"},
58       {"aten::cosh(Tensor self) -> Tensor", "unary"},
59       {"aten::sinh(Tensor self) -> Tensor", "unary"},
60       {"aten::tanh(Tensor self) -> Tensor", "unary"},
61       {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "unary"},
62       {"aten::hardsigmoid(Tensor self) -> Tensor", "unary"},
63       {"aten::hardswish(Tensor self) -> Tensor", "unary"},
64       {"aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "unary"},
65       {"aten::sqrt(Tensor self) -> Tensor", "unary"},
66       {"aten::rsqrt(Tensor self) -> Tensor", "unary"},
67       {"aten::abs(Tensor self) -> Tensor", "unary"},
68       {"aten::floor(Tensor self) -> Tensor", "unary"},
69       {"aten::ceil(Tensor self) -> Tensor", "unary"},
70       {"aten::round(Tensor self) -> Tensor", "unary"},
71       {"aten::trunc(Tensor self) -> Tensor", "unary"},
72       {"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "unary"},
73       {"aten::sigmoid(Tensor self) -> Tensor", "unary"},
74       {"aten::relu(Tensor self) -> Tensor", "unary"},
75       {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"},
76       {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"},
77       {"aten::mish(Tensor self) -> Tensor", "unary"},
78       {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor", "unary"},
79       {"aten::relu6(Tensor self) -> Tensor", "unary"},
80       {"aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "unary"},
81       {"aten::silu(Tensor self) -> Tensor", "unary"},
82       {"aten::neg(Tensor self) -> Tensor", "unary"},
83       {"aten::reciprocal(Tensor self) -> Tensor", "unary"},
84       {"aten::expm1(Tensor self) -> Tensor", "unary"},
85       {"aten::frac(Tensor self) -> Tensor", "unary"},
86       {"aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
87       {"aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
88       {"aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
89       {"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
90       {"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
91       {"aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", "unary"},
92       {"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"},
93       {"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "broadcast"},
94       {"aten::type_as(Tensor self, Tensor other) -> Tensor", "unary"},
95       {"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"},
96       {"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
97       {"aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
98       {"aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
99       {"aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
100       {"aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
101       {"aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
102       {"aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
103       {"aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
104       {"aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "broadcast"},
105       {"aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
106       {"aten::atan2(Tensor self, Tensor other) -> Tensor", "broadcast"},
107       {"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
108       {"aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
109       {"aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
110       {"aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
111       {"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
112       {"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
113       // TODO: enable other min/max variants, operators that can be both
114       // elementwise or reductions:
115       {"aten::min.other(Tensor self, Tensor other) -> Tensor", "broadcast"},
116       {"aten::max.other(Tensor self, Tensor other) -> Tensor", "broadcast"},
117       {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
118       {"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "broadcast_three"},
119       {"aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "broadcast_three"},
120       {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
121       // TODO: enable slice, shape inference is not implemented for this op yet
122   };
123   // clang-format on
124   return tensorexpr_elementwise_set;
125 }
126 
127 } // namespace torch::jit
128