1 #pragma once
2
3 #include <torch/csrc/Export.h>
4 #include <functional>
5 #include <utility>
6 #include <vector>
7
8 #include <torch/csrc/jit/tensorexpr/expr.h>
9 #include <torch/csrc/jit/tensorexpr/reduction.h>
10
11 namespace torch::jit::tensorexpr {
12
13 class TORCH_API Tensor {
14 public:
Tensor(BufPtr buf,const std::vector<VarPtr> & args,const ExprPtr & body)15 Tensor(BufPtr buf, const std::vector<VarPtr>& args, const ExprPtr& body)
16 : buf_(std::move(buf)) {
17 stmt_ = constructStmt(args, body, {}, {});
18 }
Tensor(BufHandle buf,const std::vector<VarHandle> & args,ExprHandle body)19 Tensor(BufHandle buf, const std::vector<VarHandle>& args, ExprHandle body)
20 : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {}
21
Tensor(BufPtr buf,const std::vector<VarPtr> & args,const std::vector<ExprPtr> & reduce_dims,const std::vector<VarPtr> & reduce_args,const ExprPtr & body)22 Tensor(
23 BufPtr buf,
24 const std::vector<VarPtr>& args,
25 const std::vector<ExprPtr>& reduce_dims,
26 const std::vector<VarPtr>& reduce_args,
27 const ExprPtr& body)
28 : buf_(std::move(buf)) {
29 stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
30 }
Tensor(BufHandle buf,const std::vector<VarHandle> & args,const std::vector<ExprHandle> & reduce_dims,const std::vector<VarHandle> & reduce_args,ExprHandle body)31 Tensor(
32 BufHandle buf,
33 const std::vector<VarHandle>& args,
34 const std::vector<ExprHandle>& reduce_dims,
35 const std::vector<VarHandle>& reduce_args,
36 ExprHandle body)
37 : Tensor(
38 buf.node(),
39 VarHandleVectorToVarVector(args),
40 ExprHandleVectorToExprVector(reduce_dims),
41 VarHandleVectorToVarVector(reduce_args),
42 body.node()) {}
43
Tensor(BufPtr buf,StmtPtr stmt)44 Tensor(BufPtr buf, StmtPtr stmt)
45 : buf_(std::move(buf)), stmt_(std::move(stmt)) {}
46
buf()47 BufPtr buf() const {
48 return buf_;
49 }
50
stmt()51 StmtPtr stmt() const {
52 return stmt_;
53 }
54
55 template <typename T>
56 inline ExprHandle load(const std::vector<T>& args) const;
57 template <typename... Ts>
58 inline ExprHandle load(const Ts&... ts) const;
59
60 private:
61 StmtPtr constructStmt(
62 const std::vector<VarPtr>& args,
63 const ExprPtr& body,
64 const std::vector<ExprPtr>& reduce_dims,
65 const std::vector<VarPtr>& reduce_args) const;
66
67 BufPtr buf_;
68 StmtPtr stmt_;
69 };
70
71 TORCH_API Tensor Compute(
72 const std::string& func_name,
73 const std::vector<ExprHandle>& dims,
74 const std::optional<std::vector<ExprHandle>>& strides,
75 const std::function<ExprHandle(const VarHandle&)>& body_func);
76 TORCH_API Tensor Compute(
77 const std::string& func_name,
78 const std::vector<ExprHandle>& dims,
79 const std::function<ExprHandle(const VarHandle&)>& body_func);
80 TORCH_API Tensor Compute(
81 const std::string& func_name,
82 const std::vector<ExprHandle>& dims,
83 const std::optional<std::vector<ExprHandle>>& strides,
84 const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
85 body_func);
86 TORCH_API Tensor Compute(
87 const std::string& func_name,
88 const std::vector<ExprHandle>& dims,
89 const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
90 body_func);
91 TORCH_API Tensor Compute(
92 const std::string& func_name,
93 const std::vector<ExprHandle>& dims,
94 const std::optional<std::vector<ExprHandle>>& strides,
95 const std::function<
96 ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
97 body_func);
98 TORCH_API Tensor Compute(
99 const std::string& func_name,
100 const std::vector<ExprHandle>& dims,
101 const std::function<
102 ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
103 body_func);
104 TORCH_API Tensor Compute(
105 const std::string& func_name,
106 const std::vector<ExprHandle>& dims,
107 const std::optional<std::vector<ExprHandle>>& strides,
108 const std::function<ExprHandle(
109 const VarHandle&,
110 const VarHandle&,
111 const VarHandle&,
112 const VarHandle&)>& body_func);
113 TORCH_API Tensor Compute(
114 const std::string& func_name,
115 const std::vector<ExprHandle>& dims,
116 const std::function<ExprHandle(
117 const VarHandle&,
118 const VarHandle&,
119 const VarHandle&,
120 const VarHandle&)>& body_func);
121 TORCH_API Tensor Compute(
122 const std::string& func_name,
123 const std::vector<ExprHandle>& dims,
124 const std::optional<std::vector<ExprHandle>>& strides,
125 const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
126 TORCH_API Tensor Compute(
127 const std::string& func_name,
128 const std::vector<ExprHandle>& dims,
129 const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
130
create_index_vars(const std::vector<ExprHandle> & dims)131 inline std::vector<VarHandle> create_index_vars(
132 const std::vector<ExprHandle>& dims) {
133 std::vector<VarHandle> vars;
134 vars.reserve(dims.size());
135 for (const ExprHandle& dim : dims) {
136 vars.emplace_back(alloc<Var>(
137 "i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt));
138 }
139 return vars;
140 }
141
142 // Handle reductions over a Reducer and a body_func which produces values.
143 template <typename InitFunc, typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const InitFunc & init_func,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)144 Tensor Reduce(
145 const std::string& func_name,
146 const std::vector<ExprHandle>& dims,
147 const std::optional<std::vector<ExprHandle>>& strides,
148 const Reducer& reducer,
149 const InitFunc& init_func,
150 const BodyFunc& body_func,
151 const std::vector<ExprHandle>& reduce_dims) {
152 std::vector<VarHandle> vars = create_index_vars(dims);
153 std::vector<VarHandle> reduce_vars = create_index_vars(reduce_dims);
154
155 // If reduce_vars is empty, then it's not a reduction, but rather a simple
156 // copy
157 if (reduce_vars.empty()) {
158 ExprHandle body = Reducer::getReduceBody(body_func, vars);
159 BufHandle func_result =
160 Buf::make(func_name, dims, body.dtype(), std::nullopt, strides);
161 return Tensor(std::move(func_result), vars, std::move(body));
162 }
163
164 std::vector<VarHandle> all_vars;
165 all_vars.insert(all_vars.end(), vars.begin(), vars.end());
166 all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
167
168 ExprHandle body = Reducer::getReduceBody(body_func, all_vars);
169 std::vector<ExprHandle> output_args(vars.begin(), vars.end());
170 ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars));
171 BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr);
172
173 ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars);
174 if (body.dtype() == kBFloat16) {
175 ExprHandle init_expr_acc = Cast::make(kFloat, init_func(vars));
176 BufHandle func_result_acc =
177 Buf::make(func_name + "_acc", dims, kFloat, init_expr_acc);
178 reduce_op = reducer(
179 func_result,
180 std::move(func_result_acc),
181 body,
182 output_args,
183 reduce_vars);
184 }
185
186 Tensor t = Tensor(
187 std::move(func_result),
188 vars,
189 reduce_dims,
190 reduce_vars,
191 std::move(reduce_op));
192 return t;
193 }
194 template <typename InitFunc, typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const InitFunc & init_func,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)195 Tensor Reduce(
196 const std::string& func_name,
197 const std::vector<ExprHandle>& dims,
198 const Reducer& reducer,
199 const InitFunc& init_func,
200 const BodyFunc& body_func,
201 const std::vector<ExprHandle>& reduce_dims) {
202 return Reduce<InitFunc, BodyFunc>(
203 func_name,
204 dims,
205 std::nullopt,
206 reducer,
207 init_func,
208 body_func,
209 reduce_dims);
210 }
211
212 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)213 Tensor Reduce(
214 const std::string& func_name,
215 const std::vector<ExprHandle>& dims,
216 const std::optional<std::vector<ExprHandle>>& strides,
217 const Reducer& reducer,
218 const BodyFunc& body_func,
219 const std::vector<ExprHandle>& reduce_dims) {
220 return Reduce(
221 func_name,
222 dims,
223 strides,
224 reducer,
225 [&](ParameterList& p [[maybe_unused]]) {
226 return ExprHandle(reducer.initializer());
227 },
228 body_func,
229 reduce_dims);
230 }
231 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BodyFunc & body_func,const std::vector<ExprHandle> & reduce_dims)232 Tensor Reduce(
233 const std::string& func_name,
234 const std::vector<ExprHandle>& dims,
235 const Reducer& reducer,
236 const BodyFunc& body_func,
237 const std::vector<ExprHandle>& reduce_dims) {
238 return Reduce<BodyFunc>(
239 func_name, dims, std::nullopt, reducer, body_func, reduce_dims);
240 }
241
242 // Overload which allows inline lambda functions for the body_func.
243 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BodyFunc && body_func,const std::vector<ExprHandle> & reduce_dims)244 Tensor Reduce(
245 const std::string& func_name,
246 const std::vector<ExprHandle>& dims,
247 const std::optional<std::vector<ExprHandle>>& strides,
248 const Reducer& reducer,
249 const BodyFunc&& body_func,
250 const std::vector<ExprHandle>& reduce_dims) {
251 return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims);
252 }
253 template <typename BodyFunc>
Reduce(const std::string & func_name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BodyFunc && body_func,const std::vector<ExprHandle> & reduce_dims)254 Tensor Reduce(
255 const std::string& func_name,
256 const std::vector<ExprHandle>& dims,
257 const Reducer& reducer,
258 const BodyFunc&& body_func,
259 const std::vector<ExprHandle>& reduce_dims) {
260 return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims);
261 }
262
263 TORCH_API Tensor Reduce(
264 const std::string& name,
265 const std::vector<ExprHandle>& dims,
266 const std::optional<std::vector<ExprHandle>>& strides,
267 const Reducer& reducer,
268 const BufHandle& buffer,
269 const std::vector<ExprHandle>& reduce_dims);
270 TORCH_API Tensor Reduce(
271 const std::string& name,
272 const std::vector<ExprHandle>& dims,
273 const Reducer& reducer,
274 const BufHandle& buffer,
275 const std::vector<ExprHandle>& reduce_dims);
276
277 // Overload for the common case of all dimensions of a previously Computed
278 // Tensor.
279 TORCH_API Tensor Reduce(
280 const std::string& func_name,
281 const std::vector<ExprHandle>& dims,
282 const std::optional<std::vector<ExprHandle>>& strides,
283 const Reducer& reducer,
284 const Tensor& tensor,
285 const std::vector<ExprHandle>& reduce_dims);
286 TORCH_API Tensor Reduce(
287 const std::string& func_name,
288 const std::vector<ExprHandle>& dims,
289 const Reducer& reducer,
290 const Tensor& tensor,
291 const std::vector<ExprHandle>& reduce_dims);
292
293 template <typename... Ts>
load(const Ts &...ts)294 inline ExprHandle Tensor::load(const Ts&... ts) const {
295 std::vector<ExprHandle> params({ExprHandle(ts)...});
296 return Load::make(BufHandle(this->buf()), params);
297 }
298
299 template <typename T>
load(const std::vector<T> & args)300 inline ExprHandle Tensor::load(const std::vector<T>& args) const {
301 std::vector<ExprHandle> params(args.begin(), args.end());
302 return Load::make(BufHandle(this->buf()), params);
303 }
304
305 template <typename... Ts>
load(const Ts &...ts)306 inline ExprHandle BufHandle::load(const Ts&... ts) const {
307 std::vector<ExprHandle> params({ExprHandle(ts)...});
308 return ExprHandle(alloc<Load>(node(), ExprHandleVectorToExprVector(params)));
309 }
310
311 template <typename T>
load(const std::vector<T> & args)312 inline ExprHandle BufHandle::load(const std::vector<T>& args) const {
313 std::vector<ExprHandle> params(args.begin(), args.end());
314 return ExprHandle(alloc<Load>(node(), ExprHandleVectorToExprVector(params)));
315 }
316
load(const std::vector<ExprHandle> & args)317 inline ExprHandle BufHandle::load(const std::vector<ExprHandle>& args) const {
318 return this->template load<ExprHandle>(args);
319 }
320
321 } // namespace torch::jit::tensorexpr
322