1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 5 namespace torch::jit { 6 7 struct TORCH_API LinearBNParameters { 8 at::Tensor linear_w; 9 at::Tensor linear_b; 10 at::Tensor bn_rm; 11 at::Tensor bn_rv; 12 double bn_eps = 0.0; 13 at::Tensor bn_w; 14 at::Tensor bn_b; 15 }; 16 17 /** 18 * Given the current weight and bias tensors of a Linear module and parameters 19 * of the BatchNorm module we're folding with, compute the updated values 20 * for the weight and bias. 21 * 22 * The function is basically copied from torch/nn/utils/fusion.py 23 */ 24 TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( 25 const LinearBNParameters& p); 26 27 } // namespace torch::jit 28