xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir.h>
2 #include <torch/csrc/jit/tensorexpr/stmt.h>
3 
4 #include <c10/util/irange.h>
5 
6 #include <utility>
7 
8 namespace torch::jit::tensorexpr {
9 
ChooseDtype(const Dtype & buffer_dtype,const Dtype & index_dtype)10 static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
11   return Dtype(buffer_dtype, index_dtype.lanes());
12 }
13 
dtypeOfIndices(const std::vector<ExprPtr> & indices)14 static Dtype dtypeOfIndices(const std::vector<ExprPtr>& indices) {
15   if (indices.empty()) {
16     // Return something so we can handle scalar buffers.
17     return kInt;
18   }
19   return indices.at(0)->dtype();
20 }
21 
castIndicesToInts(std::vector<ExprPtr> & indices)22 static void castIndicesToInts(std::vector<ExprPtr>& indices) {
23   // Cast all indices to either Int or Long
24   auto index_dtype = ScalarType::Int;
25   for (auto& index : indices) {
26     if (index->dtype().scalar_type() == ScalarType::Long) {
27       // If any of the indexes is Long, cast all of them to Long
28       index_dtype = ScalarType::Long;
29       break;
30     }
31   }
32 
33   for (auto& index : indices) {
34     const Dtype& dt = index->dtype();
35     if (c10::isIntegralType(dt.scalar_type(), true) &&
36         dt.scalar_type() != index_dtype) {
37       index = alloc<Cast>(Dtype(index_dtype, dt.lanes()), index);
38     }
39   }
40 }
41 
Load(Dtype dtype,BufPtr buf,std::vector<ExprPtr> indices)42 Load::Load(Dtype dtype, BufPtr buf, std::vector<ExprPtr> indices)
43     : ExprNodeBase(dtype), buf_(std::move(buf)), indices_(std::move(indices)) {
44   castIndicesToInts(indices_);
45 }
46 
Load(const BufPtr & buf,const std::vector<ExprPtr> & indices)47 Load::Load(const BufPtr& buf, const std::vector<ExprPtr>& indices)
48     : Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}
49 
make(Dtype dtype,const BufHandle & buf,const std::vector<ExprHandle> & indices)50 ExprHandle Load::make(
51     Dtype dtype,
52     const BufHandle& buf,
53     const std::vector<ExprHandle>& indices) {
54   return ExprHandle(
55       alloc<Load>(dtype, buf.node(), ExprHandleVectorToExprVector(indices)));
56 }
57 
make(const BufHandle & buf,const std::vector<ExprHandle> & indices)58 ExprHandle Load::make(
59     const BufHandle& buf,
60     const std::vector<ExprHandle>& indices) {
61   return Load::make(buf.dtype(), buf, indices);
62 }
63 
Store(BufPtr buf,std::vector<ExprPtr> indices,ExprPtr value)64 Store::Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
65     : buf_(std::move(buf)),
66       indices_(std::move(indices)),
67       value_(std::move(value)) {
68   castIndicesToInts(indices_);
69 }
70 
make(const BufHandle & buf,const std::vector<ExprHandle> & indices,const ExprHandle & value)71 StorePtr Store::make(
72     const BufHandle& buf,
73     const std::vector<ExprHandle>& indices,
74     const ExprHandle& value) {
75   return alloc<Store>(
76       buf.node(), ExprHandleVectorToExprVector(indices), value.node());
77 }
78 
store(const std::vector<ExprHandle> & args,const ExprHandle & value) const79 StorePtr BufHandle::store(
80     const std::vector<ExprHandle>& args,
81     const ExprHandle& value) const {
82   return Store::make(*this, args, value);
83 }
84 
flatten_index(const std::vector<ExprPtr> & dims,const std::vector<ExprPtr> & indices,const std::vector<ExprPtr> & strides)85 ExprPtr flatten_index(
86     const std::vector<ExprPtr>& dims,
87     const std::vector<ExprPtr>& indices,
88     const std::vector<ExprPtr>& strides) {
89   // Handle already flattened indices first
90   if (indices.size() == 1) {
91     return indices[0];
92   }
93 
94   size_t ndim = dims.size();
95   if (ndim != indices.size()) {
96     throw malformed_input("dimensions mismatch in flatten_index");
97   }
98   if (ndim != strides.size()) {
99     throw malformed_input("strides mismatch in flatten_index");
100   }
101   if (ndim == 0) {
102     return alloc<LongImm>(0);
103   }
104   ExprPtr total_index = immLike(indices[0], 0);
105   for (const auto i : c10::irange(ndim)) {
106     total_index = alloc<Add>(total_index, alloc<Mul>(indices[i], strides[i]));
107   }
108   return total_index;
109 }
110 
IntrinsicsDtype(IntrinsicsOp op_type,Dtype dt1)111 Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) {
112   if (op_type == kIsNan) {
113     return dt1.cloneWithScalarType(ScalarType::Int);
114   }
115   // TODO: check the op_type and make a real decision
116   return dt1;
117 }
118 
IntrinsicsDtype(IntrinsicsOp op_type,Dtype dt1,Dtype dt2)119 Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) {
120   // TODO: check the op_type and make a real decision
121   return dt1;
122 }
123 
IntrinsicsDtype(IntrinsicsOp op_type,const std::vector<ExprPtr> & params)124 Dtype Intrinsics::IntrinsicsDtype(
125     IntrinsicsOp op_type,
126     const std::vector<ExprPtr>& params) {
127   // TODO: check the op_type and make a real decision
128   // Doesnt this fail with kRand?
129   if (params.empty()) {
130     throw malformed_input("invalid params in Intrinsics");
131   } else if (params.size() == 1) {
132     return IntrinsicsDtype(op_type, params[0]->dtype());
133   } else if (params.size() == 2) {
134     return IntrinsicsDtype(op_type, params[0]->dtype(), params[1]->dtype());
135   }
136   return params[0]->dtype();
137 }
138 
OpArgCount(IntrinsicsOp op_type)139 size_t Intrinsics::OpArgCount(IntrinsicsOp op_type) {
140   switch (op_type) {
141     case kSin:
142     case kCos:
143     case kTan:
144     case kAsin:
145     case kAcos:
146     case kAtan:
147     case kSinh:
148     case kCosh:
149     case kTanh:
150     case kSigmoid:
151     case kExp:
152     case kExpm1:
153     case kAbs:
154     case kLog:
155     case kLog2:
156     case kLog10:
157     case kLog1p:
158     case kErf:
159     case kErfc:
160     case kSqrt:
161     case kRsqrt:
162     case kCeil:
163     case kFloor:
164     case kRound:
165     case kTrunc:
166     case kFrac:
167     case kLgamma:
168     case kIsNan:
169       return 1;
170     case kRand:
171       return 0;
172     case kAtan2:
173     case kFmod:
174     case kPow:
175     case kRemainder:
176       return 2;
177     default:
178       throw std::runtime_error("invalid op_type: " + std::to_string(op_type));
179   }
180 }
181 
make(BufHandle buf,const std::string & func_name,const std::vector<BufHandle> & buf_args,const std::vector<ExprHandle> & args)182 ExternalCallPtr ExternalCall::make(
183     BufHandle buf,
184     const std::string& func_name,
185     const std::vector<BufHandle>& buf_args,
186     const std::vector<ExprHandle>& args) {
187   std::vector<BufPtr> buf_arg_nodes;
188   buf_arg_nodes.reserve(buf_args.size());
189   for (const BufHandle& buf_arg : buf_args) {
190     buf_arg_nodes.push_back(buf_arg.node());
191   }
192   return alloc<ExternalCall>(
193       buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
194 }
195 
make(const std::string & func_name,const std::vector<BufHandle> & buf_out_args,const std::vector<BufHandle> & buf_args,const std::vector<ExprHandle> & args)196 ExternalCallWithAllocPtr ExternalCallWithAlloc::make(
197     const std::string& func_name,
198     const std::vector<BufHandle>& buf_out_args,
199     const std::vector<BufHandle>& buf_args,
200     const std::vector<ExprHandle>& args) {
201   std::vector<BufPtr> buf_out_arg_nodes;
202   buf_out_arg_nodes.reserve(buf_out_args.size());
203   for (const BufHandle& buf_out_arg : buf_out_args) {
204     buf_out_arg_nodes.push_back(buf_out_arg.node());
205   }
206 
207   std::vector<BufPtr> buf_arg_nodes;
208   buf_arg_nodes.reserve(buf_args.size());
209   for (const BufHandle& buf_arg : buf_args) {
210     buf_arg_nodes.push_back(buf_arg.node());
211   }
212   return alloc<ExternalCallWithAlloc>(
213       func_name,
214       buf_out_arg_nodes,
215       buf_arg_nodes,
216       ExprHandleVectorToExprVector(args));
217 }
218 
make(const std::vector<BufHandle> & bufs)219 FreeExtPtr FreeExt::make(const std::vector<BufHandle>& bufs) {
220   std::vector<BufPtr> buf_nodes;
221   buf_nodes.reserve(bufs.size());
222   for (const BufHandle& buf : bufs) {
223     buf_nodes.push_back(buf.node());
224   }
225   return alloc<FreeExt>(buf_nodes);
226 }
227 
ExprHandleVectorToExprVector(const std::vector<ExprHandle> & v)228 std::vector<ExprPtr> ExprHandleVectorToExprVector(
229     const std::vector<ExprHandle>& v) {
230   std::vector<ExprPtr> result(v.size());
231   for (const auto i : c10::irange(v.size())) {
232     result[i] = v[i].node();
233   }
234   return result;
235 }
236 
ExprVectorToExprHandleVector(const std::vector<ExprPtr> & v)237 std::vector<ExprHandle> ExprVectorToExprHandleVector(
238     const std::vector<ExprPtr>& v) {
239   std::vector<ExprHandle> result(v.size());
240   for (const auto i : c10::irange(v.size())) {
241     result[i] = ExprHandle(v[i]);
242   }
243   return result;
244 }
245 
VarHandleVectorToVarVector(const std::vector<VarHandle> & v)246 std::vector<VarPtr> VarHandleVectorToVarVector(
247     const std::vector<VarHandle>& v) {
248   std::vector<VarPtr> result(v.size());
249   for (const auto i : c10::irange(v.size())) {
250     result[i] = v[i].node();
251   }
252   return result;
253 }
254 
VarVectorToVarHandleVector(const std::vector<VarPtr> & v)255 std::vector<VarHandle> VarVectorToVarHandleVector(
256     const std::vector<VarPtr>& v) {
257   std::vector<VarHandle> result(v.size());
258   for (const auto i : c10::irange(v.size())) {
259     result[i] = VarHandle(v[i]);
260   }
261   return result;
262 }
263 
immediateIsNegative(const ExprPtr & e)264 bool immediateIsNegative(const ExprPtr& e) {
265 #define TYPE_CASE(Type, Name)                \
266   if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
267     return imm->value() < 0;                 \
268   }
269   AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
270 #undef TYPE_CASE
271   return false;
272 }
273 
immediateIsPositive(const ExprPtr & e)274 bool immediateIsPositive(const ExprPtr& e) {
275 #define TYPE_CASE(Type, Name)                \
276   if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
277     return imm->value() > 0;                 \
278   }
279   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
280 #undef TYPE_CASE
281   return false;
282 }
283 
immediateIsZero(const ExprPtr & e)284 bool immediateIsZero(const ExprPtr& e) {
285 #define TYPE_CASE(Type, Name)                \
286   if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
287     return imm->value() == 0;                \
288   }
289   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
290 #undef TYPE_CASE
291   return false;
292 }
293 
294 } // namespace torch::jit::tensorexpr
295