1 #pragma once
2 #include <ATen/OpMathType.h>
3 #include <ATen/native/ForeachUtils.h>
4 #include <ATen/native/cuda/MultiTensorApply.cuh>
5 #include <ATen/native/cuda/Pow.cuh>
6
7 namespace at::native {
8
9 namespace {
10
11 // TODO(crcrpar): Handle version bump in codegen.
12 // rel:
13 // https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
increment_version(TensorList tensors)14 inline void increment_version(TensorList tensors) {
15 for (const auto& t : tensors) {
16 t.unsafeGetTensorImpl()->bump_version();
17 }
18 }
19
20 // Initializes args and checks if all args are aligned
21 template <int depth, typename T>
init_args(T ** args,TensorListMetadata<depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)22 __device__ bool init_args(
23 T** args,
24 TensorListMetadata<depth>& tl,
25 const int64_t chunk_idx,
26 const int64_t chunk_size,
27 const int64_t tensor_loc) {
28 bool all_aligned = true;
29 for (int i = 0; i < depth; i++) {
30 args[i] = (T*)tl.addresses[i][tensor_loc];
31 args[i] += chunk_idx * chunk_size;
32
33 if (!is_aligned(args[i])) {
34 all_aligned = false;
35 }
36 }
37 return all_aligned;
38 }
39
40 // Initializes args and checks if all args are aligned
41 template <int depth, typename T, typename T2>
init_args(T ** args,TensorListScalarListMetadata<T2,depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)42 __device__ bool init_args(
43 T** args,
44 TensorListScalarListMetadata<T2, depth>& tl,
45 const int64_t chunk_idx,
46 const int64_t chunk_size,
47 const int64_t tensor_loc) {
48 bool all_aligned = true;
49 for (int i = 0; i < depth; i++) {
50 args[i] = (T*)tl.addresses[i][tensor_loc];
51 args[i] += chunk_idx * chunk_size;
52
53 if (!is_aligned(args[i])) {
54 all_aligned = false;
55 }
56 }
57 return all_aligned;
58 }
59
60 template <int depth, typename T>
init_args(T ** args,FusedOptimizerTensorListMetadata<depth> & tl,const int64_t chunk_idx,const int64_t chunk_size,const int64_t tensor_loc)61 __device__ bool init_args(
62 T** args,
63 FusedOptimizerTensorListMetadata<depth>& tl,
64 const int64_t chunk_idx,
65 const int64_t chunk_size,
66 const int64_t tensor_loc) {
67 bool all_aligned = true;
68 for (int i = 0; i < depth; i++) {
69 args[i] = (T*)tl.addresses[i][tensor_loc];
70 args[i] += chunk_idx * chunk_size;
71
72 if (!is_aligned(args[i])) {
73 all_aligned = false;
74 }
75 }
76 return all_aligned;
77 }
78
79 template <int depth, typename T>
load_args(T r_args[][kILP],T ** args,const int64_t i_start,const int64_t chunk_size,const int64_t n)80 __device__ void load_args(
81 T r_args[][kILP],
82 T** args,
83 const int64_t i_start,
84 const int64_t chunk_size,
85 const int64_t n) {
86 #pragma unroll
87 for (int ii = 0; ii < kILP; ii++) {
88 const auto i = i_start + threadIdx.x + ii * blockDim.x;
89 for (int r_index = 0; r_index < depth; r_index++) {
90 r_args[r_index][ii] = 0;
91 if (i < n && i < chunk_size) {
92 r_args[r_index][ii] = args[r_index][i];
93 }
94 }
95 }
96 }
97
98 template <typename T>
store_args(T * dst,T * src,const int64_t i_start,const int64_t chunk_size,const int64_t n)99 __device__ void store_args(
100 T* dst,
101 T* src,
102 const int64_t i_start,
103 const int64_t chunk_size,
104 const int64_t n) {
105 #pragma unroll
106 for (int ii = 0; ii < kILP; ii++) {
107 const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
108 if (i < n && i < chunk_size)
109 dst[i] = src[ii];
110 }
111 }
112
113 template <int res_arg_index, typename Op, typename T, typename opmath_t>
binary_op_scalar(T r_args[][kILP],T ** args,opmath_t scalar,const int64_t n,const int64_t chunk_size,const bool all_aligned,Op op)114 __device__ __forceinline__ void binary_op_scalar(
115 T r_args[][kILP],
116 T** args,
117 opmath_t scalar,
118 const int64_t n,
119 const int64_t chunk_size,
120 const bool all_aligned,
121 Op op) {
122 // to make things simple, we put aligned case in a different code path
123 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
124 for (int64_t i_start = threadIdx.x;
125 i_start * kILP < n && i_start * kILP < chunk_size;
126 i_start += blockDim.x) {
127 // load
128 load_store(r_args[0], args[0], 0, i_start);
129 #pragma unroll
130 for (int ii = 0; ii < kILP; ii++) {
131 r_args[0][ii] = static_cast<T>(
132 op(static_cast<opmath_t>(r_args[0][ii]),
133 static_cast<opmath_t>(scalar)));
134 }
135 // store
136 load_store(args[res_arg_index], r_args[0], i_start, 0);
137 }
138 } else {
139 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
140 i_start += blockDim.x * kILP) {
141 // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
142 // has depth 1
143 load_args<1>(r_args, args, i_start, chunk_size, n);
144 #pragma unroll
145 for (int ii = 0; ii < kILP; ii++) {
146 r_args[0][ii] = static_cast<T>(
147 op(static_cast<opmath_t>(r_args[0][ii]),
148 static_cast<opmath_t>(scalar)));
149 }
150 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
151 }
152 }
153 }
154
155 template <int res_arg_index, typename Op, typename T, typename opmath_t>
pointwise_op_scalar(T r_args[][kILP],T ** args,opmath_t scalar,const int64_t n,const int64_t chunk_size,const bool all_aligned,Op op)156 __device__ __forceinline__ void pointwise_op_scalar(
157 T r_args[][kILP],
158 T** args,
159 opmath_t scalar,
160 const int64_t n,
161 const int64_t chunk_size,
162 const bool all_aligned,
163 Op op) {
164 // to make things simple, we put aligned case in a different code path
165 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
166 for (int64_t i_start = threadIdx.x;
167 i_start * kILP < n && i_start * kILP < chunk_size;
168 i_start += blockDim.x) {
169 // load
170 load_store(r_args[0], args[0], 0, i_start);
171 load_store(r_args[1], args[1], 0, i_start);
172 load_store(r_args[2], args[2], 0, i_start);
173 #pragma unroll
174 for (int ii = 0; ii < kILP; ii++) {
175 r_args[0][ii] = static_cast<T>(
176 static_cast<opmath_t>(r_args[0][ii]) +
177 scalar *
178 op(static_cast<opmath_t>(r_args[1][ii]),
179 static_cast<opmath_t>(r_args[2][ii])));
180 }
181 // store
182 load_store(args[res_arg_index], r_args[0], i_start, 0);
183 }
184 } else {
185 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
186 i_start += blockDim.x * kILP) {
187 // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
188 // has depth 3
189 load_args<3>(r_args, args, i_start, chunk_size, n);
190 #pragma unroll
191 for (int ii = 0; ii < kILP; ii++) {
192 r_args[0][ii] = static_cast<T>(
193 static_cast<opmath_t>(r_args[0][ii]) +
194 scalar *
195 op(static_cast<opmath_t>(r_args[1][ii]),
196 static_cast<opmath_t>(r_args[2][ii])));
197 }
198 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
199 }
200 }
201 }
202
203 //
204 // Binary Functors
205 //
206 template <typename T, int depth, int r_args_depth, int res_arg_index>
207 struct BinaryOpScalarFunctor {
208 using opmath_t = at::opmath_type<T>;
209 template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarFunctor210 __device__ __forceinline__ void operator()(
211 int chunk_size,
212 TensorListMetadata<depth>& tl,
213 Op op,
214 opmath_t scalar) {
215 const int tensor_loc = tl.block_to_tensor[blockIdx.x];
216 const int chunk_idx = tl.block_to_chunk[blockIdx.x];
217 auto n = tl.numel_for_tensor[tensor_loc];
218
219 T* args[depth];
220 const bool all_aligned =
221 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
222 n -= chunk_idx * chunk_size;
223 T r_args[r_args_depth][kILP];
224
225 binary_op_scalar<res_arg_index>(
226 r_args, args, scalar, n, chunk_size, all_aligned, op);
227 }
228 };
229
230 template <typename T, int depth, int r_args_depth, int res_arg_index>
231 struct BinaryOpScalarListFunctor {
232 using opmath_t = at::opmath_type<T>;
233 template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarListFunctor234 __device__ __forceinline__ void operator()(
235 int chunk_size,
236 TensorListScalarListMetadata<opmath_t, depth>& tl,
237 Op op) {
238 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
239 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
240 auto n = tl.numel_for_tensor[tensor_loc];
241
242 T* args[depth];
243 const bool all_aligned =
244 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
245 opmath_t scalar = tl.scalar_vals[tensor_loc];
246 n -= chunk_idx * chunk_size;
247 T r_args[r_args_depth][kILP];
248
249 binary_op_scalar<res_arg_index>(
250 r_args, args, scalar, n, chunk_size, all_aligned, op);
251 }
252 };
253
254 template <typename T, int depth, int r_args_depth, int res_arg_index>
255 struct BinaryOpListAlphaFunctor {
256 using opmath_t = at::opmath_type<T>;
257 template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpListAlphaFunctor258 __device__ __forceinline__ void operator()(
259 int chunk_size,
260 TensorListMetadata<depth>& tl,
261 Op op,
262 opmath_t alpha) {
263 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
264 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
265 auto n = tl.numel_for_tensor[tensor_loc];
266
267 T* args[depth];
268 const bool all_aligned =
269 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
270 n -= chunk_idx * chunk_size;
271 T r_args[r_args_depth][kILP];
272
273 // to make things simple, we put aligned case in a different code path
274 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
275 for (int64_t i_start = threadIdx.x;
276 i_start * kILP < n && i_start * kILP < chunk_size;
277 i_start += blockDim.x) {
278 // load
279 load_store(r_args[0], args[0], 0, i_start);
280 load_store(r_args[1], args[1], 0, i_start);
281 #pragma unroll
282 for (int ii = 0; ii < kILP; ii++) {
283 r_args[0][ii] = static_cast<T>(
284 op(static_cast<opmath_t>(r_args[0][ii]),
285 alpha * static_cast<opmath_t>(r_args[1][ii])));
286 }
287 // store
288 load_store(args[res_arg_index], r_args[0], i_start, 0);
289 }
290 } else {
291 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
292 i_start += blockDim.x * kILP) {
293 load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
294 #pragma unroll
295 for (int ii = 0; ii < kILP; ii++) {
296 r_args[0][ii] = static_cast<T>(
297 op(static_cast<opmath_t>(r_args[0][ii]),
298 alpha * static_cast<opmath_t>(r_args[1][ii])));
299 }
300 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
301 }
302 }
303 }
304 };
305
306 template <typename T, int depth, int r_args_depth, int res_arg_index>
307 struct BinaryOpScalarTensorFunctor {
308 using opmath_t = at::opmath_type<T>;
309 template <typename Op>
operator ()at::native::__anonc1b891800111::BinaryOpScalarTensorFunctor310 __device__ __forceinline__ void operator()(
311 int chunk_size,
312 TensorListMetadata<depth>& tl,
313 Op op,
314 T* scalar,
315 opmath_t alpha) {
316 const int tensor_loc = tl.block_to_tensor[blockIdx.x];
317 const int chunk_idx = tl.block_to_chunk[blockIdx.x];
318 auto n = tl.numel_for_tensor[tensor_loc];
319
320 T* args[depth];
321 const bool all_aligned =
322 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
323 n -= chunk_idx * chunk_size;
324 T r_args[r_args_depth][kILP];
325
326 // to make things simple, we put aligned case in a different code path
327 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
328 for (int64_t i_start = threadIdx.x;
329 i_start * kILP < n && i_start * kILP < chunk_size;
330 i_start += blockDim.x) {
331 // load
332 load_store(r_args[0], args[0], 0, i_start);
333 #pragma unroll
334 for (int ii = 0; ii < kILP; ii++) {
335 r_args[0][ii] = static_cast<T>(op(
336 static_cast<opmath_t>(r_args[0][ii]),
337 static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
338 }
339 // store
340 load_store(args[res_arg_index], r_args[0], i_start, 0);
341 }
342 } else {
343 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
344 i_start += blockDim.x * kILP) {
345 // Regardless if depth is 1 (for inplace) or 2 (for out of place),
346 // r_args has depth 1
347 load_args<1>(r_args, args, i_start, chunk_size, n);
348 #pragma unroll
349 for (int ii = 0; ii < kILP; ii++) {
350 r_args[0][ii] = static_cast<T>(op(
351 static_cast<opmath_t>(r_args[0][ii]),
352 static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
353 }
354 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
355 }
356 }
357 }
358 };
359
360 //
361 // Unary Functors
362 //
363
364 template <typename T, int depth, int r_args_depth, int res_arg_index>
365 struct ZeroFunctor {
operator ()at::native::__anonc1b891800111::ZeroFunctor366 __device__ __forceinline__ void operator()(
367 int chunk_size,
368 TensorListMetadata<1>& tl) {
369 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
370 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
371 auto n = tl.numel_for_tensor[tensor_loc];
372
373 T* args[depth];
374 const auto all_aligned =
375 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
376 n -= chunk_idx * chunk_size;
377 T r_args[r_args_depth][kILP];
378
379 // to make things simple, we put aligned case in a different code path
380 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
381 for (int64_t i_start = threadIdx.x;
382 i_start * kILP < n && i_start * kILP < chunk_size;
383 i_start += blockDim.x) {
384 #pragma unroll
385 for (int ii = 0; ii < kILP; ii++) {
386 r_args[0][ii] = 0;
387 }
388 // store
389 load_store(args[0], r_args[0], i_start, 0);
390 }
391 } else {
392 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
393 i_start += blockDim.x * kILP) {
394 #pragma unroll
395 for (int ii = 0; ii < kILP; ii++) {
396 r_args[0][ii] = 0;
397 }
398 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
399 }
400 }
401 }
402 };
403
404 template <typename T, int depth, int r_args_depth, int res_arg_index>
405 struct UnaryOpFunctor {
406 using opmath_t = at::opmath_type<T>;
407 template <typename Op>
operator ()at::native::__anonc1b891800111::UnaryOpFunctor408 __device__ __forceinline__ void operator()(
409 int chunk_size,
410 TensorListMetadata<depth>& tl,
411 Op op) {
412 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
413 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
414 auto n = tl.numel_for_tensor[tensor_loc];
415
416 T* args[depth];
417 bool all_aligned =
418 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
419 n -= chunk_idx * chunk_size;
420 T r_args[r_args_depth][kILP];
421
422 // to make things simple, we put aligned case in a different code path
423 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
424 for (int64_t i_start = threadIdx.x;
425 i_start * kILP < n && i_start * kILP < chunk_size;
426 i_start += blockDim.x) {
427 // load
428 load_store(r_args[0], args[0], 0, i_start);
429 #pragma unroll
430 for (int ii = 0; ii < kILP; ii++) {
431 r_args[0][ii] =
432 static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
433 }
434 // store
435 load_store(args[res_arg_index], r_args[0], i_start, 0);
436 }
437 } else {
438 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
439 i_start += blockDim.x * kILP) {
440 load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
441 #pragma unroll
442 for (int ii = 0; ii < kILP; ii++) {
443 r_args[0][ii] =
444 static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
445 }
446 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
447 }
448 }
449 }
450 };
451
452 //
453 // Pointwise Functors
454 //
455
456 template <typename T, int depth, int r_args_depth, int res_arg_index>
457 struct PointwiseOpScalarFunctor {
458 using opmath_t = at::opmath_type<T>;
459 template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpScalarFunctor460 __device__ __forceinline__ void operator()(
461 int chunk_size,
462 TensorListMetadata<depth>& tl,
463 Op op,
464 opmath_t scalar) {
465 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
466 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
467 auto n = tl.numel_for_tensor[tensor_loc];
468
469 T* args[depth];
470 const bool all_aligned =
471 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
472 n -= chunk_idx * chunk_size;
473 T r_args[r_args_depth][kILP];
474
475 pointwise_op_scalar<res_arg_index>(
476 r_args, args, scalar, n, chunk_size, all_aligned, op);
477 }
478 };
479
480 template <typename T, int depth, int r_args_depth, int res_arg_index>
481 struct PointwiseOpScalarListFunctor {
482 using opmath_t = at::opmath_type<T>;
483 template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpScalarListFunctor484 __device__ __forceinline__ void operator()(
485 int chunk_size,
486 TensorListScalarListMetadata<opmath_t, depth>& tl,
487 Op op) {
488 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
489 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
490 auto n = tl.numel_for_tensor[tensor_loc];
491
492 T* args[depth];
493 const bool all_aligned =
494 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
495 opmath_t scalar = tl.scalar_vals[tensor_loc];
496 n -= chunk_idx * chunk_size;
497 T r_args[r_args_depth][kILP];
498
499 pointwise_op_scalar<res_arg_index>(
500 r_args, args, scalar, n, chunk_size, all_aligned, op);
501 }
502 };
503
504 template <typename T, int depth>
505 struct PointwiseOpListFunctor {
506 using opmath_t = at::opmath_type<T>;
507 template <typename Op>
operator ()at::native::__anonc1b891800111::PointwiseOpListFunctor508 __device__ __forceinline__ void operator()(
509 int chunk_size,
510 TensorListMetadata<depth>& tl,
511 Op op) {
512 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
513 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
514 auto n = tl.numel_for_tensor[tensor_loc];
515
516 T* args[depth];
517 const bool all_aligned =
518 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
519 n -= chunk_idx * chunk_size;
520 T r_args[depth - 1][kILP];
521
522 // to make things simple, we put aligned case in a different code path
523 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
524 for (int64_t i_start = threadIdx.x;
525 i_start * kILP < n && i_start * kILP < chunk_size;
526 i_start += blockDim.x) {
527 // load
528 load_store(r_args[0], args[0], 0, i_start);
529 load_store(r_args[1], args[1], 0, i_start);
530 #pragma unroll
531 for (int ii = 0; ii < kILP; ii++) {
532 r_args[0][ii] = static_cast<T>(
533 op(static_cast<opmath_t>(r_args[0][ii]),
534 static_cast<opmath_t>(r_args[1][ii])));
535 }
536 // store
537 load_store(args[2], r_args[0], i_start, 0);
538 }
539 } else {
540 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
541 i_start += blockDim.x * kILP) {
542 load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
543 #pragma unroll
544 for (int ii = 0; ii < kILP; ii++) {
545 r_args[0][ii] = static_cast<T>(
546 op(static_cast<opmath_t>(r_args[0][ii]),
547 static_cast<opmath_t>(r_args[1][ii])));
548 }
549 store_args(args[2], r_args[0], i_start, chunk_size, n);
550 }
551 }
552 }
553 };
554
555 template <typename T, int depth, int r_args_depth, int res_arg_index>
556 struct TernaryOpListFunctor {
557 using opmath_t = at::opmath_type<T>;
558 template <typename Op>
operator ()at::native::__anonc1b891800111::TernaryOpListFunctor559 __device__ __forceinline__ void operator()(
560 int chunk_size,
561 TensorListMetadata<depth>& tl,
562 Op op) {
563 static_assert(depth == 3 || depth == 4, "");
564 static_assert(depth >= r_args_depth, "");
565 static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
566 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
567 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
568 auto n = tl.numel_for_tensor[tensor_loc];
569
570 T* args[depth];
571 const bool all_aligned =
572 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
573 n -= chunk_idx * chunk_size;
574 T r_args[r_args_depth][kILP];
575
576 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
577 for (int64_t i_start = threadIdx.x;
578 i_start * kILP < n && i_start * kILP < chunk_size;
579 i_start += blockDim.x) {
580 load_store(r_args[0], args[0], 0, i_start);
581 load_store(r_args[1], args[1], 0, i_start);
582 load_store(r_args[2], args[2], 0, i_start);
583 #pragma unroll
584 for (int ii = 0; ii < kILP; ii++) {
585 r_args[0][ii] =
586 op(static_cast<opmath_t>(r_args[0][ii]),
587 static_cast<opmath_t>(r_args[1][ii]),
588 static_cast<opmath_t>(r_args[2][ii]));
589 }
590 load_store(args[res_arg_index], r_args[0], i_start, 0);
591 }
592 } else {
593 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
594 i_start += blockDim.x * kILP) {
595 load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
596 #pragma unroll
597 for (int ii = 0; ii < kILP; ii++) {
598 r_args[0][ii] =
599 op(static_cast<opmath_t>(r_args[0][ii]),
600 static_cast<opmath_t>(r_args[1][ii]),
601 static_cast<opmath_t>(r_args[2][ii]));
602 }
603 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
604 }
605 }
606 }
607 };
608
609 template <typename T, int depth, int r_args_depth, int res_arg_index>
610 struct TernaryOpScalarFunctor {
611 using opmath_t = at::opmath_type<T>;
612 template <typename Op>
operator ()at::native::__anonc1b891800111::TernaryOpScalarFunctor613 __device__ __forceinline__ void operator()(
614 int chunk_size,
615 TensorListMetadata<depth>& tl,
616 Op op,
617 opmath_t alpha) {
618 static_assert(depth == 2 || depth == 3, "");
619 static_assert(depth >= r_args_depth, "");
620 static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
621 const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
622 const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
623 auto n = tl.numel_for_tensor[tensor_loc];
624
625 T* args[depth];
626 const bool all_aligned =
627 init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
628 n -= chunk_idx * chunk_size;
629 T r_args[r_args_depth][kILP];
630
631 // to make things simple, we put aligned case in a different code path
632 if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
633 for (int64_t i_start = threadIdx.x;
634 i_start * kILP < n && i_start * kILP < chunk_size;
635 i_start += blockDim.x) {
636 // load
637 load_store(r_args[0], args[0], 0, i_start);
638 load_store(r_args[1], args[1], 0, i_start);
639 #pragma unroll
640 for (int ii = 0; ii < kILP; ii++) {
641 r_args[0][ii] =
642 op(static_cast<opmath_t>(r_args[0][ii]),
643 static_cast<opmath_t>(r_args[1][ii]),
644 alpha);
645 }
646 // store
647 load_store(args[res_arg_index], r_args[0], i_start, 0);
648 }
649 } else {
650 for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
651 i_start += blockDim.x * kILP) {
652 load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
653 #pragma unroll
654 for (int ii = 0; ii < kILP; ii++) {
655 r_args[0][ii] =
656 op(static_cast<opmath_t>(r_args[0][ii]),
657 static_cast<opmath_t>(r_args[1][ii]),
658 alpha);
659 }
660 store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
661 }
662 }
663 }
664 };
665
666 template <typename T>
667 struct power_functor {
operator ()at::native::__anonc1b891800111::power_functor668 C10_DEVICE T operator()(const T& a, const T& b) const {
669 return at::native::pow_(a, b);
670 }
671 };
672
673 template <typename T>
674 struct reverse_power_functor {
operator ()at::native::__anonc1b891800111::reverse_power_functor675 C10_DEVICE T operator()(const T& a, const T& b) const {
676 return at::native::pow_(b, a);
677 }
678 };
679
680 } // namespace
681 } // namespace at::native
682