xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SoftMax.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/native/cpu/SoftmaxKernel.h>
11 #include <ATen/NamedTensorUtils.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_log_softmax.h>
18 #include <ATen/ops/_log_softmax_backward_data_native.h>
19 #include <ATen/ops/_log_softmax_native.h>
20 #include <ATen/ops/_masked_softmax_backward_native.h>
21 #include <ATen/ops/_masked_softmax_native.h>
22 #include <ATen/ops/_softmax.h>
23 #include <ATen/ops/_softmax_backward_data_native.h>
24 #include <ATen/ops/_softmax_native.h>
25 #include <ATen/ops/empty.h>
26 #include <ATen/ops/empty_like.h>
27 #include <ATen/ops/log_softmax.h>
28 #include <ATen/ops/log_softmax_native.h>
29 #include <ATen/ops/softmax.h>
30 #include <ATen/ops/softmax_native.h>
31 #include <ATen/ops/special_log_softmax_native.h>
32 #include <ATen/ops/special_softmax_native.h>
33 #endif
34 
35 #include <c10/core/TensorOptions.h>
36 #include <c10/macros/Macros.h>
37 #include <c10/util/irange.h>
38 
39 namespace at::meta {
TORCH_META_FUNC(_softmax)40 TORCH_META_FUNC(_softmax)
41 (const Tensor& input, const int64_t dim, const bool half_to_float) {
42   int64_t dim_ = maybe_wrap_dim(dim, input.dim());
43 
44   auto output_options =
45       input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
46 
47   if (half_to_float) {
48     output_options = output_options.dtype(ScalarType::Float);
49   }
50 
51   int64_t input_dim = input.dim() > 0 ? input.dim() : 1;
52   TORCH_CHECK(
53       dim_ >= 0 && dim_ < input_dim,
54       "dim must be non-negative and less than input dimensions");
55 
56   set_output_raw_strided(0, input.sizes(), {}, output_options);
57 }
58 
TORCH_META_FUNC(_log_softmax)59 TORCH_META_FUNC(_log_softmax) (
60   const Tensor& input,
61   const int64_t dim,
62   const bool half_to_float) {
63   int64_t dim_ = maybe_wrap_dim(dim, input.dim());
64 
65   auto output_options =
66       input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
67 
68   if (half_to_float) {
69     output_options = output_options.dtype(ScalarType::Float);
70   }
71 
72   int64_t input_dim = input.dim() > 0 ? input.dim() : 1;
73   TORCH_CHECK(
74       dim_ >= 0 && dim_ < input_dim,
75       "dim must be non-negative and less than input dimensions");
76 
77   set_output_raw_strided(0, input.sizes(), {}, output_options);
78 }
79 
TORCH_META_FUNC(_softmax_backward_data)80 TORCH_META_FUNC(_softmax_backward_data)
81 (const Tensor& grad,
82  const Tensor& output,
83  int64_t dim,
84  ScalarType input_dtype) {
85   TensorArg grad_arg{grad, "grad", 1}, output_arg{output, "output", 2};
86   checkSameSize("softmax_backward", grad_arg, output_arg);
87 
88   int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
89 
90   auto grad_input_options =
91       grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
92 
93   bool half_to_float = grad.scalar_type() != input_dtype;
94   if (half_to_float) {
95     // The code below is only valid for the CUDA implementation. It's "okay"
96     // to put it here because half-to-float conversion is not supported by
97     // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA
98     // implementation that should ideally go here as well, but there is at least
99     // one test in which the grad and input dtypes do not match for the CPU
100     // implementation of this kernel and it is not true that the grad type is
101     // float and the input dtype is half (see #63057).
102     if (grad.scalar_type() == ScalarType::Float &&
103         input_dtype == ScalarType::Half) {
104       grad_input_options = grad_input_options.dtype(ScalarType::Half);
105     }
106   }
107 
108   int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1;
109   TORCH_CHECK(
110       dim_ >= 0 && dim_ < grad_dim,
111       "dim must be non-negative and less than input dimensions");
112 
113   set_output_raw_strided(0, grad.sizes(), {}, grad_input_options);
114 }
115 
TORCH_META_FUNC(_log_softmax_backward_data)116 TORCH_META_FUNC(_log_softmax_backward_data)
117 (const Tensor& grad,
118  const Tensor& output,
119  int64_t dim,
120  ScalarType input_dtype){
121   int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
122   TensorOptions grad_input_options(
123       grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
124 
125   bool half_to_float = grad.scalar_type() != input_dtype;
126   if (half_to_float) {
127     // The code below is only valid for the CUDA implementation. It's "okay"
128     // to put it here because half-to-float conversion is not supported by
129     // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA
130     // implementation that should ideally go here as well, but there is at least
131     // one test in which the grad and input dtypes do not match for the CPU
132     // implementation of this kernel and it is not true that the grad type is
133     // float and the input dtype is half (see #63057).
134     if (grad.scalar_type() == ScalarType::Float &&
135         input_dtype == ScalarType::Half) {
136       grad_input_options = grad_input_options.dtype(ScalarType::Half);
137     }
138   }
139 
140   int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1;
141   TORCH_CHECK(
142       dim_ >= 0 && dim_ < grad_dim,
143       "dim must be non-negative and less than input dimensions");
144 
145   set_output_raw_strided(0, grad.sizes(), {}, grad_input_options);
146 }
147 } // namespace at::meta
148 
149 namespace at::native {
150 namespace {
151 
152 template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false>
host_softmax(Tensor output,const Tensor & input,const int64_t dim,bool * mask=nullptr,const std::optional<int64_t> mask_type_={})153 void host_softmax(
154     Tensor output,
155     const Tensor& input,
156     const int64_t dim,
157     bool* mask = nullptr,
158     const std::optional<int64_t> mask_type_ = {}) {
159 
160   if (MaskedSoftMax) {
161     TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
162     int64_t mask_type = mask_type_.value();
163     // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
164     TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
165   }
166 
167   int64_t outer_size = 1;
168   int64_t dim_size = input.size(dim);
169   int64_t inner_size = 1;
170   for (const auto i : c10::irange(dim)) {
171     outer_size *= input.size(i);
172   }
173   for (int64_t i = dim + 1; i < input.dim(); ++i) {
174     inner_size *= input.size(i);
175   }
176   int64_t dim_stride = inner_size;
177   int64_t outer_stride = dim_size * dim_stride;
178   scalar_t* input_data_base = input.data_ptr<scalar_t>();
179   scalar_t* output_data_base = output.data_ptr<scalar_t>();
180   bool* mask_data_base = mask;
181   int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
182   parallel_for(
183       0, outer_size * inner_size, grain_size,
__anonba9190ed0202(int64_t begin, int64_t end) 184       [&](int64_t begin, int64_t end) __ubsan_ignore_float_divide_by_zero__ {
185         for (const auto i : c10::irange(begin, end)) {
186           int64_t outer_idx = i / inner_size;
187           int64_t inner_idx = i % inner_size;
188           scalar_t* input_data =
189               input_data_base + outer_idx * outer_stride + inner_idx;
190           scalar_t* output_data =
191               output_data_base + outer_idx * outer_stride + inner_idx;
192           bool* mask_data = nullptr;
193           if (MaskedSoftMax) {
194             // Process mask differently depending on the type:
195             // For a generic mask of mask_type == 2, mask shape is the same as the input shape,
196             // so indexing is the same.
197             auto mask_outer_idx = outer_idx;
198             if (mask_type_ == 0) {
199                 // Optimized case: attention mask of shape LxL
200                 // outer_idx goes over BxHxL, mask_outer_idx goes over L.
201                 mask_outer_idx = outer_idx % input.size(2);
202             } else if (mask_type_ == 1) {
203                 // Optimized case: padding mask of shape BxL
204                 // outer_idx goes over BxHxL, mask_outer_idx goes over B.
205                 mask_outer_idx = outer_idx / (input.size(1) * input.size(2));
206             }
207 
208             mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx;
209           };
210 
211           // Calc max in softmax dim
212           bool is_meaningful_max = false;
213           scalar_t max_input = input_data[0];
214           if (!MaskedSoftMax) {
215             for (const auto d : c10::irange(1, dim_size)) {
216               max_input = std::max(max_input, input_data[d * dim_stride]);
217             }
218           } else {
219             for (const auto d : c10::irange(0, dim_size)) {
220               if (!mask_data[d * dim_stride]) {
221                 max_input = is_meaningful_max
222                     ? std::max(max_input, input_data[d * dim_stride])
223                     : input_data[d * dim_stride];
224                 is_meaningful_max = true;
225               }
226             }
227           }
228 
229           // Calc sum in softmax dim
230           acc_type<scalar_t, false> tmpsum = 0;
231           for (const auto d : c10::irange(dim_size)) {
232             scalar_t z{};
233             if (!MaskedSoftMax || !mask_data[d * dim_stride]) {
234               z = std::exp(input_data[d * dim_stride] - max_input);
235             } else {
236               z = 0;
237             }
238             if (!LogSoftMax) {
239               output_data[d * dim_stride] = z;
240             }
241             tmpsum += z;
242           }
243 
244           if (LogSoftMax) {
245             tmpsum = std::log(tmpsum);
246           } else if (tmpsum == 0) {
247             tmpsum = std::numeric_limits<scalar_t>::quiet_NaN();
248           } else {
249             tmpsum = 1 / tmpsum;
250           }
251 
252           // update output
253           for (const auto d : c10::irange(dim_size)) {
254             // LogSoftMax and MaskedSoftMax should not both be true
255             if (LogSoftMax) {
256               output_data[d * dim_stride] =
257                   input_data[d * dim_stride] - max_input - tmpsum;
258             } else {
259               output_data[d * dim_stride] *= tmpsum;
260             }
261           }
262         }
263       });
264 }
265 
266 template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false>
host_softmax_backward(const Tensor & gI,const Tensor & grad,const Tensor & output,int64_t dim,bool * mask=nullptr)267 void host_softmax_backward(
268     const Tensor& gI,
269     const Tensor& grad,
270     const Tensor& output,
271     int64_t dim,
272     bool* mask = nullptr) {
273 
274   int64_t outer_size = 1;
275   int64_t dim_size = grad.size(dim);
276   int64_t inner_size = 1;
277   for (const auto i : c10::irange(dim)) {
278     outer_size *= grad.size(i);
279   }
280   for (int64_t i = dim + 1; i < grad.dim(); ++i) {
281     inner_size *= grad.size(i);
282   }
283   int64_t dim_stride = inner_size;
284   int64_t outer_stride = dim_size * dim_stride;
285   scalar_t* gradInput_data_base = gI.data_ptr<scalar_t>();
286   scalar_t* output_data_base = output.data_ptr<scalar_t>();
287   scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
288   bool* mask_data_base = mask;
289   int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
290   parallel_for(
291       0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
292         for (const auto i : c10::irange(begin, end)) {
293           int64_t outer_idx = i / inner_size;
294           int64_t inner_idx = i % inner_size;
295           scalar_t* gradInput_data =
296               gradInput_data_base + outer_idx * outer_stride + inner_idx;
297           scalar_t* output_data =
298               output_data_base + outer_idx * outer_stride + inner_idx;
299           const scalar_t* gradOutput_data =
300               gradOutput_data_base + outer_idx * outer_stride + inner_idx;
301           bool* mask_data = nullptr;
302           if (MaskedSoftMax) {
303             mask_data = mask_data_base + outer_idx * outer_stride + inner_idx;
304           }
305 
306           acc_type<scalar_t, false> sum = 0;
307           for (const auto d : c10::irange(dim_size)) {
308             if (!MaskedSoftMax || !mask_data[d * dim_stride]) {
309               if (LogSoftMax) {
310                 sum += gradOutput_data[d * dim_stride];
311               } else {
312                 sum +=
313                     gradOutput_data[d * dim_stride] * output_data[d * dim_stride];
314               }
315             }
316           }
317 
318           for (const auto d : c10::irange(dim_size)) {
319             if (MaskedSoftMax && mask_data[d * dim_stride]) {
320               gradInput_data[d * dim_stride] = 0;
321             }
322             else if (LogSoftMax) {
323               gradInput_data[d * dim_stride] = gradOutput_data[d * dim_stride] -
324                   std::exp(output_data[d * dim_stride]) * sum;
325             } else {
326               gradInput_data[d * dim_stride] = output_data[d * dim_stride] *
327                   (gradOutput_data[d * dim_stride] - sum);
328             }
329           }
330         }
331       });
332 }
333 } // namespace
334 
TORCH_IMPL_FUNC(softmax_cpu_out)335 TORCH_IMPL_FUNC(softmax_cpu_out)
336 (const Tensor& input,
337  const int64_t dim,
338  const bool half_to_float,
339  const Tensor& output) {
340   TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on CPU");
341 
342   if (input.numel() == 0) {
343     return;
344   }
345 
346   auto input_ = input.contiguous();
347   int64_t dim_ = maybe_wrap_dim(dim, input_.dim());
348 
349   if (input_.dim() == 0) {
350     input_ = input_.view(1);
351   }
352 
353   TORCH_CHECK(
354       dim_ >= 0 && dim_ < input_.dim(),
355       "dim must be non-negative and less than input dimensions");
356   if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) {
357     softmax_lastdim_kernel(kCPU, output, input_);
358   } else {
359     softmax_kernel(kCPU, output, input_, dim_);
360   }
361 }
362 
TORCH_IMPL_FUNC(log_softmax_cpu_out)363 TORCH_IMPL_FUNC(log_softmax_cpu_out)
364 (const Tensor& input,
365  const int64_t dim,
366  const bool half_to_float,
367  const Tensor& output) {
368   TORCH_CHECK(
369       !half_to_float,
370       "softmax with half to float conversion is not supported on CPU");
371 
372   if (input.numel() == 0) {
373     return;
374   }
375 
376   auto input_ = input.contiguous();
377   int64_t dim_ = maybe_wrap_dim(dim, input_.dim());
378 
379   if (input_.dim() == 0) {
380     input_ = input_.view(1);
381   }
382 
383   if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) {
384     log_softmax_lastdim_kernel(kCPU, output, input_);
385   } else {
386     log_softmax_kernel(kCPU, output, input_, dim_);
387   }
388 }
389 
TORCH_IMPL_FUNC(softmax_backward_cpu_out)390 TORCH_IMPL_FUNC(softmax_backward_cpu_out)
391 (const Tensor& grad,
392  const Tensor& output,
393  int64_t dim,
394  ScalarType input_dtype,
395  const Tensor& grad_input) {
396   int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
397   auto grad_ = grad.contiguous();
398   auto output_ = output.contiguous();
399 
400   if (output.numel() == 0) {
401     return;
402   }
403 
404   if (grad_.dim() == 0) {
405     grad_ = grad_.view(1);
406   }
407 
408   if (output_.dim() == 0) {
409     output_ = output_.view(1);
410   }
411 
412   if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) {
413     softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output);
414   } else {
415     softmax_backward_kernel(kCPU, grad_input, grad_, output, dim_);
416   }
417 }
418 
TORCH_IMPL_FUNC(log_softmax_backward_cpu_out)419 TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
420     const Tensor& grad,
421     const Tensor& output,
422     int64_t dim,
423     ScalarType input_dtype,
424     const Tensor& grad_input) {
425   int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
426   auto grad_ = grad.contiguous();
427   auto output_ = output.contiguous();
428 
429   if (output.numel() != 0) {
430     if (grad_.dim() == 0)
431       grad_ = grad_.view(1);
432     if (output_.dim() == 0) {
433       output_ = output_.view(1);
434     }
435     if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) {
436       log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output_);
437     } else {
438       log_softmax_backward_kernel(kCPU, grad_input, grad_, output_, dim_);
439     }
440   }
441 }
442 
softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)443 Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
444   auto result = [&]() {
445     NoNamesGuard guard;
446     if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
447         return at::_softmax(input_, dim_, true);
448     } else {
449         Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
450         return at::_softmax(converted, dim_, false);
451     }
452   }();
453   namedinference::propagate_names(result, input_);
454   return result;
455 }
456 
softmax_out(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype,Tensor & output_)457 Tensor& softmax_out(
458     const Tensor& input_,
459     const int64_t dim_,
460     std::optional<ScalarType> dtype,
461     Tensor& output_) {
462   Tensor output_temp;
463   if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
464       dtype == ScalarType::Float) {
465     if (!output_.is_contiguous()) {
466       auto options =
467           TensorOptions().dtype(output_.dtype()).device(output_.device());
468       output_temp = at::empty(output_.sizes(), options);
469       at::_softmax_out(output_temp, input_, dim_, true);
470     } else {
471       at::_softmax_out(output_, input_, dim_, true);
472     }
473   } else {
474     Tensor converted =
475         dtype.has_value() ? input_.toType(dtype.value()) : input_;
476     if (!output_.is_contiguous()) {
477       auto options =
478           TensorOptions().dtype(output_.dtype()).device(output_.device());
479       output_temp = at::empty(output_.sizes(), options);
480       at::_softmax_out(output_temp, converted, dim_, false);
481     } else {
482       at::_softmax_out(output_, converted, dim_, false);
483     }
484   }
485 
486   if (!output_.is_contiguous()) {
487     output_.resize_(output_temp.sizes());
488     output_.copy_(output_temp);
489   }
490 
491   return output_;
492 }
493 
494 // special_softmax, alias for softmax
special_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)495 Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
496   return at::softmax(input_, dim_, dtype);
497 }
498 
log_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)499 Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
500   auto result = [&]() {
501     NoNamesGuard guard;
502     if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
503         return at::_log_softmax(input_, dim_, true);
504     } else {
505         Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_;
506         return at::_log_softmax(converted, dim_, false);
507     }
508   }();
509   namedinference::propagate_names(result, input_);
510   return result;
511 }
512 
log_softmax_out(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype,Tensor & output_)513 Tensor& log_softmax_out(
514     const Tensor& input_,
515     const int64_t dim_,
516     std::optional<ScalarType> dtype,
517     Tensor& output_) {
518   Tensor output_temp;
519   if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
520       dtype == ScalarType::Float) {
521     if (!output_.is_contiguous()) {
522       auto options =
523           TensorOptions().dtype(output_.dtype()).device(output_.device());
524       output_temp = at::empty(output_.sizes(), options);
525       at::_log_softmax_out(output_temp, input_, dim_, true);
526     } else {
527       at::_log_softmax_out(output_, input_, dim_, true);
528     }
529   } else {
530     Tensor converted =
531         dtype.has_value() ? input_.toType(dtype.value()) : input_;
532     if (!output_.is_contiguous()) {
533       auto options =
534           TensorOptions().dtype(output_.dtype()).device(output_.device());
535       output_temp = at::empty(output_.sizes(), options);
536       at::_log_softmax_out(output_temp, converted, dim_, false);
537     } else {
538       at::_log_softmax_out(output_, converted, dim_, false);
539     }
540   }
541 
542   if (!output_.is_contiguous()) {
543     output_.resize_(output_temp.sizes());
544     output_.copy_(output_temp);
545   }
546 
547   return output_;
548 }
549 
special_log_softmax(const Tensor & input,const int64_t dim,std::optional<ScalarType> dtype)550 Tensor special_log_softmax(const Tensor& input, const int64_t dim, std::optional<ScalarType> dtype) {
551   return at::log_softmax(input, dim, dtype);
552 }
553 
554 DEFINE_DISPATCH(softmax_lastdim_kernel);
555 DEFINE_DISPATCH(log_softmax_lastdim_kernel);
556 DEFINE_DISPATCH(softmax_backward_lastdim_kernel);
557 DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel);
558 
559 DEFINE_DISPATCH(softmax_kernel);
560 DEFINE_DISPATCH(log_softmax_kernel);
561 DEFINE_DISPATCH(softmax_backward_kernel);
562 DEFINE_DISPATCH(log_softmax_backward_kernel);
563 
softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)564 Tensor softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
565   return at::softmax(self, dimname_to_position(self, dim), dtype);
566 }
567 
log_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)568 Tensor log_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
569   return at::log_softmax(self, dimname_to_position(self, dim), dtype);
570 }
571 
masked_softmax_cpu(const Tensor & input_,const Tensor & mask_,const std::optional<int64_t> dim_,const std::optional<int64_t> mask_type_)572 Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const std::optional<int64_t> dim_, const std::optional<int64_t> mask_type_) {
573 
574   auto mask = mask_.contiguous();
575   auto mask_type = mask_type_; // Mask type might get transformed below
576 
577   TORCH_CHECK(
578       mask_.scalar_type() == ScalarType::Bool,
579       "Mask should be a boolean tensor");
580 
581   if ((mask.dim() != 2) || (input_.dim() != 4)) {
582     // Mask types 0 and 1 are only allowed for 2D masks and 4D inputs
583     mask_type = 2;
584   }
585 
586   if (mask_type == 2) {
587       TORCH_CHECK(input_.sizes() == mask.sizes(),
588                   "For mask_type == 2 mask shape should match input shape")
589   } else if (mask_type == 1) {
590       // Padding mask of shape (B, L)
591       TORCH_CHECK((input_.sizes()[0] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]),
592                   "For mask_type == 1 mask shape should be (B, L)");
593       if (dim_ != input_.dim() - 1) {
594             // We only process padding mask in the optimized way if softmax is applied along the last dimesion,
595             // otherwise we need to expand the mask into a generic 4D one
596             mask = mask_.view({input_.sizes()[0], 1, 1, input_.sizes()[2]});
597             mask = mask.expand(input_.sizes()).contiguous();
598             mask_type = 2;
599       }
600   } else if (mask_type == 0) {
601       // Attention mask of shape (L, L)
602       TORCH_CHECK((mask.dim() == 2) && (input_.sizes()[2] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]),
603                   "For mask_type == 0 mask shape should be (L, L)");
604       if (dim_ != input_.dim() - 1) {
605             // We only process attention mask in a optimized way if softmax is applied along the last dimesion,
606             // otherwise we need to expand the mask into a generic 4D one
607             mask = mask.view({1, 1, input_.sizes()[2], input_.sizes()[2]});
608             mask = mask.expand(input_.sizes()).contiguous();
609             mask_type = 2;
610       }
611   }
612 
613   Tensor output = at::empty_like(input_, input_.options());
614   auto input = input_.contiguous();
615   int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1;
616   dim = maybe_wrap_dim(dim, input_.dim());
617 
618   if (input.dim() == 0) {
619     input = input.view(1);
620   }
621 
622   AT_DISPATCH_FLOATING_TYPES_AND2(
623       at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "masked_softmax", [&] {
624         host_softmax<
625             scalar_t,
626             false /* LogSoftMax */,
627             true /* MaskedSoftMax */>(
628             output, input, dim, mask.data_ptr<bool>(), mask_type);
629       });
630   return output;
631 }
632 
masked_softmax_backward_cpu(const Tensor & grad_,const Tensor & output_,const Tensor & mask_,const std::optional<int64_t> dim_)633 Tensor masked_softmax_backward_cpu(
634     const Tensor& grad_,
635     const Tensor& output_,
636     const Tensor& mask_,
637     const std::optional<int64_t> dim_) {
638   TORCH_CHECK(
639       grad_.sizes() == mask_.sizes(), "Mask shape should match grad shape");
640   TORCH_CHECK(
641       mask_.scalar_type() == ScalarType::Bool,
642       "Mask should be a boolean tensor");
643   auto grad = grad_.contiguous();
644   auto output = output_.contiguous();
645   auto mask = mask_.contiguous();
646 
647   int64_t dim = dim_.has_value() ? dim_.value() : output.dim() - 1;
648   dim = maybe_wrap_dim(dim, grad.dim());
649 
650   grad = grad.dim() == 0 ? grad.view(1) : grad;
651   output = output.dim() == 0 ? output.view(1) : output;
652   mask = mask.dim() == 0 ? mask.view(1) : mask;
653 
654   Tensor grad_input = at::empty_like(grad, grad.options());
655   AT_DISPATCH_FLOATING_TYPES_AND2(
656       at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(), "masked_softmax_backward", [&] {
657         host_softmax_backward<
658             scalar_t,
659             false /* LogSoftMax */,
660             true /* MaskedSoftmax */>(grad_input, grad, output, dim, mask.data_ptr<bool>());
661       });
662   return grad_input;
663 }
664 } // namespace at::native
665