1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/native/cpu/SoftmaxKernel.h>
11 #include <ATen/NamedTensorUtils.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_log_softmax.h>
18 #include <ATen/ops/_log_softmax_backward_data_native.h>
19 #include <ATen/ops/_log_softmax_native.h>
20 #include <ATen/ops/_masked_softmax_backward_native.h>
21 #include <ATen/ops/_masked_softmax_native.h>
22 #include <ATen/ops/_softmax.h>
23 #include <ATen/ops/_softmax_backward_data_native.h>
24 #include <ATen/ops/_softmax_native.h>
25 #include <ATen/ops/empty.h>
26 #include <ATen/ops/empty_like.h>
27 #include <ATen/ops/log_softmax.h>
28 #include <ATen/ops/log_softmax_native.h>
29 #include <ATen/ops/softmax.h>
30 #include <ATen/ops/softmax_native.h>
31 #include <ATen/ops/special_log_softmax_native.h>
32 #include <ATen/ops/special_softmax_native.h>
33 #endif
34
35 #include <c10/core/TensorOptions.h>
36 #include <c10/macros/Macros.h>
37 #include <c10/util/irange.h>
38
39 namespace at::meta {
TORCH_META_FUNC(_softmax)40 TORCH_META_FUNC(_softmax)
41 (const Tensor& input, const int64_t dim, const bool half_to_float) {
42 int64_t dim_ = maybe_wrap_dim(dim, input.dim());
43
44 auto output_options =
45 input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
46
47 if (half_to_float) {
48 output_options = output_options.dtype(ScalarType::Float);
49 }
50
51 int64_t input_dim = input.dim() > 0 ? input.dim() : 1;
52 TORCH_CHECK(
53 dim_ >= 0 && dim_ < input_dim,
54 "dim must be non-negative and less than input dimensions");
55
56 set_output_raw_strided(0, input.sizes(), {}, output_options);
57 }
58
TORCH_META_FUNC(_log_softmax)59 TORCH_META_FUNC(_log_softmax) (
60 const Tensor& input,
61 const int64_t dim,
62 const bool half_to_float) {
63 int64_t dim_ = maybe_wrap_dim(dim, input.dim());
64
65 auto output_options =
66 input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
67
68 if (half_to_float) {
69 output_options = output_options.dtype(ScalarType::Float);
70 }
71
72 int64_t input_dim = input.dim() > 0 ? input.dim() : 1;
73 TORCH_CHECK(
74 dim_ >= 0 && dim_ < input_dim,
75 "dim must be non-negative and less than input dimensions");
76
77 set_output_raw_strided(0, input.sizes(), {}, output_options);
78 }
79
TORCH_META_FUNC(_softmax_backward_data)80 TORCH_META_FUNC(_softmax_backward_data)
81 (const Tensor& grad,
82 const Tensor& output,
83 int64_t dim,
84 ScalarType input_dtype) {
85 TensorArg grad_arg{grad, "grad", 1}, output_arg{output, "output", 2};
86 checkSameSize("softmax_backward", grad_arg, output_arg);
87
88 int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
89
90 auto grad_input_options =
91 grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
92
93 bool half_to_float = grad.scalar_type() != input_dtype;
94 if (half_to_float) {
95 // The code below is only valid for the CUDA implementation. It's "okay"
96 // to put it here because half-to-float conversion is not supported by
97 // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA
98 // implementation that should ideally go here as well, but there is at least
99 // one test in which the grad and input dtypes do not match for the CPU
100 // implementation of this kernel and it is not true that the grad type is
101 // float and the input dtype is half (see #63057).
102 if (grad.scalar_type() == ScalarType::Float &&
103 input_dtype == ScalarType::Half) {
104 grad_input_options = grad_input_options.dtype(ScalarType::Half);
105 }
106 }
107
108 int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1;
109 TORCH_CHECK(
110 dim_ >= 0 && dim_ < grad_dim,
111 "dim must be non-negative and less than input dimensions");
112
113 set_output_raw_strided(0, grad.sizes(), {}, grad_input_options);
114 }
115
TORCH_META_FUNC(_log_softmax_backward_data)116 TORCH_META_FUNC(_log_softmax_backward_data)
117 (const Tensor& grad,
118 const Tensor& output,
119 int64_t dim,
120 ScalarType input_dtype){
121 int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
122 TensorOptions grad_input_options(
123 grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
124
125 bool half_to_float = grad.scalar_type() != input_dtype;
126 if (half_to_float) {
127 // The code below is only valid for the CUDA implementation. It's "okay"
128 // to put it here because half-to-float conversion is not supported by
129 // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA
130 // implementation that should ideally go here as well, but there is at least
131 // one test in which the grad and input dtypes do not match for the CPU
132 // implementation of this kernel and it is not true that the grad type is
133 // float and the input dtype is half (see #63057).
134 if (grad.scalar_type() == ScalarType::Float &&
135 input_dtype == ScalarType::Half) {
136 grad_input_options = grad_input_options.dtype(ScalarType::Half);
137 }
138 }
139
140 int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1;
141 TORCH_CHECK(
142 dim_ >= 0 && dim_ < grad_dim,
143 "dim must be non-negative and less than input dimensions");
144
145 set_output_raw_strided(0, grad.sizes(), {}, grad_input_options);
146 }
147 } // namespace at::meta
148
149 namespace at::native {
150 namespace {
151
152 template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false>
host_softmax(Tensor output,const Tensor & input,const int64_t dim,bool * mask=nullptr,const std::optional<int64_t> mask_type_={})153 void host_softmax(
154 Tensor output,
155 const Tensor& input,
156 const int64_t dim,
157 bool* mask = nullptr,
158 const std::optional<int64_t> mask_type_ = {}) {
159
160 if (MaskedSoftMax) {
161 TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
162 int64_t mask_type = mask_type_.value();
163 // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
164 TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
165 }
166
167 int64_t outer_size = 1;
168 int64_t dim_size = input.size(dim);
169 int64_t inner_size = 1;
170 for (const auto i : c10::irange(dim)) {
171 outer_size *= input.size(i);
172 }
173 for (int64_t i = dim + 1; i < input.dim(); ++i) {
174 inner_size *= input.size(i);
175 }
176 int64_t dim_stride = inner_size;
177 int64_t outer_stride = dim_size * dim_stride;
178 scalar_t* input_data_base = input.data_ptr<scalar_t>();
179 scalar_t* output_data_base = output.data_ptr<scalar_t>();
180 bool* mask_data_base = mask;
181 int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
182 parallel_for(
183 0, outer_size * inner_size, grain_size,
__anonba9190ed0202(int64_t begin, int64_t end) 184 [&](int64_t begin, int64_t end) __ubsan_ignore_float_divide_by_zero__ {
185 for (const auto i : c10::irange(begin, end)) {
186 int64_t outer_idx = i / inner_size;
187 int64_t inner_idx = i % inner_size;
188 scalar_t* input_data =
189 input_data_base + outer_idx * outer_stride + inner_idx;
190 scalar_t* output_data =
191 output_data_base + outer_idx * outer_stride + inner_idx;
192 bool* mask_data = nullptr;
193 if (MaskedSoftMax) {
194 // Process mask differently depending on the type:
195 // For a generic mask of mask_type == 2, mask shape is the same as the input shape,
196 // so indexing is the same.
197 auto mask_outer_idx = outer_idx;
198 if (mask_type_ == 0) {
199 // Optimized case: attention mask of shape LxL
200 // outer_idx goes over BxHxL, mask_outer_idx goes over L.
201 mask_outer_idx = outer_idx % input.size(2);
202 } else if (mask_type_ == 1) {
203 // Optimized case: padding mask of shape BxL
204 // outer_idx goes over BxHxL, mask_outer_idx goes over B.
205 mask_outer_idx = outer_idx / (input.size(1) * input.size(2));
206 }
207
208 mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx;
209 };
210
211 // Calc max in softmax dim
212 bool is_meaningful_max = false;
213 scalar_t max_input = input_data[0];
214 if (!MaskedSoftMax) {
215 for (const auto d : c10::irange(1, dim_size)) {
216 max_input = std::max(max_input, input_data[d * dim_stride]);
217 }
218 } else {
219 for (const auto d : c10::irange(0, dim_size)) {
220 if (!mask_data[d * dim_stride]) {
221 max_input = is_meaningful_max
222 ? std::max(max_input, input_data[d * dim_stride])
223 : input_data[d * dim_stride];
224 is_meaningful_max = true;
225 }
226 }
227 }
228
229 // Calc sum in softmax dim
230 acc_type<scalar_t, false> tmpsum = 0;
231 for (const auto d : c10::irange(dim_size)) {
232 scalar_t z{};
233 if (!MaskedSoftMax || !mask_data[d * dim_stride]) {
234 z = std::exp(input_data[d * dim_stride] - max_input);
235 } else {
236 z = 0;
237 }
238 if (!LogSoftMax) {
239 output_data[d * dim_stride] = z;
240 }
241 tmpsum += z;
242 }
243
244 if (LogSoftMax) {
245 tmpsum = std::log(tmpsum);
246 } else if (tmpsum == 0) {
247 tmpsum = std::numeric_limits<scalar_t>::quiet_NaN();
248 } else {
249 tmpsum = 1 / tmpsum;
250 }
251
252 // update output
253 for (const auto d : c10::irange(dim_size)) {
254 // LogSoftMax and MaskedSoftMax should not both be true
255 if (LogSoftMax) {
256 output_data[d * dim_stride] =
257 input_data[d * dim_stride] - max_input - tmpsum;
258 } else {
259 output_data[d * dim_stride] *= tmpsum;
260 }
261 }
262 }
263 });
264 }
265
266 template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false>
host_softmax_backward(const Tensor & gI,const Tensor & grad,const Tensor & output,int64_t dim,bool * mask=nullptr)267 void host_softmax_backward(
268 const Tensor& gI,
269 const Tensor& grad,
270 const Tensor& output,
271 int64_t dim,
272 bool* mask = nullptr) {
273
274 int64_t outer_size = 1;
275 int64_t dim_size = grad.size(dim);
276 int64_t inner_size = 1;
277 for (const auto i : c10::irange(dim)) {
278 outer_size *= grad.size(i);
279 }
280 for (int64_t i = dim + 1; i < grad.dim(); ++i) {
281 inner_size *= grad.size(i);
282 }
283 int64_t dim_stride = inner_size;
284 int64_t outer_stride = dim_size * dim_stride;
285 scalar_t* gradInput_data_base = gI.data_ptr<scalar_t>();
286 scalar_t* output_data_base = output.data_ptr<scalar_t>();
287 scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
288 bool* mask_data_base = mask;
289 int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
290 parallel_for(
291 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
292 for (const auto i : c10::irange(begin, end)) {
293 int64_t outer_idx = i / inner_size;
294 int64_t inner_idx = i % inner_size;
295 scalar_t* gradInput_data =
296 gradInput_data_base + outer_idx * outer_stride + inner_idx;
297 scalar_t* output_data =
298 output_data_base + outer_idx * outer_stride + inner_idx;
299 const scalar_t* gradOutput_data =
300 gradOutput_data_base + outer_idx * outer_stride + inner_idx;
301 bool* mask_data = nullptr;
302 if (MaskedSoftMax) {
303 mask_data = mask_data_base + outer_idx * outer_stride + inner_idx;
304 }
305
306 acc_type<scalar_t, false> sum = 0;
307 for (const auto d : c10::irange(dim_size)) {
308 if (!MaskedSoftMax || !mask_data[d * dim_stride]) {
309 if (LogSoftMax) {
310 sum += gradOutput_data[d * dim_stride];
311 } else {
312 sum +=
313 gradOutput_data[d * dim_stride] * output_data[d * dim_stride];
314 }
315 }
316 }
317
318 for (const auto d : c10::irange(dim_size)) {
319 if (MaskedSoftMax && mask_data[d * dim_stride]) {
320 gradInput_data[d * dim_stride] = 0;
321 }
322 else if (LogSoftMax) {
323 gradInput_data[d * dim_stride] = gradOutput_data[d * dim_stride] -
324 std::exp(output_data[d * dim_stride]) * sum;
325 } else {
326 gradInput_data[d * dim_stride] = output_data[d * dim_stride] *
327 (gradOutput_data[d * dim_stride] - sum);
328 }
329 }
330 }
331 });
332 }
333 } // namespace
334
TORCH_IMPL_FUNC(softmax_cpu_out)335 TORCH_IMPL_FUNC(softmax_cpu_out)
336 (const Tensor& input,
337 const int64_t dim,
338 const bool half_to_float,
339 const Tensor& output) {
340 TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on CPU");
341
342 if (input.numel() == 0) {
343 return;
344 }
345
346 auto input_ = input.contiguous();
347 int64_t dim_ = maybe_wrap_dim(dim, input_.dim());
348
349 if (input_.dim() == 0) {
350 input_ = input_.view(1);
351 }
352
353 TORCH_CHECK(
354 dim_ >= 0 && dim_ < input_.dim(),
355 "dim must be non-negative and less than input dimensions");
356 if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) {
357 softmax_lastdim_kernel(kCPU, output, input_);
358 } else {
359 softmax_kernel(kCPU, output, input_, dim_);
360 }
361 }
362
TORCH_IMPL_FUNC(log_softmax_cpu_out)363 TORCH_IMPL_FUNC(log_softmax_cpu_out)
364 (const Tensor& input,
365 const int64_t dim,
366 const bool half_to_float,
367 const Tensor& output) {
368 TORCH_CHECK(
369 !half_to_float,
370 "softmax with half to float conversion is not supported on CPU");
371
372 if (input.numel() == 0) {
373 return;
374 }
375
376 auto input_ = input.contiguous();
377 int64_t dim_ = maybe_wrap_dim(dim, input_.dim());
378
379 if (input_.dim() == 0) {
380 input_ = input_.view(1);
381 }
382
383 if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) {
384 log_softmax_lastdim_kernel(kCPU, output, input_);
385 } else {
386 log_softmax_kernel(kCPU, output, input_, dim_);
387 }
388 }
389
TORCH_IMPL_FUNC(softmax_backward_cpu_out)390 TORCH_IMPL_FUNC(softmax_backward_cpu_out)
391 (const Tensor& grad,
392 const Tensor& output,
393 int64_t dim,
394 ScalarType input_dtype,
395 const Tensor& grad_input) {
396 int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
397 auto grad_ = grad.contiguous();
398 auto output_ = output.contiguous();
399
400 if (output.numel() == 0) {
401 return;
402 }
403
404 if (grad_.dim() == 0) {
405 grad_ = grad_.view(1);
406 }
407
408 if (output_.dim() == 0) {
409 output_ = output_.view(1);
410 }
411
412 if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) {
413 softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output);
414 } else {
415 softmax_backward_kernel(kCPU, grad_input, grad_, output, dim_);
416 }
417 }
418
TORCH_IMPL_FUNC(log_softmax_backward_cpu_out)419 TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
420 const Tensor& grad,
421 const Tensor& output,
422 int64_t dim,
423 ScalarType input_dtype,
424 const Tensor& grad_input) {
425 int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
426 auto grad_ = grad.contiguous();
427 auto output_ = output.contiguous();
428
429 if (output.numel() != 0) {
430 if (grad_.dim() == 0)
431 grad_ = grad_.view(1);
432 if (output_.dim() == 0) {
433 output_ = output_.view(1);
434 }
435 if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) {
436 log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output_);
437 } else {
438 log_softmax_backward_kernel(kCPU, grad_input, grad_, output_, dim_);
439 }
440 }
441 }
442
softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)443 Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
444 auto result = [&]() {
445 NoNamesGuard guard;
446 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
447 return at::_softmax(input_, dim_, true);
448 } else {
449 Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
450 return at::_softmax(converted, dim_, false);
451 }
452 }();
453 namedinference::propagate_names(result, input_);
454 return result;
455 }
456
softmax_out(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype,Tensor & output_)457 Tensor& softmax_out(
458 const Tensor& input_,
459 const int64_t dim_,
460 std::optional<ScalarType> dtype,
461 Tensor& output_) {
462 Tensor output_temp;
463 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
464 dtype == ScalarType::Float) {
465 if (!output_.is_contiguous()) {
466 auto options =
467 TensorOptions().dtype(output_.dtype()).device(output_.device());
468 output_temp = at::empty(output_.sizes(), options);
469 at::_softmax_out(output_temp, input_, dim_, true);
470 } else {
471 at::_softmax_out(output_, input_, dim_, true);
472 }
473 } else {
474 Tensor converted =
475 dtype.has_value() ? input_.toType(dtype.value()) : input_;
476 if (!output_.is_contiguous()) {
477 auto options =
478 TensorOptions().dtype(output_.dtype()).device(output_.device());
479 output_temp = at::empty(output_.sizes(), options);
480 at::_softmax_out(output_temp, converted, dim_, false);
481 } else {
482 at::_softmax_out(output_, converted, dim_, false);
483 }
484 }
485
486 if (!output_.is_contiguous()) {
487 output_.resize_(output_temp.sizes());
488 output_.copy_(output_temp);
489 }
490
491 return output_;
492 }
493
494 // special_softmax, alias for softmax
special_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)495 Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
496 return at::softmax(input_, dim_, dtype);
497 }
498
log_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)499 Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
500 auto result = [&]() {
501 NoNamesGuard guard;
502 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
503 return at::_log_softmax(input_, dim_, true);
504 } else {
505 Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_;
506 return at::_log_softmax(converted, dim_, false);
507 }
508 }();
509 namedinference::propagate_names(result, input_);
510 return result;
511 }
512
log_softmax_out(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype,Tensor & output_)513 Tensor& log_softmax_out(
514 const Tensor& input_,
515 const int64_t dim_,
516 std::optional<ScalarType> dtype,
517 Tensor& output_) {
518 Tensor output_temp;
519 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
520 dtype == ScalarType::Float) {
521 if (!output_.is_contiguous()) {
522 auto options =
523 TensorOptions().dtype(output_.dtype()).device(output_.device());
524 output_temp = at::empty(output_.sizes(), options);
525 at::_log_softmax_out(output_temp, input_, dim_, true);
526 } else {
527 at::_log_softmax_out(output_, input_, dim_, true);
528 }
529 } else {
530 Tensor converted =
531 dtype.has_value() ? input_.toType(dtype.value()) : input_;
532 if (!output_.is_contiguous()) {
533 auto options =
534 TensorOptions().dtype(output_.dtype()).device(output_.device());
535 output_temp = at::empty(output_.sizes(), options);
536 at::_log_softmax_out(output_temp, converted, dim_, false);
537 } else {
538 at::_log_softmax_out(output_, converted, dim_, false);
539 }
540 }
541
542 if (!output_.is_contiguous()) {
543 output_.resize_(output_temp.sizes());
544 output_.copy_(output_temp);
545 }
546
547 return output_;
548 }
549
special_log_softmax(const Tensor & input,const int64_t dim,std::optional<ScalarType> dtype)550 Tensor special_log_softmax(const Tensor& input, const int64_t dim, std::optional<ScalarType> dtype) {
551 return at::log_softmax(input, dim, dtype);
552 }
553
554 DEFINE_DISPATCH(softmax_lastdim_kernel);
555 DEFINE_DISPATCH(log_softmax_lastdim_kernel);
556 DEFINE_DISPATCH(softmax_backward_lastdim_kernel);
557 DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel);
558
559 DEFINE_DISPATCH(softmax_kernel);
560 DEFINE_DISPATCH(log_softmax_kernel);
561 DEFINE_DISPATCH(softmax_backward_kernel);
562 DEFINE_DISPATCH(log_softmax_backward_kernel);
563
softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)564 Tensor softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
565 return at::softmax(self, dimname_to_position(self, dim), dtype);
566 }
567
log_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)568 Tensor log_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
569 return at::log_softmax(self, dimname_to_position(self, dim), dtype);
570 }
571
masked_softmax_cpu(const Tensor & input_,const Tensor & mask_,const std::optional<int64_t> dim_,const std::optional<int64_t> mask_type_)572 Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const std::optional<int64_t> dim_, const std::optional<int64_t> mask_type_) {
573
574 auto mask = mask_.contiguous();
575 auto mask_type = mask_type_; // Mask type might get transformed below
576
577 TORCH_CHECK(
578 mask_.scalar_type() == ScalarType::Bool,
579 "Mask should be a boolean tensor");
580
581 if ((mask.dim() != 2) || (input_.dim() != 4)) {
582 // Mask types 0 and 1 are only allowed for 2D masks and 4D inputs
583 mask_type = 2;
584 }
585
586 if (mask_type == 2) {
587 TORCH_CHECK(input_.sizes() == mask.sizes(),
588 "For mask_type == 2 mask shape should match input shape")
589 } else if (mask_type == 1) {
590 // Padding mask of shape (B, L)
591 TORCH_CHECK((input_.sizes()[0] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]),
592 "For mask_type == 1 mask shape should be (B, L)");
593 if (dim_ != input_.dim() - 1) {
594 // We only process padding mask in the optimized way if softmax is applied along the last dimesion,
595 // otherwise we need to expand the mask into a generic 4D one
596 mask = mask_.view({input_.sizes()[0], 1, 1, input_.sizes()[2]});
597 mask = mask.expand(input_.sizes()).contiguous();
598 mask_type = 2;
599 }
600 } else if (mask_type == 0) {
601 // Attention mask of shape (L, L)
602 TORCH_CHECK((mask.dim() == 2) && (input_.sizes()[2] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]),
603 "For mask_type == 0 mask shape should be (L, L)");
604 if (dim_ != input_.dim() - 1) {
605 // We only process attention mask in a optimized way if softmax is applied along the last dimesion,
606 // otherwise we need to expand the mask into a generic 4D one
607 mask = mask.view({1, 1, input_.sizes()[2], input_.sizes()[2]});
608 mask = mask.expand(input_.sizes()).contiguous();
609 mask_type = 2;
610 }
611 }
612
613 Tensor output = at::empty_like(input_, input_.options());
614 auto input = input_.contiguous();
615 int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1;
616 dim = maybe_wrap_dim(dim, input_.dim());
617
618 if (input.dim() == 0) {
619 input = input.view(1);
620 }
621
622 AT_DISPATCH_FLOATING_TYPES_AND2(
623 at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "masked_softmax", [&] {
624 host_softmax<
625 scalar_t,
626 false /* LogSoftMax */,
627 true /* MaskedSoftMax */>(
628 output, input, dim, mask.data_ptr<bool>(), mask_type);
629 });
630 return output;
631 }
632
masked_softmax_backward_cpu(const Tensor & grad_,const Tensor & output_,const Tensor & mask_,const std::optional<int64_t> dim_)633 Tensor masked_softmax_backward_cpu(
634 const Tensor& grad_,
635 const Tensor& output_,
636 const Tensor& mask_,
637 const std::optional<int64_t> dim_) {
638 TORCH_CHECK(
639 grad_.sizes() == mask_.sizes(), "Mask shape should match grad shape");
640 TORCH_CHECK(
641 mask_.scalar_type() == ScalarType::Bool,
642 "Mask should be a boolean tensor");
643 auto grad = grad_.contiguous();
644 auto output = output_.contiguous();
645 auto mask = mask_.contiguous();
646
647 int64_t dim = dim_.has_value() ? dim_.value() : output.dim() - 1;
648 dim = maybe_wrap_dim(dim, grad.dim());
649
650 grad = grad.dim() == 0 ? grad.view(1) : grad;
651 output = output.dim() == 0 ? output.view(1) : output;
652 mask = mask.dim() == 0 ? mask.view(1) : mask;
653
654 Tensor grad_input = at::empty_like(grad, grad.options());
655 AT_DISPATCH_FLOATING_TYPES_AND2(
656 at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(), "masked_softmax_backward", [&] {
657 host_softmax_backward<
658 scalar_t,
659 false /* LogSoftMax */,
660 true /* MaskedSoftmax */>(grad_input, grad, output, dim, mask.data_ptr<bool>());
661 });
662 return grad_input;
663 }
664 } // namespace at::native
665