1 #include <torch/csrc/jit/tensorexpr/ir.h>
2 #include <torch/csrc/jit/tensorexpr/stmt.h>
3
4 #include <c10/util/irange.h>
5
6 #include <utility>
7
8 namespace torch::jit::tensorexpr {
9
ChooseDtype(const Dtype & buffer_dtype,const Dtype & index_dtype)10 static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
11 return Dtype(buffer_dtype, index_dtype.lanes());
12 }
13
dtypeOfIndices(const std::vector<ExprPtr> & indices)14 static Dtype dtypeOfIndices(const std::vector<ExprPtr>& indices) {
15 if (indices.empty()) {
16 // Return something so we can handle scalar buffers.
17 return kInt;
18 }
19 return indices.at(0)->dtype();
20 }
21
castIndicesToInts(std::vector<ExprPtr> & indices)22 static void castIndicesToInts(std::vector<ExprPtr>& indices) {
23 // Cast all indices to either Int or Long
24 auto index_dtype = ScalarType::Int;
25 for (auto& index : indices) {
26 if (index->dtype().scalar_type() == ScalarType::Long) {
27 // If any of the indexes is Long, cast all of them to Long
28 index_dtype = ScalarType::Long;
29 break;
30 }
31 }
32
33 for (auto& index : indices) {
34 const Dtype& dt = index->dtype();
35 if (c10::isIntegralType(dt.scalar_type(), true) &&
36 dt.scalar_type() != index_dtype) {
37 index = alloc<Cast>(Dtype(index_dtype, dt.lanes()), index);
38 }
39 }
40 }
41
Load(Dtype dtype,BufPtr buf,std::vector<ExprPtr> indices)42 Load::Load(Dtype dtype, BufPtr buf, std::vector<ExprPtr> indices)
43 : ExprNodeBase(dtype), buf_(std::move(buf)), indices_(std::move(indices)) {
44 castIndicesToInts(indices_);
45 }
46
Load(const BufPtr & buf,const std::vector<ExprPtr> & indices)47 Load::Load(const BufPtr& buf, const std::vector<ExprPtr>& indices)
48 : Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}
49
make(Dtype dtype,const BufHandle & buf,const std::vector<ExprHandle> & indices)50 ExprHandle Load::make(
51 Dtype dtype,
52 const BufHandle& buf,
53 const std::vector<ExprHandle>& indices) {
54 return ExprHandle(
55 alloc<Load>(dtype, buf.node(), ExprHandleVectorToExprVector(indices)));
56 }
57
make(const BufHandle & buf,const std::vector<ExprHandle> & indices)58 ExprHandle Load::make(
59 const BufHandle& buf,
60 const std::vector<ExprHandle>& indices) {
61 return Load::make(buf.dtype(), buf, indices);
62 }
63
Store(BufPtr buf,std::vector<ExprPtr> indices,ExprPtr value)64 Store::Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
65 : buf_(std::move(buf)),
66 indices_(std::move(indices)),
67 value_(std::move(value)) {
68 castIndicesToInts(indices_);
69 }
70
make(const BufHandle & buf,const std::vector<ExprHandle> & indices,const ExprHandle & value)71 StorePtr Store::make(
72 const BufHandle& buf,
73 const std::vector<ExprHandle>& indices,
74 const ExprHandle& value) {
75 return alloc<Store>(
76 buf.node(), ExprHandleVectorToExprVector(indices), value.node());
77 }
78
store(const std::vector<ExprHandle> & args,const ExprHandle & value) const79 StorePtr BufHandle::store(
80 const std::vector<ExprHandle>& args,
81 const ExprHandle& value) const {
82 return Store::make(*this, args, value);
83 }
84
flatten_index(const std::vector<ExprPtr> & dims,const std::vector<ExprPtr> & indices,const std::vector<ExprPtr> & strides)85 ExprPtr flatten_index(
86 const std::vector<ExprPtr>& dims,
87 const std::vector<ExprPtr>& indices,
88 const std::vector<ExprPtr>& strides) {
89 // Handle already flattened indices first
90 if (indices.size() == 1) {
91 return indices[0];
92 }
93
94 size_t ndim = dims.size();
95 if (ndim != indices.size()) {
96 throw malformed_input("dimensions mismatch in flatten_index");
97 }
98 if (ndim != strides.size()) {
99 throw malformed_input("strides mismatch in flatten_index");
100 }
101 if (ndim == 0) {
102 return alloc<LongImm>(0);
103 }
104 ExprPtr total_index = immLike(indices[0], 0);
105 for (const auto i : c10::irange(ndim)) {
106 total_index = alloc<Add>(total_index, alloc<Mul>(indices[i], strides[i]));
107 }
108 return total_index;
109 }
110
IntrinsicsDtype(IntrinsicsOp op_type,Dtype dt1)111 Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) {
112 if (op_type == kIsNan) {
113 return dt1.cloneWithScalarType(ScalarType::Int);
114 }
115 // TODO: check the op_type and make a real decision
116 return dt1;
117 }
118
IntrinsicsDtype(IntrinsicsOp op_type,Dtype dt1,Dtype dt2)119 Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) {
120 // TODO: check the op_type and make a real decision
121 return dt1;
122 }
123
IntrinsicsDtype(IntrinsicsOp op_type,const std::vector<ExprPtr> & params)124 Dtype Intrinsics::IntrinsicsDtype(
125 IntrinsicsOp op_type,
126 const std::vector<ExprPtr>& params) {
127 // TODO: check the op_type and make a real decision
128 // Doesnt this fail with kRand?
129 if (params.empty()) {
130 throw malformed_input("invalid params in Intrinsics");
131 } else if (params.size() == 1) {
132 return IntrinsicsDtype(op_type, params[0]->dtype());
133 } else if (params.size() == 2) {
134 return IntrinsicsDtype(op_type, params[0]->dtype(), params[1]->dtype());
135 }
136 return params[0]->dtype();
137 }
138
OpArgCount(IntrinsicsOp op_type)139 size_t Intrinsics::OpArgCount(IntrinsicsOp op_type) {
140 switch (op_type) {
141 case kSin:
142 case kCos:
143 case kTan:
144 case kAsin:
145 case kAcos:
146 case kAtan:
147 case kSinh:
148 case kCosh:
149 case kTanh:
150 case kSigmoid:
151 case kExp:
152 case kExpm1:
153 case kAbs:
154 case kLog:
155 case kLog2:
156 case kLog10:
157 case kLog1p:
158 case kErf:
159 case kErfc:
160 case kSqrt:
161 case kRsqrt:
162 case kCeil:
163 case kFloor:
164 case kRound:
165 case kTrunc:
166 case kFrac:
167 case kLgamma:
168 case kIsNan:
169 return 1;
170 case kRand:
171 return 0;
172 case kAtan2:
173 case kFmod:
174 case kPow:
175 case kRemainder:
176 return 2;
177 default:
178 throw std::runtime_error("invalid op_type: " + std::to_string(op_type));
179 }
180 }
181
make(BufHandle buf,const std::string & func_name,const std::vector<BufHandle> & buf_args,const std::vector<ExprHandle> & args)182 ExternalCallPtr ExternalCall::make(
183 BufHandle buf,
184 const std::string& func_name,
185 const std::vector<BufHandle>& buf_args,
186 const std::vector<ExprHandle>& args) {
187 std::vector<BufPtr> buf_arg_nodes;
188 buf_arg_nodes.reserve(buf_args.size());
189 for (const BufHandle& buf_arg : buf_args) {
190 buf_arg_nodes.push_back(buf_arg.node());
191 }
192 return alloc<ExternalCall>(
193 buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
194 }
195
make(const std::string & func_name,const std::vector<BufHandle> & buf_out_args,const std::vector<BufHandle> & buf_args,const std::vector<ExprHandle> & args)196 ExternalCallWithAllocPtr ExternalCallWithAlloc::make(
197 const std::string& func_name,
198 const std::vector<BufHandle>& buf_out_args,
199 const std::vector<BufHandle>& buf_args,
200 const std::vector<ExprHandle>& args) {
201 std::vector<BufPtr> buf_out_arg_nodes;
202 buf_out_arg_nodes.reserve(buf_out_args.size());
203 for (const BufHandle& buf_out_arg : buf_out_args) {
204 buf_out_arg_nodes.push_back(buf_out_arg.node());
205 }
206
207 std::vector<BufPtr> buf_arg_nodes;
208 buf_arg_nodes.reserve(buf_args.size());
209 for (const BufHandle& buf_arg : buf_args) {
210 buf_arg_nodes.push_back(buf_arg.node());
211 }
212 return alloc<ExternalCallWithAlloc>(
213 func_name,
214 buf_out_arg_nodes,
215 buf_arg_nodes,
216 ExprHandleVectorToExprVector(args));
217 }
218
make(const std::vector<BufHandle> & bufs)219 FreeExtPtr FreeExt::make(const std::vector<BufHandle>& bufs) {
220 std::vector<BufPtr> buf_nodes;
221 buf_nodes.reserve(bufs.size());
222 for (const BufHandle& buf : bufs) {
223 buf_nodes.push_back(buf.node());
224 }
225 return alloc<FreeExt>(buf_nodes);
226 }
227
ExprHandleVectorToExprVector(const std::vector<ExprHandle> & v)228 std::vector<ExprPtr> ExprHandleVectorToExprVector(
229 const std::vector<ExprHandle>& v) {
230 std::vector<ExprPtr> result(v.size());
231 for (const auto i : c10::irange(v.size())) {
232 result[i] = v[i].node();
233 }
234 return result;
235 }
236
ExprVectorToExprHandleVector(const std::vector<ExprPtr> & v)237 std::vector<ExprHandle> ExprVectorToExprHandleVector(
238 const std::vector<ExprPtr>& v) {
239 std::vector<ExprHandle> result(v.size());
240 for (const auto i : c10::irange(v.size())) {
241 result[i] = ExprHandle(v[i]);
242 }
243 return result;
244 }
245
VarHandleVectorToVarVector(const std::vector<VarHandle> & v)246 std::vector<VarPtr> VarHandleVectorToVarVector(
247 const std::vector<VarHandle>& v) {
248 std::vector<VarPtr> result(v.size());
249 for (const auto i : c10::irange(v.size())) {
250 result[i] = v[i].node();
251 }
252 return result;
253 }
254
VarVectorToVarHandleVector(const std::vector<VarPtr> & v)255 std::vector<VarHandle> VarVectorToVarHandleVector(
256 const std::vector<VarPtr>& v) {
257 std::vector<VarHandle> result(v.size());
258 for (const auto i : c10::irange(v.size())) {
259 result[i] = VarHandle(v[i]);
260 }
261 return result;
262 }
263
immediateIsNegative(const ExprPtr & e)264 bool immediateIsNegative(const ExprPtr& e) {
265 #define TYPE_CASE(Type, Name) \
266 if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
267 return imm->value() < 0; \
268 }
269 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
270 #undef TYPE_CASE
271 return false;
272 }
273
immediateIsPositive(const ExprPtr & e)274 bool immediateIsPositive(const ExprPtr& e) {
275 #define TYPE_CASE(Type, Name) \
276 if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
277 return imm->value() > 0; \
278 }
279 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
280 #undef TYPE_CASE
281 return false;
282 }
283
immediateIsZero(const ExprPtr & e)284 bool immediateIsZero(const ExprPtr& e) {
285 #define TYPE_CASE(Type, Name) \
286 if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
287 return imm->value() == 0; \
288 }
289 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
290 #undef TYPE_CASE
291 return false;
292 }
293
294 } // namespace torch::jit::tensorexpr
295