xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_cloner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_cloner.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 
7 #include <c10/util/irange.h>
8 
9 namespace torch::jit::tensorexpr {
10 
11 template <
12     typename Op,
13     std::enable_if_t<std::is_same_v<
14         decltype(detail::bin_op_deducer(std::declval<Op>())),
15         void>>* = nullptr>
mutate_binary_op(NodePtr<Op> v,IRCloner * cloner,bool option=false)16 static ExprPtr mutate_binary_op(
17     NodePtr<Op> v,
18     IRCloner* cloner,
19     bool option = false) {
20   ExprPtr lhs_new = v->lhs()->accept_mutator(cloner);
21   ExprPtr rhs_new = v->rhs()->accept_mutator(cloner);
22   IRNodeType expr_type = v->expr_type();
23   switch (expr_type) {
24     case IRNodeType::kAdd:
25       return alloc<Add>(lhs_new, rhs_new);
26     case IRNodeType::kSub:
27       return alloc<Sub>(lhs_new, rhs_new);
28     case IRNodeType::kMul:
29       return alloc<Mul>(lhs_new, rhs_new);
30     case IRNodeType::kDiv:
31       return alloc<Div>(lhs_new, rhs_new);
32     case IRNodeType::kMod:
33       return alloc<Mod>(lhs_new, rhs_new);
34     case IRNodeType::kMax:
35       return alloc<Max>(lhs_new, rhs_new, option);
36     case IRNodeType::kMin:
37       return alloc<Min>(lhs_new, rhs_new, option);
38     case IRNodeType::kAnd:
39       return alloc<And>(lhs_new, rhs_new);
40     case IRNodeType::kOr:
41       return alloc<Or>(lhs_new, rhs_new);
42     case IRNodeType::kXor:
43       return alloc<Xor>(lhs_new, rhs_new);
44     case IRNodeType::kLshift:
45       return alloc<Lshift>(lhs_new, rhs_new);
46     case IRNodeType::kRshift:
47       return alloc<Rshift>(lhs_new, rhs_new);
48     default:
49       throw unimplemented_lowering(v);
50   }
51 }
52 
mutate(const AddPtr & v)53 ExprPtr IRCloner::mutate(const AddPtr& v) {
54   return mutate_binary_op(v, this);
55 }
56 
mutate(const SubPtr & v)57 ExprPtr IRCloner::mutate(const SubPtr& v) {
58   return mutate_binary_op(v, this);
59 }
60 
mutate(const MulPtr & v)61 ExprPtr IRCloner::mutate(const MulPtr& v) {
62   return mutate_binary_op(v, this);
63 }
64 
mutate(const DivPtr & v)65 ExprPtr IRCloner::mutate(const DivPtr& v) {
66   return mutate_binary_op(v, this);
67 }
68 
mutate(const ModPtr & v)69 ExprPtr IRCloner::mutate(const ModPtr& v) {
70   return mutate_binary_op(v, this);
71 }
72 
mutate(const AndPtr & v)73 ExprPtr IRCloner::mutate(const AndPtr& v) {
74   return mutate_binary_op(v, this);
75 }
76 
mutate(const OrPtr & v)77 ExprPtr IRCloner::mutate(const OrPtr& v) {
78   return mutate_binary_op(v, this);
79 }
80 
mutate(const XorPtr & v)81 ExprPtr IRCloner::mutate(const XorPtr& v) {
82   return mutate_binary_op(v, this);
83 }
84 
mutate(const LshiftPtr & v)85 ExprPtr IRCloner::mutate(const LshiftPtr& v) {
86   return mutate_binary_op(v, this);
87 }
88 
mutate(const RshiftPtr & v)89 ExprPtr IRCloner::mutate(const RshiftPtr& v) {
90   return mutate_binary_op(v, this);
91 }
92 
mutate(const MaxPtr & v)93 ExprPtr IRCloner::mutate(const MaxPtr& v) {
94   return mutate_binary_op(v, this, v->propagate_nans());
95 }
96 
mutate(const MinPtr & v)97 ExprPtr IRCloner::mutate(const MinPtr& v) {
98   return mutate_binary_op(v, this, v->propagate_nans());
99 }
100 
mutate(const CompareSelectPtr & v)101 ExprPtr IRCloner::mutate(const CompareSelectPtr& v) {
102   ExprPtr lhs_new = v->lhs()->accept_mutator(this);
103   ExprPtr rhs_new = v->rhs()->accept_mutator(this);
104   ExprPtr retval1_new = v->ret_val1()->accept_mutator(this);
105   ExprPtr retval2_new = v->ret_val2()->accept_mutator(this);
106   return alloc<CompareSelect>(
107       lhs_new,
108       rhs_new,
109       retval1_new,
110       retval2_new,
111       v->compare_select_op(),
112       v->bias());
113 }
114 
115 #define IMM_MUTATE_DEFINE(_1, Name)                 \
116   ExprPtr IRCloner::mutate(const Name##ImmPtr& v) { \
117     return v;                                       \
118   }
119 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
120 #undef IMM_MUTATE_DEFINE
121 
mutate(const CastPtr & v)122 ExprPtr IRCloner::mutate(const CastPtr& v) {
123   ExprPtr src_value_new = v->src_value()->accept_mutator(this);
124   return alloc<Cast>(v->dtype(), src_value_new);
125 }
126 
mutate(const BitCastPtr & v)127 ExprPtr IRCloner::mutate(const BitCastPtr& v) {
128   ExprPtr src_value_new = v->src_value()->accept_mutator(this);
129   return alloc<BitCast>(v->dtype(), src_value_new);
130 }
131 
mutate(const RampPtr & v)132 ExprPtr IRCloner::mutate(const RampPtr& v) {
133   ExprPtr base_new = v->base()->accept_mutator(this);
134   ExprPtr stride_new = v->stride()->accept_mutator(this);
135   return alloc<Ramp>(base_new, stride_new, v->lanes());
136 }
137 
mutate(const LoadPtr & v)138 ExprPtr IRCloner::mutate(const LoadPtr& v) {
139   std::vector<ExprPtr> indices_new;
140   indices_new.reserve(v->indices().size());
141   for (const ExprPtr& ind : v->indices()) {
142     indices_new.push_back(ind->accept_mutator(this));
143   }
144   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
145   return alloc<Load>(v->dtype(), buf_new, indices_new);
146 }
147 
148 // We do not clone Vars since the original IR and cloned IR are expected to
149 // share the underlying variables.
mutate(const VarPtr & v)150 ExprPtr IRCloner::mutate(const VarPtr& v) {
151   return v;
152 }
153 
154 // We do not clone Bufs since the original IR and cloned IR are expected to
155 // share the underlying Bufs. In spite of Bufs having expressions as dims and
156 // initializers, this is the expected usage of clone at this point.
157 //
158 // TODO: Revisit this if Bufs need to be cloned as well.
mutate(const BufPtr & v)159 ExprPtr IRCloner::mutate(const BufPtr& v) {
160   return v;
161 }
162 
mutate(const BroadcastPtr & v)163 ExprPtr IRCloner::mutate(const BroadcastPtr& v) {
164   auto lanes = v->lanes();
165   ExprPtr value_new = v->value()->accept_mutator(this);
166   return alloc<Broadcast>(value_new, lanes);
167 }
168 
mutate(const IfThenElsePtr & v)169 ExprPtr IRCloner::mutate(const IfThenElsePtr& v) {
170   ExprPtr condition_new = v->condition()->accept_mutator(this);
171   ExprPtr true_value_new = v->true_value()->accept_mutator(this);
172   ExprPtr false_value_new = v->false_value()->accept_mutator(this);
173 
174   return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
175 }
176 
mutate(const IntrinsicsPtr & v)177 ExprPtr IRCloner::mutate(const IntrinsicsPtr& v) {
178   std::vector<ExprPtr> params_new;
179   params_new.reserve(v->nparams());
180   for (const auto& param : v->params()) {
181     params_new.push_back(param->accept_mutator(this));
182   }
183   return alloc<Intrinsics>(v->op_type(), v->dtype(), params_new);
184 }
185 
mutate(const TermPtr & v)186 ExprPtr IRCloner::mutate(const TermPtr& v) {
187   ExprPtr scalar_new = v->scalar()->accept_mutator(this);
188 
189   std::vector<ExprPtr> variables_new;
190   variables_new.reserve(v->variables().size());
191   for (const auto& t : v->variables()) {
192     variables_new.push_back(t->accept_mutator(this));
193   }
194   return alloc<Term>(v->hasher(), scalar_new, variables_new);
195 }
196 
mutate(const PolynomialPtr & v)197 ExprPtr IRCloner::mutate(const PolynomialPtr& v) {
198   ExprPtr scalar_new = v->scalar()->accept_mutator(this);
199 
200   std::vector<TermPtr> variables_new;
201   variables_new.reserve(v->variables().size());
202   for (const auto& t : v->variables()) {
203     variables_new.push_back(static_to<Term>(t->accept_mutator(this)));
204   }
205   return alloc<Polynomial>(v->hasher(), scalar_new, variables_new);
206 }
207 
mutate(const RoundOffPtr & v)208 ExprPtr IRCloner::mutate(const RoundOffPtr& v) {
209   return alloc<RoundOff>(
210       v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
211 }
212 
mutate(const MaxTermPtr & v)213 ExprPtr IRCloner::mutate(const MaxTermPtr& v) {
214   ExprPtr scalar_new =
215       v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
216 
217   std::vector<ExprPtr> variables_new;
218   variables_new.reserve(v->variables().size());
219   for (const auto& t : v->variables()) {
220     variables_new.push_back(t->accept_mutator(this));
221   }
222   return alloc<MaxTerm>(
223       v->hasher(), scalar_new, v->propagate_nans(), variables_new);
224 }
225 
mutate(const MinTermPtr & v)226 ExprPtr IRCloner::mutate(const MinTermPtr& v) {
227   ExprPtr scalar_new =
228       v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
229 
230   std::vector<ExprPtr> variables_new;
231   variables_new.reserve(v->variables().size());
232   for (const auto& t : v->variables()) {
233     variables_new.push_back(t->accept_mutator(this));
234   }
235   return alloc<MinTerm>(
236       v->hasher(), scalar_new, v->propagate_nans(), variables_new);
237 }
238 
mutate(const ReduceOpPtr & v)239 ExprPtr IRCloner::mutate(const ReduceOpPtr& v) {
240   ExprPtr body_new = v->body()->accept_mutator(this);
241 
242   std::vector<VarPtr> reduce_args_new;
243   reduce_args_new.reserve(v->reduce_args().size());
244   for (const auto& r : v->reduce_args()) {
245     reduce_args_new.push_back(static_to<Var>(r->accept_mutator(this)));
246   }
247 
248   return alloc<ReduceOp>(body_new, reduce_args_new, v->reducer());
249 }
250 
mutate(const ForPtr & v)251 StmtPtr IRCloner::mutate(const ForPtr& v) {
252   auto start_new = v->start()->accept_mutator(this);
253   auto stop_new = v->stop()->accept_mutator(this);
254   auto body_new = v->body()->accept_mutator(this);
255 
256   return alloc<For>(v->var(), start_new, stop_new, body_new, v->loop_options());
257 }
258 
mutate(const BlockPtr & v)259 StmtPtr IRCloner::mutate(const BlockPtr& v) {
260   std::vector<StmtPtr> stmts_new;
261   stmts_new.reserve(v->nstmts());
262   for (const StmtPtr& stmt : *v) {
263     stmts_new.push_back(stmt->accept_mutator(this));
264   }
265   return alloc<Block>(stmts_new);
266 }
267 
mutate(const StorePtr & v)268 StmtPtr IRCloner::mutate(const StorePtr& v) {
269   std::vector<ExprPtr> indices_new;
270   indices_new.reserve(v->indices().size());
271   for (const auto& ind : v->indices()) {
272     indices_new.push_back(ind->accept_mutator(this));
273   }
274   auto value_new = v->value()->accept_mutator(this);
275   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
276   return alloc<Store>(buf_new, indices_new, value_new);
277 }
278 
mutate(const AtomicAddPtr & v)279 StmtPtr IRCloner::mutate(const AtomicAddPtr& v) {
280   std::vector<ExprPtr> indices_new;
281   indices_new.reserve(v->indices().size());
282   for (const auto& ind : v->indices()) {
283     indices_new.push_back(ind->accept_mutator(this));
284   }
285   auto value_new = v->value()->accept_mutator(this);
286   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
287   return alloc<AtomicAdd>(buf_new, indices_new, value_new);
288 }
289 
mutate(const AllocatePtr & v)290 StmtPtr IRCloner::mutate(const AllocatePtr& v) {
291   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
292   return alloc<Allocate>(buf_new);
293 }
294 
mutate(const FreePtr & v)295 StmtPtr IRCloner::mutate(const FreePtr& v) {
296   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
297   return alloc<Free>(buf_new);
298 }
299 
mutate(const SyncThreadsPtr & v)300 StmtPtr IRCloner::mutate(const SyncThreadsPtr& v) {
301   return alloc<SyncThreads>();
302 }
303 
mutate(const ExternalCallPtr & v)304 StmtPtr IRCloner::mutate(const ExternalCallPtr& v) {
305   BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
306 
307   std::vector<BufPtr> buf_args_new;
308   buf_args_new.reserve(v->buf_args().size());
309   for (const BufPtr& buf_arg : v->buf_args()) {
310     buf_args_new.push_back(to<Buf>(buf_arg->accept_mutator(this)));
311   }
312   std::vector<ExprPtr> args_new;
313   args_new.reserve(v->args().size());
314   for (const ExprPtr& arg : v->args()) {
315     args_new.push_back(arg->accept_mutator(this));
316   }
317 
318   return alloc<ExternalCall>(buf_new, v->func_name(), buf_args_new, args_new);
319 }
320 
mutate(const ExternalCallWithAllocPtr & v)321 StmtPtr IRCloner::mutate(const ExternalCallWithAllocPtr& v) {
322   std::vector<BufPtr> buf_out_args_new;
323   buf_out_args_new.reserve(v->buf_out_args().size());
324   for (const auto& buf_out_arg : v->buf_out_args()) {
325     buf_out_args_new.push_back(to<Buf>(buf_out_arg->accept_mutator(this)));
326   }
327 
328   std::vector<BufPtr> buf_args_new;
329   buf_args_new.reserve(v->buf_args().size());
330   for (const auto& buf_arg : v->buf_args()) {
331     buf_args_new.push_back(to<Buf>(buf_arg->accept_mutator(this)));
332   }
333   std::vector<ExprPtr> args_new;
334   args_new.reserve(v->args().size());
335   for (const auto& arg : v->args()) {
336     args_new.push_back(arg->accept_mutator(this));
337   }
338 
339   return alloc<ExternalCallWithAlloc>(
340       v->func_name(), buf_out_args_new, buf_args_new, args_new);
341 }
342 
mutate(const LetPtr & v)343 StmtPtr IRCloner::mutate(const LetPtr& v) {
344   auto value_new = v->value()->accept_mutator(this);
345   return alloc<Let>(v->var(), value_new);
346 }
347 
mutate(const CondPtr & v)348 StmtPtr IRCloner::mutate(const CondPtr& v) {
349   auto condition_new = v->condition()->accept_mutator(this);
350   StmtPtr true_old = v->true_stmt();
351   StmtPtr false_old = v->false_stmt();
352   StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
353   StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
354   return alloc<Cond>(condition_new, true_new, false_new);
355 }
356 
clone(const StmtPtr & s)357 StmtPtr Stmt::clone(const StmtPtr& s) {
358   IRCloner cloner;
359   StmtPtr cloned = s->accept_mutator(&cloner);
360   set_parent(cloned, nullptr);
361   return cloned;
362 }
363 
clone(const ExprPtr & e)364 ExprPtr Expr::clone(const ExprPtr& e) {
365   IRCloner cloner;
366   return e->accept_mutator(&cloner);
367 }
368 
369 } // namespace torch::jit::tensorexpr
370