xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SoftMax.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/TensorOperators.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <c10/macros/Macros.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <type_traits>
13 
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/MemoryAccess.cuh>
16 #include <ATen/native/cuda/PersistentSoftmax.cuh>
17 #include <ATen/native/IndexingUtils.h>
18 #include <ATen/native/cuda/block_reduce.cuh>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_masked_softmax_native.h>
25 #include <ATen/ops/_log_softmax_native.h>
26 #include <ATen/ops/_log_softmax_backward_data_native.h>
27 #include <ATen/ops/_softmax_native.h>
28 #include <ATen/ops/_softmax_backward_data_native.h>
29 #include <ATen/ops/softmax.h>
30 #include <ATen/ops/_softmax_backward_data.h>
31 #endif
32 
33 namespace at::native {
34 
35 namespace {
36 
37 constexpr int ALIGN_BYTES = 16;
38 
39 template<typename T, typename AccumT, typename OutT>
40 struct LogSoftMaxForwardEpilogue {
LogSoftMaxForwardEpilogueat::native::__anonb786886e0111::LogSoftMaxForwardEpilogue41   __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
42     : max_input(max_input),  logsum(std::log(sum)) {}
43 
operator ()at::native::__anonb786886e0111::LogSoftMaxForwardEpilogue44   __device__ __forceinline__ OutT operator()(T input) const {
45     return static_cast<OutT>(input - max_input - logsum);
46 }
47 
48   const AccumT max_input;
49   const AccumT logsum;
50 };
51 
52 template<typename T, typename AccumT, typename OutT>
53 struct LogSoftMaxBackwardEpilogue {
LogSoftMaxBackwardEpilogueat::native::__anonb786886e0111::LogSoftMaxBackwardEpilogue54   __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
55     : sum(sum) {}
56 
operator ()at::native::__anonb786886e0111::LogSoftMaxBackwardEpilogue57   __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
58     return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
59   }
60 
61   const AccumT sum;
62 };
63 
64 template<typename T, typename AccumT, typename OutT>
65 struct SoftMaxForwardEpilogue {
SoftMaxForwardEpilogueat::native::__anonb786886e0111::SoftMaxForwardEpilogue66   __device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
67     : max_input(max_input)
68     , sum(sum) {}
69 
operator ()at::native::__anonb786886e0111::SoftMaxForwardEpilogue70   __device__ __forceinline__ OutT operator()(T input) const {
71     return static_cast<OutT>(std::exp(input - max_input) / sum);
72   }
73 
74   const AccumT max_input;
75   const AccumT sum;
76 };
77 
78 template<typename T, typename AccumT, typename OutT>
79 struct SoftMaxBackwardEpilogue {
SoftMaxBackwardEpilogueat::native::__anonb786886e0111::SoftMaxBackwardEpilogue80   __device__ __forceinline__ SoftMaxBackwardEpilogue(AccumT sum)
81     : sum(sum) {}
82 
83   // XXX: gradOutput that we get here is really gradOutput * output
84   // Look for cmul in SoftMax_updateGradInput
operator ()at::native::__anonb786886e0111::SoftMaxBackwardEpilogue85   __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
86     return static_cast<T>(gradOutput - output * sum);
87   }
88 
89   const AccumT sum;
90 };
91 
92 
93 
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 // Spatial kernel (fast with large inner_size and small dim_size)
97 ////////////////////////////////////////////////////////////////////////////////
98 // Let's assume that our input has been flattened to have only three dimension:
99 //     outer x dim x inner
100 // The spatial algorithm tries to parallelize along all of them.
101 // Within a 2d block threadIdx.y parallelizes over dim slices, and threads that
102 // share it will speed up reductions over dim (along axis x).
103 // The 2d grid is used to parallelize inner dimension over y axis and outer over x.
SpatialSoftMax_getGridSize(dim3 block,uint32_t max_active_blocks,uint64_t outer_size,uint64_t inner_size)104 inline dim3 SpatialSoftMax_getGridSize(
105     dim3 block, uint32_t max_active_blocks,
106     uint64_t outer_size, uint64_t inner_size) {
107   // First, tile as many blocks as we can over the y axis
108   uint32_t inner_blocks = (inner_size + block.y - 1) / block.y;
109   if (inner_blocks > max_active_blocks)
110     inner_blocks = max_active_blocks;
111   // Fill the x axis with as many blocks as we can fit (a little more is ok too)
112   uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
113   if (outer_blocks > outer_size)
114     outer_blocks = outer_size;
115   return dim3(outer_blocks, inner_blocks);
116 }
117 
118 const int max_threads = 1024;
119 
SpatialSoftMax_getBlockSize(uint64_t dim_size,uint64_t inner_size)120 inline dim3 SpatialSoftMax_getBlockSize(
121   uint64_t dim_size, uint64_t inner_size) {
122   uint32_t inner_threads = inner_size;
123   inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads));
124   uint32_t dim_threads = 1;
125   if (inner_threads <= 64 && dim_size >= 64) {
126     while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size)
127       dim_threads *= 2;
128     dim_threads /= 2;
129   }
130   return dim3(dim_threads, inner_threads);
131 }
132 
133 
134 template<typename accscalar_t, typename Kernel>
SpatialSoftMax_getLaunchSizes(Kernel k,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size,dim3 & grid,dim3 & block,uint32_t & smem_size)135 void SpatialSoftMax_getLaunchSizes(
136     Kernel k,
137     uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
138     dim3& grid, dim3& block, uint32_t& smem_size) {
139   block = SpatialSoftMax_getBlockSize(dim_size, inner_size);
140   uint32_t block_threads = block.x * block.y;
141   smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
142   int max_active_blocks;
143 #if defined(USE_ROCM) && TORCH_HIP_VERSION < 305
144   // HIP function signature is not compatible yet.
145   uint32_t max_blocks;
146   cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks,
147                                                 k, block_threads, smem_size);
148   max_active_blocks = max_blocks;
149 #else
150   cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
151                                                 k, block_threads, smem_size);
152 #endif
153   max_active_blocks *= at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
154   grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, inner_size);
155 }
156 
SoftMax_getBlockSize(int ILP,uint64_t dim_size)157 inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
158   uint64_t block_size = 1;
159   uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
160 
161   // In the vectorized case we want to trade off allowing more of the buffers to be accessed
162   // in a vectorized way against wanting a larger block size to get better utilisation.
163   // In general with ILP you can have (ILP-1)/ILP of the buffer accessed vectorised, at the risk
164   // of having a very small block size. We choose to keep >= 1/2 of the buffer vectorised while
165   // allowing a larger block size.
166   if (ILP > 1) {
167     max_block_size /= 2;
168   }
169 
170   while (block_size < (max_block_size)) block_size *= 2;
171   // Launch at least a single warp - the kernel assumes that.
172   block_size = std::max(block_size, static_cast<uint64_t>(at::cuda::warp_size()));
173   return dim3(block_size);
174 }
175 
SoftMaxForward_getBlockSize(uint64_t dim_size)176 inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
177   uint64_t block_size = 1;
178   uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
179 
180   // We need a block size that is a multiple of C10_WARP_SIZE in order
181   // to perform block size reductions using warp shuffle instructions.
182   // Since max_threads is also a multiple of C10_WARPS_SIZE we do not
183   // risk creating a block size larger than the limit.
184 
185   if (max_block_size % C10_WARP_SIZE == 0) {
186     block_size = max_block_size;
187   } else {
188     block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
189   }
190 
191   return dim3(block_size);
192 }
193 
194 template<typename T>
195 struct Add {
operator ()at::native::__anonb786886e0111::Add196   __device__ __forceinline__ T operator()(T a, T b) const {
197     return a + b;
198   }
199 
combineat::native::__anonb786886e0111::Add200   __device__ __forceinline__ T combine(T a, T b) const {
201     return a + b;
202   }
203 
204   // Needed to allow warp level reduction as a first step in the
205   // thread block reduction
warp_shfl_downat::native::__anonb786886e0111::Add206   __device__ __forceinline__ T warp_shfl_down(T data, int offset) const {
207     return WARP_SHFL_DOWN(data, offset);
208   }
209 };
210 
211 template<typename T>
212 struct Max {
operator ()at::native::__anonb786886e0111::Max213   __device__ __forceinline__ T operator()(T a, T b) const {
214     return a < b ? b : a;
215   }
216 
combineat::native::__anonb786886e0111::Max217   __device__ __forceinline__ T combine(T a, T b) const {
218     return a < b ? b : a;
219   }
220 
221   // Needed to allow warp level reduction as a first step in the
222   // thread block reduction
warp_shfl_downat::native::__anonb786886e0111::Max223   __device__ __forceinline__ T warp_shfl_down(T data, int offset) const {
224     return WARP_SHFL_DOWN(data, offset);
225   }
226 };
227 
228 // Note that it's not a complete block-wide reduction.
229 // Only threads that share threadIdx.y reduce values.
230 template<typename T, template<typename> class ReduceOp>
231 __forceinline__ __device__
spatialBlockReduceX(T * shared,T val)232 T spatialBlockReduceX(T *shared, T val) {
233   ReduceOp<T> r;
234   shared += threadIdx.y * blockDim.x;
235 
236   __syncthreads();
237 
238   shared[threadIdx.x] = val;
239 
240   // NOTE: loop starts with __syncthreads()
241   int offset = blockDim.x / 2;
242   while (offset > 0) {
243     __syncthreads();
244     if (threadIdx.x < offset)
245       shared[threadIdx.x] = r(shared[threadIdx.x], shared[threadIdx.x + offset]);
246     offset /= 2;
247   }
248 
249   __syncthreads();
250 
251   return shared[0];
252 }
253 
254 template <typename scalar_t, typename accscalar_t, typename outscalar_t, typename index_t, template<typename, typename, typename> class Epilogue>
cunn_SpatialSoftMaxForward(outscalar_t * output,const scalar_t * input,index_t outer_size,index_t dim_size,index_t inner_size)255 __global__ void cunn_SpatialSoftMaxForward(
256     outscalar_t *output, const scalar_t *input,
257     index_t outer_size, index_t dim_size, index_t inner_size)
258 {
259   extern __shared__ unsigned char smem[];
260   auto sdata = reinterpret_cast<accscalar_t*>(smem);
261   const index_t outer_stride = inner_size * dim_size;
262   const index_t dim_stride = inner_size;
263 
264   for (index_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
265     const index_t outer_offset = outer_index * outer_stride;
266     for (index_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
267       const index_t data_offset = outer_offset + inner_index;
268       ////////////////////////////////////////////////////////////
269       // These two blocks are really equivalent, but specializing on
270       // blockDim.x == 1 makes the kernel faster when it's unused.
271       // I didn't want to thread an extra template parameter, and nvcc
272       // seems to be smart enough to hoist the if outside of the loops.
273       ////////////////////////////////////////////////////////////
274 
275       if (blockDim.x > 1) {
276         accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
277         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
278           const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
279           max_input = Max<accscalar_t>()(max_input, value);
280         }
281         max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input);
282 
283         accscalar_t sum = 0;
284         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
285           sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
286                  - max_input);
287         sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
288 
289         Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
290         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
291           output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
292       } else {
293         accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
294         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
295           const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
296           max_input = Max<accscalar_t>()(max_input, value);
297         }
298         accscalar_t sum = 0;
299         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
300           sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
301                  - max_input);
302         Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
303         for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
304           output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
305       }
306     }
307   }
308 }
309 
310 
311 
312 template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
cunn_SpatialSoftMaxBackward(scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,uint32_t outer_size,uint32_t dim_size,uint32_t inner_size)313 __global__ void cunn_SpatialSoftMaxBackward(
314     scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput,
315     uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
316 {
317   extern __shared__ unsigned char smem[];
318   auto sdata = reinterpret_cast<accscalar_t*>(smem);
319   const uint32_t outer_stride = inner_size * dim_size;
320   const uint32_t dim_stride = inner_size;
321 
322   for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
323     const uint32_t outer_offset = outer_index * outer_stride;
324     for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
325       const uint32_t data_offset = outer_offset + inner_index;
326       // See the comment in forward kernel
327       if (blockDim.x > 1) {
328         accscalar_t sum = 0;
329         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
330           sum += gradOutput[data_offset + d * dim_stride];
331         sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
332 
333         Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
334         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
335           gradInput[data_offset + d * dim_stride] =
336             epilogue(gradOutput[data_offset + d * dim_stride],
337                     output[data_offset + d * dim_stride]);
338         }
339       } else {
340         accscalar_t sum = 0;
341         for (uint32_t d = 0; d < dim_size; d++)
342           sum += gradOutput[data_offset + d * dim_stride];
343 
344         Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
345         for (uint32_t d = 0; d < dim_size; d++) {
346           gradInput[data_offset + d * dim_stride] =
347             epilogue(gradOutput[data_offset + d * dim_stride],
348                     output[data_offset + d * dim_stride]);
349         }
350       }
351     }
352   }
353 }
354 
355 
356 ////////////////////////////////////////////////////////////////////////////////
357 // Regular kernel (fast when dim_size is large; requires inner_size == 1)
358 ////////////////////////////////////////////////////////////////////////////////
359 
360 
361 template <typename T, typename AccumT>
362 struct MaxFloat
363 {
operator ()at::native::__anonb786886e0111::MaxFloat364   __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
365     return ::max(max, (AccumT)v);
366   }
367 };
368 
369 template<typename T, typename AccumT>
370 struct AddFloat
371 {
operator ()at::native::__anonb786886e0111::AddFloat372   __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
373     return sum + v;
374   }
375 };
376 
377 template<typename T, typename AccumT>
378 struct SumExpFloat
379 {
SumExpFloatat::native::__anonb786886e0111::SumExpFloat380   __device__ __forceinline__ SumExpFloat(AccumT v)
381     : max_k(v) {}
382 
operator ()at::native::__anonb786886e0111::SumExpFloat383   __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
384     return sum + std::exp(v - max_k);
385   }
386 
387   const AccumT max_k;
388 };
389 
390 template <template<typename> class Reduction, typename AccumT>
391 __device__ __forceinline__ AccumT
blockReduce(AccumT * smem,AccumT val,const Reduction<AccumT> & r,AccumT defaultVal)392 blockReduce(AccumT* smem, AccumT val,
393             const Reduction<AccumT>& r,
394             AccumT defaultVal)
395 {
396   // To avoid RaW races from chaining blockReduce calls together, we need a sync here
397   __syncthreads();
398 
399   smem[threadIdx.x] = val;
400 
401   __syncthreads();
402 
403   AccumT warpVal = defaultVal;
404 
405   // First warp will perform per-warp reductions for the remaining warps
406   uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1;
407   if (threadIdx.x < C10_WARP_SIZE) {
408     int lane = threadIdx.x % C10_WARP_SIZE;
409     if (lane < blockDim.x / C10_WARP_SIZE) {
410 #pragma unroll
411       for (int i = 0; i < C10_WARP_SIZE; ++i) {
412         warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]);
413       }
414 #if !defined(USE_ROCM)
415       __syncwarp(mask);
416 #endif
417       smem[lane] = warpVal;
418     }
419   }
420 
421   __syncthreads();
422 
423   // First thread will perform a reduction of the above per-warp reductions
424   AccumT blockVal = defaultVal;
425 
426   if (threadIdx.x == 0) {
427     for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
428       blockVal = r(blockVal, smem[i]);
429     }
430     smem[0] = blockVal;
431   }
432 
433   // Sync and broadcast
434   __syncthreads();
435   return smem[0];
436 }
437 
438 // Performs a thread block reduction with a given functor but uses
439 // warp shuffles as the first step in the reduction
440 template <template<typename> class Reduction, typename T>
441 __device__ __forceinline__
blockReduceWarp(T * smem_cache,T value,const Reduction<T> & op,T defaultVal)442 T blockReduceWarp(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
443 {
444   T result = cuda_utils::BlockReduce<T, Reduction<T>>(value, op, defaultVal, smem_cache);
445   if (threadIdx.x == 0) {
446     smem_cache[0] = result;
447   }
448   __syncthreads();
449   return smem_cache[0];
450 }
451 
452 template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT, typename index_t=int>
453 __device__ __forceinline__ AccumT
ilpReduce(index_t shift,const T * data,index_t size,const Reduction<T,AccumT> & r,AccumT defaultVal)454 ilpReduce(index_t shift,
455           const T* data,
456           index_t size,
457           const Reduction<T, AccumT>& r,
458           AccumT defaultVal)
459 {
460   using LoadT = at::native::memory::aligned_vector<T, ILP>;
461   AccumT threadVal = defaultVal;
462   index_t offset = threadIdx.x;
463 
464   // shift and do 1
465   if(shift > 0){
466     data -= shift;
467     size += shift;
468     if(threadIdx.x >= shift){
469       threadVal = r(threadVal, data[offset]);
470     }
471     size -= blockDim.x;
472     data += blockDim.x;
473   }
474   index_t last = size % (ILP * blockDim.x);
475 
476   T v[ILP];
477   LoadT* value = reinterpret_cast<LoadT*>(&v);
478 
479   for (; offset * ILP < (size - last); offset += blockDim.x) {
480     *value = reinterpret_cast<const LoadT*>(data)[offset];
481 
482     #pragma unroll
483     for (int j = 0; j < ILP; ++j) {
484       threadVal = r(threadVal, v[j]);
485     }
486   }
487 
488   offset = size - last + threadIdx.x;
489   // Epilogue
490   for (; offset < size; offset += blockDim.x)
491     threadVal = r(threadVal, data[offset]);
492 
493   return threadVal;
494 }
495 
496 /**
497  * This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
498  */
499 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
500 __device__ __forceinline__ void
WriteFpropResultsVectorized(int size,const int shift,const scalar_t * input,outscalar_t * output,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)501 WriteFpropResultsVectorized(
502              int size,
503              const int shift,
504              const scalar_t *input,
505              outscalar_t *output,
506              Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
507   using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
508   using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
509 
510   int offset = threadIdx.x;
511 
512   // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later
513   if (shift > 0) {
514     input -= shift;
515     output -= shift;
516     size += shift;
517 
518     if (threadIdx.x >= shift) {
519       output[offset] = epilogue(input[offset]);
520     }
521     size -= blockDim.x;
522     input += blockDim.x;
523     output += blockDim.x;
524   }
525 
526   const int last = size % (ILP * blockDim.x);
527 
528   scalar_t in_v[ILP];
529   LoadT* in_value = reinterpret_cast<LoadT*>(&in_v);
530 
531   outscalar_t out_v[ILP];
532   const StoreT* out_value = reinterpret_cast<const StoreT*>(&out_v);
533 
534   for (; offset * ILP < (size - last); offset += blockDim.x) {
535     *in_value = reinterpret_cast<const LoadT*>(input)[offset];
536 
537     #pragma unroll
538     for (int j = 0; j < ILP; ++j) {
539       out_v[j] = epilogue(in_v[j]);
540     }
541 
542     reinterpret_cast<StoreT*>(output)[offset] = *out_value;
543   }
544 
545   offset = size - last + threadIdx.x;
546   // handle the tail
547   for (; offset < size; offset += blockDim.x) {
548     output[offset] = epilogue(input[offset]);
549   }
550 }
551 
552 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue, typename index_t = int32_t>
553 __device__ __forceinline__ void
WriteBpropResultsVectorized(index_t size,const index_t shift,scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)554 WriteBpropResultsVectorized(
555              index_t size,
556              const index_t shift,
557              scalar_t *gradInput,
558              const outscalar_t *output,
559              const outscalar_t *gradOutput,
560              Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
561   using gradInputT = at::native::memory::aligned_vector<scalar_t, ILP>;
562   using outputT = at::native::memory::aligned_vector<outscalar_t, ILP>;
563 
564   index_t offset = threadIdx.x;
565 
566   // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later
567   if (shift > 0) {
568     gradInput -= shift;
569     output -= shift;
570     gradOutput -= shift;
571     size += shift;
572 
573     if (threadIdx.x >= shift) {
574       gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
575     }
576     size -= blockDim.x;
577     gradInput += blockDim.x;
578     output += blockDim.x;
579     gradOutput += blockDim.x;
580   }
581 
582   const index_t last = size % (ILP * blockDim.x);
583 
584   scalar_t dX[ILP];
585   gradInputT *dX_v = reinterpret_cast<gradInputT*>(&dX);
586 
587   outscalar_t Y[ILP];
588   outputT *Y_v = reinterpret_cast<outputT*>(&Y);
589 
590   outscalar_t dY[ILP];
591   outputT *dY_v = reinterpret_cast<outputT*>(&dY);
592 
593   for (; offset * ILP < (size - last); offset += blockDim.x) {
594     *Y_v = reinterpret_cast<const outputT*>(output)[offset];
595     *dY_v = reinterpret_cast<const outputT*>(gradOutput)[offset];
596 
597     #pragma unroll
598     for (int j = 0; j < ILP; ++j) {
599       dX[j] = epilogue(dY[j], Y[j]);
600     }
601 
602     reinterpret_cast<gradInputT*>(gradInput)[offset] = *dX_v;
603   }
604 
605   offset = size - last + threadIdx.x;
606   for (; offset < size; offset += blockDim.x) {
607     gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
608   }
609 }
610 
611 /**
612  * This will apply the Epilogue with non-vectorized reads & writes for the general case
613  */
614 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
615 __device__ __forceinline__ void
WriteFpropResults(int classes,const scalar_t * input,outscalar_t * output,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)616 WriteFpropResults(
617              int classes,
618              const scalar_t *input,
619              outscalar_t *output,
620              Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
621   for (int offset = threadIdx.x; offset < classes; offset += blockDim.x) {
622     output[offset] = epilogue(input[offset]);
623   }
624 }
625 
626 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue, typename index_t>
627 __device__ __forceinline__ void
WriteBpropResults(int classes,scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)628 WriteBpropResults(
629              int classes,
630              scalar_t *gradInput,
631              const outscalar_t *output,
632              const outscalar_t *gradOutput,
633              Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
634 
635   index_t offset = threadIdx.x;
636 
637   index_t last = classes % (ILP * blockDim.x);
638 
639   for (; offset < classes - last; offset += blockDim.x * ILP) {
640     outscalar_t tmpOutput[ILP];
641     outscalar_t tmpGradOutput[ILP];
642 
643     #pragma unroll
644     for (int j = 0; j < ILP; ++j) {
645       tmpOutput[j] = output[offset + j * blockDim.x];
646       tmpGradOutput[j] = gradOutput[offset + j * blockDim.x];
647     }
648 
649     #pragma unroll
650     for (int j = 0; j < ILP; ++j) {
651       gradInput[offset + j * blockDim.x] = epilogue(tmpGradOutput[j], tmpOutput[j]);
652     }
653   }
654 
655   // Remainder - no ILP
656   for (; offset < classes; offset += blockDim.x) {
657     gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
658   }
659 }
660 
661 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
662 __global__ void
cunn_SoftMaxForward(outscalar_t * output,const scalar_t * input,int classes)663 cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
664 {
665   extern __shared__ unsigned char smem[];
666   auto sdata = reinterpret_cast<accscalar_t*>(smem);
667 
668   // forward pointers to batch[blockIdx.x]
669   // each block handles a sample in the mini-batch
670   input += static_cast<int64_t>(blockIdx.x) * classes;
671   output += static_cast<int64_t>(blockIdx.x) * classes;
672 
673   const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
674   const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
675 
676   // find the max
677   accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
678     shift, input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
679   accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(sdata, threadMax,
680     Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
681 
682   // reduce all values
683   accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
684     shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
685   accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(sdata, threadExp,
686     Add<accscalar_t>(), static_cast<accscalar_t>(0));
687 
688   Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
689 
690   if (shift == output_shift) {
691     WriteFpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, shift, input, output, epilogue);
692   } else {
693     WriteFpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, input, output, epilogue);
694   }
695 }
696 
697 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
698   template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
699 __global__ void
cunn_SoftMaxForwardSmem(outscalar_t * output,const scalar_t * input,index_t classes)700 cunn_SoftMaxForwardSmem(outscalar_t *output, const scalar_t *input, index_t classes)
701 {
702   // Each thread block processes a sample in the batch
703   input += static_cast<int64_t>(blockIdx.x) * classes;
704   output += static_cast<int64_t>(blockIdx.x) * classes;
705 
706   accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
707   accscalar_t threadExp = static_cast<accscalar_t>(0);
708 
709   // The first smem segment is used to cache input values and the last
710   // segment is used for thread block reductions
711   extern __shared__ unsigned char smem[];
712   auto smem_input_cache = reinterpret_cast<scalar_t*>(smem);
713   auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem +
714     classes * sizeof(scalar_t));
715 
716   using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
717   const LoadT* const input_vec_ptr = reinterpret_cast<const LoadT*>(input);
718   LoadT* const smem_input_cache_vec_ptr = reinterpret_cast<LoadT*>(smem_input_cache);
719 
720   // Download inputs to shared memory while doing the first step
721   // in max calculation
722   MaxFloat<scalar_t, accscalar_t> maxFunc;
723   for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
724     LoadT crnt_vec = input_vec_ptr[offset];
725     smem_input_cache_vec_ptr[offset] = crnt_vec;
726 
727     #pragma unroll
728     for (int i = 0; i < ILP; ++i) {
729       threadMax = maxFunc(threadMax, crnt_vec.val[i]);
730     }
731   }
732 
733   accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(smem_reduction_cache, threadMax,
734     Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
735 
736   // Reload input from shared memory to compute the sum. The previous
737   // reduce has performed a __syncthreads() so the smem contents are populated.
738   SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
739   for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
740     LoadT crnt_vec = smem_input_cache_vec_ptr[offset];
741 
742     #pragma unroll
743     for (int i = 0; i < ILP; ++i) {
744       threadExp = sumExpFunc(threadExp, crnt_vec.val[i]);
745     }
746   }
747 
748   accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(smem_reduction_cache, threadExp,
749     Add<accscalar_t>(), static_cast<accscalar_t>(0));
750 
751   Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
752 
753   // Use vectorized stores to save the output
754   using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
755   StoreT* output_vec_ptr = reinterpret_cast<StoreT*>(output);
756   for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
757     LoadT crnt_vec = smem_input_cache_vec_ptr[offset];
758     StoreT out_vec;
759 
760     #pragma unroll
761     for (int i = 0; i < ILP; ++i) {
762       out_vec.val[i] = epilogue(crnt_vec.val[i]);
763     }
764 
765     output_vec_ptr[offset] = out_vec;
766   }
767 }
768 
is_32bit_representable(const int64_t value)769 C10_DEVICE bool inline is_32bit_representable(const int64_t value) {
770   return value < static_cast<int64_t>(std::numeric_limits<int32_t>::max());
771 }
772 
773 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
774 __global__ void
cunn_SoftMaxBackward(scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,int64_t classes)775 cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput, int64_t classes)
776 {
777   using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
778   using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
779 
780   extern __shared__ unsigned char smem[];
781   auto sdata = reinterpret_cast<accscalar_t*>(smem);
782   gradInput += static_cast<int64_t>(blockIdx.x) * classes;
783   output += static_cast<int64_t>(blockIdx.x) * classes;
784   gradOutput += static_cast<int64_t>(blockIdx.x) * classes;
785 
786   const int64_t shift = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
787   const int64_t output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
788   const int64_t grad_output_shift = ((uint64_t)gradOutput) % ALIGN_BYTES / sizeof(outscalar_t);
789 
790   const bool can_use_32bit_indexing = is_32bit_representable(shift) && is_32bit_representable(output_shift) && is_32bit_representable(grad_output_shift) && is_32bit_representable(classes);
791   accscalar_t threadSum;
792   if (can_use_32bit_indexing) {
793     threadSum = ilpReduce<AddFloat, ILP, outscalar_t, accscalar_t, int32_t>(
794         static_cast<int32_t>(grad_output_shift), gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));
795   } else {
796     threadSum = ilpReduce<AddFloat, ILP, outscalar_t, accscalar_t, int64_t>(
797         grad_output_shift, gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));
798   }
799   accscalar_t sum_k = blockReduce<Add, accscalar_t>(
800         sdata, threadSum, Add<accscalar_t>(), accscalar_t(0));
801 
802   Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum_k);
803 
804   if (shift == output_shift && shift == grad_output_shift) {
805     if (can_use_32bit_indexing) {
806       WriteBpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int32_t>(classes, static_cast<int32_t>(shift), gradInput, output, gradOutput, epilogue);
807     } else {
808       WriteBpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int64_t>(classes, shift, gradInput, output, gradOutput, epilogue);
809     }
810   } else {
811     if (can_use_32bit_indexing) {
812       WriteBpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int32_t>(classes, gradInput, output, gradOutput, epilogue);
813     } else {
814       WriteBpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int64_t>(classes, gradInput, output, gradOutput, epilogue);
815     }
816   }
817 }
818 
819 template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
host_softmax(const Tensor & input_,const int64_t dim_,const bool half_to_float,const Tensor & output)820 Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
821   if (half_to_float) {
822     TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only");
823   }
824   auto input = input_.contiguous();
825   static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
826   if (input.dim() == 0) input = input.view(1);
827   int64_t dim = maybe_wrap_dim(dim_, input.dim());
828   TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions");
829   int64_t outer_size = 1;
830   int64_t dim_size = input.size(dim);
831 
832   if (input.numel() > 0) {
833     int64_t inner_size = 1;
834     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
835     for (int64_t i = 0; i < dim; ++i)
836       outer_size *= input.size(i);
837     for (int64_t i = dim + 1; i < input.dim(); ++i)
838       inner_size *= input.size(i);
839     // This kernel spawns a block per each element in the batch.
840     // XXX: it assumes that inner_size == 1
841 
842     if (inner_size == 1) {
843       dim3 grid(outer_size);
844       AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] {
845         using accscalar_t = acc_type<scalar_t, true>;
846         if (!half_to_float) {
847           auto output_ptr = output.mutable_data_ptr<scalar_t>();
848           auto input_ptr = input.const_data_ptr<scalar_t>();
849           if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
850             int64_t remaining = outer_size;
851             int64_t chunk_size = (1L << 30L) / dim_size;
852             while(remaining > 0) {
853               dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>(
854                 output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
855               input_ptr += chunk_size * dim_size;
856               output_ptr += chunk_size * dim_size;
857               remaining -= chunk_size;
858             }
859           } else {
860             constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
861             dim3 block = SoftMaxForward_getBlockSize(dim_size);
862             size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
863             auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
864               smem_reduction_sz) / sizeof(scalar_t);
865 
866             bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
867             can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
868             can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
869             can_use_smem &= !(dim_size % ILP);
870 
871             if (can_use_smem) {
872               size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
873               cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
874                 <<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
875             } else {
876               cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
877                 <<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
878             }
879 
880             C10_CUDA_KERNEL_LAUNCH_CHECK();
881           }
882         } else {
883           auto output_ptr = output.mutable_data_ptr<accscalar_t>();
884           auto input_ptr = input.const_data_ptr<scalar_t>();
885           if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
886             int64_t remaining = outer_size;
887             int64_t chunk_size = (1<<30) / dim_size;
888             while(remaining > 0) {
889               dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax, false>(
890                   output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
891               input_ptr += chunk_size * dim_size;
892               output_ptr += chunk_size * dim_size;
893               remaining -= chunk_size;
894             }
895           } else {
896             constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
897             dim3 block = SoftMaxForward_getBlockSize(dim_size);
898             size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
899             auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
900               smem_reduction_sz) / sizeof(scalar_t);
901 
902             bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
903             can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
904             can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
905             can_use_smem &= !(dim_size % ILP);
906 
907             if (can_use_smem) {
908               size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
909               cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
910                 <<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
911             } else {
912               cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
913                 <<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
914             }
915 
916             C10_CUDA_KERNEL_LAUNCH_CHECK();
917           }
918         }
919       });
920     // This kernel runs in a 2D grid, where each application along y dimension has a fixed
921     // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size.
922     // Reductions over dim are done in a single-threaded manner.
923     } else {
924       uint32_t smem_size;
925       dim3 grid, block;
926       AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] {
927         using accscalar_t = acc_type<scalar_t, true>;
928         AT_DISPATCH_INDEX_TYPES(
929             at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long,
930         "host_softmax_launcher", [&] {
931             if (!half_to_float) {
932                 SpatialSoftMax_getLaunchSizes<accscalar_t>(
933                     &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>,
934                     outer_size, dim_size, inner_size,
935                     grid, block, smem_size);
936                 cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>
937                   <<<grid, block, smem_size, stream>>>(
938                   output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
939                 C10_CUDA_KERNEL_LAUNCH_CHECK();
940             } else {
941                 SpatialSoftMax_getLaunchSizes<accscalar_t>(
942                     &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>,
943                     outer_size, dim_size, inner_size,
944                     grid, block, smem_size);
945                 cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>
946                   <<<grid, block, smem_size, stream>>>(
947                   output.mutable_data_ptr<accscalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
948                 C10_CUDA_KERNEL_LAUNCH_CHECK();
949             }
950          });
951       });
952     }
953   }
954   return output;
955 }
956 
957 template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
host_softmax_backward(const Tensor & grad_,const Tensor & output_,int64_t dim_,bool half_to_float,const Tensor & gI)958 void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float, const Tensor &gI){
959   int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
960   if (grad_.numel() == 0) {
961     return;
962   }
963   auto grad = grad_.contiguous();
964   static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
965   if (grad.dim() == 0) grad = grad.view(1);
966   TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
967   auto output = output_.contiguous();
968   if (output.dim() == 0) output = output.view(1);
969   int64_t outer_size = 1;
970   int64_t dim_size = output.size(dim);
971   int64_t inner_size = 1;
972   for (int64_t i = 0; i < dim; ++i)
973     outer_size *= output.size(i);
974   for (int64_t i = dim + 1; i < output.dim(); ++i)
975     inner_size *= output.size(i);
976 // See descriptions of kernels above.
977   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
978   if (inner_size == 1) {
979     dim3 grid(outer_size);
980     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] {
981     using accscalar_t = acc_type<scalar_t, true>;
982     if (!half_to_float) {
983       if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
984         auto gI_ptr = gI.mutable_data_ptr<scalar_t>();
985         auto grad_ptr = grad.const_data_ptr<scalar_t>();
986         auto output_ptr = output.const_data_ptr<scalar_t>();
987         int64_t remaining = outer_size;
988         int64_t chunk_size = (1<<30) / dim_size;
989         while(remaining > 0) {
990           dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>(
991             gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
992           gI_ptr += chunk_size * dim_size;
993           grad_ptr += chunk_size * dim_size;
994           output_ptr += chunk_size * dim_size;
995           remaining -= chunk_size;
996         }
997       } else {
998         constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
999         dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1000         cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
1001          <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1002             gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(), dim_size
1003         );
1004         C10_CUDA_KERNEL_LAUNCH_CHECK();
1005       }
1006     } else {
1007       if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
1008         auto gI_ptr = gI.mutable_data_ptr<scalar_t>();
1009         auto grad_ptr = grad.const_data_ptr<accscalar_t>();
1010         auto output_ptr = output.const_data_ptr<accscalar_t>();
1011         int64_t remaining = outer_size;
1012         int64_t chunk_size = (1<<30) / dim_size;
1013         while(remaining > 0) {
1014           dispatch_softmax_backward<accscalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>(
1015             gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
1016           gI_ptr += chunk_size * dim_size;
1017           grad_ptr += chunk_size * dim_size;
1018           output_ptr += chunk_size * dim_size;
1019           remaining -= chunk_size;
1020         }
1021       } else {
1022         constexpr int ILP = sizeof(float4) / sizeof(accscalar_t);
1023         dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1024         cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
1025          <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1026             gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<accscalar_t>(), grad.const_data_ptr<accscalar_t>(), dim_size
1027         );
1028         C10_CUDA_KERNEL_LAUNCH_CHECK();
1029       }
1030     }
1031     });
1032   } else {
1033     uint32_t smem_size;
1034     dim3 grid, block;
1035     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] {
1036       using accscalar_t = acc_type<scalar_t, true>;
1037       if (!half_to_float) {
1038           SpatialSoftMax_getLaunchSizes<accscalar_t>(
1039               &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>,
1040               outer_size, dim_size, inner_size,
1041               grid, block, smem_size);
1042 
1043           cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>
1044             <<<grid, block, smem_size, stream>>>(
1045               gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(),
1046               outer_size, dim_size, inner_size
1047           );
1048           C10_CUDA_KERNEL_LAUNCH_CHECK();
1049       } else {
1050           SpatialSoftMax_getLaunchSizes<accscalar_t>(
1051               &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>,
1052               outer_size, dim_size, inner_size,
1053               grid, block, smem_size);
1054 
1055           cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>
1056             <<<grid, block, smem_size, stream>>>(
1057               gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<accscalar_t>(), grad.const_data_ptr<accscalar_t>(),
1058               outer_size, dim_size, inner_size
1059           );
1060           C10_CUDA_KERNEL_LAUNCH_CHECK();
1061       }
1062     });
1063   }
1064 }
1065 }
1066 
TORCH_IMPL_FUNC(log_softmax_cuda_out)1067 TORCH_IMPL_FUNC(log_softmax_cuda_out) (
1068   const Tensor &input,
1069   const int64_t dim,
1070   const bool half_to_float,
1071   const Tensor &output) {
1072   host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output);
1073 }
1074 
TORCH_IMPL_FUNC(log_softmax_backward_cuda_out)1075 TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
1076   const Tensor& grad,
1077   const Tensor& output,
1078   int64_t dim,
1079   ScalarType input_dtype,
1080   const Tensor& grad_input) {
1081   bool half_to_float = grad.scalar_type() != input_dtype;
1082   if (half_to_float) {
1083     TORCH_CHECK(
1084         (grad.scalar_type() == ScalarType::Float &&
1085          input_dtype == ScalarType::Half),
1086         "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
1087   }
1088   host_softmax_backward<LogSoftMaxBackwardEpilogue, true>(grad, output, dim, half_to_float, grad_input);
1089 }
1090 
TORCH_IMPL_FUNC(softmax_cuda_out)1091 TORCH_IMPL_FUNC(softmax_cuda_out) (
1092   const Tensor &input,
1093   const int64_t dim,
1094   const bool half_to_float,
1095   const Tensor &output) {
1096   host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output);
1097 }
1098 
TORCH_IMPL_FUNC(softmax_backward_cuda_out)1099 TORCH_IMPL_FUNC(softmax_backward_cuda_out)
1100 (const Tensor& grad,
1101  const Tensor& output,
1102  int64_t dim,
1103  ScalarType input_dtype,
1104  const Tensor& grad_input) {
1105   bool half_to_float = grad.scalar_type() != input_dtype;
1106   if (half_to_float) {
1107     TORCH_CHECK(
1108         (grad.scalar_type() == ScalarType::Float &&
1109          input_dtype == ScalarType::Half),
1110         "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
1111   }
1112   Tensor tmp = grad * output;
1113   host_softmax_backward<SoftMaxBackwardEpilogue, false>(tmp, output, dim, half_to_float, grad_input);
1114 }
1115 
masked_softmax_cuda(const Tensor & input_,const Tensor & mask_,const std::optional<int64_t> dim_,const std::optional<int64_t> mask_type_)1116 Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const std::optional<int64_t> dim_, const std::optional<int64_t> mask_type_) {
1117   Tensor output = at::empty_like(input_, input_.options());
1118   TORCH_CHECK(mask_.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor");
1119 
1120   TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
1121   int64_t mask_type = mask_type_.value();
1122   TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)");
1123 
1124   // If input is [B, H, T, T] and mask is [B, T]
1125   // we have special fast kernel
1126   // mask_type == 1 => mask_ is a src_key_padding_mask
1127   bool is_BxT_mask = (mask_type == 1) && (input_.dim() == 4 && mask_.dim() == 2 && input_.size(0) == mask_.size(0) && input_.size(2) == mask_.size(1) && input_.size(3) == mask_.size(1));
1128 
1129   // If input is [B, H, T, T] and mask is [T, T]
1130   // expand mask to [B, H, T, T] and treat it like regular mask
1131   // TODO We should have special fast kernel for TxT mask as well
1132   // mask_type == 0 => mask_ is a src_mask
1133   bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
1134   // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
1135   TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
1136 
1137   auto input = input_.dim() == 0 ? input_.view(1) : input_;
1138   auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_;
1139   if (is_TxT_mask) {
1140     mask = mask.expand(input.sizes());
1141   }
1142   int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1;
1143 
1144   int softmax_elements = input.size(dim);
1145   // Persistent softmax is only supported when all of the conditions are held:
1146   //     1) softmax_elements <= 1024
1147   //     2) softmax_elements * input.element_size() <= 4096
1148   //     3) mask.is_contiguous()
1149   //     4) dim == input.dim() - 1
1150   // Otherwise, we fallback to vanilla softmax (where we do not support transformer_mask since converting the mask is expensive)
1151   if (softmax_elements > 1024 || softmax_elements * input.element_size() > 4096 || !mask.is_contiguous() || dim < input.dim()-1) {
1152     if (is_BxT_mask) {
1153       mask = mask.view({mask_.size(0), 1, 1, mask_.size(1)}).expand(input.sizes());
1154     }
1155     AT_DISPATCH_FLOATING_TYPES_AND2(
1156       ScalarType::Half,
1157       ScalarType::BFloat16,
1158       input.scalar_type(),
1159       "masked_softmax",
1160       [&] {
1161         output = at::softmax(input.masked_fill(mask, -std::numeric_limits<scalar_t>::infinity()), dim);
1162       });
1163     return output;
1164   }
1165   int batch_count = input.numel() / softmax_elements;
1166   int chunk_size = input.numel() / input.size(0);
1167   if (is_BxT_mask) {
1168     // Only support when num_heads is even in transformer
1169     TORCH_CHECK(input.size(1) % 2 == 0, "Only support when num_heads is even in transformer");
1170     AT_DISPATCH_FLOATING_TYPES_AND2(
1171       ScalarType::Half,
1172       ScalarType::BFloat16,
1173       input.scalar_type(),
1174       "masked_softmax",
1175       [&] {
1176         using accscalar_t = acc_type<scalar_t, true>;
1177         dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>(
1178           output.mutable_data_ptr<scalar_t>(),    // dst
1179           input.const_data_ptr<scalar_t>(),       // src
1180           softmax_elements,
1181           softmax_elements,
1182           batch_count,
1183           mask.const_data_ptr<bool>(),
1184           chunk_size,
1185           true // is_transformer_mask
1186         );
1187       });
1188 
1189   } else {
1190     AT_DISPATCH_FLOATING_TYPES_AND2(
1191       ScalarType::Half,
1192       ScalarType::BFloat16,
1193       input.scalar_type(),
1194       "masked_softmax",
1195       [&] {
1196         using accscalar_t = acc_type<scalar_t, true>;
1197         dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>(
1198           output.mutable_data_ptr<scalar_t>(),    // dst
1199           input.const_data_ptr<scalar_t>(),       // src
1200           softmax_elements,
1201           softmax_elements,
1202           batch_count,
1203           mask.const_data_ptr<bool>()
1204         );
1205       });
1206   }
1207   return output;
1208 }
1209 
masked_softmax_backward_cuda(const Tensor & grad_,const Tensor & output_,const Tensor & mask_,const std::optional<int64_t> dim_)1210 Tensor masked_softmax_backward_cuda(
1211     const Tensor& grad_,
1212     const Tensor& output_,
1213     const Tensor& mask_,
1214     const std::optional<int64_t> dim_) {
1215   Tensor grad_input = at::empty_like(grad_, grad_.options());
1216   if (grad_.numel() == 0) {
1217     return grad_input;
1218   }
1219 
1220   auto grad = grad_.contiguous();
1221   auto output = output_.contiguous();
1222   auto mask = mask_.contiguous();
1223   int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) : output.dim() - 1;
1224 
1225   grad = grad.dim() == 0 ? grad.view(1) : grad;
1226   mask = mask.dim() == 0 ? mask.view(1) : mask;
1227   output = output.dim() == 0 ? output.view(1) : output;
1228 
1229   TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
1230   TORCH_CHECK(grad.sizes() == mask.sizes(), "Mask shape should match grad shape");
1231   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor");
1232 
1233   int softmax_elements = output.size(dim);
1234   int64_t batch_count = grad.numel() / softmax_elements;
1235 
1236   if (softmax_elements > 1024 || softmax_elements * grad.element_size() > 4096 || dim < grad.dim()-1) {
1237     AT_DISPATCH_FLOATING_TYPES_AND2(
1238       ScalarType::Half,
1239       ScalarType::BFloat16,
1240       grad_input.scalar_type(),
1241       "masked_softmax_backward",
1242       [&] {
1243         grad_input = at::_softmax_backward_data(
1244           grad,
1245           output.masked_fill(mask, 0),
1246           dim,
1247           grad.scalar_type()
1248         );
1249       });
1250   } else {
1251     grad = grad * output;
1252     AT_DISPATCH_FLOATING_TYPES_AND2(
1253       ScalarType::Half,
1254       ScalarType::BFloat16,
1255       grad_input.scalar_type(),
1256       "masked_softmax_backward",
1257       [&] {
1258         using accscalar_t = acc_type<scalar_t, true>;
1259         dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, false, true /* masked_softmax */>(
1260           grad_input.mutable_data_ptr<scalar_t>(),  // gI_ptr
1261           grad.const_data_ptr<scalar_t>(),  // grad_ptr
1262           output.const_data_ptr<scalar_t>(),  // output_ptr
1263           softmax_elements,  // softmax_elements
1264           softmax_elements,   // softmax_elements_stride
1265           batch_count,  // batch_count
1266           mask.const_data_ptr<bool>()  /* not masked */
1267         );
1268       });
1269   }
1270   return grad_input;
1271 }
1272 
1273 } // namespace at::native
1274