xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fold_linear_bn.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 
5 namespace torch::jit {
6 
7 struct TORCH_API LinearBNParameters {
8   at::Tensor linear_w;
9   at::Tensor linear_b;
10   at::Tensor bn_rm;
11   at::Tensor bn_rv;
12   double bn_eps = 0.0;
13   at::Tensor bn_w;
14   at::Tensor bn_b;
15 };
16 
17 /**
18  * Given the current weight and bias tensors of a Linear module and parameters
19  * of the BatchNorm module we're folding with, compute the updated values
20  * for the weight and bias.
21  *
22  * The function is basically copied from torch/nn/utils/fusion.py
23  */
24 TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias(
25     const LinearBNParameters& p);
26 
27 } // namespace torch::jit
28