xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ForeachFunctors.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/OpMathType.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <ATen/native/cuda/MultiTensorApply.cuh>
5 #include <ATen/native/cuda/Pow.cuh>
6 
7 namespace at::native {
8 
9 namespace {
10 
11 // TODO(crcrpar): Handle version bump in codegen.
12 // rel:
13 // https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
increment_version(TensorList tensors)14 inline void increment_version(TensorList tensors) {
15   for (const auto& t : tensors) {
16     t.unsafeGetTensorImpl()->bump_version();
17   }
18 }
19 
20 // Initializes args and checks if all args are aligned
21 template <int depth, typename T>
init_args(T ** args,TensorListMetadata<depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)22 __device__ bool init_args(
23     T** args,
24     TensorListMetadata<depth>& tl,
25     const int64_t chunk_idx,
26     const int64_t chunk_size,
27     const int64_t tensor_loc) {
28   bool all_aligned = true;
29   for (int i = 0; i < depth; i++) {
30     args[i] = (T*)tl.addresses[i][tensor_loc];
31     args[i] += chunk_idx * chunk_size;
32 
33     if (!is_aligned(args[i])) {
34       all_aligned = false;
35     }
36   }
37   return all_aligned;
38 }
39 
40 // Initializes args and checks if all args are aligned
41 template <int depth, typename T, typename T2>
init_args(T ** args,TensorListScalarListMetadata<T2,depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)42 __device__ bool init_args(
43     T** args,
44     TensorListScalarListMetadata<T2, depth>& tl,
45     const int64_t chunk_idx,
46     const int64_t chunk_size,
47     const int64_t tensor_loc) {
48   bool all_aligned = true;
49   for (int i = 0; i < depth; i++) {
50     args[i] = (T*)tl.addresses[i][tensor_loc];
51     args[i] += chunk_idx * chunk_size;
52 
53     if (!is_aligned(args[i])) {
54       all_aligned = false;
55     }
56   }
57   return all_aligned;
58 }
59 
60 template <int depth, typename T>
init_args(T ** args,FusedOptimizerTensorListMetadata<depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)61 __device__ bool init_args(
62     T** args,
63     FusedOptimizerTensorListMetadata<depth>& tl,
64     const int64_t chunk_idx,
65     const int64_t chunk_size,
66     const int64_t tensor_loc) {
67   bool all_aligned = true;
68   for (int i = 0; i < depth; i++) {
69     args[i] = (T*)tl.addresses[i][tensor_loc];
70     args[i] += chunk_idx * chunk_size;
71 
72     if (!is_aligned(args[i])) {
73       all_aligned = false;
74     }
75   }
76   return all_aligned;
77 }
78 
79 template <int depth, typename T>
load_args(T r_args[][kILP],T ** args,const int64_t i_start,const int64_t chunk_size,const int64_t n)80 __device__ void load_args(
81     T r_args[][kILP],
82     T** args,
83     const int64_t i_start,
84     const int64_t chunk_size,
85     const int64_t n) {
86 #pragma unroll
87   for (int ii = 0; ii < kILP; ii++) {
88     const auto i = i_start + threadIdx.x + ii * blockDim.x;
89     for (int r_index = 0; r_index < depth; r_index++) {
90       r_args[r_index][ii] = 0;
91       if (i < n && i < chunk_size) {
92         r_args[r_index][ii] = args[r_index][i];
93       }
94     }
95   }
96 }
97 
98 template <typename T>
store_args(T * dst,T * src,const int64_t i_start,const int64_t chunk_size,const int64_t n)99 __device__ void store_args(
100     T* dst,
101     T* src,
102     const int64_t i_start,
103     const int64_t chunk_size,
104     const int64_t n) {
105 #pragma unroll
106   for (int ii = 0; ii < kILP; ii++) {
107     const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
108     if (i < n && i < chunk_size)
109       dst[i] = src[ii];
110   }
111 }
112 
113 template <int res_arg_index, typename Op, typename T, typename opmath_t>
binary_op_scalar(T r_args[][kILP],T ** args,opmath_t scalar,const int64_t n,const int64_t chunk_size,const bool all_aligned,Op op)114 __device__ __forceinline__ void binary_op_scalar(
115     T r_args[][kILP],
116     T** args,
117     opmath_t scalar,
118     const int64_t n,
119     const int64_t chunk_size,
120     const bool all_aligned,
121     Op op) {
122   // to make things simple, we put aligned case in a different code path
123   if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
124     for (int64_t i_start = threadIdx.x;
125          i_start * kILP < n && i_start * kILP < chunk_size;
126          i_start += blockDim.x) {
127       // load
128       load_store(r_args[0], args[0], 0, i_start);
129 #pragma unroll
130       for (int ii = 0; ii < kILP; ii++) {
131         r_args[0][ii] = static_cast<T>(
132             op(static_cast<opmath_t>(r_args[0][ii]),
133                static_cast<opmath_t>(scalar)));
134       }
135       // store
136       load_store(args[res_arg_index], r_args[0], i_start, 0);
137     }
138   } else {
139     for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
140          i_start += blockDim.x * kILP) {
141       // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
142       // has depth 1
143       load_args<1>(r_args, args, i_start, chunk_size, n);
144 #pragma unroll
145       for (int ii = 0; ii < kILP; ii++) {
146         r_args[0][ii] = static_cast<T>(
147             op(static_cast<opmath_t>(r_args[0][ii]),
148                static_cast<opmath_t>(scalar)));
149       }
150       store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
151     }
152   }
153 }
154 
155 template <int res_arg_index, typename Op, typename T, typename opmath_t>
pointwise_op_scalar(T r_args[][kILP],T ** args,opmath_t scalar,const int64_t n,const int64_t chunk_size,const bool all_aligned,Op op)156 __device__ __forceinline__ void pointwise_op_scalar(
157     T r_args[][kILP],
158     T** args,
159     opmath_t scalar,
160     const int64_t n,
161     const int64_t chunk_size,
162     const bool all_aligned,
163     Op op) {
164   // to make things simple, we put aligned case in a different code path
165   if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
166     for (int64_t i_start = threadIdx.x;
167          i_start * kILP < n && i_start * kILP < chunk_size;
168          i_start += blockDim.x) {
169       // load
170       load_store(r_args[0], args[0], 0, i_start);
171       load_store(r_args[1], args[1], 0, i_start);
172       load_store(r_args[2], args[2], 0, i_start);
173 #pragma unroll
174       for (int ii = 0; ii < kILP; ii++) {
175         r_args[0][ii] = static_cast<T>(
176             static_cast<opmath_t>(r_args[0][ii]) +
177             scalar *
178                 op(static_cast<opmath_t>(r_args[1][ii]),
179                    static_cast<opmath_t>(r_args[2][ii])));
180       }
181       // store
182       load_store(args[res_arg_index], r_args[0], i_start, 0);
183     }
184   } else {
185     for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
186          i_start += blockDim.x * kILP) {
187       // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
188       // has depth 3
189       load_args<3>(r_args, args, i_start, chunk_size, n);
190 #pragma unroll
191       for (int ii = 0; ii < kILP; ii++) {
192         r_args[0][ii] = static_cast<T>(
193             static_cast<opmath_t>(r_args[0][ii]) +
194             scalar *
195                 op(static_cast<opmath_t>(r_args[1][ii]),
196                    static_cast<opmath_t>(r_args[2][ii])));
197       }
198       store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
199     }
200   }
201 }
202 
203 //
204 // Binary Functors
205 //
206 template <typename T, int depth, int r_args_depth, int res_arg_index>
207 struct BinaryOpScalarFunctor {
208   using opmath_t = at::opmath_type<T>;
209   template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarFunctor210   __device__ __forceinline__ void operator()(
211       int chunk_size,
212       TensorListMetadata<depth>& tl,
213       Op op,
214       opmath_t scalar) {
215     const int tensor_loc = tl.block_to_tensor[blockIdx.x];
216     const int chunk_idx = tl.block_to_chunk[blockIdx.x];
217     auto n = tl.numel_for_tensor[tensor_loc];
218 
219     T* args[depth];
220     const bool all_aligned =
221         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
222     n -= chunk_idx * chunk_size;
223     T r_args[r_args_depth][kILP];
224 
225     binary_op_scalar<res_arg_index>(
226         r_args, args, scalar, n, chunk_size, all_aligned, op);
227   }
228 };
229 
230 template <typename T, int depth, int r_args_depth, int res_arg_index>
231 struct BinaryOpScalarListFunctor {
232   using opmath_t = at::opmath_type<T>;
233   template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarListFunctor234   __device__ __forceinline__ void operator()(
235       int chunk_size,
236       TensorListScalarListMetadata<opmath_t, depth>& tl,
237       Op op) {
238     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
239     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
240     auto n = tl.numel_for_tensor[tensor_loc];
241 
242     T* args[depth];
243     const bool all_aligned =
244         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
245     opmath_t scalar = tl.scalar_vals[tensor_loc];
246     n -= chunk_idx * chunk_size;
247     T r_args[r_args_depth][kILP];
248 
249     binary_op_scalar<res_arg_index>(
250         r_args, args, scalar, n, chunk_size, all_aligned, op);
251   }
252 };
253 
254 template <typename T, int depth, int r_args_depth, int res_arg_index>
255 struct BinaryOpListAlphaFunctor {
256   using opmath_t = at::opmath_type<T>;
257   template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpListAlphaFunctor258   __device__ __forceinline__ void operator()(
259       int chunk_size,
260       TensorListMetadata<depth>& tl,
261       Op op,
262       opmath_t alpha) {
263     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
264     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
265     auto n = tl.numel_for_tensor[tensor_loc];
266 
267     T* args[depth];
268     const bool all_aligned =
269         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
270     n -= chunk_idx * chunk_size;
271     T r_args[r_args_depth][kILP];
272 
273     // to make things simple, we put aligned case in a different code path
274     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
275       for (int64_t i_start = threadIdx.x;
276            i_start * kILP < n && i_start * kILP < chunk_size;
277            i_start += blockDim.x) {
278         // load
279         load_store(r_args[0], args[0], 0, i_start);
280         load_store(r_args[1], args[1], 0, i_start);
281 #pragma unroll
282         for (int ii = 0; ii < kILP; ii++) {
283           r_args[0][ii] = static_cast<T>(
284               op(static_cast<opmath_t>(r_args[0][ii]),
285                  alpha * static_cast<opmath_t>(r_args[1][ii])));
286         }
287         // store
288         load_store(args[res_arg_index], r_args[0], i_start, 0);
289       }
290     } else {
291       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
292            i_start += blockDim.x * kILP) {
293         load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
294 #pragma unroll
295         for (int ii = 0; ii < kILP; ii++) {
296           r_args[0][ii] = static_cast<T>(
297               op(static_cast<opmath_t>(r_args[0][ii]),
298                  alpha * static_cast<opmath_t>(r_args[1][ii])));
299         }
300         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
301       }
302     }
303   }
304 };
305 
306 template <typename T, int depth, int r_args_depth, int res_arg_index>
307 struct BinaryOpScalarTensorFunctor {
308   using opmath_t = at::opmath_type<T>;
309   template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarTensorFunctor310   __device__ __forceinline__ void operator()(
311       int chunk_size,
312       TensorListMetadata<depth>& tl,
313       Op op,
314       T* scalar,
315       opmath_t alpha) {
316     const int tensor_loc = tl.block_to_tensor[blockIdx.x];
317     const int chunk_idx = tl.block_to_chunk[blockIdx.x];
318     auto n = tl.numel_for_tensor[tensor_loc];
319 
320     T* args[depth];
321     const bool all_aligned =
322         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
323     n -= chunk_idx * chunk_size;
324     T r_args[r_args_depth][kILP];
325 
326     // to make things simple, we put aligned case in a different code path
327     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
328       for (int64_t i_start = threadIdx.x;
329            i_start * kILP < n && i_start * kILP < chunk_size;
330            i_start += blockDim.x) {
331         // load
332         load_store(r_args[0], args[0], 0, i_start);
333 #pragma unroll
334         for (int ii = 0; ii < kILP; ii++) {
335           r_args[0][ii] = static_cast<T>(op(
336               static_cast<opmath_t>(r_args[0][ii]),
337               static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
338         }
339         // store
340         load_store(args[res_arg_index], r_args[0], i_start, 0);
341       }
342     } else {
343       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
344            i_start += blockDim.x * kILP) {
345         // Regardless if depth is 1 (for inplace) or 2 (for out of place),
346         // r_args has depth 1
347         load_args<1>(r_args, args, i_start, chunk_size, n);
348 #pragma unroll
349         for (int ii = 0; ii < kILP; ii++) {
350           r_args[0][ii] = static_cast<T>(op(
351               static_cast<opmath_t>(r_args[0][ii]),
352               static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
353         }
354         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
355       }
356     }
357   }
358 };
359 
360 //
361 // Unary Functors
362 //
363 
364 template <typename T, int depth, int r_args_depth, int res_arg_index>
365 struct ZeroFunctor {
operator ()at::native::__anonc1b891800111::ZeroFunctor366   __device__ __forceinline__ void operator()(
367       int chunk_size,
368       TensorListMetadata<1>& tl) {
369     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
370     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
371     auto n = tl.numel_for_tensor[tensor_loc];
372 
373     T* args[depth];
374     const auto all_aligned =
375         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
376     n -= chunk_idx * chunk_size;
377     T r_args[r_args_depth][kILP];
378 
379     // to make things simple, we put aligned case in a different code path
380     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
381       for (int64_t i_start = threadIdx.x;
382            i_start * kILP < n && i_start * kILP < chunk_size;
383            i_start += blockDim.x) {
384 #pragma unroll
385         for (int ii = 0; ii < kILP; ii++) {
386           r_args[0][ii] = 0;
387         }
388         // store
389         load_store(args[0], r_args[0], i_start, 0);
390       }
391     } else {
392       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
393            i_start += blockDim.x * kILP) {
394 #pragma unroll
395         for (int ii = 0; ii < kILP; ii++) {
396           r_args[0][ii] = 0;
397         }
398         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
399       }
400     }
401   }
402 };
403 
404 template <typename T, int depth, int r_args_depth, int res_arg_index>
405 struct UnaryOpFunctor {
406   using opmath_t = at::opmath_type<T>;
407   template <typename Op>
operator ()at::native::__anonc1b891800111::UnaryOpFunctor408   __device__ __forceinline__ void operator()(
409       int chunk_size,
410       TensorListMetadata<depth>& tl,
411       Op op) {
412     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
413     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
414     auto n = tl.numel_for_tensor[tensor_loc];
415 
416     T* args[depth];
417     bool all_aligned =
418         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
419     n -= chunk_idx * chunk_size;
420     T r_args[r_args_depth][kILP];
421 
422     // to make things simple, we put aligned case in a different code path
423     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
424       for (int64_t i_start = threadIdx.x;
425            i_start * kILP < n && i_start * kILP < chunk_size;
426            i_start += blockDim.x) {
427         // load
428         load_store(r_args[0], args[0], 0, i_start);
429 #pragma unroll
430         for (int ii = 0; ii < kILP; ii++) {
431           r_args[0][ii] =
432               static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
433         }
434         // store
435         load_store(args[res_arg_index], r_args[0], i_start, 0);
436       }
437     } else {
438       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
439            i_start += blockDim.x * kILP) {
440         load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
441 #pragma unroll
442         for (int ii = 0; ii < kILP; ii++) {
443           r_args[0][ii] =
444               static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
445         }
446         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
447       }
448     }
449   }
450 };
451 
452 //
453 // Pointwise Functors
454 //
455 
456 template <typename T, int depth, int r_args_depth, int res_arg_index>
457 struct PointwiseOpScalarFunctor {
458   using opmath_t = at::opmath_type<T>;
459   template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpScalarFunctor460   __device__ __forceinline__ void operator()(
461       int chunk_size,
462       TensorListMetadata<depth>& tl,
463       Op op,
464       opmath_t scalar) {
465     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
466     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
467     auto n = tl.numel_for_tensor[tensor_loc];
468 
469     T* args[depth];
470     const bool all_aligned =
471         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
472     n -= chunk_idx * chunk_size;
473     T r_args[r_args_depth][kILP];
474 
475     pointwise_op_scalar<res_arg_index>(
476         r_args, args, scalar, n, chunk_size, all_aligned, op);
477   }
478 };
479 
480 template <typename T, int depth, int r_args_depth, int res_arg_index>
481 struct PointwiseOpScalarListFunctor {
482   using opmath_t = at::opmath_type<T>;
483   template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpScalarListFunctor484   __device__ __forceinline__ void operator()(
485       int chunk_size,
486       TensorListScalarListMetadata<opmath_t, depth>& tl,
487       Op op) {
488     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
489     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
490     auto n = tl.numel_for_tensor[tensor_loc];
491 
492     T* args[depth];
493     const bool all_aligned =
494         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
495     opmath_t scalar = tl.scalar_vals[tensor_loc];
496     n -= chunk_idx * chunk_size;
497     T r_args[r_args_depth][kILP];
498 
499     pointwise_op_scalar<res_arg_index>(
500         r_args, args, scalar, n, chunk_size, all_aligned, op);
501   }
502 };
503 
504 template <typename T, int depth>
505 struct PointwiseOpListFunctor {
506   using opmath_t = at::opmath_type<T>;
507   template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpListFunctor508   __device__ __forceinline__ void operator()(
509       int chunk_size,
510       TensorListMetadata<depth>& tl,
511       Op op) {
512     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
513     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
514     auto n = tl.numel_for_tensor[tensor_loc];
515 
516     T* args[depth];
517     const bool all_aligned =
518         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
519     n -= chunk_idx * chunk_size;
520     T r_args[depth - 1][kILP];
521 
522     // to make things simple, we put aligned case in a different code path
523     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
524       for (int64_t i_start = threadIdx.x;
525            i_start * kILP < n && i_start * kILP < chunk_size;
526            i_start += blockDim.x) {
527         // load
528         load_store(r_args[0], args[0], 0, i_start);
529         load_store(r_args[1], args[1], 0, i_start);
530 #pragma unroll
531         for (int ii = 0; ii < kILP; ii++) {
532           r_args[0][ii] = static_cast<T>(
533               op(static_cast<opmath_t>(r_args[0][ii]),
534                  static_cast<opmath_t>(r_args[1][ii])));
535         }
536         // store
537         load_store(args[2], r_args[0], i_start, 0);
538       }
539     } else {
540       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
541            i_start += blockDim.x * kILP) {
542         load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
543 #pragma unroll
544         for (int ii = 0; ii < kILP; ii++) {
545           r_args[0][ii] = static_cast<T>(
546               op(static_cast<opmath_t>(r_args[0][ii]),
547                  static_cast<opmath_t>(r_args[1][ii])));
548         }
549         store_args(args[2], r_args[0], i_start, chunk_size, n);
550       }
551     }
552   }
553 };
554 
555 template <typename T, int depth, int r_args_depth, int res_arg_index>
556 struct TernaryOpListFunctor {
557   using opmath_t = at::opmath_type<T>;
558   template <typename Op>
operator ()at::native::__anonc1b891800111::TernaryOpListFunctor559   __device__ __forceinline__ void operator()(
560       int chunk_size,
561       TensorListMetadata<depth>& tl,
562       Op op) {
563     static_assert(depth == 3 || depth == 4, "");
564     static_assert(depth >= r_args_depth, "");
565     static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
566     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
567     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
568     auto n = tl.numel_for_tensor[tensor_loc];
569 
570     T* args[depth];
571     const bool all_aligned =
572         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
573     n -= chunk_idx * chunk_size;
574     T r_args[r_args_depth][kILP];
575 
576     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
577       for (int64_t i_start = threadIdx.x;
578            i_start * kILP < n && i_start * kILP < chunk_size;
579            i_start += blockDim.x) {
580         load_store(r_args[0], args[0], 0, i_start);
581         load_store(r_args[1], args[1], 0, i_start);
582         load_store(r_args[2], args[2], 0, i_start);
583 #pragma unroll
584         for (int ii = 0; ii < kILP; ii++) {
585           r_args[0][ii] =
586               op(static_cast<opmath_t>(r_args[0][ii]),
587                  static_cast<opmath_t>(r_args[1][ii]),
588                  static_cast<opmath_t>(r_args[2][ii]));
589         }
590         load_store(args[res_arg_index], r_args[0], i_start, 0);
591       }
592     } else {
593       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
594            i_start += blockDim.x * kILP) {
595         load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
596 #pragma unroll
597         for (int ii = 0; ii < kILP; ii++) {
598           r_args[0][ii] =
599               op(static_cast<opmath_t>(r_args[0][ii]),
600                  static_cast<opmath_t>(r_args[1][ii]),
601                  static_cast<opmath_t>(r_args[2][ii]));
602         }
603         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
604       }
605     }
606   }
607 };
608 
609 template <typename T, int depth, int r_args_depth, int res_arg_index>
610 struct TernaryOpScalarFunctor {
611   using opmath_t = at::opmath_type<T>;
612   template <typename Op>
operator ()at::native::__anonc1b891800111::TernaryOpScalarFunctor613   __device__ __forceinline__ void operator()(
614       int chunk_size,
615       TensorListMetadata<depth>& tl,
616       Op op,
617       opmath_t alpha) {
618     static_assert(depth == 2 || depth == 3, "");
619     static_assert(depth >= r_args_depth, "");
620     static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
621     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
622     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
623     auto n = tl.numel_for_tensor[tensor_loc];
624 
625     T* args[depth];
626     const bool all_aligned =
627         init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
628     n -= chunk_idx * chunk_size;
629     T r_args[r_args_depth][kILP];
630 
631     // to make things simple, we put aligned case in a different code path
632     if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
633       for (int64_t i_start = threadIdx.x;
634            i_start * kILP < n && i_start * kILP < chunk_size;
635            i_start += blockDim.x) {
636         // load
637         load_store(r_args[0], args[0], 0, i_start);
638         load_store(r_args[1], args[1], 0, i_start);
639 #pragma unroll
640         for (int ii = 0; ii < kILP; ii++) {
641           r_args[0][ii] =
642               op(static_cast<opmath_t>(r_args[0][ii]),
643                  static_cast<opmath_t>(r_args[1][ii]),
644                  alpha);
645         }
646         // store
647         load_store(args[res_arg_index], r_args[0], i_start, 0);
648       }
649     } else {
650       for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
651            i_start += blockDim.x * kILP) {
652         load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
653 #pragma unroll
654         for (int ii = 0; ii < kILP; ii++) {
655           r_args[0][ii] =
656               op(static_cast<opmath_t>(r_args[0][ii]),
657                  static_cast<opmath_t>(r_args[1][ii]),
658                  alpha);
659         }
660         store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
661       }
662     }
663   }
664 };
665 
666 template <typename T>
667 struct power_functor {
operator ()at::native::__anonc1b891800111::power_functor668   C10_DEVICE T operator()(const T& a, const T& b) const {
669     return at::native::pow_(a, b);
670   }
671 };
672 
673 template <typename T>
674 struct reverse_power_functor {
operator ()at::native::__anonc1b891800111::reverse_power_functor675   C10_DEVICE T operator()(const T& a, const T& b) const {
676     return at::native::pow_(b, a);
677   }
678 };
679 
680 } // namespace
681 } // namespace at::native
682