1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/TensorOperators.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <c10/macros/Macros.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <type_traits>
13
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/MemoryAccess.cuh>
16 #include <ATen/native/cuda/PersistentSoftmax.cuh>
17 #include <ATen/native/IndexingUtils.h>
18 #include <ATen/native/cuda/block_reduce.cuh>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_masked_softmax_native.h>
25 #include <ATen/ops/_log_softmax_native.h>
26 #include <ATen/ops/_log_softmax_backward_data_native.h>
27 #include <ATen/ops/_softmax_native.h>
28 #include <ATen/ops/_softmax_backward_data_native.h>
29 #include <ATen/ops/softmax.h>
30 #include <ATen/ops/_softmax_backward_data.h>
31 #endif
32
33 namespace at::native {
34
35 namespace {
36
37 constexpr int ALIGN_BYTES = 16;
38
39 template<typename T, typename AccumT, typename OutT>
40 struct LogSoftMaxForwardEpilogue {
LogSoftMaxForwardEpilogueat::native::__anonb786886e0111::LogSoftMaxForwardEpilogue41 __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
42 : max_input(max_input), logsum(std::log(sum)) {}
43
operator ()at::native::__anonb786886e0111::LogSoftMaxForwardEpilogue44 __device__ __forceinline__ OutT operator()(T input) const {
45 return static_cast<OutT>(input - max_input - logsum);
46 }
47
48 const AccumT max_input;
49 const AccumT logsum;
50 };
51
52 template<typename T, typename AccumT, typename OutT>
53 struct LogSoftMaxBackwardEpilogue {
LogSoftMaxBackwardEpilogueat::native::__anonb786886e0111::LogSoftMaxBackwardEpilogue54 __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
55 : sum(sum) {}
56
operator ()at::native::__anonb786886e0111::LogSoftMaxBackwardEpilogue57 __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
58 return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
59 }
60
61 const AccumT sum;
62 };
63
64 template<typename T, typename AccumT, typename OutT>
65 struct SoftMaxForwardEpilogue {
SoftMaxForwardEpilogueat::native::__anonb786886e0111::SoftMaxForwardEpilogue66 __device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
67 : max_input(max_input)
68 , sum(sum) {}
69
operator ()at::native::__anonb786886e0111::SoftMaxForwardEpilogue70 __device__ __forceinline__ OutT operator()(T input) const {
71 return static_cast<OutT>(std::exp(input - max_input) / sum);
72 }
73
74 const AccumT max_input;
75 const AccumT sum;
76 };
77
78 template<typename T, typename AccumT, typename OutT>
79 struct SoftMaxBackwardEpilogue {
SoftMaxBackwardEpilogueat::native::__anonb786886e0111::SoftMaxBackwardEpilogue80 __device__ __forceinline__ SoftMaxBackwardEpilogue(AccumT sum)
81 : sum(sum) {}
82
83 // XXX: gradOutput that we get here is really gradOutput * output
84 // Look for cmul in SoftMax_updateGradInput
operator ()at::native::__anonb786886e0111::SoftMaxBackwardEpilogue85 __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
86 return static_cast<T>(gradOutput - output * sum);
87 }
88
89 const AccumT sum;
90 };
91
92
93
94
95 ////////////////////////////////////////////////////////////////////////////////
96 // Spatial kernel (fast with large inner_size and small dim_size)
97 ////////////////////////////////////////////////////////////////////////////////
98 // Let's assume that our input has been flattened to have only three dimension:
99 // outer x dim x inner
100 // The spatial algorithm tries to parallelize along all of them.
101 // Within a 2d block threadIdx.y parallelizes over dim slices, and threads that
102 // share it will speed up reductions over dim (along axis x).
103 // The 2d grid is used to parallelize inner dimension over y axis and outer over x.
SpatialSoftMax_getGridSize(dim3 block,uint32_t max_active_blocks,uint64_t outer_size,uint64_t inner_size)104 inline dim3 SpatialSoftMax_getGridSize(
105 dim3 block, uint32_t max_active_blocks,
106 uint64_t outer_size, uint64_t inner_size) {
107 // First, tile as many blocks as we can over the y axis
108 uint32_t inner_blocks = (inner_size + block.y - 1) / block.y;
109 if (inner_blocks > max_active_blocks)
110 inner_blocks = max_active_blocks;
111 // Fill the x axis with as many blocks as we can fit (a little more is ok too)
112 uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
113 if (outer_blocks > outer_size)
114 outer_blocks = outer_size;
115 return dim3(outer_blocks, inner_blocks);
116 }
117
118 const int max_threads = 1024;
119
SpatialSoftMax_getBlockSize(uint64_t dim_size,uint64_t inner_size)120 inline dim3 SpatialSoftMax_getBlockSize(
121 uint64_t dim_size, uint64_t inner_size) {
122 uint32_t inner_threads = inner_size;
123 inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads));
124 uint32_t dim_threads = 1;
125 if (inner_threads <= 64 && dim_size >= 64) {
126 while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size)
127 dim_threads *= 2;
128 dim_threads /= 2;
129 }
130 return dim3(dim_threads, inner_threads);
131 }
132
133
134 template<typename accscalar_t, typename Kernel>
SpatialSoftMax_getLaunchSizes(Kernel k,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size,dim3 & grid,dim3 & block,uint32_t & smem_size)135 void SpatialSoftMax_getLaunchSizes(
136 Kernel k,
137 uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
138 dim3& grid, dim3& block, uint32_t& smem_size) {
139 block = SpatialSoftMax_getBlockSize(dim_size, inner_size);
140 uint32_t block_threads = block.x * block.y;
141 smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
142 int max_active_blocks;
143 #if defined(USE_ROCM) && TORCH_HIP_VERSION < 305
144 // HIP function signature is not compatible yet.
145 uint32_t max_blocks;
146 cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks,
147 k, block_threads, smem_size);
148 max_active_blocks = max_blocks;
149 #else
150 cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
151 k, block_threads, smem_size);
152 #endif
153 max_active_blocks *= at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
154 grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, inner_size);
155 }
156
SoftMax_getBlockSize(int ILP,uint64_t dim_size)157 inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
158 uint64_t block_size = 1;
159 uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
160
161 // In the vectorized case we want to trade off allowing more of the buffers to be accessed
162 // in a vectorized way against wanting a larger block size to get better utilisation.
163 // In general with ILP you can have (ILP-1)/ILP of the buffer accessed vectorised, at the risk
164 // of having a very small block size. We choose to keep >= 1/2 of the buffer vectorised while
165 // allowing a larger block size.
166 if (ILP > 1) {
167 max_block_size /= 2;
168 }
169
170 while (block_size < (max_block_size)) block_size *= 2;
171 // Launch at least a single warp - the kernel assumes that.
172 block_size = std::max(block_size, static_cast<uint64_t>(at::cuda::warp_size()));
173 return dim3(block_size);
174 }
175
SoftMaxForward_getBlockSize(uint64_t dim_size)176 inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
177 uint64_t block_size = 1;
178 uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
179
180 // We need a block size that is a multiple of C10_WARP_SIZE in order
181 // to perform block size reductions using warp shuffle instructions.
182 // Since max_threads is also a multiple of C10_WARPS_SIZE we do not
183 // risk creating a block size larger than the limit.
184
185 if (max_block_size % C10_WARP_SIZE == 0) {
186 block_size = max_block_size;
187 } else {
188 block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
189 }
190
191 return dim3(block_size);
192 }
193
194 template<typename T>
195 struct Add {
operator ()at::native::__anonb786886e0111::Add196 __device__ __forceinline__ T operator()(T a, T b) const {
197 return a + b;
198 }
199
combineat::native::__anonb786886e0111::Add200 __device__ __forceinline__ T combine(T a, T b) const {
201 return a + b;
202 }
203
204 // Needed to allow warp level reduction as a first step in the
205 // thread block reduction
warp_shfl_downat::native::__anonb786886e0111::Add206 __device__ __forceinline__ T warp_shfl_down(T data, int offset) const {
207 return WARP_SHFL_DOWN(data, offset);
208 }
209 };
210
211 template<typename T>
212 struct Max {
operator ()at::native::__anonb786886e0111::Max213 __device__ __forceinline__ T operator()(T a, T b) const {
214 return a < b ? b : a;
215 }
216
combineat::native::__anonb786886e0111::Max217 __device__ __forceinline__ T combine(T a, T b) const {
218 return a < b ? b : a;
219 }
220
221 // Needed to allow warp level reduction as a first step in the
222 // thread block reduction
warp_shfl_downat::native::__anonb786886e0111::Max223 __device__ __forceinline__ T warp_shfl_down(T data, int offset) const {
224 return WARP_SHFL_DOWN(data, offset);
225 }
226 };
227
228 // Note that it's not a complete block-wide reduction.
229 // Only threads that share threadIdx.y reduce values.
230 template<typename T, template<typename> class ReduceOp>
231 __forceinline__ __device__
spatialBlockReduceX(T * shared,T val)232 T spatialBlockReduceX(T *shared, T val) {
233 ReduceOp<T> r;
234 shared += threadIdx.y * blockDim.x;
235
236 __syncthreads();
237
238 shared[threadIdx.x] = val;
239
240 // NOTE: loop starts with __syncthreads()
241 int offset = blockDim.x / 2;
242 while (offset > 0) {
243 __syncthreads();
244 if (threadIdx.x < offset)
245 shared[threadIdx.x] = r(shared[threadIdx.x], shared[threadIdx.x + offset]);
246 offset /= 2;
247 }
248
249 __syncthreads();
250
251 return shared[0];
252 }
253
254 template <typename scalar_t, typename accscalar_t, typename outscalar_t, typename index_t, template<typename, typename, typename> class Epilogue>
cunn_SpatialSoftMaxForward(outscalar_t * output,const scalar_t * input,index_t outer_size,index_t dim_size,index_t inner_size)255 __global__ void cunn_SpatialSoftMaxForward(
256 outscalar_t *output, const scalar_t *input,
257 index_t outer_size, index_t dim_size, index_t inner_size)
258 {
259 extern __shared__ unsigned char smem[];
260 auto sdata = reinterpret_cast<accscalar_t*>(smem);
261 const index_t outer_stride = inner_size * dim_size;
262 const index_t dim_stride = inner_size;
263
264 for (index_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
265 const index_t outer_offset = outer_index * outer_stride;
266 for (index_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
267 const index_t data_offset = outer_offset + inner_index;
268 ////////////////////////////////////////////////////////////
269 // These two blocks are really equivalent, but specializing on
270 // blockDim.x == 1 makes the kernel faster when it's unused.
271 // I didn't want to thread an extra template parameter, and nvcc
272 // seems to be smart enough to hoist the if outside of the loops.
273 ////////////////////////////////////////////////////////////
274
275 if (blockDim.x > 1) {
276 accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
277 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
278 const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
279 max_input = Max<accscalar_t>()(max_input, value);
280 }
281 max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input);
282
283 accscalar_t sum = 0;
284 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
285 sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
286 - max_input);
287 sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
288
289 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
290 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
291 output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
292 } else {
293 accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
294 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
295 const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
296 max_input = Max<accscalar_t>()(max_input, value);
297 }
298 accscalar_t sum = 0;
299 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
300 sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
301 - max_input);
302 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
303 for (index_t d = threadIdx.x; d < dim_size; d += blockDim.x)
304 output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
305 }
306 }
307 }
308 }
309
310
311
312 template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
cunn_SpatialSoftMaxBackward(scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,uint32_t outer_size,uint32_t dim_size,uint32_t inner_size)313 __global__ void cunn_SpatialSoftMaxBackward(
314 scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput,
315 uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
316 {
317 extern __shared__ unsigned char smem[];
318 auto sdata = reinterpret_cast<accscalar_t*>(smem);
319 const uint32_t outer_stride = inner_size * dim_size;
320 const uint32_t dim_stride = inner_size;
321
322 for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
323 const uint32_t outer_offset = outer_index * outer_stride;
324 for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
325 const uint32_t data_offset = outer_offset + inner_index;
326 // See the comment in forward kernel
327 if (blockDim.x > 1) {
328 accscalar_t sum = 0;
329 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
330 sum += gradOutput[data_offset + d * dim_stride];
331 sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
332
333 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
334 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
335 gradInput[data_offset + d * dim_stride] =
336 epilogue(gradOutput[data_offset + d * dim_stride],
337 output[data_offset + d * dim_stride]);
338 }
339 } else {
340 accscalar_t sum = 0;
341 for (uint32_t d = 0; d < dim_size; d++)
342 sum += gradOutput[data_offset + d * dim_stride];
343
344 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
345 for (uint32_t d = 0; d < dim_size; d++) {
346 gradInput[data_offset + d * dim_stride] =
347 epilogue(gradOutput[data_offset + d * dim_stride],
348 output[data_offset + d * dim_stride]);
349 }
350 }
351 }
352 }
353 }
354
355
356 ////////////////////////////////////////////////////////////////////////////////
357 // Regular kernel (fast when dim_size is large; requires inner_size == 1)
358 ////////////////////////////////////////////////////////////////////////////////
359
360
361 template <typename T, typename AccumT>
362 struct MaxFloat
363 {
operator ()at::native::__anonb786886e0111::MaxFloat364 __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
365 return ::max(max, (AccumT)v);
366 }
367 };
368
369 template<typename T, typename AccumT>
370 struct AddFloat
371 {
operator ()at::native::__anonb786886e0111::AddFloat372 __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
373 return sum + v;
374 }
375 };
376
377 template<typename T, typename AccumT>
378 struct SumExpFloat
379 {
SumExpFloatat::native::__anonb786886e0111::SumExpFloat380 __device__ __forceinline__ SumExpFloat(AccumT v)
381 : max_k(v) {}
382
operator ()at::native::__anonb786886e0111::SumExpFloat383 __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
384 return sum + std::exp(v - max_k);
385 }
386
387 const AccumT max_k;
388 };
389
390 template <template<typename> class Reduction, typename AccumT>
391 __device__ __forceinline__ AccumT
blockReduce(AccumT * smem,AccumT val,const Reduction<AccumT> & r,AccumT defaultVal)392 blockReduce(AccumT* smem, AccumT val,
393 const Reduction<AccumT>& r,
394 AccumT defaultVal)
395 {
396 // To avoid RaW races from chaining blockReduce calls together, we need a sync here
397 __syncthreads();
398
399 smem[threadIdx.x] = val;
400
401 __syncthreads();
402
403 AccumT warpVal = defaultVal;
404
405 // First warp will perform per-warp reductions for the remaining warps
406 uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1;
407 if (threadIdx.x < C10_WARP_SIZE) {
408 int lane = threadIdx.x % C10_WARP_SIZE;
409 if (lane < blockDim.x / C10_WARP_SIZE) {
410 #pragma unroll
411 for (int i = 0; i < C10_WARP_SIZE; ++i) {
412 warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]);
413 }
414 #if !defined(USE_ROCM)
415 __syncwarp(mask);
416 #endif
417 smem[lane] = warpVal;
418 }
419 }
420
421 __syncthreads();
422
423 // First thread will perform a reduction of the above per-warp reductions
424 AccumT blockVal = defaultVal;
425
426 if (threadIdx.x == 0) {
427 for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
428 blockVal = r(blockVal, smem[i]);
429 }
430 smem[0] = blockVal;
431 }
432
433 // Sync and broadcast
434 __syncthreads();
435 return smem[0];
436 }
437
438 // Performs a thread block reduction with a given functor but uses
439 // warp shuffles as the first step in the reduction
440 template <template<typename> class Reduction, typename T>
441 __device__ __forceinline__
blockReduceWarp(T * smem_cache,T value,const Reduction<T> & op,T defaultVal)442 T blockReduceWarp(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
443 {
444 T result = cuda_utils::BlockReduce<T, Reduction<T>>(value, op, defaultVal, smem_cache);
445 if (threadIdx.x == 0) {
446 smem_cache[0] = result;
447 }
448 __syncthreads();
449 return smem_cache[0];
450 }
451
452 template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT, typename index_t=int>
453 __device__ __forceinline__ AccumT
ilpReduce(index_t shift,const T * data,index_t size,const Reduction<T,AccumT> & r,AccumT defaultVal)454 ilpReduce(index_t shift,
455 const T* data,
456 index_t size,
457 const Reduction<T, AccumT>& r,
458 AccumT defaultVal)
459 {
460 using LoadT = at::native::memory::aligned_vector<T, ILP>;
461 AccumT threadVal = defaultVal;
462 index_t offset = threadIdx.x;
463
464 // shift and do 1
465 if(shift > 0){
466 data -= shift;
467 size += shift;
468 if(threadIdx.x >= shift){
469 threadVal = r(threadVal, data[offset]);
470 }
471 size -= blockDim.x;
472 data += blockDim.x;
473 }
474 index_t last = size % (ILP * blockDim.x);
475
476 T v[ILP];
477 LoadT* value = reinterpret_cast<LoadT*>(&v);
478
479 for (; offset * ILP < (size - last); offset += blockDim.x) {
480 *value = reinterpret_cast<const LoadT*>(data)[offset];
481
482 #pragma unroll
483 for (int j = 0; j < ILP; ++j) {
484 threadVal = r(threadVal, v[j]);
485 }
486 }
487
488 offset = size - last + threadIdx.x;
489 // Epilogue
490 for (; offset < size; offset += blockDim.x)
491 threadVal = r(threadVal, data[offset]);
492
493 return threadVal;
494 }
495
496 /**
497 * This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
498 */
499 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
500 __device__ __forceinline__ void
WriteFpropResultsVectorized(int size,const int shift,const scalar_t * input,outscalar_t * output,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)501 WriteFpropResultsVectorized(
502 int size,
503 const int shift,
504 const scalar_t *input,
505 outscalar_t *output,
506 Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
507 using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
508 using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
509
510 int offset = threadIdx.x;
511
512 // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later
513 if (shift > 0) {
514 input -= shift;
515 output -= shift;
516 size += shift;
517
518 if (threadIdx.x >= shift) {
519 output[offset] = epilogue(input[offset]);
520 }
521 size -= blockDim.x;
522 input += blockDim.x;
523 output += blockDim.x;
524 }
525
526 const int last = size % (ILP * blockDim.x);
527
528 scalar_t in_v[ILP];
529 LoadT* in_value = reinterpret_cast<LoadT*>(&in_v);
530
531 outscalar_t out_v[ILP];
532 const StoreT* out_value = reinterpret_cast<const StoreT*>(&out_v);
533
534 for (; offset * ILP < (size - last); offset += blockDim.x) {
535 *in_value = reinterpret_cast<const LoadT*>(input)[offset];
536
537 #pragma unroll
538 for (int j = 0; j < ILP; ++j) {
539 out_v[j] = epilogue(in_v[j]);
540 }
541
542 reinterpret_cast<StoreT*>(output)[offset] = *out_value;
543 }
544
545 offset = size - last + threadIdx.x;
546 // handle the tail
547 for (; offset < size; offset += blockDim.x) {
548 output[offset] = epilogue(input[offset]);
549 }
550 }
551
552 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue, typename index_t = int32_t>
553 __device__ __forceinline__ void
WriteBpropResultsVectorized(index_t size,const index_t shift,scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)554 WriteBpropResultsVectorized(
555 index_t size,
556 const index_t shift,
557 scalar_t *gradInput,
558 const outscalar_t *output,
559 const outscalar_t *gradOutput,
560 Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
561 using gradInputT = at::native::memory::aligned_vector<scalar_t, ILP>;
562 using outputT = at::native::memory::aligned_vector<outscalar_t, ILP>;
563
564 index_t offset = threadIdx.x;
565
566 // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later
567 if (shift > 0) {
568 gradInput -= shift;
569 output -= shift;
570 gradOutput -= shift;
571 size += shift;
572
573 if (threadIdx.x >= shift) {
574 gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
575 }
576 size -= blockDim.x;
577 gradInput += blockDim.x;
578 output += blockDim.x;
579 gradOutput += blockDim.x;
580 }
581
582 const index_t last = size % (ILP * blockDim.x);
583
584 scalar_t dX[ILP];
585 gradInputT *dX_v = reinterpret_cast<gradInputT*>(&dX);
586
587 outscalar_t Y[ILP];
588 outputT *Y_v = reinterpret_cast<outputT*>(&Y);
589
590 outscalar_t dY[ILP];
591 outputT *dY_v = reinterpret_cast<outputT*>(&dY);
592
593 for (; offset * ILP < (size - last); offset += blockDim.x) {
594 *Y_v = reinterpret_cast<const outputT*>(output)[offset];
595 *dY_v = reinterpret_cast<const outputT*>(gradOutput)[offset];
596
597 #pragma unroll
598 for (int j = 0; j < ILP; ++j) {
599 dX[j] = epilogue(dY[j], Y[j]);
600 }
601
602 reinterpret_cast<gradInputT*>(gradInput)[offset] = *dX_v;
603 }
604
605 offset = size - last + threadIdx.x;
606 for (; offset < size; offset += blockDim.x) {
607 gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
608 }
609 }
610
611 /**
612 * This will apply the Epilogue with non-vectorized reads & writes for the general case
613 */
614 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
615 __device__ __forceinline__ void
WriteFpropResults(int classes,const scalar_t * input,outscalar_t * output,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)616 WriteFpropResults(
617 int classes,
618 const scalar_t *input,
619 outscalar_t *output,
620 Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
621 for (int offset = threadIdx.x; offset < classes; offset += blockDim.x) {
622 output[offset] = epilogue(input[offset]);
623 }
624 }
625
626 template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue, typename index_t>
627 __device__ __forceinline__ void
WriteBpropResults(int classes,scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,Epilogue<scalar_t,accum_t,outscalar_t> epilogue)628 WriteBpropResults(
629 int classes,
630 scalar_t *gradInput,
631 const outscalar_t *output,
632 const outscalar_t *gradOutput,
633 Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
634
635 index_t offset = threadIdx.x;
636
637 index_t last = classes % (ILP * blockDim.x);
638
639 for (; offset < classes - last; offset += blockDim.x * ILP) {
640 outscalar_t tmpOutput[ILP];
641 outscalar_t tmpGradOutput[ILP];
642
643 #pragma unroll
644 for (int j = 0; j < ILP; ++j) {
645 tmpOutput[j] = output[offset + j * blockDim.x];
646 tmpGradOutput[j] = gradOutput[offset + j * blockDim.x];
647 }
648
649 #pragma unroll
650 for (int j = 0; j < ILP; ++j) {
651 gradInput[offset + j * blockDim.x] = epilogue(tmpGradOutput[j], tmpOutput[j]);
652 }
653 }
654
655 // Remainder - no ILP
656 for (; offset < classes; offset += blockDim.x) {
657 gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
658 }
659 }
660
661 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
662 __global__ void
cunn_SoftMaxForward(outscalar_t * output,const scalar_t * input,int classes)663 cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
664 {
665 extern __shared__ unsigned char smem[];
666 auto sdata = reinterpret_cast<accscalar_t*>(smem);
667
668 // forward pointers to batch[blockIdx.x]
669 // each block handles a sample in the mini-batch
670 input += static_cast<int64_t>(blockIdx.x) * classes;
671 output += static_cast<int64_t>(blockIdx.x) * classes;
672
673 const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
674 const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
675
676 // find the max
677 accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
678 shift, input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
679 accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(sdata, threadMax,
680 Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
681
682 // reduce all values
683 accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
684 shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
685 accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(sdata, threadExp,
686 Add<accscalar_t>(), static_cast<accscalar_t>(0));
687
688 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
689
690 if (shift == output_shift) {
691 WriteFpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, shift, input, output, epilogue);
692 } else {
693 WriteFpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, input, output, epilogue);
694 }
695 }
696
697 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
698 template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
699 __global__ void
cunn_SoftMaxForwardSmem(outscalar_t * output,const scalar_t * input,index_t classes)700 cunn_SoftMaxForwardSmem(outscalar_t *output, const scalar_t *input, index_t classes)
701 {
702 // Each thread block processes a sample in the batch
703 input += static_cast<int64_t>(blockIdx.x) * classes;
704 output += static_cast<int64_t>(blockIdx.x) * classes;
705
706 accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
707 accscalar_t threadExp = static_cast<accscalar_t>(0);
708
709 // The first smem segment is used to cache input values and the last
710 // segment is used for thread block reductions
711 extern __shared__ unsigned char smem[];
712 auto smem_input_cache = reinterpret_cast<scalar_t*>(smem);
713 auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem +
714 classes * sizeof(scalar_t));
715
716 using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
717 const LoadT* const input_vec_ptr = reinterpret_cast<const LoadT*>(input);
718 LoadT* const smem_input_cache_vec_ptr = reinterpret_cast<LoadT*>(smem_input_cache);
719
720 // Download inputs to shared memory while doing the first step
721 // in max calculation
722 MaxFloat<scalar_t, accscalar_t> maxFunc;
723 for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
724 LoadT crnt_vec = input_vec_ptr[offset];
725 smem_input_cache_vec_ptr[offset] = crnt_vec;
726
727 #pragma unroll
728 for (int i = 0; i < ILP; ++i) {
729 threadMax = maxFunc(threadMax, crnt_vec.val[i]);
730 }
731 }
732
733 accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(smem_reduction_cache, threadMax,
734 Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
735
736 // Reload input from shared memory to compute the sum. The previous
737 // reduce has performed a __syncthreads() so the smem contents are populated.
738 SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
739 for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
740 LoadT crnt_vec = smem_input_cache_vec_ptr[offset];
741
742 #pragma unroll
743 for (int i = 0; i < ILP; ++i) {
744 threadExp = sumExpFunc(threadExp, crnt_vec.val[i]);
745 }
746 }
747
748 accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(smem_reduction_cache, threadExp,
749 Add<accscalar_t>(), static_cast<accscalar_t>(0));
750
751 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
752
753 // Use vectorized stores to save the output
754 using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
755 StoreT* output_vec_ptr = reinterpret_cast<StoreT*>(output);
756 for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
757 LoadT crnt_vec = smem_input_cache_vec_ptr[offset];
758 StoreT out_vec;
759
760 #pragma unroll
761 for (int i = 0; i < ILP; ++i) {
762 out_vec.val[i] = epilogue(crnt_vec.val[i]);
763 }
764
765 output_vec_ptr[offset] = out_vec;
766 }
767 }
768
is_32bit_representable(const int64_t value)769 C10_DEVICE bool inline is_32bit_representable(const int64_t value) {
770 return value < static_cast<int64_t>(std::numeric_limits<int32_t>::max());
771 }
772
773 template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
774 __global__ void
cunn_SoftMaxBackward(scalar_t * gradInput,const outscalar_t * output,const outscalar_t * gradOutput,int64_t classes)775 cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outscalar_t *gradOutput, int64_t classes)
776 {
777 using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
778 using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
779
780 extern __shared__ unsigned char smem[];
781 auto sdata = reinterpret_cast<accscalar_t*>(smem);
782 gradInput += static_cast<int64_t>(blockIdx.x) * classes;
783 output += static_cast<int64_t>(blockIdx.x) * classes;
784 gradOutput += static_cast<int64_t>(blockIdx.x) * classes;
785
786 const int64_t shift = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
787 const int64_t output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
788 const int64_t grad_output_shift = ((uint64_t)gradOutput) % ALIGN_BYTES / sizeof(outscalar_t);
789
790 const bool can_use_32bit_indexing = is_32bit_representable(shift) && is_32bit_representable(output_shift) && is_32bit_representable(grad_output_shift) && is_32bit_representable(classes);
791 accscalar_t threadSum;
792 if (can_use_32bit_indexing) {
793 threadSum = ilpReduce<AddFloat, ILP, outscalar_t, accscalar_t, int32_t>(
794 static_cast<int32_t>(grad_output_shift), gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));
795 } else {
796 threadSum = ilpReduce<AddFloat, ILP, outscalar_t, accscalar_t, int64_t>(
797 grad_output_shift, gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));
798 }
799 accscalar_t sum_k = blockReduce<Add, accscalar_t>(
800 sdata, threadSum, Add<accscalar_t>(), accscalar_t(0));
801
802 Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum_k);
803
804 if (shift == output_shift && shift == grad_output_shift) {
805 if (can_use_32bit_indexing) {
806 WriteBpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int32_t>(classes, static_cast<int32_t>(shift), gradInput, output, gradOutput, epilogue);
807 } else {
808 WriteBpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int64_t>(classes, shift, gradInput, output, gradOutput, epilogue);
809 }
810 } else {
811 if (can_use_32bit_indexing) {
812 WriteBpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int32_t>(classes, gradInput, output, gradOutput, epilogue);
813 } else {
814 WriteBpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue, int64_t>(classes, gradInput, output, gradOutput, epilogue);
815 }
816 }
817 }
818
819 template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
host_softmax(const Tensor & input_,const int64_t dim_,const bool half_to_float,const Tensor & output)820 Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
821 if (half_to_float) {
822 TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only");
823 }
824 auto input = input_.contiguous();
825 static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
826 if (input.dim() == 0) input = input.view(1);
827 int64_t dim = maybe_wrap_dim(dim_, input.dim());
828 TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions");
829 int64_t outer_size = 1;
830 int64_t dim_size = input.size(dim);
831
832 if (input.numel() > 0) {
833 int64_t inner_size = 1;
834 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
835 for (int64_t i = 0; i < dim; ++i)
836 outer_size *= input.size(i);
837 for (int64_t i = dim + 1; i < input.dim(); ++i)
838 inner_size *= input.size(i);
839 // This kernel spawns a block per each element in the batch.
840 // XXX: it assumes that inner_size == 1
841
842 if (inner_size == 1) {
843 dim3 grid(outer_size);
844 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] {
845 using accscalar_t = acc_type<scalar_t, true>;
846 if (!half_to_float) {
847 auto output_ptr = output.mutable_data_ptr<scalar_t>();
848 auto input_ptr = input.const_data_ptr<scalar_t>();
849 if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
850 int64_t remaining = outer_size;
851 int64_t chunk_size = (1L << 30L) / dim_size;
852 while(remaining > 0) {
853 dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>(
854 output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
855 input_ptr += chunk_size * dim_size;
856 output_ptr += chunk_size * dim_size;
857 remaining -= chunk_size;
858 }
859 } else {
860 constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
861 dim3 block = SoftMaxForward_getBlockSize(dim_size);
862 size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
863 auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
864 smem_reduction_sz) / sizeof(scalar_t);
865
866 bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
867 can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
868 can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
869 can_use_smem &= !(dim_size % ILP);
870
871 if (can_use_smem) {
872 size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
873 cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
874 <<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
875 } else {
876 cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
877 <<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
878 }
879
880 C10_CUDA_KERNEL_LAUNCH_CHECK();
881 }
882 } else {
883 auto output_ptr = output.mutable_data_ptr<accscalar_t>();
884 auto input_ptr = input.const_data_ptr<scalar_t>();
885 if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
886 int64_t remaining = outer_size;
887 int64_t chunk_size = (1<<30) / dim_size;
888 while(remaining > 0) {
889 dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax, false>(
890 output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
891 input_ptr += chunk_size * dim_size;
892 output_ptr += chunk_size * dim_size;
893 remaining -= chunk_size;
894 }
895 } else {
896 constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
897 dim3 block = SoftMaxForward_getBlockSize(dim_size);
898 size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
899 auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
900 smem_reduction_sz) / sizeof(scalar_t);
901
902 bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
903 can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
904 can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
905 can_use_smem &= !(dim_size % ILP);
906
907 if (can_use_smem) {
908 size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
909 cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
910 <<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
911 } else {
912 cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
913 <<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
914 }
915
916 C10_CUDA_KERNEL_LAUNCH_CHECK();
917 }
918 }
919 });
920 // This kernel runs in a 2D grid, where each application along y dimension has a fixed
921 // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size.
922 // Reductions over dim are done in a single-threaded manner.
923 } else {
924 uint32_t smem_size;
925 dim3 grid, block;
926 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] {
927 using accscalar_t = acc_type<scalar_t, true>;
928 AT_DISPATCH_INDEX_TYPES(
929 at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long,
930 "host_softmax_launcher", [&] {
931 if (!half_to_float) {
932 SpatialSoftMax_getLaunchSizes<accscalar_t>(
933 &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>,
934 outer_size, dim_size, inner_size,
935 grid, block, smem_size);
936 cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, index_t, Epilogue>
937 <<<grid, block, smem_size, stream>>>(
938 output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
939 C10_CUDA_KERNEL_LAUNCH_CHECK();
940 } else {
941 SpatialSoftMax_getLaunchSizes<accscalar_t>(
942 &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>,
943 outer_size, dim_size, inner_size,
944 grid, block, smem_size);
945 cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, index_t, Epilogue>
946 <<<grid, block, smem_size, stream>>>(
947 output.mutable_data_ptr<accscalar_t>(), input.const_data_ptr<scalar_t>(), outer_size, dim_size, inner_size);
948 C10_CUDA_KERNEL_LAUNCH_CHECK();
949 }
950 });
951 });
952 }
953 }
954 return output;
955 }
956
957 template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
host_softmax_backward(const Tensor & grad_,const Tensor & output_,int64_t dim_,bool half_to_float,const Tensor & gI)958 void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float, const Tensor &gI){
959 int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
960 if (grad_.numel() == 0) {
961 return;
962 }
963 auto grad = grad_.contiguous();
964 static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
965 if (grad.dim() == 0) grad = grad.view(1);
966 TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
967 auto output = output_.contiguous();
968 if (output.dim() == 0) output = output.view(1);
969 int64_t outer_size = 1;
970 int64_t dim_size = output.size(dim);
971 int64_t inner_size = 1;
972 for (int64_t i = 0; i < dim; ++i)
973 outer_size *= output.size(i);
974 for (int64_t i = dim + 1; i < output.dim(); ++i)
975 inner_size *= output.size(i);
976 // See descriptions of kernels above.
977 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
978 if (inner_size == 1) {
979 dim3 grid(outer_size);
980 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] {
981 using accscalar_t = acc_type<scalar_t, true>;
982 if (!half_to_float) {
983 if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
984 auto gI_ptr = gI.mutable_data_ptr<scalar_t>();
985 auto grad_ptr = grad.const_data_ptr<scalar_t>();
986 auto output_ptr = output.const_data_ptr<scalar_t>();
987 int64_t remaining = outer_size;
988 int64_t chunk_size = (1<<30) / dim_size;
989 while(remaining > 0) {
990 dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>(
991 gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
992 gI_ptr += chunk_size * dim_size;
993 grad_ptr += chunk_size * dim_size;
994 output_ptr += chunk_size * dim_size;
995 remaining -= chunk_size;
996 }
997 } else {
998 constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
999 dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1000 cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
1001 <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1002 gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(), dim_size
1003 );
1004 C10_CUDA_KERNEL_LAUNCH_CHECK();
1005 }
1006 } else {
1007 if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
1008 auto gI_ptr = gI.mutable_data_ptr<scalar_t>();
1009 auto grad_ptr = grad.const_data_ptr<accscalar_t>();
1010 auto output_ptr = output.const_data_ptr<accscalar_t>();
1011 int64_t remaining = outer_size;
1012 int64_t chunk_size = (1<<30) / dim_size;
1013 while(remaining > 0) {
1014 dispatch_softmax_backward<accscalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>(
1015 gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size));
1016 gI_ptr += chunk_size * dim_size;
1017 grad_ptr += chunk_size * dim_size;
1018 output_ptr += chunk_size * dim_size;
1019 remaining -= chunk_size;
1020 }
1021 } else {
1022 constexpr int ILP = sizeof(float4) / sizeof(accscalar_t);
1023 dim3 block = SoftMax_getBlockSize(ILP, dim_size);
1024 cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
1025 <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
1026 gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<accscalar_t>(), grad.const_data_ptr<accscalar_t>(), dim_size
1027 );
1028 C10_CUDA_KERNEL_LAUNCH_CHECK();
1029 }
1030 }
1031 });
1032 } else {
1033 uint32_t smem_size;
1034 dim3 grid, block;
1035 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] {
1036 using accscalar_t = acc_type<scalar_t, true>;
1037 if (!half_to_float) {
1038 SpatialSoftMax_getLaunchSizes<accscalar_t>(
1039 &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>,
1040 outer_size, dim_size, inner_size,
1041 grid, block, smem_size);
1042
1043 cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>
1044 <<<grid, block, smem_size, stream>>>(
1045 gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(),
1046 outer_size, dim_size, inner_size
1047 );
1048 C10_CUDA_KERNEL_LAUNCH_CHECK();
1049 } else {
1050 SpatialSoftMax_getLaunchSizes<accscalar_t>(
1051 &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>,
1052 outer_size, dim_size, inner_size,
1053 grid, block, smem_size);
1054
1055 cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>
1056 <<<grid, block, smem_size, stream>>>(
1057 gI.mutable_data_ptr<scalar_t>(), output.const_data_ptr<accscalar_t>(), grad.const_data_ptr<accscalar_t>(),
1058 outer_size, dim_size, inner_size
1059 );
1060 C10_CUDA_KERNEL_LAUNCH_CHECK();
1061 }
1062 });
1063 }
1064 }
1065 }
1066
TORCH_IMPL_FUNC(log_softmax_cuda_out)1067 TORCH_IMPL_FUNC(log_softmax_cuda_out) (
1068 const Tensor &input,
1069 const int64_t dim,
1070 const bool half_to_float,
1071 const Tensor &output) {
1072 host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output);
1073 }
1074
TORCH_IMPL_FUNC(log_softmax_backward_cuda_out)1075 TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
1076 const Tensor& grad,
1077 const Tensor& output,
1078 int64_t dim,
1079 ScalarType input_dtype,
1080 const Tensor& grad_input) {
1081 bool half_to_float = grad.scalar_type() != input_dtype;
1082 if (half_to_float) {
1083 TORCH_CHECK(
1084 (grad.scalar_type() == ScalarType::Float &&
1085 input_dtype == ScalarType::Half),
1086 "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
1087 }
1088 host_softmax_backward<LogSoftMaxBackwardEpilogue, true>(grad, output, dim, half_to_float, grad_input);
1089 }
1090
TORCH_IMPL_FUNC(softmax_cuda_out)1091 TORCH_IMPL_FUNC(softmax_cuda_out) (
1092 const Tensor &input,
1093 const int64_t dim,
1094 const bool half_to_float,
1095 const Tensor &output) {
1096 host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output);
1097 }
1098
TORCH_IMPL_FUNC(softmax_backward_cuda_out)1099 TORCH_IMPL_FUNC(softmax_backward_cuda_out)
1100 (const Tensor& grad,
1101 const Tensor& output,
1102 int64_t dim,
1103 ScalarType input_dtype,
1104 const Tensor& grad_input) {
1105 bool half_to_float = grad.scalar_type() != input_dtype;
1106 if (half_to_float) {
1107 TORCH_CHECK(
1108 (grad.scalar_type() == ScalarType::Float &&
1109 input_dtype == ScalarType::Half),
1110 "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
1111 }
1112 Tensor tmp = grad * output;
1113 host_softmax_backward<SoftMaxBackwardEpilogue, false>(tmp, output, dim, half_to_float, grad_input);
1114 }
1115
masked_softmax_cuda(const Tensor & input_,const Tensor & mask_,const std::optional<int64_t> dim_,const std::optional<int64_t> mask_type_)1116 Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const std::optional<int64_t> dim_, const std::optional<int64_t> mask_type_) {
1117 Tensor output = at::empty_like(input_, input_.options());
1118 TORCH_CHECK(mask_.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor");
1119
1120 TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
1121 int64_t mask_type = mask_type_.value();
1122 TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)");
1123
1124 // If input is [B, H, T, T] and mask is [B, T]
1125 // we have special fast kernel
1126 // mask_type == 1 => mask_ is a src_key_padding_mask
1127 bool is_BxT_mask = (mask_type == 1) && (input_.dim() == 4 && mask_.dim() == 2 && input_.size(0) == mask_.size(0) && input_.size(2) == mask_.size(1) && input_.size(3) == mask_.size(1));
1128
1129 // If input is [B, H, T, T] and mask is [T, T]
1130 // expand mask to [B, H, T, T] and treat it like regular mask
1131 // TODO We should have special fast kernel for TxT mask as well
1132 // mask_type == 0 => mask_ is a src_mask
1133 bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
1134 // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
1135 TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
1136
1137 auto input = input_.dim() == 0 ? input_.view(1) : input_;
1138 auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_;
1139 if (is_TxT_mask) {
1140 mask = mask.expand(input.sizes());
1141 }
1142 int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1;
1143
1144 int softmax_elements = input.size(dim);
1145 // Persistent softmax is only supported when all of the conditions are held:
1146 // 1) softmax_elements <= 1024
1147 // 2) softmax_elements * input.element_size() <= 4096
1148 // 3) mask.is_contiguous()
1149 // 4) dim == input.dim() - 1
1150 // Otherwise, we fallback to vanilla softmax (where we do not support transformer_mask since converting the mask is expensive)
1151 if (softmax_elements > 1024 || softmax_elements * input.element_size() > 4096 || !mask.is_contiguous() || dim < input.dim()-1) {
1152 if (is_BxT_mask) {
1153 mask = mask.view({mask_.size(0), 1, 1, mask_.size(1)}).expand(input.sizes());
1154 }
1155 AT_DISPATCH_FLOATING_TYPES_AND2(
1156 ScalarType::Half,
1157 ScalarType::BFloat16,
1158 input.scalar_type(),
1159 "masked_softmax",
1160 [&] {
1161 output = at::softmax(input.masked_fill(mask, -std::numeric_limits<scalar_t>::infinity()), dim);
1162 });
1163 return output;
1164 }
1165 int batch_count = input.numel() / softmax_elements;
1166 int chunk_size = input.numel() / input.size(0);
1167 if (is_BxT_mask) {
1168 // Only support when num_heads is even in transformer
1169 TORCH_CHECK(input.size(1) % 2 == 0, "Only support when num_heads is even in transformer");
1170 AT_DISPATCH_FLOATING_TYPES_AND2(
1171 ScalarType::Half,
1172 ScalarType::BFloat16,
1173 input.scalar_type(),
1174 "masked_softmax",
1175 [&] {
1176 using accscalar_t = acc_type<scalar_t, true>;
1177 dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>(
1178 output.mutable_data_ptr<scalar_t>(), // dst
1179 input.const_data_ptr<scalar_t>(), // src
1180 softmax_elements,
1181 softmax_elements,
1182 batch_count,
1183 mask.const_data_ptr<bool>(),
1184 chunk_size,
1185 true // is_transformer_mask
1186 );
1187 });
1188
1189 } else {
1190 AT_DISPATCH_FLOATING_TYPES_AND2(
1191 ScalarType::Half,
1192 ScalarType::BFloat16,
1193 input.scalar_type(),
1194 "masked_softmax",
1195 [&] {
1196 using accscalar_t = acc_type<scalar_t, true>;
1197 dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>(
1198 output.mutable_data_ptr<scalar_t>(), // dst
1199 input.const_data_ptr<scalar_t>(), // src
1200 softmax_elements,
1201 softmax_elements,
1202 batch_count,
1203 mask.const_data_ptr<bool>()
1204 );
1205 });
1206 }
1207 return output;
1208 }
1209
masked_softmax_backward_cuda(const Tensor & grad_,const Tensor & output_,const Tensor & mask_,const std::optional<int64_t> dim_)1210 Tensor masked_softmax_backward_cuda(
1211 const Tensor& grad_,
1212 const Tensor& output_,
1213 const Tensor& mask_,
1214 const std::optional<int64_t> dim_) {
1215 Tensor grad_input = at::empty_like(grad_, grad_.options());
1216 if (grad_.numel() == 0) {
1217 return grad_input;
1218 }
1219
1220 auto grad = grad_.contiguous();
1221 auto output = output_.contiguous();
1222 auto mask = mask_.contiguous();
1223 int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) : output.dim() - 1;
1224
1225 grad = grad.dim() == 0 ? grad.view(1) : grad;
1226 mask = mask.dim() == 0 ? mask.view(1) : mask;
1227 output = output.dim() == 0 ? output.view(1) : output;
1228
1229 TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
1230 TORCH_CHECK(grad.sizes() == mask.sizes(), "Mask shape should match grad shape");
1231 TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor");
1232
1233 int softmax_elements = output.size(dim);
1234 int64_t batch_count = grad.numel() / softmax_elements;
1235
1236 if (softmax_elements > 1024 || softmax_elements * grad.element_size() > 4096 || dim < grad.dim()-1) {
1237 AT_DISPATCH_FLOATING_TYPES_AND2(
1238 ScalarType::Half,
1239 ScalarType::BFloat16,
1240 grad_input.scalar_type(),
1241 "masked_softmax_backward",
1242 [&] {
1243 grad_input = at::_softmax_backward_data(
1244 grad,
1245 output.masked_fill(mask, 0),
1246 dim,
1247 grad.scalar_type()
1248 );
1249 });
1250 } else {
1251 grad = grad * output;
1252 AT_DISPATCH_FLOATING_TYPES_AND2(
1253 ScalarType::Half,
1254 ScalarType::BFloat16,
1255 grad_input.scalar_type(),
1256 "masked_softmax_backward",
1257 [&] {
1258 using accscalar_t = acc_type<scalar_t, true>;
1259 dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, false, true /* masked_softmax */>(
1260 grad_input.mutable_data_ptr<scalar_t>(), // gI_ptr
1261 grad.const_data_ptr<scalar_t>(), // grad_ptr
1262 output.const_data_ptr<scalar_t>(), // output_ptr
1263 softmax_elements, // softmax_elements
1264 softmax_elements, // softmax_elements_stride
1265 batch_count, // batch_count
1266 mask.const_data_ptr<bool>() /* not masked */
1267 );
1268 });
1269 }
1270 return grad_input;
1271 }
1272
1273 } // namespace at::native
1274