xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/_equalize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import Any, Dict
4
5import torch
6
7
8__all__ = [
9    "set_module_weight",
10    "set_module_bias",
11    "has_bias",
12    "get_module_weight",
13    "get_module_bias",
14    "max_over_ndim",
15    "min_over_ndim",
16    "channel_range",
17    "get_name_by_module",
18    "cross_layer_equalization",
19    "process_paired_modules_list_to_name",
20    "expand_groups_in_paired_modules_list",
21    "equalize",
22    "converged",
23]
24
25_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d}
26_supported_intrinsic_types = {
27    torch.ao.nn.intrinsic.ConvReLU2d,
28    torch.ao.nn.intrinsic.LinearReLU,
29    torch.ao.nn.intrinsic.ConvReLU1d,
30}
31_all_supported_types = _supported_types.union(_supported_intrinsic_types)
32
33
34def set_module_weight(module, weight) -> None:
35    if type(module) in _supported_types:
36        module.weight = torch.nn.Parameter(weight)
37    else:
38        module[0].weight = torch.nn.Parameter(weight)
39
40
41def set_module_bias(module, bias) -> None:
42    if type(module) in _supported_types:
43        module.bias = torch.nn.Parameter(bias)
44    else:
45        module[0].bias = torch.nn.Parameter(bias)
46
47
48def has_bias(module) -> bool:
49    if type(module) in _supported_types:
50        return module.bias is not None
51    else:
52        return module[0].bias is not None
53
54
55def get_module_weight(module):
56    if type(module) in _supported_types:
57        return module.weight
58    else:
59        return module[0].weight
60
61
62def get_module_bias(module):
63    if type(module) in _supported_types:
64        return module.bias
65    else:
66        return module[0].bias
67
68
69def max_over_ndim(input, axis_list, keepdim=False):
70    """Apply 'torch.max' over the given axes."""
71    axis_list.sort(reverse=True)
72    for axis in axis_list:
73        input, _ = input.max(axis, keepdim)
74    return input
75
76
77def min_over_ndim(input, axis_list, keepdim=False):
78    """Apply 'torch.min' over the given axes."""
79    axis_list.sort(reverse=True)
80    for axis in axis_list:
81        input, _ = input.min(axis, keepdim)
82    return input
83
84
85def channel_range(input, axis=0):
86    """Find the range of weights associated with a specific channel."""
87    size_of_tensor_dim = input.ndim
88    axis_list = list(range(size_of_tensor_dim))
89    axis_list.remove(axis)
90
91    mins = min_over_ndim(input, axis_list)
92    maxs = max_over_ndim(input, axis_list)
93
94    assert mins.size(0) == input.size(
95        axis
96    ), "Dimensions of resultant channel range does not match size of requested axis"
97    return maxs - mins
98
99
100def get_name_by_module(model, module):
101    """Get the name of a module within a model.
102
103    Args:
104        model: a model (nn.module) that equalization is to be applied on
105        module: a module within the model
106
107    Returns:
108        name: the name of the module within the model
109    """
110    for name, m in model.named_modules():
111        if m is module:
112            return name
113    raise ValueError("module is not in the model")
114
115
116def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
117    """Scale the range of Tensor1.output to equal Tensor2.input.
118
119    Given two adjacent tensors', the weights are scaled such that
120    the ranges of the first tensors' output channel are equal to the
121    ranges of the second tensors' input channel
122    """
123    if (
124        type(module1) not in _all_supported_types
125        or type(module2) not in _all_supported_types
126    ):
127        raise ValueError(
128            "module type not supported:", type(module1), " ", type(module2)
129        )
130
131    conv1_has_bias = has_bias(module1)
132    bias = None
133
134    weight1 = get_module_weight(module1)
135    weight2 = get_module_weight(module2)
136
137    if weight1.size(output_axis) != weight2.size(input_axis):
138        raise TypeError(
139            "Number of output channels of first arg do not match \
140        number input channels of second arg"
141        )
142
143    if conv1_has_bias:
144        bias = get_module_bias(module1)
145
146    weight1_range = channel_range(weight1, output_axis)
147    weight2_range = channel_range(weight2, input_axis)
148
149    # producing scaling factors to applied
150    weight2_range += 1e-9
151    scaling_factors = torch.sqrt(weight1_range / weight2_range)
152    inverse_scaling_factors = torch.reciprocal(scaling_factors)
153
154    if conv1_has_bias:
155        bias = bias * inverse_scaling_factors
156
157    # formatting the scaling (1D) tensors to be applied on the given argument tensors
158    # pads axis to (1D) tensors to then be broadcasted
159    size1 = [1] * weight1.ndim
160    size1[output_axis] = weight1.size(output_axis)
161    size2 = [1] * weight2.ndim
162    size2[input_axis] = weight2.size(input_axis)
163
164    scaling_factors = torch.reshape(scaling_factors, size2)
165    inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
166
167    weight1 = weight1 * inverse_scaling_factors
168    weight2 = weight2 * scaling_factors
169
170    set_module_weight(module1, weight1)
171    if conv1_has_bias:
172        set_module_bias(module1, bias)
173    set_module_weight(module2, weight2)
174
175
176def process_paired_modules_list_to_name(model, paired_modules_list):
177    """Processes a list of paired modules to a list of names of paired modules."""
178
179    for group in paired_modules_list:
180        for i, item in enumerate(group):
181            if isinstance(item, torch.nn.Module):
182                group[i] = get_name_by_module(model, item)
183            elif not isinstance(item, str):
184                raise TypeError("item must be a nn.Module or a string")
185    return paired_modules_list
186
187
188def expand_groups_in_paired_modules_list(paired_modules_list):
189    """Expands module pair groups larger than two into groups of two modules."""
190    new_list = []
191
192    for group in paired_modules_list:
193        if len(group) == 1:
194            raise ValueError("Group must have at least two modules")
195        elif len(group) == 2:
196            new_list.append(group)
197        elif len(group) > 2:
198            for i in range(len(group) - 1):
199                new_list.append([group[i], group[i + 1]])
200
201    return new_list
202
203
204def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
205    """Equalize modules until convergence is achieved.
206
207    Given a list of adjacent modules within a model, equalization will
208    be applied between each pair, this will repeated until convergence is achieved
209
210    Keeps a copy of the changing modules from the previous iteration, if the copies
211    are not that different than the current modules (determined by converged_test),
212    then the modules have converged enough that further equalizing is not necessary
213
214    Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
215
216    Args:
217        model: a model (nn.Module) that equalization is to be applied on
218            paired_modules_list (List(List[nn.module || str])): a list of lists
219            where each sublist is a pair of two submodules found in the model,
220            for each pair the two modules have to be adjacent in the model,
221            with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between
222            to get expected results.
223            The list can contain either modules, or names of modules in the model.
224            If you pass multiple modules in the same list, they will all be equalized together.
225            threshold (float): a number used by the converged function to determine what degree
226            of similarity between models is necessary for them to be called equivalent
227        inplace (bool): determines if function is inplace or not
228    """
229
230    paired_modules_list = process_paired_modules_list_to_name(
231        model, paired_modules_list
232    )
233
234    if not inplace:
235        model = copy.deepcopy(model)
236
237    paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list)
238
239    name_to_module: Dict[str, torch.nn.Module] = {}
240    previous_name_to_module: Dict[str, Any] = {}
241    name_set = {name for pair in paired_modules_list for name in pair}
242
243    for name, module in model.named_modules():
244        if name in name_set:
245            name_to_module[name] = module
246            previous_name_to_module[name] = None
247    while not converged(name_to_module, previous_name_to_module, threshold):
248        for pair in paired_modules_list:
249            previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
250            previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
251
252            cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
253
254    return model
255
256
257def converged(curr_modules, prev_modules, threshold=1e-4):
258    """Test whether modules are converged to a specified threshold.
259
260    Tests for the summed norm of the differences between each set of modules
261    being less than the given threshold
262
263    Takes two dictionaries mapping names to modules, the set of names for each dictionary
264    should be the same, looping over the set of names, for each name take the difference
265    between the associated modules in each dictionary
266
267    """
268    if curr_modules.keys() != prev_modules.keys():
269        raise ValueError(
270            "The keys to the given mappings must have the same set of names of modules"
271        )
272
273    summed_norms = torch.tensor(0.0)
274    if None in prev_modules.values():
275        return False
276    for name in curr_modules.keys():
277        curr_weight = get_module_weight(curr_modules[name])
278        prev_weight = get_module_weight(prev_modules[name])
279
280        difference = curr_weight.sub(prev_weight)
281        summed_norms += torch.norm(difference)
282    return bool(summed_norms < threshold)
283