xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/_correct_bias.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.quantized as nnq
4import torch.ao.ns._numeric_suite as ns
5import torch.ao.quantization
6import torch.nn as nn
7
8
9__all__ = [
10    "get_module",
11    "parent_child_names",
12    "get_param",
13    "MeanShadowLogger",
14    "bias_correction",
15]
16
17_supported_modules = {nn.Linear, nn.Conv2d}
18_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
19
20
21def get_module(model, name):
22    """Given name of submodule, this function grabs the submodule from given model."""
23    return dict(model.named_modules())[name]
24
25
26def parent_child_names(name):
27    """Split full name of submodule into parent submodule's full name and submodule's name."""
28    split_name = name.rsplit(".", 1)
29    if len(split_name) == 1:
30        return "", split_name[0]
31    else:
32        return split_name[0], split_name[1]
33
34
35def get_param(module, attr):
36    """Get the parameter given a module and attribute.
37
38    Sometimes the weights/bias attribute gives you the raw tensor, but sometimes
39    gives a function that will give you the raw tensor, this function takes care of that logic
40    """
41    param = getattr(module, attr, None)
42    if callable(param):
43        return param()
44    else:
45        return param
46
47
48class MeanShadowLogger(ns.Logger):
49    """Mean Logger for a Shadow module.
50
51    A logger for a Shadow module whose purpose is to record the rolling mean
52    of the data passed to the floating point and quantized models
53    """
54
55    def __init__(self):
56        """Set up initial values for float and quantized stats, count, float sum, and quant sum."""
57        super().__init__()
58        self.stats["float"] = None
59        self.stats["quantized"] = None
60        self.count = 0
61        self.float_sum = None
62        self.quant_sum = None
63
64    def forward(self, x, y):
65        """Compute the average of quantized and floating-point data from modules.
66
67        The inputs x,y are output data from the quantized and floating-point modules.
68        x is for the quantized module, y is for the floating point module
69        """
70        if x.is_quantized:
71            x = x.dequantize()
72
73        self.count += 1
74        if self.stats["quantized"] is None:
75            self.stats["quantized"] = x
76            self.quant_sum = x
77        else:
78            self.quant_sum += x
79            self.stats["quantized"] = self.quant_sum / self.count
80
81        if self.stats["float"] is None:
82            self.stats["float"] = y
83            self.float_sum = y
84        else:
85            self.float_sum += y
86            self.stats["float"] = self.float_sum / self.count
87
88    def clear(self):
89        self.stats["float"] = None
90        self.stats["quantized"] = None
91        self.count = 0
92        self.float_sum = None
93        self.quant_sum = None
94
95
96def bias_correction(
97    float_model,
98    quantized_model,
99    img_data,
100    target_modules=_supported_modules_quantized,
101    neval_batches=None,
102):
103    """Perform bias correction on a module.
104
105    Using numeric suite shadow module, the expected output of the floating point and quantized modules
106    is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused
107    by quantization
108    Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2)
109
110    Args:
111        float_model: a trained model that serves as a reference to what bias correction should aim for
112        quantized_model: quantized form of float_model that bias correction is to applied to
113        img_data: calibration data to estimate the expected output (used to find quantization error)
114        target_modules: specifies what submodules in quantized_model need bias correction (can be extended to
115                unquantized submodules)
116        neval_batches: a cap to the number of batches you want to be used for estimating the expected output
117    """
118    ns.prepare_model_with_stubs(
119        float_model, quantized_model, _supported_modules, MeanShadowLogger
120    )
121
122    uncorrected_modules = {}
123    for name, submodule in quantized_model.named_modules():
124        if type(submodule) in target_modules:
125            uncorrected_modules[name] = submodule
126
127    for uncorrected_module in uncorrected_modules:
128        quantized_submodule = get_module(quantized_model, uncorrected_module)
129        bias = get_param(quantized_submodule, "bias")
130        if bias is not None:
131            count = 0
132            for data in img_data:
133                quantized_model(data[0])
134                count += 1
135                if count == neval_batches:
136                    break
137            ob_dict = ns.get_logger_dict(quantized_model)
138            parent_name, _ = parent_child_names(uncorrected_module)
139
140            float_data = ob_dict[parent_name + ".stats"]["float"]
141            quant_data = ob_dict[parent_name + ".stats"]["quantized"]
142
143            # math for expected_error
144            quantization_error = quant_data - float_data
145            dims = list(range(quantization_error.dim()))
146            # Note: we don't want to take the mean over the output channel dimension
147            dims.remove(1)
148            expected_error = torch.mean(quantization_error, dims)
149
150            updated_bias = bias.data - expected_error
151
152            bias.data = updated_bias
153
154            # Resets the data contained in the loggers
155            for name, submodule in quantized_model.named_modules():
156                if isinstance(submodule, MeanShadowLogger):
157                    submodule.clear()
158