1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 5 namespace torch::jit { 6 7 /** \brief Fold Conv2d-BatchNorm2d into Conv2d in all methods of this 8 * module and all its submodules, forward is included by default. 9 * 10 * The weight and bias of the Conv2d are correspondingly updated. Should only be 11 * used on modules in eval mode. 12 */ 13 TORCH_API Module FoldConvBatchNorm(const Module& module); 14 15 struct TORCH_API ConvBNParameters { 16 at::Tensor conv_w; 17 at::Tensor conv_b; 18 at::Tensor bn_rm; 19 at::Tensor bn_rv; 20 double bn_eps = 0.0; 21 at::Tensor bn_w; 22 at::Tensor bn_b; 23 }; 24 25 /** 26 * Given the current weight and bias tensors of a Conv module and parameters 27 * of the BatchNorm module we're folding with, compute the updated values 28 * for the weight and bias. 29 * 30 * The function is basically copied from torch/nn/utils/fusion.py 31 */ 32 TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias( 33 const ConvBNParameters& p); 34 35 } // namespace torch::jit 36