xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/efficient_conv_bn_eval.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.nn as nn
4from torch._dynamo.utils import counters
5from torch._inductor import config as inductor_config
6from torch.func import functional_call
7
8from ..pattern_matcher import (
9    CallFunctionVarArgs,
10    CallModuleVarArgs,
11    Match,
12    register_graph_pattern,
13)
14from .pre_grad import efficient_conv_bn_eval_pass
15
16
17def efficient_conv_bn_eval(
18    bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor
19):
20    """
21    Implementation based on https://arxiv.org/abs/2305.11624
22    "Efficient ConvBN Blocks for Transfer Learning and Beyond"
23    It leverages the associative law between convolution and affine transform,
24    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
25    It works for Eval mode of ConvBN blocks during validation, and can be used
26    for **training** as well, but only if one sets `bn.training=False`. It
27     reduces memory footprint and computation cost, at the cost of slightly
28     reduced numerical stability.
29    Args:
30        bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module.
31        conv (nn.modules.conv._ConvNd): a conv module
32        x (torch.Tensor): Input feature map.
33    """
34
35    assert bn.running_var is not None
36
37    # These lines of code are designed to deal with various cases
38    # like bn without affine transform, and conv without bias
39    weight_on_the_fly = conv.weight
40    if conv.bias is not None:
41        bias_on_the_fly = conv.bias
42    else:
43        bias_on_the_fly = torch.zeros_like(bn.running_var)
44
45    if bn.weight is not None:
46        bn_weight = bn.weight
47    else:
48        bn_weight = torch.ones_like(bn.running_var)
49
50    if bn.bias is not None:
51        bn_bias = bn.bias
52    else:
53        bn_bias = torch.zeros_like(bn.running_var)
54
55    # shape of [C_out, 1, 1, 1] in Conv2d
56    target_shape = [-1] + [1] * (conv.weight.ndim - 1)
57    if isinstance(conv, nn.modules.conv._ConvTransposeNd):
58        # for transposed conv, the C_out dimension should at index 1.
59        target_shape[:2] = [target_shape[1], target_shape[0]]
60    weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape)
61    # shape of [C_out, 1, 1, 1] in Conv2d
62    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
63
64    # shape of [C_out, C_in, k, k] in Conv2d
65    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
66    # shape of [C_out] in Conv2d
67    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
68        bias_on_the_fly - bn.running_mean
69    )
70
71    input = x
72    params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly}
73    output = functional_call(conv, params, input)
74    return output
75
76
77def efficient_conv_bn_eval_decomposed(
78    bn_weight,
79    bn_bias,
80    bn_running_mean,
81    bn_running_var,
82    bn_eps,
83    conv: torch._ops.OpOverload,
84    conv_weight,
85    conv_bias,
86    x,
87    conv_remainging_args,
88):
89    """
90    Implementation based on https://arxiv.org/abs/2305.11624
91    "Efficient ConvBN Blocks for Transfer Learning and Beyond"
92    It leverages the associative law between convolution and affine transform,
93    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
94    It works for Eval mode of ConvBN blocks during validation, and can be used
95    for **training** as well, but only if one sets `bn.training=False`. It
96     reduces memory footprint and computation cost, at the cost of slightly
97     reduced numerical stability.
98    Args:
99    """
100    assert bn_running_var is not None
101
102    # These lines of code are designed to deal with various cases
103    # like bn without affine transform, and conv without bias
104    weight_on_the_fly = conv_weight
105    if conv_bias is not None:
106        bias_on_the_fly = conv_bias
107    else:
108        bias_on_the_fly = torch.zeros_like(bn_running_var)
109
110    if bn_weight is not None:
111        bn_weight = bn_weight
112    else:
113        bn_weight = torch.ones_like(bn_running_var)
114
115    if bn_bias is not None:
116        bn_bias = bn_bias
117    else:
118        bn_bias = torch.zeros_like(bn_running_var)
119
120    # shape of [C_out, 1, 1, 1] in Conv2d
121    target_shape = [-1] + [1] * (conv_weight.ndim - 1)
122    if "conv_transpose" in conv.__str__():
123        # for transposed conv, the C_out dimension should at index 1.
124        target_shape[:2] = [target_shape[1], target_shape[0]]
125    weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape)
126    # shape of [C_out, 1, 1, 1] in Conv2d
127    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
128
129    # shape of [C_out, C_in, k, k] in Conv2d
130    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
131    # shape of [C_out] in Conv2d
132    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (
133        bias_on_the_fly - bn_running_mean
134    )
135
136    input = x
137    return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args))
138
139
140@register_graph_pattern(
141    CallFunctionVarArgs(
142        [
143            torch.nn.functional.batch_norm,
144        ]
145    ),
146    pass_dict=efficient_conv_bn_eval_pass,
147    extra_check=lambda match: not inductor_config.freezing
148    and inductor_config.efficient_conv_bn_eval_fx_passes,
149)
150def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs):
151    bn_node = match.nodes[0]
152    graph = match.graph
153    assert len(bn_node.args) == 8
154
155    # We can only use efficient conv-bn for eval mode with track_running_stats
156    # bn_node.args is `training`
157    if bn_node.args[-3]:
158        return
159
160    # Check if the input is Conv
161    input_node = bn_node.args[0]
162
163    if input_node.op != "call_function":  # type: ignore[union-attr]
164        return
165
166    input_fn = input_node.target  # type: ignore[arg-type, union-attr]
167    supported_convs = [
168        torch._C._nn.linear,
169        torch.conv1d,
170        torch.conv2d,
171        torch.conv3d,
172        torch.conv_transpose1d,
173        torch.conv_transpose2d,
174        torch.conv_transpose3d,
175    ]
176
177    if not any(input_fn is cls for cls in supported_convs):
178        return
179
180    conv_node = input_node
181    # Output of conv is used by other nodes, cannot optimize
182    if len(conv_node.users) > 1:  # type: ignore[union-attr]
183        return
184
185    counters["inductor"]["efficient_conv_bn_eval"] += 1
186
187    with graph.inserting_before(bn_node):
188        # prepare args for the fused function
189        bn_running_mean = bn_node.args[1]
190        bn_running_var = bn_node.args[2]
191        bn_weight = bn_node.args[3]
192        bn_bias = bn_node.args[4]
193        bn_eps = bn_node.args[7]
194        assert len(conv_node.args) >= 2  # type: ignore[union-attr]
195        conv_input = conv_node.args[0]  # type: ignore[union-attr]
196        conv_weight = conv_node.args[1]  # type: ignore[union-attr]
197        conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None  # type: ignore[union-attr]
198        conv_remainging_args = conv_node.args[3:]  # type: ignore[union-attr]
199        args = (
200            bn_weight,
201            bn_bias,
202            bn_running_mean,
203            bn_running_var,
204            bn_eps,
205            conv_node.target,  # type: ignore[union-attr]
206            conv_weight,
207            conv_bias,
208            conv_input,
209            conv_remainging_args,
210        )
211
212        # create a new node
213        new_node = graph.create_node(
214            op="call_function",
215            target=efficient_conv_bn_eval_decomposed,
216            args=args,  # type: ignore[arg-type]
217            name="efficient_conv_bn_eval",
218        )
219
220    # this node replaces the original conv + bn, and therefore
221    # should replace the uses of bn_node
222    bn_node.replace_all_uses_with(new_node)
223    # take care of the deletion order:
224    # delete bn_node first, and then conv_node
225    graph.erase_node(bn_node)
226    graph.erase_node(conv_node)  # type: ignore[arg-type]
227
228    return
229
230
231@register_graph_pattern(
232    CallFunctionVarArgs(
233        [
234            torch.ops.aten.batch_norm.default,
235        ]
236    ),
237    pass_dict=efficient_conv_bn_eval_pass,
238    extra_check=lambda match: not inductor_config.freezing
239    and inductor_config.efficient_conv_bn_eval_fx_passes,
240)
241def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs):
242    bn_node = match.nodes[0]
243    graph = match.graph
244    assert len(bn_node.args) == 9
245
246    # We can only use efficient conv-bn for eval mode with track_running_stats
247    # bn_node.args is `training`
248    if bn_node.args[-4]:
249        return
250
251    # Check if the input is Conv
252    input_node = bn_node.args[0]
253
254    if input_node.op != "call_function":  # type: ignore[union-attr]
255        return
256
257    input_fn = input_node.target  # type: ignore[arg-type, union-attr]
258    supported_convs = [
259        torch.ops.aten.linear.default,
260        torch.ops.aten.conv1d.default,
261        torch.ops.aten.conv2d.default,
262        torch.ops.aten.conv3d.default,
263        torch.ops.aten.conv_transpose1d.default,
264        torch.ops.aten.conv_transpose2d.input,
265        torch.ops.aten.conv_transpose3d.input,
266    ]
267
268    if not any(input_fn is cls for cls in supported_convs):
269        return
270
271    conv_node = input_node
272    # Output of conv is used by other nodes, cannot optimize
273    if len(conv_node.users) > 1:  # type: ignore[union-attr]
274        return
275
276    counters["inductor"]["efficient_conv_bn_eval"] += 1
277
278    with graph.inserting_before(bn_node):
279        # prepare args for the fused function
280        bn_weight = bn_node.args[1]
281        bn_bias = bn_node.args[2]
282        bn_running_mean = bn_node.args[3]
283        bn_running_var = bn_node.args[4]
284        bn_eps = bn_node.args[7]
285        assert len(conv_node.args) >= 2  # type: ignore[union-attr]
286        conv_input = conv_node.args[0]  # type: ignore[union-attr]
287        conv_weight = conv_node.args[1]  # type: ignore[union-attr]
288        conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None  # type: ignore[union-attr]
289        conv_remainging_args = conv_node.args[3:]  # type: ignore[union-attr]
290        args = (
291            bn_weight,
292            bn_bias,
293            bn_running_mean,
294            bn_running_var,
295            bn_eps,
296            conv_node.target,  # type: ignore[union-attr]
297            conv_weight,
298            conv_bias,
299            conv_input,
300            conv_remainging_args,
301        )
302
303        # create a new node
304        new_node = graph.create_node(
305            op="call_function",
306            target=efficient_conv_bn_eval_decomposed,
307            args=args,  # type: ignore[arg-type]
308            name="efficient_conv_bn_eval",
309        )
310
311    # this node replaces the original conv + bn, and therefore
312    # should replace the uses of bn_node
313    bn_node.replace_all_uses_with(new_node)
314    # take care of the deletion order:
315    # delete bn_node first, and then conv_node
316    graph.erase_node(bn_node)
317    graph.erase_node(conv_node)  # type: ignore[arg-type]
318
319    return
320
321
322@register_graph_pattern(
323    CallModuleVarArgs(
324        [
325            nn.modules.batchnorm._BatchNorm,
326            nn.BatchNorm1d,
327            nn.BatchNorm2d,
328            nn.BatchNorm3d,
329            nn.SyncBatchNorm,
330        ],
331    ),
332    pass_dict=efficient_conv_bn_eval_pass,
333    extra_check=lambda match: not inductor_config.freezing
334    and inductor_config.efficient_conv_bn_eval_fx_passes,
335)
336def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
337    # We matched a BN node
338    bn_node = match.nodes[0]
339    graph = match.graph
340    gm = graph.owning_module
341    bn_mod = getattr(gm, bn_node.target)  # type: ignore[arg-type]
342
343    # We can only use efficient conv-bn for eval mode with track_running_stats
344    if not bn_mod.track_running_stats or bn_mod.training:
345        return
346
347    # Check if the input is Conv
348    if bn_node.args:
349        input_node = bn_node.args[0]
350    else:
351        input_node = bn_node.kwargs["input"]
352    if input_node.op != "call_module":  # type: ignore[union-attr]
353        return
354    if not hasattr(gm, input_node.target):  # type: ignore[arg-type, union-attr]
355        return
356    input_mod = getattr(gm, input_node.target)  # type: ignore[arg-type, union-attr]
357    supported_convs = [
358        nn.Linear,
359        nn.Conv1d,
360        nn.Conv2d,
361        nn.Conv3d,
362        nn.ConvTranspose1d,
363        nn.ConvTranspose2d,
364        nn.ConvTranspose3d,
365    ]
366    if not any(isinstance(input_mod, cls) for cls in supported_convs):
367        return
368    conv_node = input_node
369    # Output of conv is used by other nodes, cannot optimize
370    if len(conv_node.users) > 1:  # type: ignore[union-attr]
371        return
372
373    # Find a pair of conv and bn computation nodes to optimize.
374    counters["inductor"]["efficient_conv_bn_eval"] += 1
375
376    with graph.inserting_before(conv_node):  # type: ignore[arg-type]
377        # create `get_attr` node to access modules
378        # note that we directly call `create_node` to fill the `name`
379        # argument. `graph.get_attr` and
380        # `graph.call_function` does not allow the `name` argument.
381        conv_get_node = graph.create_node(
382            op="get_attr", target=conv_node.target, name="get_conv"  # type: ignore[union-attr]
383        )
384        bn_get_node = graph.create_node(
385            op="get_attr", target=bn_node.target, name="get_bn"
386        )
387        if conv_node.args:  # type: ignore[union-attr]
388            conv_input = conv_node.args[0]  # type: ignore[union-attr]
389        else:
390            conv_input = conv_node.kwargs["input"]  # type: ignore[union-attr]
391        # prepare args for the fused function
392        args = (bn_get_node, conv_get_node, conv_input)
393        # create a new node
394        new_node = graph.create_node(
395            op="call_function",
396            target=efficient_conv_bn_eval,
397            args=args,
398            name="efficient_conv_bn_eval",
399        )
400    # this node replaces the original conv + bn, and therefore
401    # should replace the uses of bn_node
402    bn_node.replace_all_uses_with(new_node)
403    # take care of the deletion order:
404    # delete bn_node first, and then conv_node
405    graph.erase_node(bn_node)
406    graph.erase_node(conv_node)  # type: ignore[arg-type]
407