xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/linear.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/types.h>
4 
5 namespace torch {
6 namespace nn {
7 namespace functional {
8 
9 inline Tensor bilinear(
10     const Tensor& input1,
11     const Tensor& input2,
12     const Tensor& weight,
13     const Tensor& bias = Tensor()) {
14   return torch::bilinear(input1, input2, weight, bias);
15 }
16 
17 // ============================================================================
18 
19 inline Tensor linear(
20     const Tensor& input,
21     const Tensor& weight,
22     const Tensor& bias = {}) {
23   if (input.dim() == 2 && bias.defined()) {
24     // fused op is marginally faster
25     return torch::addmm(bias, input, weight.t());
26   } else {
27     auto output = input.matmul(weight.t());
28     if (bias.defined()) {
29       output += bias;
30     }
31     return output;
32   }
33 }
34 
35 } // namespace functional
36 } // namespace nn
37 } // namespace torch
38