1 #pragma once
2
3 #include <ATen/Dispatch.h>
4 #include <torch/nn/functional/dropout.h>
5 #include <torch/nn/functional/linear.h>
6 #include <torch/nn/options/activation.h>
7 #include <torch/nn/options/dropout.h>
8 #include <torch/nn/options/linear.h>
9 #include <torch/types.h>
10 #include <limits>
11 #include <utility>
12
13 namespace torch {
14 namespace nn {
15 namespace functional {
16
17 #ifndef DOXYGEN_SHOULD_SKIP_THIS
18 namespace detail {
elu(Tensor input,double alpha,bool inplace)19 inline Tensor elu(Tensor input, double alpha, bool inplace) {
20 if (inplace) {
21 return torch::elu_(input, alpha);
22 } else {
23 return torch::elu(input, alpha);
24 }
25 }
26 } // namespace detail
27 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
28
29 /// See
30 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.elu
31 /// about the exact behavior of this functional.
32 ///
33 /// See the documentation for `torch::nn::functional::ELUFuncOptions` class to
34 /// learn what optional arguments are supported for this functional.
35 ///
36 /// Example:
37 /// ```
38 /// namespace F = torch::nn::functional;
39 /// F::elu(x, F::ELUFuncOptions().alpha(0.42).inplace(true));
40 /// ```
41 inline Tensor elu(Tensor input, const ELUFuncOptions& options = {}) {
42 return detail::elu(std::move(input), options.alpha(), options.inplace());
43 }
44
45 // ============================================================================
46
47 #ifndef DOXYGEN_SHOULD_SKIP_THIS
48 namespace detail {
selu(Tensor input,bool inplace)49 inline Tensor selu(Tensor input, bool inplace) {
50 if (inplace) {
51 return torch::selu_(input);
52 } else {
53 return torch::selu(input);
54 }
55 }
56 } // namespace detail
57 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
58
59 /// See
60 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.selu
61 /// about the exact behavior of this functional.
62 ///
63 /// See the documentation for `torch::nn::functional::SELUFuncOptions` class to
64 /// learn what optional arguments are supported for this functional.
65 ///
66 /// Example:
67 /// ```
68 /// namespace F = torch::nn::functional;
69 /// F::selu(input, F::SELUFuncOptions(false));
70 /// ```
71 inline Tensor selu(Tensor input, const SELUFuncOptions& options = {}) {
72 return detail::selu(std::move(input), options.inplace());
73 }
74
75 // ============================================================================
76
77 #ifndef DOXYGEN_SHOULD_SKIP_THIS
78 namespace detail {
hardshrink(const Tensor & input,double lambda)79 inline Tensor hardshrink(const Tensor& input, double lambda) {
80 return torch::hardshrink(input, lambda);
81 }
82 } // namespace detail
83 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
84
85 /// See
86 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hardshrink
87 /// about the exact behavior of this functional.
88 ///
89 /// See the documentation for `torch::nn::functional::HardshrinkFuncOptions`
90 /// class to learn what optional arguments are supported for this functional.
91 ///
92 /// Example:
93 /// ```
94 /// namespace F = torch::nn::functional;
95 /// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42));
96 /// ```
97 inline Tensor hardshrink(
98 const Tensor& input,
99 const HardshrinkFuncOptions& options = {}) {
100 return detail::hardshrink(input, options.lambda());
101 }
102
103 // ============================================================================
104
105 #ifndef DOXYGEN_SHOULD_SKIP_THIS
106 namespace detail {
hardtanh(Tensor input,double min_val,double max_val,bool inplace)107 inline Tensor hardtanh(
108 Tensor input,
109 double min_val,
110 double max_val,
111 bool inplace) {
112 if (inplace) {
113 return torch::hardtanh_(input, min_val, max_val);
114 } else {
115 return torch::hardtanh(input, min_val, max_val);
116 }
117 }
118 } // namespace detail
119 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
120
121 /// See
122 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hardtanh
123 /// about the exact behavior of this functional.
124 ///
125 /// See the documentation for `torch::nn::functional::HardtanhFuncOptions` class
126 /// to learn what optional arguments are supported for this functional.
127 ///
128 /// Example:
129 /// ```
130 /// namespace F = torch::nn::functional;
131 /// F::hardtanh(x,
132 /// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true));
133 /// ```
134 inline Tensor hardtanh(Tensor input, const HardtanhFuncOptions& options = {}) {
135 return detail::hardtanh(
136 std::move(input),
137 options.min_val(),
138 options.max_val(),
139 options.inplace());
140 }
141
142 // ============================================================================
143
144 #ifndef DOXYGEN_SHOULD_SKIP_THIS
145 namespace detail {
leaky_relu(Tensor input,double negative_slope,bool inplace)146 inline Tensor leaky_relu(Tensor input, double negative_slope, bool inplace) {
147 if (inplace) {
148 return torch::leaky_relu_(input, negative_slope);
149 } else {
150 return torch::leaky_relu(input, negative_slope);
151 }
152 }
153 } // namespace detail
154 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
155
156 /// See
157 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.leaky_relu
158 /// about the exact behavior of this functional.
159 ///
160 /// See the documentation for `torch::nn::functional::LeakyReLUFuncOptions`
161 /// class to learn what optional arguments are supported for this functional.
162 ///
163 /// Example:
164 /// ```
165 /// namespace F = torch::nn::functional;
166 /// F::leaky_relu(x,
167 /// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true));
168 /// ```
169 inline Tensor leaky_relu(
170 Tensor input,
171 const LeakyReLUFuncOptions& options = {}) {
172 return detail::leaky_relu(
173 std::move(input), options.negative_slope(), options.inplace());
174 }
175
176 // ============================================================================
177
logsigmoid(const Tensor & input)178 inline Tensor logsigmoid(const Tensor& input) {
179 return torch::log_sigmoid(input);
180 }
181
182 // ============================================================================
183
184 #ifndef DOXYGEN_SHOULD_SKIP_THIS
185 namespace detail {
gumbel_softmax(const Tensor & logits,double tau,bool hard,int dim)186 inline Tensor gumbel_softmax(
187 const Tensor& logits,
188 double tau,
189 bool hard,
190 int dim) {
191 auto gumbels =
192 -torch::empty_like(logits).exponential_().log(); // ~Gumbel(0,1)
193 gumbels = (logits + gumbels) / tau; // ~Gumbel(logits, tau)
194 auto y_soft = gumbels.softmax(dim);
195
196 torch::Tensor ret;
197 if (hard) {
198 // Straight through.
199 auto index = std::get<1>(y_soft.max(dim, /*keepdim=*/true));
200 auto y_hard = torch::zeros_like(logits).scatter_(dim, index, 1.0);
201 ret = y_hard - y_soft.detach() + y_soft;
202 } else {
203 ret = y_soft;
204 }
205 return ret;
206 }
207 } // namespace detail
208 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
209
210 /// See
211 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.gumbel_softmax
212 /// about the exact behavior of this functional.
213 ///
214 /// See the documentation for `torch::nn::functional::GumbelSoftmaxFuncOptions`
215 /// class to learn what optional arguments are supported for this functional.
216 ///
217 /// Example:
218 /// ```
219 /// namespace F = torch::nn::functional;
220 /// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1));
221 /// ```
222 inline Tensor gumbel_softmax(
223 const Tensor& logits,
224 const GumbelSoftmaxFuncOptions& options = {}) {
225 return detail::gumbel_softmax(
226 logits, options.tau(), options.hard(), options.dim());
227 }
228
229 // ============================================================================
230
231 #ifndef DOXYGEN_SHOULD_SKIP_THIS
232 namespace detail {
softmax(const Tensor & input,int64_t dim,std::optional<torch::Dtype> dtype)233 inline Tensor softmax(
234 const Tensor& input,
235 int64_t dim,
236 std::optional<torch::Dtype> dtype) {
237 Tensor ret;
238
239 if (dtype == std::nullopt) {
240 ret = input.softmax(dim);
241 } else {
242 ret = input.softmax(dim, dtype);
243 }
244
245 return ret;
246 }
247 } // namespace detail
248 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
249
250 /// See
251 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softmax
252 /// about the exact behavior of this functional.
253 ///
254 /// See the documentation for `torch::nn::functional::SoftmaxFuncOptions` class
255 /// to learn what optional arguments are supported for this functional.
256 ///
257 /// Example:
258 /// ```
259 /// namespace F = torch::nn::functional;
260 /// F::softmax(input, F::SoftmaxFuncOptions(1));
261 /// ```
softmax(const Tensor & input,const SoftmaxFuncOptions & options)262 inline Tensor softmax(const Tensor& input, const SoftmaxFuncOptions& options) {
263 return detail::softmax(input, options.dim(), options.dtype());
264 }
265
266 // ============================================================================
267
268 #ifndef DOXYGEN_SHOULD_SKIP_THIS
269 namespace detail {
softmin(const Tensor & input,int64_t dim,std::optional<torch::Dtype> dtype)270 inline Tensor softmin(
271 const Tensor& input,
272 int64_t dim,
273 std::optional<torch::Dtype> dtype) {
274 Tensor ret;
275
276 if (dtype == std::nullopt) {
277 ret = (-input).softmax(dim);
278 } else {
279 ret = (-input).softmax(dim, dtype);
280 }
281
282 return ret;
283 }
284 } // namespace detail
285 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
286
287 /// See
288 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softmin
289 /// about the exact behavior of this functional.
290 ///
291 /// See the documentation for `torch::nn::functional::SoftminFuncOptions` class
292 /// to learn what optional arguments are supported for this functional.
293 ///
294 /// Example:
295 /// ```
296 /// namespace F = torch::nn::functional;
297 /// F::softmin(input, F::SoftminFuncOptions(1));
298 /// ```
softmin(const Tensor & input,const SoftminFuncOptions & options)299 inline Tensor softmin(const Tensor& input, const SoftminFuncOptions& options) {
300 return detail::softmin(input, options.dim(), options.dtype());
301 }
302
303 // ============================================================================
304
305 #ifndef DOXYGEN_SHOULD_SKIP_THIS
306 namespace detail {
log_softmax(const Tensor & input,int64_t dim,std::optional<torch::Dtype> dtype)307 inline Tensor log_softmax(
308 const Tensor& input,
309 int64_t dim,
310 std::optional<torch::Dtype> dtype) {
311 Tensor ret;
312
313 if (dtype == std::nullopt) {
314 ret = input.log_softmax(dim);
315 } else {
316 ret = input.log_softmax(dim, dtype);
317 }
318
319 return ret;
320 }
321 } // namespace detail
322 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
323
324 /// See
325 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.log_softmax
326 /// about the exact behavior of this functional.
327 ///
328 /// See the documentation for `torch::nn::functional::LogSoftmaxFuncOptions`
329 /// class to learn what optional arguments are supported for this functional.
330 ///
331 /// Example:
332 /// ```
333 /// namespace F = torch::nn::functional;
334 /// F::log_softmax(input, LogSoftmaxFuncOptions(1));
335 /// ```
log_softmax(const Tensor & input,const LogSoftmaxFuncOptions & options)336 inline Tensor log_softmax(
337 const Tensor& input,
338 const LogSoftmaxFuncOptions& options) {
339 return detail::log_softmax(input, options.dim(), options.dtype());
340 }
341
342 // ============================================================================
343
344 #ifndef DOXYGEN_SHOULD_SKIP_THIS
345 namespace detail {
glu(const Tensor & input,int64_t dim)346 inline Tensor glu(const Tensor& input, int64_t dim) {
347 TORCH_CHECK(
348 input.dim() != 0,
349 "glu does not suppport scalars because halving size must be even");
350 return torch::glu(input, dim);
351 }
352 } // namespace detail
353 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
354
355 /// See
356 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.glu
357 /// about the exact behavior of this functional.
358 ///
359 /// See the documentation for `torch::nn::functional::GLUFuncOptions` class to
360 /// learn what optional arguments are supported for this functional.
361 ///
362 /// Example:
363 /// ```
364 /// namespace F = torch::nn::functional;
365 /// F::glu(input, GLUFuncOptions(1));
366 /// ```
367 inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) {
368 return detail::glu(input, options.dim());
369 }
370
371 // ============================================================================
372
373 #ifndef DOXYGEN_SHOULD_SKIP_THIS
374 namespace detail {
gelu(const Tensor & input,string approximate)375 inline Tensor gelu(const Tensor& input, string approximate) {
376 return torch::gelu(input, approximate);
377 }
378 } // namespace detail
379 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
380
381 inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) {
382 return detail::gelu(input, options.approximate());
383 }
384
385 // ============================================================================
386
silu(const Tensor & input)387 inline Tensor silu(const Tensor& input) {
388 return torch::silu(input);
389 }
390
391 // ============================================================================
392
mish(const Tensor & input)393 inline Tensor mish(const Tensor& input) {
394 return torch::mish(input);
395 }
396
397 // ============================================================================
398
prelu(const Tensor & input,const Tensor & weight)399 inline Tensor prelu(const Tensor& input, const Tensor& weight) {
400 return torch::prelu(input, weight);
401 }
402
403 // ============================================================================
404
405 #ifndef DOXYGEN_SHOULD_SKIP_THIS
406 namespace detail {
relu(Tensor input,bool inplace)407 inline Tensor relu(Tensor input, bool inplace) {
408 if (inplace) {
409 return torch::relu_(input);
410 } else {
411 return torch::relu(input);
412 }
413 }
414 } // namespace detail
415 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
416
417 /// See
418 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.relu
419 /// about the exact behavior of this functional.
420 ///
421 /// See the documentation for `torch::nn::functional::ReLUFuncOptions` class to
422 /// learn what optional arguments are supported for this functional.
423 ///
424 /// Example:
425 /// ```
426 /// namespace F = torch::nn::functional;
427 /// F::relu(x, F::ReLUFuncOptions().inplace(true));
428 /// ```
429 inline Tensor relu(Tensor input, const ReLUFuncOptions& options = {}) {
430 return detail::relu(std::move(input), options.inplace());
431 }
432
433 // ============================================================================
434
435 #ifndef DOXYGEN_SHOULD_SKIP_THIS
436 namespace detail {
relu6(Tensor input,bool inplace)437 inline Tensor relu6(Tensor input, bool inplace) {
438 if (inplace) {
439 return torch::relu6_(input);
440 } else {
441 return torch::relu6(input);
442 }
443 }
444 } // namespace detail
445 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
446
447 /// See
448 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.relu6
449 /// about the exact behavior of this functional.
450 ///
451 /// See the documentation for `torch::nn::functional::ReLU6FuncOptions` class to
452 /// learn what optional arguments are supported for this functional.
453 ///
454 /// Example:
455 /// ```
456 /// namespace F = torch::nn::functional;
457 /// F::relu6(x, F::ReLU6FuncOptions().inplace(true));
458 /// ```
459 inline Tensor relu6(Tensor input, const ReLU6FuncOptions& options = {}) {
460 return detail::relu6(std::move(input), options.inplace());
461 }
462
463 // ============================================================================
464
465 #ifndef DOXYGEN_SHOULD_SKIP_THIS
466 namespace detail {
rrelu(Tensor input,double lower,double upper,bool training,bool inplace)467 inline Tensor rrelu(
468 Tensor input,
469 double lower,
470 double upper,
471 bool training,
472 bool inplace) {
473 if (inplace) {
474 return torch::rrelu_(input, lower, upper, training);
475 } else {
476 return torch::rrelu(input, lower, upper, training);
477 }
478 }
479 } // namespace detail
480 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
481
482 /// See
483 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.rrelu
484 /// about the exact behavior of this functional.
485 ///
486 /// See the documentation for `torch::nn::functional::RReLUFuncOptions` class to
487 /// learn what optional arguments are supported for this functional.
488 ///
489 /// Example:
490 /// ```
491 /// namespace F = torch::nn::functional;
492 /// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true));
493 /// ```
494 inline Tensor rrelu(Tensor input, const RReLUFuncOptions& options = {}) {
495 return detail::rrelu(
496 std::move(input),
497 options.lower(),
498 options.upper(),
499 options.training(),
500 options.inplace());
501 }
502
503 // ============================================================================
504
505 #ifndef DOXYGEN_SHOULD_SKIP_THIS
506 namespace detail {
celu(Tensor input,double alpha,bool inplace)507 inline Tensor celu(Tensor input, double alpha, bool inplace) {
508 if (inplace) {
509 return torch::celu_(input, alpha);
510 } else {
511 return torch::celu(input, alpha);
512 }
513 }
514 } // namespace detail
515 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
516
517 /// See
518 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.celu
519 /// about the exact behavior of this functional.
520 ///
521 /// See the documentation for `torch::nn::functional::CELUFuncOptions` class to
522 /// learn what optional arguments are supported for this functional.
523 ///
524 /// Example:
525 /// ```
526 /// namespace F = torch::nn::functional;
527 /// F::celu(x, F::CELUFuncOptions().alpha(0.42).inplace(true));
528 /// ```
529 inline Tensor celu(Tensor input, const CELUFuncOptions& options = {}) {
530 return detail::celu(std::move(input), options.alpha(), options.inplace());
531 }
532
533 // ============================================================================
534
535 #ifndef DOXYGEN_SHOULD_SKIP_THIS
536 namespace detail {
softplus(const Tensor & input,double beta,double threshold)537 inline Tensor softplus(const Tensor& input, double beta, double threshold) {
538 return torch::softplus(input, beta, threshold);
539 }
540 } // namespace detail
541 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
542
543 /// See
544 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softplus
545 /// about the exact behavior of this functional.
546 ///
547 /// See the documentation for `torch::nn::functional::SoftplusFuncOptions` class
548 /// to learn what optional arguments are supported for this functional.
549 ///
550 /// Example:
551 /// ```
552 /// namespace F = torch::nn::functional;
553 /// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0));
554 /// ```
555 inline Tensor softplus(
556 const Tensor& input,
557 const SoftplusFuncOptions& options = {}) {
558 return detail::softplus(input, options.beta(), options.threshold());
559 }
560
561 // ============================================================================
562
563 #ifndef DOXYGEN_SHOULD_SKIP_THIS
564 namespace detail {
softshrink(const Tensor & input,double lambda)565 inline Tensor softshrink(const Tensor& input, double lambda) {
566 return torch::softshrink(input, lambda);
567 }
568 } // namespace detail
569 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
570
571 /// See
572 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softshrink
573 /// about the exact behavior of this functional.
574 ///
575 /// See the documentation for `torch::nn::functional::SoftshrinkFuncOptions`
576 /// class to learn what optional arguments are supported for this functional.
577 ///
578 /// Example:
579 /// ```
580 /// namespace F = torch::nn::functional;
581 /// F::softshrink(x, F::SoftshrinkFuncOptions(0.42));
582 /// ```
583 inline Tensor softshrink(
584 const Tensor& input,
585 const SoftshrinkFuncOptions& options = {}) {
586 return detail::softshrink(input, options.lambda());
587 }
588
589 // ============================================================================
590
softsign(const Tensor & input)591 inline Tensor softsign(const Tensor& input) {
592 return input / (input.abs() + 1);
593 }
594
595 // ============================================================================
596
tanhshrink(const Tensor & input)597 inline Tensor tanhshrink(const Tensor& input) {
598 return input - input.tanh();
599 }
600
601 // ============================================================================
602
603 #ifndef DOXYGEN_SHOULD_SKIP_THIS
604 namespace detail {
threshold(Tensor input,double threshold,double value,bool inplace)605 inline Tensor threshold(
606 Tensor input,
607 double threshold,
608 double value,
609 bool inplace) {
610 if (inplace) {
611 return torch::threshold_(input, threshold, value);
612 } else {
613 return torch::threshold(input, threshold, value);
614 }
615 }
616 } // namespace detail
617 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
618
619 /// See
620 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.threshold
621 /// about the exact behavior of this functional.
622 ///
623 /// See the documentation for `torch::nn::functional::ThresholdFuncOptions`
624 /// class to learn what optional arguments are supported for this functional.
625 ///
626 /// Example:
627 /// ```
628 /// namespace F = torch::nn::functional;
629 /// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true));
630 /// ```
threshold(Tensor input,const ThresholdFuncOptions & options)631 inline Tensor threshold(Tensor input, const ThresholdFuncOptions& options) {
632 return detail::threshold(
633 std::move(input),
634 options.threshold(),
635 options.value(),
636 options.inplace());
637 }
638
639 // ============================================================================
640
641 #ifndef DOXYGEN_SHOULD_SKIP_THIS
642 namespace detail {
643 inline std::tuple<Tensor, Tensor> multi_head_attention_forward(
644 const Tensor& query,
645 const Tensor& key,
646 const Tensor& value,
647 int64_t embed_dim_to_check,
648 int64_t num_heads,
649 const Tensor& in_proj_weight,
650 const Tensor& in_proj_bias,
651 const Tensor& bias_k,
652 const Tensor& bias_v,
653 bool add_zero_attn,
654 double dropout_p,
655 const Tensor& out_proj_weight,
656 const Tensor& out_proj_bias,
657 bool training = true,
658 const Tensor& key_padding_mask = {},
659 bool need_weights = true,
660 const Tensor& attn_mask = {},
661 bool use_separate_proj_weight = false,
662 const Tensor& q_proj_weight = {},
663 const Tensor& k_proj_weight = {},
664 const Tensor& v_proj_weight = {},
665 const Tensor& static_k = {},
666 const Tensor& static_v = {},
667 bool average_attn_weights = true) {
668 namespace F = torch::nn::functional;
669
670 const auto query_sizes = query.sizes();
671 const auto& tgt_len = query_sizes[0];
672 const auto& bsz = query_sizes[1];
673 const auto& embed_dim = query_sizes[2];
674 TORCH_INTERNAL_ASSERT(embed_dim == embed_dim_to_check);
675 TORCH_INTERNAL_ASSERT(key.sizes() == value.sizes());
676
677 const auto head_dim = embed_dim / num_heads;
678 TORCH_CHECK(
679 head_dim * num_heads == embed_dim,
680 "embed_dim must be divisible by num_heads");
681 const auto scaling = 1 / std::sqrt(head_dim);
682
683 Tensor q, k, v;
684 if (!use_separate_proj_weight) {
685 if (torch::equal(query, key) && torch::equal(key, value)) {
686 // self-attention
687 const auto chunks =
688 F::linear(query, in_proj_weight, in_proj_bias).chunk(3, /*dim=*/-1);
689 q = chunks[0];
690 k = chunks[1];
691 v = chunks[2];
692 } else if (torch::equal(key, value)) {
693 // encoder-decoder attention
694 // This is inline in_proj function with in_proj_weight and in_proj_bias
695 auto _b = in_proj_bias;
696 auto _start = 0;
697 auto _end = embed_dim;
698 auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
699 if (_b.defined()) {
700 _b = _b.slice(/*dim=*/0, _start, _end);
701 }
702 q = F::linear(query, _w, _b);
703
704 if (!key.defined()) {
705 TORCH_INTERNAL_ASSERT(!value.defined());
706 k.reset();
707 v.reset();
708 } else {
709 // This is inline in_proj function with in_proj_weight and in_proj_bias
710 _b = in_proj_bias;
711 _start = embed_dim;
712 _w = in_proj_weight.slice(/*dim=*/0, _start);
713 if (_b.defined()) {
714 _b = _b.slice(/*dim=*/0, _start);
715 }
716 const auto chunks = F::linear(key, _w, _b).chunk(2, /*dim=*/-1);
717 k = chunks[0];
718 v = chunks[1];
719 }
720 } else {
721 // This is inline in_proj function with in_proj_weight and in_proj_bias
722 auto _b = in_proj_bias;
723 auto _start = 0;
724 auto _end = embed_dim;
725 auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
726 if (_b.defined()) {
727 _b = _b.slice(/*dim=*/0, _start, _end);
728 }
729 q = F::linear(query, _w, _b);
730
731 // This is inline in_proj function with in_proj_weight and in_proj_bias
732 _b = in_proj_bias;
733 _start = embed_dim;
734 _end = embed_dim * 2;
735 _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
736 if (_b.defined()) {
737 _b = _b.slice(/*dim=*/0, _start, _end);
738 }
739 k = F::linear(key, _w, _b);
740
741 // This is inline in_proj function with in_proj_weight and in_proj_bias
742 _b = in_proj_bias;
743 _start = embed_dim * 2;
744 _w = in_proj_weight.slice(/*dim=*/0, _start);
745 if (_b.defined()) {
746 _b = _b.slice(0, _start);
747 }
748 v = F::linear(value, _w, _b);
749 }
750 } else {
751 const auto& q_proj_weight_non_opt = q_proj_weight;
752 {
753 const auto sizes = q_proj_weight_non_opt.sizes();
754 const auto len1 = sizes[0];
755 const auto len2 = sizes[1];
756 TORCH_CHECK(len1 == embed_dim && len2 == query.size(-1));
757 }
758
759 const auto& k_proj_weight_non_opt = k_proj_weight;
760 {
761 const auto sizes = k_proj_weight_non_opt.sizes();
762 const auto len1 = sizes[0];
763 const auto len2 = sizes[1];
764 TORCH_CHECK(len1 == embed_dim && len2 == key.size(-1));
765 }
766
767 const auto& v_proj_weight_non_opt = v_proj_weight;
768 {
769 const auto sizes = v_proj_weight_non_opt.sizes();
770 const auto len1 = sizes[0];
771 const auto len2 = sizes[1];
772 TORCH_CHECK(len1 == embed_dim && len2 == value.size(-1));
773 }
774
775 if (in_proj_bias.defined()) {
776 q = F::linear(
777 query,
778 q_proj_weight_non_opt,
779 in_proj_bias.slice(/*dim=*/0, 0, embed_dim));
780 k = F::linear(
781 key,
782 k_proj_weight_non_opt,
783 in_proj_bias.slice(/*dim=*/0, embed_dim, (embed_dim * 2)));
784 v = F::linear(
785 value,
786 v_proj_weight_non_opt,
787 in_proj_bias.slice(/*dim=*/0, (embed_dim * 2)));
788 } else {
789 q = F::linear(query, q_proj_weight_non_opt, in_proj_bias);
790 k = F::linear(key, k_proj_weight_non_opt, in_proj_bias);
791 v = F::linear(value, v_proj_weight_non_opt, in_proj_bias);
792 }
793 }
794 q = q * scaling;
795 Tensor attn_mask_ = attn_mask;
796 Tensor key_padding_mask_ = key_padding_mask;
797 if (bias_k.defined() && bias_v.defined()) {
798 if (!static_k.defined() && !static_v.defined()) {
799 k = torch::cat({k, bias_k.repeat({1, bsz, 1})});
800 v = torch::cat({v, bias_v.repeat({1, bsz, 1})});
801 if (attn_mask_.defined()) {
802 attn_mask_ = torch::cat(
803 {attn_mask_,
804 torch::zeros(
805 {attn_mask_.size(0), 1},
806 at::TensorOptions(attn_mask_.dtype())
807 .device(attn_mask_.device()))},
808 /*dim=*/1);
809 }
810 if (key_padding_mask_.defined()) {
811 key_padding_mask_ = torch::cat(
812 {key_padding_mask_,
813 torch::zeros(
814 {key_padding_mask_.size(0), 1},
815 at::TensorOptions(key_padding_mask_.dtype())
816 .device(key_padding_mask_.device()))},
817 /*dim=*/1);
818 }
819 } else {
820 TORCH_CHECK(!static_k.defined(), "bias cannot be added to static key.");
821 TORCH_CHECK(!static_v.defined(), "bias cannot be added to static value.");
822 }
823 } else {
824 TORCH_CHECK(!bias_k.defined());
825 TORCH_CHECK(!bias_v.defined());
826 }
827 q = q.contiguous().view({tgt_len, bsz * num_heads, head_dim}).transpose(0, 1);
828 if (k.defined()) {
829 k = k.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1);
830 }
831 if (v.defined()) {
832 v = v.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1);
833 }
834 if (static_k.defined()) {
835 TORCH_CHECK(static_k.size(0) == bsz * num_heads);
836 TORCH_CHECK(static_k.size(2) == head_dim);
837 k = static_k;
838 }
839 if (static_v.defined()) {
840 TORCH_CHECK(static_v.size(0) == bsz * num_heads);
841 TORCH_CHECK(static_v.size(2) == head_dim);
842 v = static_v;
843 }
844 auto src_len = k.size(1);
845 if (key_padding_mask_.defined()) {
846 TORCH_CHECK(key_padding_mask_.size(0) == bsz);
847 TORCH_CHECK(key_padding_mask_.size(1) == src_len);
848 }
849 if (add_zero_attn) {
850 src_len += 1;
851 auto k_sizes = k.sizes().vec();
852 k_sizes[1] = 1;
853 k = torch::cat(
854 {k,
855 torch::zeros(
856 k_sizes, at::TensorOptions(k.dtype()).device(k.device()))},
857 /*dim=*/1);
858 auto v_sizes = v.sizes().vec();
859 v_sizes[1] = 1;
860 v = torch::cat(
861 {v,
862 torch::zeros(
863 v_sizes, at::TensorOptions(v.dtype()).device(v.device()))},
864 /*dim=*/1);
865 if (attn_mask_.defined()) {
866 attn_mask_ = torch::cat(
867 {attn_mask_,
868 torch::zeros(
869 {attn_mask_.size(0), 1},
870 at::TensorOptions(attn_mask_.dtype())
871 .device(attn_mask_.device()))},
872 /*dim=*/1);
873 }
874 if (key_padding_mask_.defined()) {
875 key_padding_mask_ = torch::cat(
876 {key_padding_mask_,
877 torch::zeros(
878 {key_padding_mask_.size(0), 1},
879 at::TensorOptions(key_padding_mask_.dtype())
880 .device(key_padding_mask_.device()))},
881 /*dim=*/1);
882 }
883 }
884 auto attn_output_weights = torch::bmm(q, k.transpose(1, 2));
885 TORCH_CHECK(
886 attn_output_weights.sizes() ==
887 IntArrayRef({bsz * num_heads, tgt_len, src_len}));
888 if (attn_mask_.defined()) {
889 attn_mask_ = attn_mask_.unsqueeze(0);
890 attn_output_weights += attn_mask_;
891 }
892 if (key_padding_mask_.defined()) {
893 attn_output_weights =
894 attn_output_weights.view({bsz, num_heads, tgt_len, src_len});
895 attn_output_weights = AT_DISPATCH_FLOATING_TYPES(
896 attn_output_weights.scalar_type(),
897 "attn_output_weights.masked_fill",
898 [&]() {
899 return attn_output_weights.masked_fill(
900 key_padding_mask_.unsqueeze(1).unsqueeze(2),
901 -std::numeric_limits<scalar_t>::infinity());
902 });
903 attn_output_weights =
904 attn_output_weights.view({bsz * num_heads, tgt_len, src_len});
905 }
906 // NOLINTNEXTLINE(bugprone-argument-comment)
907 attn_output_weights = F::softmax(attn_output_weights, /*dim=*/-1);
908 attn_output_weights = F::dropout(
909 attn_output_weights,
910 F::DropoutFuncOptions().p(dropout_p).training(training));
911 auto attn_output = torch::bmm(attn_output_weights, v);
912 TORCH_CHECK(
913 attn_output.sizes() == IntArrayRef({bsz * num_heads, tgt_len, head_dim}));
914 attn_output =
915 attn_output.transpose(0, 1).contiguous().view({tgt_len, bsz, embed_dim});
916 attn_output = F::linear(attn_output, out_proj_weight, out_proj_bias);
917 if (need_weights) {
918 attn_output_weights =
919 attn_output_weights.view({bsz, num_heads, tgt_len, src_len});
920 if (average_attn_weights) {
921 // average attention weights over heads
922 attn_output_weights = attn_output_weights.sum(/*dim=*/1) / num_heads;
923 }
924 return std::make_tuple(attn_output, attn_output_weights);
925 } else {
926 return std::make_tuple(attn_output, Tensor());
927 }
928 }
929 } // namespace detail
930 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
931
multi_head_attention_forward(const Tensor & query,const Tensor & key,const Tensor & value,const MultiheadAttentionForwardFuncOptions & options)932 inline std::tuple<Tensor, Tensor> multi_head_attention_forward(
933 const Tensor& query,
934 const Tensor& key,
935 const Tensor& value,
936 const MultiheadAttentionForwardFuncOptions& options) {
937 return detail::multi_head_attention_forward(
938 query,
939 key,
940 value,
941 options.embed_dim_to_check(),
942 options.num_heads(),
943 options.in_proj_weight(),
944 options.in_proj_bias(),
945 options.bias_k(),
946 options.bias_v(),
947 options.add_zero_attn(),
948 options.dropout_p(),
949 options.out_proj_weight(),
950 options.out_proj_bias(),
951 options.training(),
952 options.key_padding_mask(),
953 options.need_weights(),
954 options.attn_mask(),
955 options.use_separate_proj_weight(),
956 options.q_proj_weight(),
957 options.k_proj_weight(),
958 options.v_proj_weight(),
959 options.static_k(),
960 options.static_v(),
961 options.average_attn_weights());
962 }
963
964 } // namespace functional
965 } // namespace nn
966 } // namespace torch
967