xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <functional>
5 #include <utility>
6 #include <vector>
7 
8 #include <torch/csrc/jit/tensorexpr/expr.h>
9 #include <torch/csrc/jit/tensorexpr/reduction.h>
10 
11 namespace torch::jit::tensorexpr {
12 
13 class TORCH_API Tensor {
14  public:
Tensor(BufPtr buf,const std::vector<VarPtr> & args,const ExprPtr & body)15   Tensor(BufPtr buf, const std::vector<VarPtr>& args, const ExprPtr& body)
16       : buf_(std::move(buf)) {
17     stmt_ = constructStmt(args, body, {}, {});
18   }
Tensor(BufHandle buf,const std::vector<VarHandle> & args,ExprHandle body)19   Tensor(BufHandle buf, const std::vector<VarHandle>& args, ExprHandle body)
20       : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {}
21 
Tensor(BufPtr buf,const std::vector<VarPtr> & args,const std::vector<ExprPtr> & reduce_dims,const std::vector<VarPtr> & reduce_args,const ExprPtr & body)22   Tensor(
23       BufPtr buf,
24       const std::vector<VarPtr>& args,
25       const std::vector<ExprPtr>& reduce_dims,
26       const std::vector<VarPtr>& reduce_args,
27       const ExprPtr& body)
28       : buf_(std::move(buf)) {
29     stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
30   }
Tensor(BufHandle buf,const std::vector<VarHandle> & args,const std::vector<ExprHandle> & reduce_dims,const std::vector<VarHandle> & reduce_args,ExprHandle body)31   Tensor(
32       BufHandle buf,
33       const std::vector<VarHandle>& args,
34       const std::vector<ExprHandle>& reduce_dims,
35       const std::vector<VarHandle>& reduce_args,
36       ExprHandle body)
37       : Tensor(
38             buf.node(),
39             VarHandleVectorToVarVector(args),
40             ExprHandleVectorToExprVector(reduce_dims),
41             VarHandleVectorToVarVector(reduce_args),
42             body.node()) {}
43 
Tensor(BufPtr buf,StmtPtr stmt)44   Tensor(BufPtr buf, StmtPtr stmt)
45       : buf_(std::move(buf)), stmt_(std::move(stmt)) {}
46 
buf()47   BufPtr buf() const {
48     return buf_;
49   }
50 
stmt()51   StmtPtr stmt() const {
52     return stmt_;
53   }
54 
55   template <typename T>
56   inline ExprHandle load(const std::vector<T>& args) const;
57   template <typename... Ts>
58   inline ExprHandle load(const Ts&... ts) const;
59 
60  private:
61   StmtPtr constructStmt(
62       const std::vector<VarPtr>& args,
63       const ExprPtr& body,
64       const std::vector<ExprPtr>& reduce_dims,
65       const std::vector<VarPtr>& reduce_args) const;
66 
67   BufPtr buf_;
68   StmtPtr stmt_;
69 };
70 
71 TORCH_API Tensor Compute(
72     const std::string& func_name,
73     const std::vector<ExprHandle>& dims,
74     const std::optional<std::vector<ExprHandle>>& strides,
75     const std::function<ExprHandle(const VarHandle&)>& body_func);
76 TORCH_API Tensor Compute(
77     const std::string& func_name,
78     const std::vector<ExprHandle>& dims,
79     const std::function<ExprHandle(const VarHandle&)>& body_func);
80 TORCH_API Tensor Compute(
81     const std::string& func_name,
82     const std::vector<ExprHandle>& dims,
83     const std::optional<std::vector<ExprHandle>>& strides,
84     const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
85         body_func);
86 TORCH_API Tensor Compute(
87     const std::string& func_name,
88     const std::vector<ExprHandle>& dims,
89     const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
90         body_func);
91 TORCH_API Tensor Compute(
92     const std::string& func_name,
93     const std::vector<ExprHandle>& dims,
94     const std::optional<std::vector<ExprHandle>>& strides,
95     const std::function<
96         ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
97         body_func);
98 TORCH_API Tensor Compute(
99     const std::string& func_name,
100     const std::vector<ExprHandle>& dims,
101     const std::function<
102         ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
103         body_func);
104 TORCH_API Tensor Compute(
105     const std::string& func_name,
106     const std::vector<ExprHandle>& dims,
107     const std::optional<std::vector<ExprHandle>>& strides,
108     const std::function<ExprHandle(
109         const VarHandle&,
110         const VarHandle&,
111         const VarHandle&,
112         const VarHandle&)>& body_func);
113 TORCH_API Tensor Compute(
114     const std::string& func_name,
115     const std::vector<ExprHandle>& dims,
116     const std::function<ExprHandle(
117         const VarHandle&,
118         const VarHandle&,
119         const VarHandle&,
120         const VarHandle&)>& body_func);
121 TORCH_API Tensor Compute(
122     const std::string& func_name,
123     const std::vector<ExprHandle>& dims,
124     const std::optional<std::vector<ExprHandle>>& strides,
125     const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
126 TORCH_API Tensor Compute(
127     const std::string& func_name,
128     const std::vector<ExprHandle>& dims,
129     const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
130 
create_index_vars(const std::vector<ExprHandle> & dims)131 inline std::vector<VarHandle> create_index_vars(
132     const std::vector<ExprHandle>& dims) {
133   std::vector<VarHandle> vars;
134   vars.reserve(dims.size());
135   for (const ExprHandle& dim : dims) {
136     vars.emplace_back(alloc<Var>(
137         "i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt));
138   }
139   return vars;
140 }
141 
142 // Handle reductions over a Reducer and a body_func which produces values.
143 template <typename InitFunc, typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const InitFunc & init_func,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)144 Tensor Reduce(
145     const std::string& func_name,
146     const std::vector<ExprHandle>& dims,
147     const std::optional<std::vector<ExprHandle>>& strides,
148     const Reducer& reducer,
149     const InitFunc& init_func,
150     const BodyFunc& body_func,
151     const std::vector<ExprHandle>& reduce_dims) {
152   std::vector<VarHandle> vars = create_index_vars(dims);
153   std::vector<VarHandle> reduce_vars = create_index_vars(reduce_dims);
154 
155   // If reduce_vars is empty, then it's not a reduction, but rather a simple
156   // copy
157   if (reduce_vars.empty()) {
158     ExprHandle body = Reducer::getReduceBody(body_func, vars);
159     BufHandle func_result =
160         Buf::make(func_name, dims, body.dtype(), std::nullopt, strides);
161     return Tensor(std::move(func_result), vars, std::move(body));
162   }
163 
164   std::vector<VarHandle> all_vars;
165   all_vars.insert(all_vars.end(), vars.begin(), vars.end());
166   all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
167 
168   ExprHandle body = Reducer::getReduceBody(body_func, all_vars);
169   std::vector<ExprHandle> output_args(vars.begin(), vars.end());
170   ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars));
171   BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr);
172 
173   ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars);
174   if (body.dtype() == kBFloat16) {
175     ExprHandle init_expr_acc = Cast::make(kFloat, init_func(vars));
176     BufHandle func_result_acc =
177         Buf::make(func_name + "_acc", dims, kFloat, init_expr_acc);
178     reduce_op = reducer(
179         func_result,
180         std::move(func_result_acc),
181         body,
182         output_args,
183         reduce_vars);
184   }
185 
186   Tensor t = Tensor(
187       std::move(func_result),
188       vars,
189       reduce_dims,
190       reduce_vars,
191       std::move(reduce_op));
192   return t;
193 }
194 template <typename InitFunc, typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const InitFunc & init_func,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)195 Tensor Reduce(
196     const std::string& func_name,
197     const std::vector<ExprHandle>& dims,
198     const Reducer& reducer,
199     const InitFunc& init_func,
200     const BodyFunc& body_func,
201     const std::vector<ExprHandle>& reduce_dims) {
202   return Reduce<InitFunc, BodyFunc>(
203       func_name,
204       dims,
205       std::nullopt,
206       reducer,
207       init_func,
208       body_func,
209       reduce_dims);
210 }
211 
212 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)213 Tensor Reduce(
214     const std::string& func_name,
215     const std::vector<ExprHandle>& dims,
216     const std::optional<std::vector<ExprHandle>>& strides,
217     const Reducer& reducer,
218     const BodyFunc& body_func,
219     const std::vector<ExprHandle>& reduce_dims) {
220   return Reduce(
221       func_name,
222       dims,
223       strides,
224       reducer,
225       [&](ParameterList& p [[maybe_unused]]) {
226         return ExprHandle(reducer.initializer());
227       },
228       body_func,
229       reduce_dims);
230 }
231 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)232 Tensor Reduce(
233     const std::string& func_name,
234     const std::vector<ExprHandle>& dims,
235     const Reducer& reducer,
236     const BodyFunc& body_func,
237     const std::vector<ExprHandle>& reduce_dims) {
238   return Reduce<BodyFunc>(
239       func_name, dims, std::nullopt, reducer, body_func, reduce_dims);
240 }
241 
242 // Overload which allows inline lambda functions for the body_func.
243 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BodyFunc && body_func,const std::vector<ExprHandle> & reduce_dims)244 Tensor Reduce(
245     const std::string& func_name,
246     const std::vector<ExprHandle>& dims,
247     const std::optional<std::vector<ExprHandle>>& strides,
248     const Reducer& reducer,
249     const BodyFunc&& body_func,
250     const std::vector<ExprHandle>& reduce_dims) {
251   return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims);
252 }
253 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BodyFunc && body_func,const std::vector<ExprHandle> & reduce_dims)254 Tensor Reduce(
255     const std::string& func_name,
256     const std::vector<ExprHandle>& dims,
257     const Reducer& reducer,
258     const BodyFunc&& body_func,
259     const std::vector<ExprHandle>& reduce_dims) {
260   return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims);
261 }
262 
263 TORCH_API Tensor Reduce(
264     const std::string& name,
265     const std::vector<ExprHandle>& dims,
266     const std::optional<std::vector<ExprHandle>>& strides,
267     const Reducer& reducer,
268     const BufHandle& buffer,
269     const std::vector<ExprHandle>& reduce_dims);
270 TORCH_API Tensor Reduce(
271     const std::string& name,
272     const std::vector<ExprHandle>& dims,
273     const Reducer& reducer,
274     const BufHandle& buffer,
275     const std::vector<ExprHandle>& reduce_dims);
276 
277 // Overload for the common case of all dimensions of a previously Computed
278 // Tensor.
279 TORCH_API Tensor Reduce(
280     const std::string& func_name,
281     const std::vector<ExprHandle>& dims,
282     const std::optional<std::vector<ExprHandle>>& strides,
283     const Reducer& reducer,
284     const Tensor& tensor,
285     const std::vector<ExprHandle>& reduce_dims);
286 TORCH_API Tensor Reduce(
287     const std::string& func_name,
288     const std::vector<ExprHandle>& dims,
289     const Reducer& reducer,
290     const Tensor& tensor,
291     const std::vector<ExprHandle>& reduce_dims);
292 
293 template <typename... Ts>
load(const Ts &...ts)294 inline ExprHandle Tensor::load(const Ts&... ts) const {
295   std::vector<ExprHandle> params({ExprHandle(ts)...});
296   return Load::make(BufHandle(this->buf()), params);
297 }
298 
299 template <typename T>
load(const std::vector<T> & args)300 inline ExprHandle Tensor::load(const std::vector<T>& args) const {
301   std::vector<ExprHandle> params(args.begin(), args.end());
302   return Load::make(BufHandle(this->buf()), params);
303 }
304 
305 template <typename... Ts>
load(const Ts &...ts)306 inline ExprHandle BufHandle::load(const Ts&... ts) const {
307   std::vector<ExprHandle> params({ExprHandle(ts)...});
308   return ExprHandle(alloc<Load>(node(), ExprHandleVectorToExprVector(params)));
309 }
310 
311 template <typename T>
load(const std::vector<T> & args)312 inline ExprHandle BufHandle::load(const std::vector<T>& args) const {
313   std::vector<ExprHandle> params(args.begin(), args.end());
314   return ExprHandle(alloc<Load>(node(), ExprHandleVectorToExprVector(params)));
315 }
316 
load(const std::vector<ExprHandle> & args)317 inline ExprHandle BufHandle::load(const std::vector<ExprHandle>& args) const {
318   return this->template load<ExprHandle>(args);
319 }
320 
321 } // namespace torch::jit::tensorexpr
322