xref: /aosp_15_r20/external/pytorch/torch/nn/utils/fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import copy
4from typing import Optional, Tuple, TypeVar
5
6import torch
7
8
9__all__ = [
10    "fuse_conv_bn_eval",
11    "fuse_conv_bn_weights",
12    "fuse_linear_bn_eval",
13    "fuse_linear_bn_weights",
14]
15
16ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
17LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
18
19
20def fuse_conv_bn_eval(
21    conv: ConvT,
22    bn: torch.nn.modules.batchnorm._BatchNorm,
23    transpose: bool = False,
24) -> ConvT:
25    r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
26
27    Args:
28        conv (torch.nn.modules.conv._ConvNd): A convolutional module.
29        bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
30        transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False.
31
32    Returns:
33        torch.nn.modules.conv._ConvNd: The fused convolutional module.
34
35    .. note::
36        Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
37    """
38    assert not (conv.training or bn.training), "Fusion only for eval!"
39    fused_conv = copy.deepcopy(conv)
40
41    assert bn.running_mean is not None and bn.running_var is not None
42    fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
43        fused_conv.weight,
44        fused_conv.bias,
45        bn.running_mean,
46        bn.running_var,
47        bn.eps,
48        bn.weight,
49        bn.bias,
50        transpose,
51    )
52
53    return fused_conv
54
55
56def fuse_conv_bn_weights(
57    conv_w: torch.Tensor,
58    conv_b: Optional[torch.Tensor],
59    bn_rm: torch.Tensor,
60    bn_rv: torch.Tensor,
61    bn_eps: float,
62    bn_w: Optional[torch.Tensor],
63    bn_b: Optional[torch.Tensor],
64    transpose: bool = False,
65) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
66    r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
67
68    Args:
69        conv_w (torch.Tensor): Convolutional weight.
70        conv_b (Optional[torch.Tensor]): Convolutional bias.
71        bn_rm (torch.Tensor): BatchNorm running mean.
72        bn_rv (torch.Tensor): BatchNorm running variance.
73        bn_eps (float): BatchNorm epsilon.
74        bn_w (Optional[torch.Tensor]): BatchNorm weight.
75        bn_b (Optional[torch.Tensor]): BatchNorm bias.
76        transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
77
78    Returns:
79        Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
80    """
81    conv_weight_dtype = conv_w.dtype
82    conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
83    if conv_b is None:
84        conv_b = torch.zeros_like(bn_rm)
85    if bn_w is None:
86        bn_w = torch.ones_like(bn_rm)
87    if bn_b is None:
88        bn_b = torch.zeros_like(bn_rm)
89    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
90
91    if transpose:
92        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
93    else:
94        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
95
96    fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(
97        dtype=conv_weight_dtype
98    )
99    fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(
100        dtype=conv_bias_dtype
101    )
102
103    return (
104        torch.nn.Parameter(fused_conv_w, conv_w.requires_grad),
105        torch.nn.Parameter(fused_conv_b, conv_b.requires_grad),
106    )
107
108
109def fuse_linear_bn_eval(
110    linear: LinearT,
111    bn: torch.nn.modules.batchnorm._BatchNorm,
112) -> LinearT:
113    r"""Fuse a linear module and a BatchNorm module into a single, new linear module.
114
115    Args:
116        linear (torch.nn.Linear): A Linear module.
117        bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
118
119    Returns:
120        torch.nn.Linear: The fused linear module.
121
122    .. note::
123        Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
124    """
125    assert not (linear.training or bn.training), "Fusion only for eval!"
126    fused_linear = copy.deepcopy(linear)
127
128    """
129    Linear-BN needs to be fused while preserving the shapes of linear weight/bias.
130    To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear,
131    because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in).
132    To be broadcastable, the number of features in bn and
133    the number of output features from linear must satisfy the following condition:
134    1. they are equal, or
135    2. the number of features in bn is 1
136    Otherwise, skip the folding path
137    """
138    assert (
139        linear.out_features == bn.num_features or bn.num_features == 1
140    ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
141
142    assert bn.running_mean is not None and bn.running_var is not None
143    fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
144        fused_linear.weight,
145        fused_linear.bias,
146        bn.running_mean,
147        bn.running_var,
148        bn.eps,
149        bn.weight,
150        bn.bias,
151    )
152
153    return fused_linear
154
155
156def fuse_linear_bn_weights(
157    linear_w: torch.Tensor,
158    linear_b: Optional[torch.Tensor],
159    bn_rm: torch.Tensor,
160    bn_rv: torch.Tensor,
161    bn_eps: float,
162    bn_w: torch.Tensor,
163    bn_b: torch.Tensor,
164) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
165    r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
166
167    Args:
168        linear_w (torch.Tensor): Linear weight.
169        linear_b (Optional[torch.Tensor]): Linear bias.
170        bn_rm (torch.Tensor): BatchNorm running mean.
171        bn_rv (torch.Tensor): BatchNorm running variance.
172        bn_eps (float): BatchNorm epsilon.
173        bn_w (torch.Tensor): BatchNorm weight.
174        bn_b (torch.Tensor): BatchNorm bias.
175
176    Returns:
177        Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
178    """
179    linear_weight_dtype = linear_w.dtype
180    linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
181    if linear_b is None:
182        linear_b = torch.zeros_like(bn_rm)
183    bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
184
185    fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
186    fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
187
188    return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
189        fused_b, linear_b.requires_grad
190    )
191