1 #include <torch/csrc/jit/tensorexpr/ir_cloner.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6
7 #include <c10/util/irange.h>
8
9 namespace torch::jit::tensorexpr {
10
11 template <
12 typename Op,
13 std::enable_if_t<std::is_same_v<
14 decltype(detail::bin_op_deducer(std::declval<Op>())),
15 void>>* = nullptr>
mutate_binary_op(NodePtr<Op> v,IRCloner * cloner,bool option=false)16 static ExprPtr mutate_binary_op(
17 NodePtr<Op> v,
18 IRCloner* cloner,
19 bool option = false) {
20 ExprPtr lhs_new = v->lhs()->accept_mutator(cloner);
21 ExprPtr rhs_new = v->rhs()->accept_mutator(cloner);
22 IRNodeType expr_type = v->expr_type();
23 switch (expr_type) {
24 case IRNodeType::kAdd:
25 return alloc<Add>(lhs_new, rhs_new);
26 case IRNodeType::kSub:
27 return alloc<Sub>(lhs_new, rhs_new);
28 case IRNodeType::kMul:
29 return alloc<Mul>(lhs_new, rhs_new);
30 case IRNodeType::kDiv:
31 return alloc<Div>(lhs_new, rhs_new);
32 case IRNodeType::kMod:
33 return alloc<Mod>(lhs_new, rhs_new);
34 case IRNodeType::kMax:
35 return alloc<Max>(lhs_new, rhs_new, option);
36 case IRNodeType::kMin:
37 return alloc<Min>(lhs_new, rhs_new, option);
38 case IRNodeType::kAnd:
39 return alloc<And>(lhs_new, rhs_new);
40 case IRNodeType::kOr:
41 return alloc<Or>(lhs_new, rhs_new);
42 case IRNodeType::kXor:
43 return alloc<Xor>(lhs_new, rhs_new);
44 case IRNodeType::kLshift:
45 return alloc<Lshift>(lhs_new, rhs_new);
46 case IRNodeType::kRshift:
47 return alloc<Rshift>(lhs_new, rhs_new);
48 default:
49 throw unimplemented_lowering(v);
50 }
51 }
52
mutate(const AddPtr & v)53 ExprPtr IRCloner::mutate(const AddPtr& v) {
54 return mutate_binary_op(v, this);
55 }
56
mutate(const SubPtr & v)57 ExprPtr IRCloner::mutate(const SubPtr& v) {
58 return mutate_binary_op(v, this);
59 }
60
mutate(const MulPtr & v)61 ExprPtr IRCloner::mutate(const MulPtr& v) {
62 return mutate_binary_op(v, this);
63 }
64
mutate(const DivPtr & v)65 ExprPtr IRCloner::mutate(const DivPtr& v) {
66 return mutate_binary_op(v, this);
67 }
68
mutate(const ModPtr & v)69 ExprPtr IRCloner::mutate(const ModPtr& v) {
70 return mutate_binary_op(v, this);
71 }
72
mutate(const AndPtr & v)73 ExprPtr IRCloner::mutate(const AndPtr& v) {
74 return mutate_binary_op(v, this);
75 }
76
mutate(const OrPtr & v)77 ExprPtr IRCloner::mutate(const OrPtr& v) {
78 return mutate_binary_op(v, this);
79 }
80
mutate(const XorPtr & v)81 ExprPtr IRCloner::mutate(const XorPtr& v) {
82 return mutate_binary_op(v, this);
83 }
84
mutate(const LshiftPtr & v)85 ExprPtr IRCloner::mutate(const LshiftPtr& v) {
86 return mutate_binary_op(v, this);
87 }
88
mutate(const RshiftPtr & v)89 ExprPtr IRCloner::mutate(const RshiftPtr& v) {
90 return mutate_binary_op(v, this);
91 }
92
mutate(const MaxPtr & v)93 ExprPtr IRCloner::mutate(const MaxPtr& v) {
94 return mutate_binary_op(v, this, v->propagate_nans());
95 }
96
mutate(const MinPtr & v)97 ExprPtr IRCloner::mutate(const MinPtr& v) {
98 return mutate_binary_op(v, this, v->propagate_nans());
99 }
100
mutate(const CompareSelectPtr & v)101 ExprPtr IRCloner::mutate(const CompareSelectPtr& v) {
102 ExprPtr lhs_new = v->lhs()->accept_mutator(this);
103 ExprPtr rhs_new = v->rhs()->accept_mutator(this);
104 ExprPtr retval1_new = v->ret_val1()->accept_mutator(this);
105 ExprPtr retval2_new = v->ret_val2()->accept_mutator(this);
106 return alloc<CompareSelect>(
107 lhs_new,
108 rhs_new,
109 retval1_new,
110 retval2_new,
111 v->compare_select_op(),
112 v->bias());
113 }
114
115 #define IMM_MUTATE_DEFINE(_1, Name) \
116 ExprPtr IRCloner::mutate(const Name##ImmPtr& v) { \
117 return v; \
118 }
119 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE);
120 #undef IMM_MUTATE_DEFINE
121
mutate(const CastPtr & v)122 ExprPtr IRCloner::mutate(const CastPtr& v) {
123 ExprPtr src_value_new = v->src_value()->accept_mutator(this);
124 return alloc<Cast>(v->dtype(), src_value_new);
125 }
126
mutate(const BitCastPtr & v)127 ExprPtr IRCloner::mutate(const BitCastPtr& v) {
128 ExprPtr src_value_new = v->src_value()->accept_mutator(this);
129 return alloc<BitCast>(v->dtype(), src_value_new);
130 }
131
mutate(const RampPtr & v)132 ExprPtr IRCloner::mutate(const RampPtr& v) {
133 ExprPtr base_new = v->base()->accept_mutator(this);
134 ExprPtr stride_new = v->stride()->accept_mutator(this);
135 return alloc<Ramp>(base_new, stride_new, v->lanes());
136 }
137
mutate(const LoadPtr & v)138 ExprPtr IRCloner::mutate(const LoadPtr& v) {
139 std::vector<ExprPtr> indices_new;
140 indices_new.reserve(v->indices().size());
141 for (const ExprPtr& ind : v->indices()) {
142 indices_new.push_back(ind->accept_mutator(this));
143 }
144 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
145 return alloc<Load>(v->dtype(), buf_new, indices_new);
146 }
147
148 // We do not clone Vars since the original IR and cloned IR are expected to
149 // share the underlying variables.
mutate(const VarPtr & v)150 ExprPtr IRCloner::mutate(const VarPtr& v) {
151 return v;
152 }
153
154 // We do not clone Bufs since the original IR and cloned IR are expected to
155 // share the underlying Bufs. In spite of Bufs having expressions as dims and
156 // initializers, this is the expected usage of clone at this point.
157 //
158 // TODO: Revisit this if Bufs need to be cloned as well.
mutate(const BufPtr & v)159 ExprPtr IRCloner::mutate(const BufPtr& v) {
160 return v;
161 }
162
mutate(const BroadcastPtr & v)163 ExprPtr IRCloner::mutate(const BroadcastPtr& v) {
164 auto lanes = v->lanes();
165 ExprPtr value_new = v->value()->accept_mutator(this);
166 return alloc<Broadcast>(value_new, lanes);
167 }
168
mutate(const IfThenElsePtr & v)169 ExprPtr IRCloner::mutate(const IfThenElsePtr& v) {
170 ExprPtr condition_new = v->condition()->accept_mutator(this);
171 ExprPtr true_value_new = v->true_value()->accept_mutator(this);
172 ExprPtr false_value_new = v->false_value()->accept_mutator(this);
173
174 return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
175 }
176
mutate(const IntrinsicsPtr & v)177 ExprPtr IRCloner::mutate(const IntrinsicsPtr& v) {
178 std::vector<ExprPtr> params_new;
179 params_new.reserve(v->nparams());
180 for (const auto& param : v->params()) {
181 params_new.push_back(param->accept_mutator(this));
182 }
183 return alloc<Intrinsics>(v->op_type(), v->dtype(), params_new);
184 }
185
mutate(const TermPtr & v)186 ExprPtr IRCloner::mutate(const TermPtr& v) {
187 ExprPtr scalar_new = v->scalar()->accept_mutator(this);
188
189 std::vector<ExprPtr> variables_new;
190 variables_new.reserve(v->variables().size());
191 for (const auto& t : v->variables()) {
192 variables_new.push_back(t->accept_mutator(this));
193 }
194 return alloc<Term>(v->hasher(), scalar_new, variables_new);
195 }
196
mutate(const PolynomialPtr & v)197 ExprPtr IRCloner::mutate(const PolynomialPtr& v) {
198 ExprPtr scalar_new = v->scalar()->accept_mutator(this);
199
200 std::vector<TermPtr> variables_new;
201 variables_new.reserve(v->variables().size());
202 for (const auto& t : v->variables()) {
203 variables_new.push_back(static_to<Term>(t->accept_mutator(this)));
204 }
205 return alloc<Polynomial>(v->hasher(), scalar_new, variables_new);
206 }
207
mutate(const RoundOffPtr & v)208 ExprPtr IRCloner::mutate(const RoundOffPtr& v) {
209 return alloc<RoundOff>(
210 v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
211 }
212
mutate(const MaxTermPtr & v)213 ExprPtr IRCloner::mutate(const MaxTermPtr& v) {
214 ExprPtr scalar_new =
215 v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
216
217 std::vector<ExprPtr> variables_new;
218 variables_new.reserve(v->variables().size());
219 for (const auto& t : v->variables()) {
220 variables_new.push_back(t->accept_mutator(this));
221 }
222 return alloc<MaxTerm>(
223 v->hasher(), scalar_new, v->propagate_nans(), variables_new);
224 }
225
mutate(const MinTermPtr & v)226 ExprPtr IRCloner::mutate(const MinTermPtr& v) {
227 ExprPtr scalar_new =
228 v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
229
230 std::vector<ExprPtr> variables_new;
231 variables_new.reserve(v->variables().size());
232 for (const auto& t : v->variables()) {
233 variables_new.push_back(t->accept_mutator(this));
234 }
235 return alloc<MinTerm>(
236 v->hasher(), scalar_new, v->propagate_nans(), variables_new);
237 }
238
mutate(const ReduceOpPtr & v)239 ExprPtr IRCloner::mutate(const ReduceOpPtr& v) {
240 ExprPtr body_new = v->body()->accept_mutator(this);
241
242 std::vector<VarPtr> reduce_args_new;
243 reduce_args_new.reserve(v->reduce_args().size());
244 for (const auto& r : v->reduce_args()) {
245 reduce_args_new.push_back(static_to<Var>(r->accept_mutator(this)));
246 }
247
248 return alloc<ReduceOp>(body_new, reduce_args_new, v->reducer());
249 }
250
mutate(const ForPtr & v)251 StmtPtr IRCloner::mutate(const ForPtr& v) {
252 auto start_new = v->start()->accept_mutator(this);
253 auto stop_new = v->stop()->accept_mutator(this);
254 auto body_new = v->body()->accept_mutator(this);
255
256 return alloc<For>(v->var(), start_new, stop_new, body_new, v->loop_options());
257 }
258
mutate(const BlockPtr & v)259 StmtPtr IRCloner::mutate(const BlockPtr& v) {
260 std::vector<StmtPtr> stmts_new;
261 stmts_new.reserve(v->nstmts());
262 for (const StmtPtr& stmt : *v) {
263 stmts_new.push_back(stmt->accept_mutator(this));
264 }
265 return alloc<Block>(stmts_new);
266 }
267
mutate(const StorePtr & v)268 StmtPtr IRCloner::mutate(const StorePtr& v) {
269 std::vector<ExprPtr> indices_new;
270 indices_new.reserve(v->indices().size());
271 for (const auto& ind : v->indices()) {
272 indices_new.push_back(ind->accept_mutator(this));
273 }
274 auto value_new = v->value()->accept_mutator(this);
275 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
276 return alloc<Store>(buf_new, indices_new, value_new);
277 }
278
mutate(const AtomicAddPtr & v)279 StmtPtr IRCloner::mutate(const AtomicAddPtr& v) {
280 std::vector<ExprPtr> indices_new;
281 indices_new.reserve(v->indices().size());
282 for (const auto& ind : v->indices()) {
283 indices_new.push_back(ind->accept_mutator(this));
284 }
285 auto value_new = v->value()->accept_mutator(this);
286 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
287 return alloc<AtomicAdd>(buf_new, indices_new, value_new);
288 }
289
mutate(const AllocatePtr & v)290 StmtPtr IRCloner::mutate(const AllocatePtr& v) {
291 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
292 return alloc<Allocate>(buf_new);
293 }
294
mutate(const FreePtr & v)295 StmtPtr IRCloner::mutate(const FreePtr& v) {
296 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
297 return alloc<Free>(buf_new);
298 }
299
mutate(const SyncThreadsPtr & v)300 StmtPtr IRCloner::mutate(const SyncThreadsPtr& v) {
301 return alloc<SyncThreads>();
302 }
303
mutate(const ExternalCallPtr & v)304 StmtPtr IRCloner::mutate(const ExternalCallPtr& v) {
305 BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
306
307 std::vector<BufPtr> buf_args_new;
308 buf_args_new.reserve(v->buf_args().size());
309 for (const BufPtr& buf_arg : v->buf_args()) {
310 buf_args_new.push_back(to<Buf>(buf_arg->accept_mutator(this)));
311 }
312 std::vector<ExprPtr> args_new;
313 args_new.reserve(v->args().size());
314 for (const ExprPtr& arg : v->args()) {
315 args_new.push_back(arg->accept_mutator(this));
316 }
317
318 return alloc<ExternalCall>(buf_new, v->func_name(), buf_args_new, args_new);
319 }
320
mutate(const ExternalCallWithAllocPtr & v)321 StmtPtr IRCloner::mutate(const ExternalCallWithAllocPtr& v) {
322 std::vector<BufPtr> buf_out_args_new;
323 buf_out_args_new.reserve(v->buf_out_args().size());
324 for (const auto& buf_out_arg : v->buf_out_args()) {
325 buf_out_args_new.push_back(to<Buf>(buf_out_arg->accept_mutator(this)));
326 }
327
328 std::vector<BufPtr> buf_args_new;
329 buf_args_new.reserve(v->buf_args().size());
330 for (const auto& buf_arg : v->buf_args()) {
331 buf_args_new.push_back(to<Buf>(buf_arg->accept_mutator(this)));
332 }
333 std::vector<ExprPtr> args_new;
334 args_new.reserve(v->args().size());
335 for (const auto& arg : v->args()) {
336 args_new.push_back(arg->accept_mutator(this));
337 }
338
339 return alloc<ExternalCallWithAlloc>(
340 v->func_name(), buf_out_args_new, buf_args_new, args_new);
341 }
342
mutate(const LetPtr & v)343 StmtPtr IRCloner::mutate(const LetPtr& v) {
344 auto value_new = v->value()->accept_mutator(this);
345 return alloc<Let>(v->var(), value_new);
346 }
347
mutate(const CondPtr & v)348 StmtPtr IRCloner::mutate(const CondPtr& v) {
349 auto condition_new = v->condition()->accept_mutator(this);
350 StmtPtr true_old = v->true_stmt();
351 StmtPtr false_old = v->false_stmt();
352 StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
353 StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
354 return alloc<Cond>(condition_new, true_new, false_new);
355 }
356
clone(const StmtPtr & s)357 StmtPtr Stmt::clone(const StmtPtr& s) {
358 IRCloner cloner;
359 StmtPtr cloned = s->accept_mutator(&cloner);
360 set_parent(cloned, nullptr);
361 return cloned;
362 }
363
clone(const ExprPtr & e)364 ExprPtr Expr::clone(const ExprPtr& e) {
365 IRCloner cloner;
366 return e->accept_mutator(&cloner);
367 }
368
369 } // namespace torch::jit::tensorexpr
370