xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LinearAlgebra.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/LinearAlgebra.h>
5 #include <ATen/native/BatchLinearAlgebra.h>
6 #include <ATen/native/DispatchStub.h>
7 #include <ATen/native/cuda/Loops.cuh>
8 #include <ATen/native/SharedReduceOps.h>
9 #include <ATen/native/ReduceOps.h>
10 #include <c10/core/Scalar.h>
11 
12 #include <thrust/swap.h>
13 
14 namespace at::native {
15 
16 namespace {
17 
addr_kernel_cuda(TensorIterator & iter,const Scalar & beta,const Scalar & alpha)18 void addr_kernel_cuda(TensorIterator &iter, const Scalar& beta, const Scalar& alpha) {
19   if (iter.dtype() == ScalarType::Bool) {
20     using scalar_t = bool;
21     auto beta_val = beta.to<scalar_t>();
22     auto alpha_val = alpha.to<scalar_t>();
23 
24     // when beta is false, values in self should be ignored,
25     // nans and infs in self should not propagate.
26     if (beta_val == false) {
27       gpu_kernel(
28         iter,
29         [=] GPU_LAMBDA (scalar_t self_val,
30                         scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
31           return alpha_val && vec1_val && vec2_val;
32         }
33       );
34     } else {
35       gpu_kernel(
36         iter,
37         [=] GPU_LAMBDA (scalar_t self_val,
38                         scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
39           return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
40         }
41       );
42     }
43     return;
44   }
45 
46   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
47                                          iter.dtype(), "addr_cuda", [&] {
48     auto beta_val = beta.to<scalar_t>();
49     auto alpha_val = alpha.to<scalar_t>();
50 
51     scalar_t zero_val(0);
52     // when beta==0, values in self should be ignored,
53     // nans and infs in self should not propagate.
54     if (beta_val == zero_val) {
55       gpu_kernel(
56         iter,
57         [=] GPU_LAMBDA (scalar_t self_val,
58                         scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
59           return alpha_val * vec1_val * vec2_val;
60         }
61       );
62     } else {
63       gpu_kernel(
64         iter,
65         [=] GPU_LAMBDA (scalar_t self_val,
66                         scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
67           return beta_val * self_val + alpha_val * vec1_val * vec2_val;
68         }
69       );
70     }
71   });
72 }
73 
74 
75 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)76 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
77 __global__ void _elementwise_kernel(int total_n_elems, func_t f) {
78   constexpr int total_work_block = n_threads * n_elems_per_thread;
79   int idx = total_work_block * blockIdx.x + threadIdx.x;
80 
81   #pragma unroll
82   for (int i = 0; i < n_elems_per_thread; ++i) {
83     if (idx < total_n_elems) {
84       f(idx);
85       idx += n_threads;
86     }
87   }
88 }
89 
90 template <int n_threads, int n_elems_per_thread, typename func_t>
_launch_kernel(int total_n_elems,func_t f)91 static void _launch_kernel(int total_n_elems, func_t f) {
92   TORCH_INTERNAL_ASSERT(
93     total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
94   );
95 
96   dim3 block(n_threads);
97   constexpr int total_work_block = n_threads * n_elems_per_thread;
98   dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
99 
100   auto stream = at::cuda::getCurrentCUDAStream();
101   _elementwise_kernel<n_threads, n_elems_per_thread, func_t>
102     <<<grid, block, 0, stream>>>(total_n_elems, f);
103   C10_CUDA_KERNEL_LAUNCH_CHECK();
104 }
105 
unpack_pivots_cuda_kernel(TensorIterator & iter,const int64_t dim_size,const int64_t max_pivot)106 void unpack_pivots_cuda_kernel(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) {
107   if (iter.numel() == 0) {
108     return;
109   }
110 
111   if (!iter.can_use_32bit_indexing()) {
112     for (auto& sub_iter : iter.with_32bit_indexing()) {
113       unpack_pivots_cuda_kernel(sub_iter, dim_size, max_pivot);
114     }
115     return;
116   }
117 
118   const auto offset_calculator = make_offset_calculator<2>(iter);
119 
120   const auto perm_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
121   const auto pivots_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
122 
123   auto loop = [=]C10_DEVICE(const int idx) {
124     const auto offsets = offset_calculator.get(idx);
125 
126     int64_t* const __restrict__ perm_data = reinterpret_cast<int64_t*>(perm_ptr + offsets[0]);
127     const int32_t* const __restrict__ pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr + offsets[1]);
128 
129     // QUESTION: can we mix 64bit offsets with 32bit Iterator indexing?
130     for (int64_t i = 0; i < dim_size; ++i) {
131       thrust::swap(
132         perm_data[i],
133         perm_data[pivots_data[i] - 1]
134       );
135     }
136   };
137 
138   _launch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
139 }
140 } // anonymous namespace
141 
142 REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cuda_kernel);
143 REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
144 } // namespace at::native
145