1 #pragma once
2
3 #include <torch/csrc/jit/tensorexpr/expr.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
6 #include <torch/csrc/jit/tensorexpr/stmt.h>
7 #include <torch/csrc/jit/tensorexpr/types.h>
8
9 #include <functional>
10 #include <utility>
11 #include <vector>
12
13 namespace torch::jit::tensorexpr {
14
15 using ParameterList = const std::vector<VarHandle>;
16 using ReduceInteraction = std::function<ExprHandle(ExprHandle, ExprHandle)>;
17
18 // A Reducer is a user interface describing a particular reduction
19 // operation. It has three components: An initialization value, a way of
20 // interacting each value with the accumulation, and a method for obtaining the
21 // current value to be reduced. It is materialized into a ReduceOp when loop
22 // variables are known.
23 class TORCH_API Reducer {
24 public:
Reducer(ExprHandle init,ReduceInteraction & interaction)25 Reducer(ExprHandle init, ReduceInteraction& interaction)
26 : init_(init.node()), interaction_(interaction) {}
27
28 template <typename RI>
Reducer(ExprHandle init,RI interaction)29 Reducer(ExprHandle init, RI interaction)
30 : init_(init.node()), interaction_(std::move(interaction)) {}
31
initializer()32 ExprPtr initializer() const {
33 return init_;
34 }
35
36 ExprHandle operator()(
37 const BufHandle& result_buf,
38 ExprHandle body,
39 const std::vector<ExprHandle>& output,
40 const std::vector<VarHandle>& inner) const;
41
42 ReduceOpPtr operator()(
43 const BufPtr& result_buf,
44 ExprPtr body,
45 const std::vector<ExprPtr>& output,
46 const std::vector<VarPtr>& inner) const;
47
48 ExprHandle operator()(
49 const BufHandle& result_buf,
50 BufHandle acc_buf,
51 const ExprHandle& body,
52 const std::vector<ExprHandle>& output,
53 const std::vector<VarHandle>& inner) const;
54
55 // Polymorphic handling of Body functions with a variety of parameters.
getReduceBody(const std::function<ExprHandle (ParameterList &)> & func,const std::vector<VarHandle> & vars)56 static ExprHandle getReduceBody(
57 const std::function<ExprHandle(ParameterList&)>& func,
58 const std::vector<VarHandle>& vars) {
59 return func(vars);
60 }
61
getReduceBody(const std::function<ExprHandle (const VarHandle &)> & func,const std::vector<VarHandle> & vars)62 static ExprHandle getReduceBody(
63 const std::function<ExprHandle(const VarHandle&)>& func,
64 const std::vector<VarHandle>& vars) {
65 if (vars.size() != 1) {
66 throw malformed_input("mismatch between reduce body and arg size (1)");
67 }
68
69 return func(vars[0]);
70 }
71
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)72 static ExprHandle getReduceBody(
73 const std::function<ExprHandle(const VarHandle&, const VarHandle&)>& func,
74 const std::vector<VarHandle>& vars) {
75 if (vars.size() != 2) {
76 throw malformed_input("mismatch between reduce body and arg size (2)");
77 }
78 return func(vars[0], vars[1]);
79 }
80
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)81 static ExprHandle getReduceBody(
82 const std::function<
83 ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
84 func,
85 const std::vector<VarHandle>& vars) {
86 if (vars.size() != 3) {
87 throw malformed_input("mismatch between reduce body and arg size (3)");
88 }
89 return func(vars[0], vars[1], vars[2]);
90 }
91
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)92 static ExprHandle getReduceBody(
93 const std::function<ExprHandle(
94 const VarHandle&,
95 const VarHandle&,
96 const VarHandle&,
97 const VarHandle&)>& func,
98 const std::vector<VarHandle>& vars) {
99 if (vars.size() != 4) {
100 throw malformed_input("mismatch between reduce body and arg size (4)");
101 }
102 return func(vars[0], vars[1], vars[2], vars[3]);
103 }
104
105 // Completes the reduction operator by applying the interaction function to
106 // the accumulation and the body expression.
complete(const BufPtr & accumulator,const ReduceInteraction & interaction,ExprHandle body,const std::vector<ExprPtr> & output_args,const std::vector<VarPtr> & reduce_args)107 static ExprPtr complete(
108 const BufPtr& accumulator,
109 const ReduceInteraction& interaction,
110 ExprHandle body,
111 const std::vector<ExprPtr>& output_args,
112 const std::vector<VarPtr>& reduce_args) {
113 ExprHandle accum =
114 ExprHandle(alloc<Load>(body.dtype(), accumulator, output_args));
115 auto e = interaction(std::move(accum), std::move(body));
116 return e.node();
117 }
complete(const BufHandle & accumulator,const ReduceInteraction & interaction,ExprHandle body,const std::vector<ExprHandle> & output_args,const std::vector<VarHandle> & reduce_args)118 static ExprHandle complete(
119 const BufHandle& accumulator,
120 const ReduceInteraction& interaction,
121 ExprHandle body,
122 const std::vector<ExprHandle>& output_args,
123 const std::vector<VarHandle>& reduce_args) {
124 ExprHandle accum = Load::make(body.dtype(), accumulator, output_args);
125 auto e = interaction(std::move(accum), std::move(body));
126 return e;
127 }
128
129 private:
130 ExprPtr init_;
131 ReduceInteraction interaction_;
132 };
133
134 // An expression representing a Reduction operation (e.g. Sum, Max) broken into
135 // it's component parts: initialization, accumulation var, acquisition of value
136 // to be reduced and interaction.
137 //
138 // This is intended to be expanded in the loopnest and not make it to codegen.
139 class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
140 public:
ReduceOp(const ExprPtr & body,std::vector<VarPtr> reduce_args,Reducer reducer)141 ReduceOp(
142 const ExprPtr& body,
143 std::vector<VarPtr> reduce_args,
144 Reducer reducer)
145 : ExprNodeBase(body->dtype()),
146 body_(body),
147 reduce_args_(std::move(reduce_args)),
148 reducer_(std::move(reducer)) {
149 result_buf_ = nullptr;
150 acc_buf_ = nullptr;
151 ri_operand_ = nullptr;
152 }
153
ReduceOp(const ExprPtr & body,std::vector<VarPtr> reduce_args,BufPtr result_buf,BufPtr acc_buf,ExprPtr ri_operand,Reducer reducer)154 ReduceOp(
155 const ExprPtr& body,
156 std::vector<VarPtr> reduce_args,
157 BufPtr result_buf,
158 BufPtr acc_buf,
159 ExprPtr ri_operand,
160 Reducer reducer)
161 : ExprNodeBase(body->dtype()),
162 body_(body),
163 reduce_args_(std::move(reduce_args)),
164 result_buf_(std::move(result_buf)),
165 acc_buf_(std::move(acc_buf)),
166 ri_operand_(std::move(ri_operand)),
167 reducer_(std::move(reducer)) {}
168
169 static ExprHandle make(
170 ExprHandle body,
171 const std::vector<VarHandle>& reduce_args,
172 const Reducer& reducer);
173
174 static ExprHandle make(
175 ExprHandle body,
176 const std::vector<VarHandle>& reduce_args,
177 BufHandle result_buf,
178 BufHandle acc_buf,
179 ExprHandle ri_operand,
180 const Reducer& reducer);
181
182 // return the body expression which obtains the value to be reduced.
body()183 ExprPtr body() const {
184 return body_;
185 }
186
187 // Returns the original Reducer factory that can create ReduceOps.
reducer()188 const Reducer& reducer() const {
189 return reducer_;
190 }
191
192 // returns variables associated with the axes of reduction.
reduce_args()193 const std::vector<VarPtr>& reduce_args() const {
194 return reduce_args_;
195 }
196
setAccBuf(BufHandle acc_buf)197 void setAccBuf(BufHandle acc_buf) {
198 acc_buf_ = acc_buf.node();
199 }
getAccBuf()200 BufPtr getAccBuf() {
201 return acc_buf_;
202 }
203
setResultBuf(BufHandle buf)204 void setResultBuf(BufHandle buf) {
205 result_buf_ = buf.node();
206 }
getResultBuf()207 BufPtr getResultBuf() {
208 return result_buf_;
209 }
210
setRiOperand(ExprHandle ri_operand)211 void setRiOperand(ExprHandle ri_operand) {
212 ri_operand_ = ri_operand.node();
213 }
getRiOperand()214 ExprPtr getRiOperand() {
215 return ri_operand_;
216 }
217
218 private:
219 // body_ = reducer_->interaction_(result_buf_, ri_operand_)
220 ExprPtr body_;
221 std::vector<VarPtr> reduce_args_;
222
223 BufPtr result_buf_;
224 BufPtr acc_buf_;
225 ExprPtr ri_operand_;
226
227 const Reducer reducer_;
228 };
229
230 class Sum : public Reducer {
231 public:
Sum()232 Sum()
233 : Reducer(ExprHandle(0), [](const ExprHandle& a, const ExprHandle& b) {
234 return a + b;
235 }) {}
236 };
237
maximumVal(ScalarType type)238 inline ExprHandle maximumVal(ScalarType type) {
239 switch (type) {
240 #define MAX_BY_TYPE_CASE(Type, Name) \
241 case ScalarType::Name: \
242 return ExprHandle(std::numeric_limits<Type>::max());
243 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
244 #undef MAX_BY_TYPE_CASE
245 default:
246 throw unsupported_dtype();
247 }
248 return ExprHandle();
249 }
250
minimumVal(ScalarType type)251 inline ExprHandle minimumVal(ScalarType type) {
252 switch (type) {
253 #define MAX_BY_TYPE_CASE(Type, Name) \
254 case ScalarType::Name: \
255 return ExprHandle(std::numeric_limits<Type>::min());
256 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
257 #undef MAX_BY_TYPE_CASE
258 default:
259 throw unsupported_dtype();
260 }
261 }
262
263 class Maximum : public Reducer {
264 public:
265 // TODO possible to remove this arg by deferring the init value until we
266 // know the dtype of the body.
Maximum(Dtype dtype)267 Maximum(Dtype dtype)
268 : Reducer(
269 minimumVal(dtype.scalar_type()),
270 [](const ExprHandle& a, const ExprHandle& b) {
271 return Max::make(a, b, true);
272 }) {}
Maximum(ExprHandle initializer)273 Maximum(ExprHandle initializer)
274 : Reducer(
275 std::move(initializer),
276 [](const ExprHandle& a, const ExprHandle& b) {
277 return Max::make(a, b, true);
278 }) {}
279 };
280
281 class Minimum : public Reducer {
282 public:
Minimum(Dtype dtype)283 Minimum(Dtype dtype)
284 : Reducer(
285 maximumVal(dtype.scalar_type()),
286 [](const ExprHandle& a, const ExprHandle& b) {
287 return Min::make(a, b, true);
288 }) {}
Minimum(const ExprHandle & initializer)289 Minimum(const ExprHandle& initializer)
290 : Reducer(initializer, [](const ExprHandle& a, const ExprHandle& b) {
291 return Min::make(a, b, true);
292 }) {}
293 };
294
295 class ReductionExpander : public IRMutator {
296 public:
expand(const StmtPtr & s)297 StmtPtr expand(const StmtPtr& s) {
298 return s->accept_mutator(this);
299 }
300
mutate(const ReduceOpPtr & v)301 ExprPtr mutate(const ReduceOpPtr& v) override {
302 return v->body();
303 }
304 };
305
306 } // namespace torch::jit::tensorexpr
307