xref: /aosp_15_r20/external/pytorch/torch/ao/ns/_numeric_suite.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, Dict, List, Optional, Set, Union
3
4import torch
5import torch.ao.nn.quantized as nnq
6import torch.ao.nn.quantized.dynamic as nnqd
7import torch.nn as nn
8from torch.ao.quantization import prepare
9from torch.ao.quantization.quantization_mappings import (
10    get_default_compare_output_module_list,
11)
12
13
14NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
15    nnqd.Linear,
16    nnq.Linear,
17    nnqd.LSTM,
18    nn.LSTM,
19}
20
21
22def _find_match(
23    str_list: Union[Dict[str, Any], List[str]],
24    key_str: str,
25    postfix: str,
26) -> Optional[str]:
27    split_str = key_str.split(".")
28    if split_str[-1] == postfix:
29        match_string = "".join(key_str.split(".")[0:-1])
30        for s2 in str_list:
31            pattern1 = "".join(s2.split(".")[0:-1])
32            pattern2 = "".join(s2.split(".")[0:-2])
33            if match_string == pattern1:
34                return s2
35            if match_string == pattern2:
36                return s2
37
38        # For matching "fc.weight" and "fc._packed_params._packed_params"
39        if postfix == "_packed_params":
40            match_string = "".join(key_str.split(".")[0:-2])
41            if len(match_string) == 0:
42                return None
43            for s2 in str_list:
44                pattern1 = "".join(s2.split(".")[0:-1])
45                pattern2 = "".join(s2.split(".")[0:-2])
46                if match_string == pattern1:
47                    return s2
48                if match_string == pattern2:
49                    return s2
50        return None
51    else:
52        return None
53
54
55def compare_weights(
56    float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
57) -> Dict[str, Dict[str, torch.Tensor]]:
58    r"""Compare the weights of the float module with its corresponding quantized
59    module. Return a dict with key corresponding to module names and each entry being
60    a dictionary with two keys 'float' and 'quantized', containing the float and
61    quantized weights. This dict can be used to compare and compute the quantization
62    error of the weights of float and quantized models.
63
64    Example usage::
65
66        wt_compare_dict = compare_weights(
67            float_model.state_dict(), qmodel.state_dict())
68        for key in wt_compare_dict:
69            print(
70                key,
71                compute_error(
72                    wt_compare_dict[key]['float'],
73                    wt_compare_dict[key]['quantized'].dequantize()
74                )
75            )
76
77    Args:
78        float_dict: state dict of the float model
79        quantized_dict: state dict of the quantized model
80
81    Return:
82        weight_dict: dict with key corresponding to module names and each entry being
83        a dictionary with two keys 'float' and 'quantized', containing the float and
84        quantized weights
85    """
86    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
87    weight_dict: Dict[str, Dict] = {}
88    for key in quantized_dict:
89        match_key = _find_match(float_dict, key, "weight")
90        if match_key is not None:
91            weight_dict[key] = {}
92            weight_dict[key]["float"] = float_dict[match_key]
93            weight_dict[key]["quantized"] = quantized_dict[key]
94            continue
95
96        # For matching "fc.weight" and "fc._packed_params._packed_params"
97        match_key = _find_match(float_dict, key, "_packed_params")
98        if match_key is not None:
99            weight_dict[key] = {}
100            weight_dict[key]["float"] = float_dict[match_key]
101            weight_dict[key]["quantized"] = quantized_dict[key][0]
102
103        # For LSTM
104        split_str = key.split(".")
105        if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
106            layer = split_str[-2]
107            module_name = ".".join(split_str[:-3])
108            float_weight_ih_key = module_name + ".weight_ih_l" + layer
109            float_weight_hh_key = module_name + ".weight_hh_l" + layer
110            if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
111                weight_dict[key] = {}
112                weight_dict[key]["float"] = float_dict[float_weight_ih_key]
113                weight_dict[key]["quantized"] = (
114                    quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
115                )
116                weight_dict[key]["float"] = float_dict[float_weight_hh_key]
117                weight_dict[key]["quantized"] = (
118                    quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
119                )
120
121    return weight_dict
122
123
124def _get_logger_dict_helper(
125    mod: nn.Module,
126    target_dict: Dict[str, Any],
127    prefix: str = "",
128) -> None:
129    r"""This is the helper function for get_logger_dict
130
131    Args:
132        mod: module we want to save all logger stats
133        prefix: prefix for the current module
134        target_dict: the dictionary used to save all logger stats
135    """
136
137    def get_prefix(prefix):
138        return prefix if prefix == "" else prefix + "."
139
140    for name, child in mod.named_children():
141        if isinstance(child, Logger):
142            target_dict[get_prefix(prefix) + "stats"] = child.stats
143            break
144
145    for name, child in mod.named_children():
146        module_prefix = get_prefix(prefix) + name if prefix else name
147        _get_logger_dict_helper(child, target_dict, module_prefix)
148
149
150def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
151    r"""Traverse the modules and save all logger stats into target dict.
152    This is mainly used for quantization accuracy debug.
153
154    Type of loggers supported:
155        ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
156        OutputLogger: used to log the outputs of the modules
157
158    Args:
159        mod: module we want to save all logger stats
160        prefix: prefix for the current module
161
162    Return:
163        target_dict: the dictionary used to save all logger stats
164
165    """
166    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
167
168    target_dict: Dict[str, Dict] = {}
169    _get_logger_dict_helper(mod, target_dict, prefix)
170    return target_dict
171
172
173class Logger(nn.Module):
174    r"""Base class for stats logging"""
175
176    def __init__(self):
177        super().__init__()
178        self.stats = {}
179        # We only insert observer if the op is quantized with static quantization,
180        # which is identified by activation_observer.dtype == quint8.  This is needed
181        # when attaching Logger as observer for FX mode
182        self.dtype = torch.quint8
183
184    def forward(self, x):
185        # fmt: off
186        """
187        """  # blank docblock to make autodoc happy
188        # fmt: on
189
190
191class ShadowLogger(Logger):
192    r"""Class used in Shadow module to record the outputs of the original and
193    shadow modules.
194    """
195
196    def __init__(self):
197        super().__init__()
198        self.stats["float"] = []
199        self.stats["quantized"] = []
200
201    def forward(self, x, y):
202        # fmt: off
203        """
204        """  # blank docblock to make autodoc happy
205        # fmt: on
206        if len(x) > 1:
207            x = x[0]
208        if len(y) > 1:
209            y = y[0]
210        self.stats["quantized"].append(x.detach())
211        self.stats["float"].append(y.detach())
212
213
214class OutputLogger(Logger):
215    r"""Class used to log the outputs of the module"""
216
217    def __init__(self):
218        super().__init__()
219        self.stats["tensor_val"] = []
220
221    def forward(self, x):
222        # fmt: off
223        """
224        """  # blank docblock to make autodoc happy
225        # fmt: on
226        self.stats["tensor_val"].append(x)
227        return x
228
229
230def _convert_tuple_to_list(t: Any) -> Any:
231    return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
232
233
234def _dequantize_tensor_list(t: Any) -> Any:
235    return (
236        [_dequantize_tensor_list(x) for x in t]
237        if type(t) is list
238        else t.dequantize()
239        if t.is_quantized
240        else t
241    )
242
243
244class Shadow(nn.Module):
245    r"""Shadow module attaches the float module to its matching quantized module
246    as the shadow. Then it uses Logger module to process the outputs of both
247    modules.
248
249    Args:
250        q_module: module quantized from float_module that we want to shadow
251        float_module: float module used to shadow q_module
252        logger_cls: type of logger used to process the outputs of q_module and
253            float_module. ShadowLogger or custom loggers can be used.
254    """
255
256    def __init__(self, q_module, float_module, logger_cls):
257        super().__init__()
258        self.orig_module = q_module
259        self.shadow_module = float_module
260        self.dequant = nnq.DeQuantize()
261        self.logger = logger_cls()
262
263    def forward(self, *x) -> torch.Tensor:
264        # fmt: off
265        """
266        """  # blank docblock to make autodoc happy
267        # fmt: on
268        xl = _convert_tuple_to_list(x)
269        output = self.orig_module(*xl)
270        xl_float = _dequantize_tensor_list(xl)
271        shadow_output = self.shadow_module(*xl_float)
272        self.logger(output, shadow_output)
273        return output
274
275    def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276        # fmt: off
277        """
278        """  # blank docblock to make autodoc happy
279        # fmt: on
280        output = self.orig_module.add(x, y)
281        x = x.dequantize()
282        y = y.dequantize()
283        shadow_output = self.shadow_module.add(x, y)
284        self.logger(output, shadow_output)
285        return output
286
287    def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
288        # fmt: off
289        """
290        """  # blank docblock to make autodoc happy
291        # fmt: on
292        output = self.orig_module.add_scalar(x, y)
293        x = x.dequantize()
294        shadow_output = self.shadow_module.add_scalar(x, y)
295        self.logger(output, shadow_output)
296        return output
297
298    def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
299        # fmt: off
300        """
301        """  # blank docblock to make autodoc happy
302        # fmt: on
303        output = self.orig_module.mul(x, y)
304        x = x.dequantize()
305        y = y.dequantize()
306        shadow_output = self.shadow_module.mul(x, y)
307        self.logger(output, shadow_output)
308        return output
309
310    def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
311        # fmt: off
312        """
313        """  # blank docblock to make autodoc happy
314        # fmt: on
315        output = self.orig_module.mul_scalar(x, y)
316        x = x.dequantize()
317        shadow_output = self.shadow_module.mul_scalar(x, y)
318        self.logger(output, shadow_output)
319        return output
320
321    def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
322        # fmt: off
323        """
324        """  # blank docblock to make autodoc happy
325        # fmt: on
326        output = self.orig_module.cat(x, dim)
327        x = [y.dequantize() for y in x]
328        shadow_output = self.shadow_module.cat(x, dim)
329        self.logger(output, shadow_output)
330        return output
331
332    def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
333        # fmt: off
334        """
335        """  # blank docblock to make autodoc happy
336        # fmt: on
337        output = self.orig_module.add_relu(x, y)
338        x = x.dequantize()
339        y = y.dequantize()
340        shadow_output = self.shadow_module.add_relu(x, y)
341        self.logger(output, shadow_output)
342        return output
343
344
345def prepare_model_with_stubs(
346    float_module: nn.Module,
347    q_module: nn.Module,
348    module_swap_list: Set[type],
349    logger_cls: Callable,
350) -> None:
351    r"""Prepare the model by attaching the float module to its matching quantized
352    module as the shadow if the float module type is in module_swap_list.
353
354    Example usage::
355
356        prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
357        q_model(data)
358        ob_dict = get_logger_dict(q_model)
359
360    Args:
361        float_module: float module used to generate the q_module
362        q_module: module quantized from float_module
363        module_swap_list: list of float module types to attach the shadow
364        logger_cls: type of logger to be used in shadow module to process the outputs of
365            quantized module and its float shadow module
366    """
367    torch._C._log_api_usage_once(
368        "quantization_api._numeric_suite.prepare_model_with_stubs"
369    )
370
371    float_module_children = {}
372    for name, mod in float_module.named_children():
373        float_module_children[name] = mod
374
375    reassign = {}
376    for name, mod in q_module.named_children():
377        if name not in float_module_children:
378            continue
379
380        float_mod = float_module_children[name]
381
382        if type(float_mod) not in module_swap_list:
383            prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
384
385        # Insert shadow module only if the module is not of the same type as
386        # the floating point module
387        if type(float_mod) in module_swap_list and not _is_identical_module_type(
388            mod, float_mod
389        ):
390            reassign[name] = Shadow(mod, float_mod, logger_cls)
391
392    for key, value in reassign.items():
393        q_module._modules[key] = value
394
395
396def _is_identical_module_type(mod1, mod2):
397    # Compare if two modules have the same dtype
398    mod1_module_types = [type(mod) for mod in mod1.modules()]
399    mod2_module_types = [type(mod) for mod in mod2.modules()]
400    return mod1_module_types == mod2_module_types
401
402
403def compare_model_stub(
404    float_model: nn.Module,
405    q_model: nn.Module,
406    module_swap_list: Set[type],
407    *data,
408    logger_cls=ShadowLogger,
409) -> Dict[str, Dict]:
410    r"""Compare quantized module in a model with its floating point counterpart,
411    feeding both of them the same input. Return a dict with key corresponding to
412    module names and each entry being a dictionary with two keys 'float' and
413    'quantized', containing the output tensors of quantized and its matching
414    float shadow module. This dict can be used to compare and compute the module
415    level quantization error.
416
417    This function first call prepare_model_with_stubs() to swap the quantized
418    module that we want to compare with the Shadow module, which takes quantized
419    module, corresponding float module and logger as input, and creates a forward
420    path inside to make the float module to shadow quantized module sharing the
421    same input. The logger can be customizable, default logger is ShadowLogger
422    and it will save the outputs of the quantized module and float module that
423    can be used to compute the module level quantization error.
424
425    Example usage::
426
427        module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
428        ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
429        for key in ob_dict:
430            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
431
432    Args:
433        float_model: float model used to generate the q_model
434        q_model: model quantized from float_model
435        module_swap_list: list of float module types at which shadow modules will
436            be attached.
437        data: input data used to run the prepared q_model
438        logger_cls: type of logger to be used in shadow module to process the outputs of
439            quantized module and its float shadow module
440    """
441    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
442    prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
443    q_model(*data)
444    ob_dict = get_logger_dict(q_model)
445    return ob_dict
446
447
448def get_matching_activations(
449    float_module: nn.Module,
450    q_module: nn.Module,
451) -> Dict[str, Dict[str, torch.Tensor]]:
452    r"""Find the matching activation between float and quantized modules.
453
454    Args:
455        float_module: float module used to generate the q_module
456        q_module: module quantized from float_module
457
458    Return:
459        act_dict: dict with key corresponding to quantized module names and each
460        entry being a dictionary with two keys 'float' and 'quantized', containing
461        the matching float and quantized activations
462    """
463    torch._C._log_api_usage_once(
464        "quantization_api._numeric_suite.get_matching_activations"
465    )
466    float_dict = get_logger_dict(float_module)
467    quantized_dict = get_logger_dict(q_module)
468    act_dict: Dict[str, Dict] = {}
469    for key in quantized_dict:
470        if len(quantized_dict[key]["tensor_val"]) == 0:
471            continue
472        match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
473        if match_key is not None:
474            act_dict[key] = {}
475            act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
476            act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
477    return act_dict
478
479
480def prepare_model_outputs(
481    float_module: nn.Module,
482    q_module: nn.Module,
483    logger_cls=OutputLogger,
484    allow_list=None,
485) -> None:
486    r"""Prepare the model by attaching the logger to both float module
487    and quantized module if they are in the allow_list.
488
489    Args:
490        float_module: float module used to generate the q_module
491        q_module: module quantized from float_module
492        logger_cls: type of logger to be attached to float_module and q_module
493        allow_list: list of module types to attach logger
494    """
495    torch._C._log_api_usage_once(
496        "quantization_api._numeric_suite.prepare_model_outputs"
497    )
498    if allow_list is None:
499        allow_list = get_default_compare_output_module_list()
500
501    qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
502    float_module.qconfig = qconfig_debug  # type: ignore[assignment]
503    prepare(
504        float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
505    )
506    q_module.qconfig = qconfig_debug  # type: ignore[assignment]
507    prepare(
508        q_module,
509        inplace=True,
510        allow_list=allow_list,
511        observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
512        prepare_custom_config_dict={},
513    )
514
515
516def compare_model_outputs(
517    float_model: nn.Module,
518    q_model: nn.Module,
519    *data,
520    logger_cls=OutputLogger,
521    allow_list=None,
522) -> Dict[str, Dict[str, torch.Tensor]]:
523    r"""Compare output activations between float and quantized models at
524    corresponding locations for the same input. Return a dict with key corresponding
525    to quantized module names and each entry being a dictionary with two keys
526    'float' and 'quantized', containing the activations of quantized model and
527    float model at matching locations. This dict can be used to compare and
528    compute the propagation quantization error.
529
530    Example usage::
531
532        act_compare_dict = compare_model_outputs(float_model, qmodel, data)
533        for key in act_compare_dict:
534            print(
535                key,
536                compute_error(
537                    act_compare_dict[key]['float'],
538                    act_compare_dict[key]['quantized'].dequantize()
539                )
540            )
541
542    Args:
543        float_model: float model used to generate the q_model
544        q_model: model quantized from float_model
545        data: input data used to run the prepared float_model and q_model
546        logger_cls: type of logger to be attached to float_module and q_module
547        allow_list: list of module types to attach logger
548
549    Return:
550        act_compare_dict: dict with key corresponding to quantized module names
551        and each entry being a dictionary with two keys 'float' and 'quantized',
552        containing the matching float and quantized activations
553    """
554    torch._C._log_api_usage_once(
555        "quantization_api._numeric_suite.compare_model_outputs"
556    )
557    if allow_list is None:
558        allow_list = get_default_compare_output_module_list()
559    prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
560    float_model(*data)
561    q_model(*data)
562    act_compare_dict = get_matching_activations(float_model, q_model)
563    return act_compare_dict
564