1 #pragma once 2 3 #include <torch/types.h> 4 5 namespace torch { 6 namespace nn { 7 namespace functional { 8 9 inline Tensor bilinear( 10 const Tensor& input1, 11 const Tensor& input2, 12 const Tensor& weight, 13 const Tensor& bias = Tensor()) { 14 return torch::bilinear(input1, input2, weight, bias); 15 } 16 17 // ============================================================================ 18 19 inline Tensor linear( 20 const Tensor& input, 21 const Tensor& weight, 22 const Tensor& bias = {}) { 23 if (input.dim() == 2 && bias.defined()) { 24 // fused op is marginally faster 25 return torch::addmm(bias, input, weight.t()); 26 } else { 27 auto output = input.matmul(weight.t()); 28 if (bias.defined()) { 29 output += bias; 30 } 31 return output; 32 } 33 } 34 35 } // namespace functional 36 } // namespace nn 37 } // namespace torch 38