xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/reduction.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/expr.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
6 #include <torch/csrc/jit/tensorexpr/stmt.h>
7 #include <torch/csrc/jit/tensorexpr/types.h>
8 
9 #include <functional>
10 #include <utility>
11 #include <vector>
12 
13 namespace torch::jit::tensorexpr {
14 
15 using ParameterList = const std::vector<VarHandle>;
16 using ReduceInteraction = std::function<ExprHandle(ExprHandle, ExprHandle)>;
17 
18 // A Reducer is a user interface describing a particular reduction
19 // operation. It has three components: An initialization value, a way of
20 // interacting each value with the accumulation, and a method for obtaining the
21 // current value to be reduced. It is materialized into a ReduceOp when loop
22 // variables are known.
23 class TORCH_API Reducer {
24  public:
Reducer(ExprHandle init,ReduceInteraction & interaction)25   Reducer(ExprHandle init, ReduceInteraction& interaction)
26       : init_(init.node()), interaction_(interaction) {}
27 
28   template <typename RI>
Reducer(ExprHandle init,RI interaction)29   Reducer(ExprHandle init, RI interaction)
30       : init_(init.node()), interaction_(std::move(interaction)) {}
31 
initializer()32   ExprPtr initializer() const {
33     return init_;
34   }
35 
36   ExprHandle operator()(
37       const BufHandle& result_buf,
38       ExprHandle body,
39       const std::vector<ExprHandle>& output,
40       const std::vector<VarHandle>& inner) const;
41 
42   ReduceOpPtr operator()(
43       const BufPtr& result_buf,
44       ExprPtr body,
45       const std::vector<ExprPtr>& output,
46       const std::vector<VarPtr>& inner) const;
47 
48   ExprHandle operator()(
49       const BufHandle& result_buf,
50       BufHandle acc_buf,
51       const ExprHandle& body,
52       const std::vector<ExprHandle>& output,
53       const std::vector<VarHandle>& inner) const;
54 
55   // Polymorphic handling of Body functions with a variety of parameters.
getReduceBody(const std::function<ExprHandle (ParameterList &)> & func,const std::vector<VarHandle> & vars)56   static ExprHandle getReduceBody(
57       const std::function<ExprHandle(ParameterList&)>& func,
58       const std::vector<VarHandle>& vars) {
59     return func(vars);
60   }
61 
getReduceBody(const std::function<ExprHandle (const VarHandle &)> & func,const std::vector<VarHandle> & vars)62   static ExprHandle getReduceBody(
63       const std::function<ExprHandle(const VarHandle&)>& func,
64       const std::vector<VarHandle>& vars) {
65     if (vars.size() != 1) {
66       throw malformed_input("mismatch between reduce body and arg size (1)");
67     }
68 
69     return func(vars[0]);
70   }
71 
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)72   static ExprHandle getReduceBody(
73       const std::function<ExprHandle(const VarHandle&, const VarHandle&)>& func,
74       const std::vector<VarHandle>& vars) {
75     if (vars.size() != 2) {
76       throw malformed_input("mismatch between reduce body and arg size (2)");
77     }
78     return func(vars[0], vars[1]);
79   }
80 
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)81   static ExprHandle getReduceBody(
82       const std::function<
83           ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
84           func,
85       const std::vector<VarHandle>& vars) {
86     if (vars.size() != 3) {
87       throw malformed_input("mismatch between reduce body and arg size (3)");
88     }
89     return func(vars[0], vars[1], vars[2]);
90   }
91 
getReduceBody(const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & func,const std::vector<VarHandle> & vars)92   static ExprHandle getReduceBody(
93       const std::function<ExprHandle(
94           const VarHandle&,
95           const VarHandle&,
96           const VarHandle&,
97           const VarHandle&)>& func,
98       const std::vector<VarHandle>& vars) {
99     if (vars.size() != 4) {
100       throw malformed_input("mismatch between reduce body and arg size (4)");
101     }
102     return func(vars[0], vars[1], vars[2], vars[3]);
103   }
104 
105   // Completes the reduction operator by applying the interaction function to
106   // the accumulation and the body expression.
complete(const BufPtr & accumulator,const ReduceInteraction & interaction,ExprHandle body,const std::vector<ExprPtr> & output_args,const std::vector<VarPtr> & reduce_args)107   static ExprPtr complete(
108       const BufPtr& accumulator,
109       const ReduceInteraction& interaction,
110       ExprHandle body,
111       const std::vector<ExprPtr>& output_args,
112       const std::vector<VarPtr>& reduce_args) {
113     ExprHandle accum =
114         ExprHandle(alloc<Load>(body.dtype(), accumulator, output_args));
115     auto e = interaction(std::move(accum), std::move(body));
116     return e.node();
117   }
complete(const BufHandle & accumulator,const ReduceInteraction & interaction,ExprHandle body,const std::vector<ExprHandle> & output_args,const std::vector<VarHandle> & reduce_args)118   static ExprHandle complete(
119       const BufHandle& accumulator,
120       const ReduceInteraction& interaction,
121       ExprHandle body,
122       const std::vector<ExprHandle>& output_args,
123       const std::vector<VarHandle>& reduce_args) {
124     ExprHandle accum = Load::make(body.dtype(), accumulator, output_args);
125     auto e = interaction(std::move(accum), std::move(body));
126     return e;
127   }
128 
129  private:
130   ExprPtr init_;
131   ReduceInteraction interaction_;
132 };
133 
134 // An expression representing a Reduction operation (e.g. Sum, Max) broken into
135 // it's component parts: initialization, accumulation var, acquisition of value
136 // to be reduced and interaction.
137 //
138 // This is intended to be expanded in the loopnest and not make it to codegen.
139 class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
140  public:
ReduceOp(const ExprPtr & body,std::vector<VarPtr> reduce_args,Reducer reducer)141   ReduceOp(
142       const ExprPtr& body,
143       std::vector<VarPtr> reduce_args,
144       Reducer reducer)
145       : ExprNodeBase(body->dtype()),
146         body_(body),
147         reduce_args_(std::move(reduce_args)),
148         reducer_(std::move(reducer)) {
149     result_buf_ = nullptr;
150     acc_buf_ = nullptr;
151     ri_operand_ = nullptr;
152   }
153 
ReduceOp(const ExprPtr & body,std::vector<VarPtr> reduce_args,BufPtr result_buf,BufPtr acc_buf,ExprPtr ri_operand,Reducer reducer)154   ReduceOp(
155       const ExprPtr& body,
156       std::vector<VarPtr> reduce_args,
157       BufPtr result_buf,
158       BufPtr acc_buf,
159       ExprPtr ri_operand,
160       Reducer reducer)
161       : ExprNodeBase(body->dtype()),
162         body_(body),
163         reduce_args_(std::move(reduce_args)),
164         result_buf_(std::move(result_buf)),
165         acc_buf_(std::move(acc_buf)),
166         ri_operand_(std::move(ri_operand)),
167         reducer_(std::move(reducer)) {}
168 
169   static ExprHandle make(
170       ExprHandle body,
171       const std::vector<VarHandle>& reduce_args,
172       const Reducer& reducer);
173 
174   static ExprHandle make(
175       ExprHandle body,
176       const std::vector<VarHandle>& reduce_args,
177       BufHandle result_buf,
178       BufHandle acc_buf,
179       ExprHandle ri_operand,
180       const Reducer& reducer);
181 
182   // return the body expression which obtains the value to be reduced.
body()183   ExprPtr body() const {
184     return body_;
185   }
186 
187   // Returns the original Reducer factory that can create ReduceOps.
reducer()188   const Reducer& reducer() const {
189     return reducer_;
190   }
191 
192   // returns variables associated with the axes of reduction.
reduce_args()193   const std::vector<VarPtr>& reduce_args() const {
194     return reduce_args_;
195   }
196 
setAccBuf(BufHandle acc_buf)197   void setAccBuf(BufHandle acc_buf) {
198     acc_buf_ = acc_buf.node();
199   }
getAccBuf()200   BufPtr getAccBuf() {
201     return acc_buf_;
202   }
203 
setResultBuf(BufHandle buf)204   void setResultBuf(BufHandle buf) {
205     result_buf_ = buf.node();
206   }
getResultBuf()207   BufPtr getResultBuf() {
208     return result_buf_;
209   }
210 
setRiOperand(ExprHandle ri_operand)211   void setRiOperand(ExprHandle ri_operand) {
212     ri_operand_ = ri_operand.node();
213   }
getRiOperand()214   ExprPtr getRiOperand() {
215     return ri_operand_;
216   }
217 
218  private:
219   // body_ = reducer_->interaction_(result_buf_, ri_operand_)
220   ExprPtr body_;
221   std::vector<VarPtr> reduce_args_;
222 
223   BufPtr result_buf_;
224   BufPtr acc_buf_;
225   ExprPtr ri_operand_;
226 
227   const Reducer reducer_;
228 };
229 
230 class Sum : public Reducer {
231  public:
Sum()232   Sum()
233       : Reducer(ExprHandle(0), [](const ExprHandle& a, const ExprHandle& b) {
234           return a + b;
235         }) {}
236 };
237 
maximumVal(ScalarType type)238 inline ExprHandle maximumVal(ScalarType type) {
239   switch (type) {
240 #define MAX_BY_TYPE_CASE(Type, Name) \
241   case ScalarType::Name:             \
242     return ExprHandle(std::numeric_limits<Type>::max());
243     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
244 #undef MAX_BY_TYPE_CASE
245     default:
246       throw unsupported_dtype();
247   }
248   return ExprHandle();
249 }
250 
minimumVal(ScalarType type)251 inline ExprHandle minimumVal(ScalarType type) {
252   switch (type) {
253 #define MAX_BY_TYPE_CASE(Type, Name) \
254   case ScalarType::Name:             \
255     return ExprHandle(std::numeric_limits<Type>::min());
256     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE)
257 #undef MAX_BY_TYPE_CASE
258     default:
259       throw unsupported_dtype();
260   }
261 }
262 
263 class Maximum : public Reducer {
264  public:
265   // TODO possible to remove this arg by deferring the init value until we
266   // know the dtype of the body.
Maximum(Dtype dtype)267   Maximum(Dtype dtype)
268       : Reducer(
269             minimumVal(dtype.scalar_type()),
270             [](const ExprHandle& a, const ExprHandle& b) {
271               return Max::make(a, b, true);
272             }) {}
Maximum(ExprHandle initializer)273   Maximum(ExprHandle initializer)
274       : Reducer(
275             std::move(initializer),
276             [](const ExprHandle& a, const ExprHandle& b) {
277               return Max::make(a, b, true);
278             }) {}
279 };
280 
281 class Minimum : public Reducer {
282  public:
Minimum(Dtype dtype)283   Minimum(Dtype dtype)
284       : Reducer(
285             maximumVal(dtype.scalar_type()),
286             [](const ExprHandle& a, const ExprHandle& b) {
287               return Min::make(a, b, true);
288             }) {}
Minimum(const ExprHandle & initializer)289   Minimum(const ExprHandle& initializer)
290       : Reducer(initializer, [](const ExprHandle& a, const ExprHandle& b) {
291           return Min::make(a, b, true);
292         }) {}
293 };
294 
295 class ReductionExpander : public IRMutator {
296  public:
expand(const StmtPtr & s)297   StmtPtr expand(const StmtPtr& s) {
298     return s->accept_mutator(this);
299   }
300 
mutate(const ReduceOpPtr & v)301   ExprPtr mutate(const ReduceOpPtr& v) override {
302     return v->body();
303   }
304 };
305 
306 } // namespace torch::jit::tensorexpr
307