1 #include <torch/csrc/jit/frontend/ir_emitter.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/inliner.h>
4 #include <torch/csrc/jit/runtime/operator.h>
5 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
6 #include <unordered_map>
7
8 namespace torch::jit {
9
get_tensorexpr_elementwise_set()10 const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
11 // clang-format off
12 static const OperatorMap<std::string> tensorexpr_elementwise_set{
13 {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
14 {"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "unary"},
15 {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
16 {"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
17 {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
18 {"aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
19 {"aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
20 {"aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
21 {"aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
22 {"aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
23 {"aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
24 {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "unary"},
25 {"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "unary"},
26 {"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
27 {"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
28 {"aten::to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", "unary"},
29 {"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"},
30 {"aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", "unary"},
31 {"aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "unary"},
32 {"aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "unary"},
33 {"aten::isnan(Tensor self) -> Tensor", "unary"},
34 {"aten::lgamma(Tensor self) -> Tensor", "unary"},
35 {"aten::log10(Tensor self) -> Tensor", "unary"},
36 {"aten::log(Tensor self) -> Tensor", "unary"},
37 {"aten::log2(Tensor self) -> Tensor", "unary"},
38 {"aten::log1p(Tensor self) -> Tensor", "unary"},
39 {"aten::exp(Tensor self) -> Tensor", "unary"},
40 {"aten::erf(Tensor self) -> Tensor", "unary"},
41 {"aten::erfc(Tensor self) -> Tensor", "unary"},
42 // TODO: uncomment when we properly support pow
43 // "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor",
44 // "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor",
45 // TODO: support clamp_min, clamp_max
46 // "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor",
47 // "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", TODO: requires 0-dim Tensor
48 // "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
49 // TODO: uncomment once we can handle rand+broadcasts
50 // "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
51 {"aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
52 {"aten::cos(Tensor self) -> Tensor", "unary"},
53 {"aten::sin(Tensor self) -> Tensor", "unary"},
54 {"aten::tan(Tensor self) -> Tensor", "unary"},
55 {"aten::acos(Tensor self) -> Tensor", "unary"},
56 {"aten::asin(Tensor self) -> Tensor", "unary"},
57 {"aten::atan(Tensor self) -> Tensor", "unary"},
58 {"aten::cosh(Tensor self) -> Tensor", "unary"},
59 {"aten::sinh(Tensor self) -> Tensor", "unary"},
60 {"aten::tanh(Tensor self) -> Tensor", "unary"},
61 {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "unary"},
62 {"aten::hardsigmoid(Tensor self) -> Tensor", "unary"},
63 {"aten::hardswish(Tensor self) -> Tensor", "unary"},
64 {"aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "unary"},
65 {"aten::sqrt(Tensor self) -> Tensor", "unary"},
66 {"aten::rsqrt(Tensor self) -> Tensor", "unary"},
67 {"aten::abs(Tensor self) -> Tensor", "unary"},
68 {"aten::floor(Tensor self) -> Tensor", "unary"},
69 {"aten::ceil(Tensor self) -> Tensor", "unary"},
70 {"aten::round(Tensor self) -> Tensor", "unary"},
71 {"aten::trunc(Tensor self) -> Tensor", "unary"},
72 {"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "unary"},
73 {"aten::sigmoid(Tensor self) -> Tensor", "unary"},
74 {"aten::relu(Tensor self) -> Tensor", "unary"},
75 {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"},
76 {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"},
77 {"aten::mish(Tensor self) -> Tensor", "unary"},
78 {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor", "unary"},
79 {"aten::relu6(Tensor self) -> Tensor", "unary"},
80 {"aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "unary"},
81 {"aten::silu(Tensor self) -> Tensor", "unary"},
82 {"aten::neg(Tensor self) -> Tensor", "unary"},
83 {"aten::reciprocal(Tensor self) -> Tensor", "unary"},
84 {"aten::expm1(Tensor self) -> Tensor", "unary"},
85 {"aten::frac(Tensor self) -> Tensor", "unary"},
86 {"aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
87 {"aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
88 {"aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
89 {"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
90 {"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "unary"},
91 {"aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", "unary"},
92 {"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"},
93 {"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "broadcast"},
94 {"aten::type_as(Tensor self, Tensor other) -> Tensor", "unary"},
95 {"aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast"},
96 {"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
97 {"aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
98 {"aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
99 {"aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
100 {"aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
101 {"aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
102 {"aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
103 {"aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
104 {"aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "broadcast"},
105 {"aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
106 {"aten::atan2(Tensor self, Tensor other) -> Tensor", "broadcast"},
107 {"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
108 {"aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
109 {"aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
110 {"aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
111 {"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
112 {"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
113 // TODO: enable other min/max variants, operators that can be both
114 // elementwise or reductions:
115 {"aten::min.other(Tensor self, Tensor other) -> Tensor", "broadcast"},
116 {"aten::max.other(Tensor self, Tensor other) -> Tensor", "broadcast"},
117 {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
118 {"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "broadcast_three"},
119 {"aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "broadcast_three"},
120 {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
121 // TODO: enable slice, shape inference is not implemented for this op yet
122 };
123 // clang-format on
124 return tensorexpr_elementwise_set;
125 }
126
127 } // namespace torch::jit
128