xref: /aosp_15_r20/external/pytorch/torch/ao/ns/_numeric_suite_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module contains tooling to compare weights and activations
4across models. Example usage::
5
6    import copy
7    import torch
8    import torch.ao.quantization.quantize_fx as quantize_fx
9    import torch.ao.ns._numeric_suite_fx as ns
10
11    m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
12    mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
13    # We convert a copy because we need the original prepared model
14    # to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
15    mq = quantize_fx.convert_fx(copy.deepcopy(mp))
16
17    #
18    # Comparing weights
19    #
20
21    # extract weight pairs
22    weight_comparison = ns.extract_weights('a', mp, 'b', mq)
23
24    # add SQNR for each comparison, inplace
25    ns.extend_logger_results_with_comparison(
26        weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
27        'sqnr')
28
29    # weight_comparison contains the weights from `mp` and `mq` stored
30    # in pairs, and can be used for further analysis.
31
32
33    #
34    # Comparing activations, with error propagation
35    #
36
37    # add loggers
38    mp_ns, mq_ns = ns.add_loggers(
39        'a', copy.deepcopy(mp),
40        'b', copy.deepcopy(mq),
41        ns.OutputLogger)
42
43    # send an example datum to capture intermediate activations
44    datum = torch.randn(1, 1, 1, 1)
45    mp_ns(datum)
46    mq_ns(datum)
47
48    # extract intermediate activations
49    act_comparison = ns.extract_logger_info(
50        mp_ns, mq_ns, ns.OutputLogger, 'b')
51
52    # add SQNR for each comparison, inplace
53    ns.extend_logger_results_with_comparison(
54        act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
55        'sqnr')
56
57    # act_comparison contains the activations from `mp_ns` and `mq_ns` stored
58    # in pairs, and can be used for further analysis.
59
60    #
61    # Comparing activations, without error propagation
62    #
63
64    # create shadow model
65    mp_shadows_mq = ns.add_shadow_loggers(
66        'a', copy.deepcopy(mp),
67        'b', copy.deepcopy(mq),
68        ns.OutputLogger)
69
70    # send an example datum to capture intermediate activations
71    datum = torch.randn(1, 1, 1, 1)
72    mp_shadows_mq(datum)
73
74    # extract intermediate activations
75    shadow_act_comparison = ns.extract_shadow_logger_info(
76        mp_shadows_mq, ns.OutputLogger, 'b')
77
78    # add SQNR for each comparison, inplace
79    ns.extend_logger_results_with_comparison(
80        shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
81        'sqnr')
82
83    # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
84    # in pairs, and can be used for further analysis.
85
86"""
87
88import collections
89from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
90
91import torch
92import torch.ao.quantization.quantize_fx as quantize_fx
93import torch.nn as nn
94from torch.ao.ns.fx.graph_matcher import (
95    get_matching_subgraph_pairs,
96    get_type_a_related_to_b,
97)
98from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops
99from torch.ao.ns.fx.n_shadows_utils import (
100    _get_dedup_subgraphs,
101    create_add_loggers_graph,
102    create_n_transformed_and_logged_copies_of_subgraph,
103    create_results_comparison,
104    extract_weight_comparison,
105    group_results_by_subgraph,
106    OutputProp,
107    print_n_shadows_summary,
108    SHADOW_WRAPPER_NODE_NAME_PREFIX,
109)
110from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
111from torch.ao.quantization import QConfigMapping
112from torch.ao.quantization.backend_config import BackendConfig
113from torch.ao.quantization.backend_config.utils import (
114    get_fusion_pattern_to_root_node_getter,
115)
116from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
117from torch.ao.quantization.fx.match_utils import _find_matches
118from torch.ao.quantization.fx.qconfig_mapping_utils import (
119    _generate_node_name_to_qconfig,
120)
121from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
122from torch.fx import GraphModule
123from torch.fx.graph import Node
124
125from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b
126from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType
127from .fx.utils import (
128    get_target_type_str,
129    maybe_add_missing_fqns,
130    rekey_logger_info_on_node_name_of_model,
131)
132from .fx.weight_utils import extract_weight_from_node
133
134
135if TYPE_CHECKING:
136    from torch.ao.quantization.qconfig import QConfigAny
137
138RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
139
140
141class OutputLogger(nn.Module):
142    """
143    Base class for capturing intermediate values.
144    """
145
146    stats: List[torch.Tensor]
147    stats_rnn: List[RNNReturnType]
148
149    # Mark as impure so that calls to it will not be removed during DCE.
150    _is_impure = True
151
152    def __init__(
153        self,
154        ref_node_name: str,
155        prev_node_name: str,
156        model_name: str,
157        ref_name: str,
158        prev_node_target_type: str,
159        ref_node_target_type: str,
160        results_type: str,
161        index_within_arg: int,
162        index_of_arg: int,
163        fqn: Optional[str],
164        qconfig_str: Optional[str] = "",
165    ):
166        super().__init__()
167        self.stats: List[torch.Tensor] = []
168        self.stats_rnn: List[RNNReturnType] = []
169
170        # name of the node which was responsible for adding this logger
171        # Note:
172        # - if we are logging node outputs, this is the same as prev_node_name
173        # - if we are logging node inputs, this is the name of the node
174        #   whose input this logger is logging.
175        #
176        # example, where logger1 is logging input of op1 and logger2 is logging
177        #    the output of op1:
178        #
179        #  x1 -> logger1 -> op1 -> logger2 -> x2
180        #
181        # in this example,
182        #   - logger1's prev_node_name is x1 and ref_node_name is op1
183        #   - logger2's prev_node_name is op1 and ref_node_name is op1
184        self.ref_node_name = ref_node_name
185        # name of the node whose output this Logger is capturing
186        self.prev_node_name = prev_node_name
187
188        # name of the model from which the node originated from
189        self.model_name = model_name
190        # reference name, used to match loggers from separate models
191        # to each other
192        self.ref_name = ref_name
193        # type of the target of the node whose output this logger is logging
194        self.prev_node_target_type = prev_node_target_type
195        # type of the target of the node which was responsible for adding this
196        # logger
197        self.ref_node_target_type = ref_node_target_type
198        # what kind of values are inside of stats
199        self.results_type = results_type
200        # index of this node within the arg of the input/output node
201        # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
202        self.index_within_arg = index_within_arg
203        # index of this node within the args of the input/output node
204        # for example, in add(x1, x2), x2 would have index_of_arg == 1
205        self.index_of_arg = index_of_arg
206        # fully qualified name
207        self.fqn = fqn
208        # if loggers are added before prepare_fx, but we do not want
209        # collect results of calibration, only results after convert_fx
210        # so, we add a flag to control whether this logger collects data
211        self.enabled = True
212        # string representation of qconfig
213        self.qconfig_str = qconfig_str
214        # this can be turned off to reduce memory usage during calibration
215        self.save_activations = True
216
217    # Note: cannot annotate the type of x because TorchScript does not support
218    #   the Union type.
219    def forward(self, x):
220        # fmt: off
221        """
222        """  # blank docblock to make autodoc happy
223        # fmt: on
224        # TODO(future PR): consider designing this better, as the difference
225        # between these two flags is subtle and not obvious.
226        if not self.enabled:
227            return x
228        if not self.save_activations:
229            return x
230        # TODO(future PR): consider refactoring this to better reuse the parent
231        # class
232        if isinstance(x, torch.Tensor):
233            self.stats.append(x.detach())
234        elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
235            new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
236            self.stats_rnn.append(new_res)
237        return x
238
239    def __repr__(self):
240        clean_dict = {
241            k: v
242            for k, v in self.__dict__.items()
243            # skip nn.Module keys
244            if (k != "training") and not k.startswith("_")
245        }
246        return f"OutputLogger({clean_dict})"
247
248
249class OutputComparisonLogger(OutputLogger):
250    """
251    Same as OutputLogger, but also requires the original activation
252    in order to calculate the comparison at calibration time
253    """
254
255    def __init__(self, *args, **kwargs):
256        super().__init__(*args, **kwargs)
257        # TODO(future PR): make the comparison function configurable
258        self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
259        self.comparison_fn_name = "sqnr"
260        # precalculated comparisons of logger output versus reference
261        self.comparisons = []
262        # precalculated comparisons function
263
264    def forward(self, x, x_ref):
265        # fmt: off
266        """
267        """  # blank docblock to make autodoc happy
268        # fmt: on
269        if not self.enabled:
270            return x
271        assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
272        if self.save_activations:
273            # save the activation, for debugging
274            self.stats.append(x.detach())
275        # save the comparison
276        self.comparisons.append(self.comparison_fn(x, x_ref))
277        return x
278
279    def __repr__(self):
280        clean_dict = {
281            k: v
282            for k, v in self.__dict__.items()
283            # skip nn.Module keys
284            if (k != "training") and not k.startswith("_")
285        }
286        return f"OutputComparisonLogger({clean_dict})"
287
288
289class NSTracer(quantize_fx.QuantizationTracer):
290    """
291    Just like a regular FX quantization tracer, but treats observers and fake_quantize
292    modules as leaf modules.
293    """
294
295    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
296        # fmt: off
297        """
298        """  # blank docblock to make autodoc happy
299        # fmt: on
300        if isinstance(m, torch.ao.quantization.ObserverBase):
301            return True
302        elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
303            return True
304        return super().is_leaf_module(m, module_qualified_name)
305
306
307def _extract_weights_one_model(
308    model_name: str,
309    model: GraphModule,
310    nodes_and_names_to_instrument: List[Tuple[Node, str]],
311    results: NSResultsType,
312    op_to_type_to_weight_extraction_fn: Optional[
313        Dict[str, Dict[Callable, Callable]]
314    ] = None,
315) -> None:
316    torch._C._log_api_usage_once(
317        "quantization_api._numeric_suite_fx._extract_weights_one_model"
318    )
319    for node, ref_name in nodes_and_names_to_instrument:
320        res_type = NSSingleResultValuesType.WEIGHT.value
321        extracted_weight = extract_weight_from_node(
322            node, model, op_to_type_to_weight_extraction_fn
323        )
324        if extracted_weight:
325            if ref_name not in results:
326                results[ref_name] = {res_type: {}}
327            results[ref_name][res_type][model_name] = [extracted_weight]
328
329
330def _extract_weights_impl(
331    model_name_a: str,
332    gm_a: GraphModule,
333    model_name_b: str,
334    gm_b: GraphModule,
335    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
336    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
337    op_to_type_to_weight_extraction_fn: Optional[
338        Dict[str, Dict[Callable, Callable]]
339    ] = None,
340) -> NSResultsType:
341    torch._C._log_api_usage_once(
342        "quantization_api._numeric_suite_fx._extract_weights_impl"
343    )
344    matched_subgraph_pairs = get_matching_subgraph_pairs(
345        gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
346    )
347
348    # split the subgraph pairs into one data structure for each model
349    nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
350    nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
351    for match_name, match in matched_subgraph_pairs.items():
352        subgraph_a, subgraph_b = match
353        nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
354        nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
355
356    # populate the results, one model at a time
357    results: NSResultsType = {}
358    _extract_weights_one_model(
359        model_name_a,
360        gm_a,
361        nodes_and_names_to_instrument_a,
362        results,
363        op_to_type_to_weight_extraction_fn,
364    )
365    _extract_weights_one_model(
366        model_name_b,
367        gm_b,
368        nodes_and_names_to_instrument_b,
369        results,
370        op_to_type_to_weight_extraction_fn,
371    )
372
373    # fill in missing fqn entries
374    maybe_add_missing_fqns(results)
375
376    # rekey on names of nodes in gm_b
377    results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
378
379    return results
380
381
382def extract_weights(
383    model_name_a: str,
384    model_a: nn.Module,
385    model_name_b: str,
386    model_b: nn.Module,
387    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
388    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
389    op_to_type_to_weight_extraction_fn: Optional[
390        Dict[str, Dict[Callable, Callable]]
391    ] = None,
392) -> NSResultsType:
393    """
394    Extract weights from model A and model B, and return a comparison.
395
396    Args:
397        model_name_a: string name of model A to use in results
398        model_a: model A
399        model_name_b: string name of model B to use in results
400        model_b: model B
401        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
402        unmatchable_types_map: optional override of unmatchable types, subject to change
403        op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
404            from a type, subject to change
405
406    Return:
407        NSResultsType, containing the weight comparisons
408    """
409
410    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
411    if base_name_to_sets_of_related_ops is None:
412        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
413    type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
414
415    # TODO(future PR): expose these
416    skipped_module_names: List[str] = []
417    skipped_module_classes: List[Callable] = []
418    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
419    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
420    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
421    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
422        model_a, "node_name_to_scope"
423    )
424    if maybe_model_a_node_name_to_scope is not None:
425        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
426    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
427    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
428        model_b, "node_name_to_scope"
429    )
430    if maybe_model_b_node_name_to_scope is not None:
431        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
432    return _extract_weights_impl(
433        model_name_a,
434        gm_a,
435        model_name_b,
436        gm_b,
437        base_name_to_sets_of_related_ops,
438        unmatchable_types_map,
439        op_to_type_to_weight_extraction_fn,
440    )
441
442
443def _add_loggers_one_model(
444    model_name: str,
445    model: GraphModule,
446    nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
447    nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
448    logger_cls: Callable,
449) -> nn.Module:
450    torch._C._log_api_usage_once(
451        "quantization_api._numeric_suite_fx._add_loggers_one_model"
452    )
453
454    # TODO(future PR): do not observe nodes we do not care
455    #   about (both fp32, denylist, etc)
456    node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
457    node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
458    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
459        node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
460    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
461        node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
462
463    model = add_loggers_to_model(
464        model,
465        node_to_instrument_inputs_to_ref_name,
466        node_to_instrument_outputs_to_ref_name,
467        logger_cls,
468        model_name,
469    )
470    return model
471
472
473def _add_loggers_impl(
474    name_a: str,
475    gm_a: GraphModule,
476    name_b: str,
477    gm_b: GraphModule,
478    logger_cls: Callable,
479    should_log_inputs: bool,
480    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
481    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
482) -> Tuple[nn.Module, nn.Module]:
483    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
484    matched_subgraph_pairs = get_matching_subgraph_pairs(
485        gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
486    )
487    nodes_and_names_to_instrument_inputs_a = []
488    nodes_and_names_to_instrument_inputs_b = []
489    nodes_and_names_to_instrument_outputs_a = []
490    nodes_and_names_to_instrument_outputs_b = []
491    for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
492        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
493        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
494        # Note: for matching inputs we use start_node, such as observing
495        # the input of linear in linear-relu
496        if should_log_inputs:
497            nodes_and_names_to_instrument_inputs_a.append(
498                (subgraph_a.start_node, match_name, ref_node_type_a)
499            )
500            nodes_and_names_to_instrument_inputs_b.append(
501                (subgraph_b.start_node, match_name, ref_node_type_b)
502            )
503        # Note: for matching activations we always use end_node,
504        # such as observing the output of relu in linear-relu
505        nodes_and_names_to_instrument_outputs_a.append(
506            (subgraph_a.end_node, match_name, ref_node_type_a)
507        )
508        nodes_and_names_to_instrument_outputs_b.append(
509            (subgraph_b.end_node, match_name, ref_node_type_b)
510        )
511
512    new_model_a = _add_loggers_one_model(
513        name_a,
514        gm_a,
515        nodes_and_names_to_instrument_inputs_a,
516        nodes_and_names_to_instrument_outputs_a,
517        logger_cls,
518    )
519    new_model_b = _add_loggers_one_model(
520        name_b,
521        gm_b,
522        nodes_and_names_to_instrument_inputs_b,
523        nodes_and_names_to_instrument_outputs_b,
524        logger_cls,
525    )
526    return (new_model_a, new_model_b)
527
528
529def add_loggers(
530    name_a: str,
531    model_a: nn.Module,
532    name_b: str,
533    model_b: nn.Module,
534    logger_cls: Callable,
535    should_log_inputs: bool = False,
536    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
537    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
538) -> Tuple[nn.Module, nn.Module]:
539    """
540    Instrument model A and model B with loggers.
541
542    Args:
543        name_a: string name of model A to use in results
544        model_a: model A
545        name_b: string name of model B to use in results
546        model_b: model B
547        logger_cls: class of Logger to use
548        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
549        unmatchable_types_map: optional override of unmatchable types, subject to change
550
551    Return:
552        Returns a tuple of (model_a_with_loggers, model_b_with_loggers).  Modifies both models inplace.
553    """
554
555    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
556    # TODO(future PR): expose these
557    skipped_module_names: List[str] = []
558    skipped_module_classes: List[Callable] = []
559    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
560    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
561    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
562    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
563        model_a, "node_name_to_scope"
564    )
565    if maybe_model_a_node_name_to_scope is not None:
566        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
567    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
568    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
569        model_b, "node_name_to_scope"
570    )
571    if maybe_model_b_node_name_to_scope is not None:
572        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
573    return _add_loggers_impl(
574        name_a,
575        gm_a,
576        name_b,
577        gm_b,
578        logger_cls,
579        should_log_inputs=should_log_inputs,
580        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
581        unmatchable_types_map=unmatchable_types_map,
582    )
583
584
585def _extract_logger_info_one_model(
586    model: nn.Module,
587    results: NSResultsType,
588    logger_cls: Callable,
589) -> None:
590    torch._C._log_api_usage_once(
591        "quantization_api._numeric_suite_fx._extract_logger_info_one_model"
592    )
593    for gm_name, mod in model.named_modules():
594        # TODO(future PR): better check when scripted
595        is_logger = isinstance(mod, logger_cls) or (  # type: ignore[arg-type]
596            isinstance(mod, torch.jit.RecursiveScriptModule)
597            and mod.original_name == "OutputLogger"
598        )
599        if is_logger:
600            key = mod.ref_name
601            if key not in results:
602                results[key] = {}
603            assert (
604                mod.model_name not in results[key]
605            ), f"{mod.model_name} is already present in results"
606            if mod.results_type not in results[key]:
607                results[key][mod.results_type] = {}
608            if mod.model_name not in results[key][mod.results_type]:
609                results[key][mod.results_type][mod.model_name] = []
610            stats_to_use = mod.stats
611            if len(mod.stats_rnn) > 0:
612                stats_to_use = mod.stats_rnn
613            data = {
614                "type": mod.results_type,
615                "values": stats_to_use,
616                "ref_node_name": mod.ref_node_name,
617                "ref_node_target_type": mod.ref_node_target_type,
618                "prev_node_name": mod.prev_node_name,
619                "prev_node_target_type": mod.prev_node_target_type,
620                "index_within_arg": mod.index_within_arg,
621                "index_of_arg": mod.index_of_arg,
622                "fqn": mod.fqn,
623                "qconfig_str": mod.qconfig_str,
624            }
625            if hasattr(mod, "comparisons"):
626                data["comparisons"] = mod.comparisons
627                data["comparison_fn_name"] = mod.comparison_fn_name
628            else:
629                data["comparisons"] = []
630                data["comparison_fn_name"] = ""
631            results[key][mod.results_type][mod.model_name].append(data)
632            # ensure the list stays sorted
633            results[key][mod.results_type][mod.model_name].sort(
634                key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}"
635            )
636
637
638# TODO(future PR): align on naming
639# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
640def extract_logger_info(
641    model_a: nn.Module,
642    model_b: nn.Module,
643    logger_cls: Callable,
644    model_name_to_use_for_layer_names: str,
645) -> NSResultsType:
646    """
647    Traverse all loggers in `model_a` and `model_b`, and extract the logged
648    information.
649
650    Args:
651        model_a: model A
652        model_b: model B
653        logger_cls: class of Logger to use
654        model_name_to_use_for_layer_names: string name of model to use for
655          layer names in the output
656
657    Return:
658        NSResultsType, containing the logged comparisons
659    """
660    torch._C._log_api_usage_once(
661        "quantization_api._numeric_suite_fx.extract_logger_info"
662    )
663    results: NSResultsType = {}
664    for model in (model_a, model_b):
665        _extract_logger_info_one_model(model, results, logger_cls)
666    # fill in missing fqn entries
667    maybe_add_missing_fqns(results)
668    # rekey on the name of model b
669    results = rekey_logger_info_on_node_name_of_model(
670        results, model_name_to_use_for_layer_names
671    )
672    return results
673
674
675def _add_shadow_loggers_impl(
676    name_a: str,
677    gm_a: GraphModule,
678    name_b: str,
679    gm_b: GraphModule,
680    logger_cls: Callable,
681    should_log_inputs: bool,
682    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
683    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
684    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
685) -> nn.Module:
686    torch._C._log_api_usage_once(
687        "quantization_api._numeric_suite_fx._add_shadow_loggers_impl"
688    )
689    matched_subgraph_pairs = get_matching_subgraph_pairs(
690        gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map
691    )
692    gm_a_shadows_b = create_a_shadows_b(
693        name_a,
694        gm_a,
695        name_b,
696        gm_b,
697        matched_subgraph_pairs,
698        logger_cls,
699        should_log_inputs=should_log_inputs,
700        node_type_to_io_type_map=node_type_to_io_type_map,
701    )
702    return gm_a_shadows_b
703
704
705def add_shadow_loggers(
706    name_a: str,
707    model_a: nn.Module,
708    name_b: str,
709    model_b: nn.Module,
710    logger_cls: Callable,
711    should_log_inputs: bool = False,
712    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
713    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
714    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
715) -> nn.Module:
716    """
717    Instrument model A and model B with shadow loggers.
718
719    Args:
720        name_a: string name of model A to use in results
721        model_a: model A
722        name_b: string name of model B to use in results
723        model_b: model B
724        logger_cls: class of Logger to use
725        should_log_inputs: whether to log inputs
726        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
727        unmatchable_types_map: optional override of unmatchable types, subject to change
728    """
729    torch._C._log_api_usage_once(
730        "quantization_api._numeric_suite_fx.add_shadow_loggers"
731    )
732    # TODO(future PR): expose these
733    skipped_module_names: List[str] = []
734    skipped_module_classes: List[Callable] = []
735    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
736    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
737    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
738    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(
739        model_a, "node_name_to_scope"
740    )
741    if maybe_model_a_node_name_to_scope is not None:
742        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
743    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
744    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(
745        model_b, "node_name_to_scope"
746    )
747    if maybe_model_b_node_name_to_scope is not None:
748        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
749    return _add_shadow_loggers_impl(
750        name_a,
751        gm_a,
752        name_b,
753        gm_b,
754        logger_cls,
755        should_log_inputs=should_log_inputs,
756        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
757        node_type_to_io_type_map=node_type_to_io_type_map,
758        unmatchable_types_map=unmatchable_types_map,
759    )
760
761
762def extract_shadow_logger_info(
763    model_a_shadows_b: nn.Module,
764    logger_cls: Callable,
765    model_name_to_use_for_layer_names: str,
766) -> NSResultsType:
767    """
768    Traverse all loggers in a shadow model, and extract the logged
769    information.
770
771    Args:
772        model_a_shadows_b: shadow model
773        logger_cls: class of Logger to use
774        model_name_to_use_for_layer_names: string name of model to use for
775          layer names in the output
776
777    Return:
778        NSResultsType, containing the logged comparisons
779    """
780    torch._C._log_api_usage_once(
781        "quantization_api._numeric_suite_fx.extract_shadow_logger_info"
782    )
783    results: NSResultsType = collections.defaultdict(dict)
784    _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
785    # fill in missing fqn entries
786    maybe_add_missing_fqns(results)
787    # rekey on the name of model b
788    results = rekey_logger_info_on_node_name_of_model(
789        results, model_name_to_use_for_layer_names
790    )
791    return dict(results)
792
793
794def extend_logger_results_with_comparison(
795    results: NSResultsType,
796    model_name_1: str,
797    model_name_2: str,
798    comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
799    comparison_name: str,
800) -> None:
801    """
802    Compares the logged values from `model_name_2` against the corresponding
803    values in `model_name_1`, using `comparison_fn`. Records the result
804    in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
805
806    Args:
807        results: the result data structure from `extract_logger_info` or
808          `extract_shadow_logger_info`.
809        model_name_1: string name of model 1
810        model_name_2: string name of model 2
811        comparison_fn: function to compare two Tensors
812        comparison_name: string name of model to use for
813          layer names in the output
814    """
815    for results_type_to_results in results.values():
816        for model_name_to_results in results_type_to_results.values():
817            assert (
818                model_name_1 in model_name_to_results
819            ), f"{model_name_1} not found in results"
820            assert (
821                model_name_2 in model_name_to_results
822            ), f"{model_name_2} not found in results"
823
824            results_1 = model_name_to_results[model_name_1]
825            results_2 = model_name_to_results[model_name_2]
826
827            for result_2 in results_2:
828                index_within_arg_2 = result_2["index_within_arg"]
829                index_of_arg_2 = result_2["index_of_arg"]
830                # find corresponding result_1
831                result_1 = None
832                for cur_result_1 in results_1:
833                    index_within_arg_1 = cur_result_1["index_within_arg"]
834                    index_of_arg_1 = cur_result_1["index_of_arg"]
835                    if (index_within_arg_1 == index_within_arg_2) and (
836                        index_of_arg_1 == index_of_arg_2
837                    ):
838                        result_1 = cur_result_1
839                        break
840                assert result_1 is not None
841
842                values_1 = result_1["values"]
843                values_2 = result_2["values"]
844                result_2[comparison_name] = []
845                for value_1, value_2 in zip(values_1, values_2):
846                    comparison_result = comparison_fn(value_1, value_2)
847                    result_2[comparison_name].append(comparison_result)
848
849
850def prepare_n_shadows_model(
851    model: torch.nn.Module,
852    example_inputs: Any,
853    qconfig_multi_mapping: QConfigMultiMapping,
854    backend_config: BackendConfig,
855    custom_prepare_fn: Optional[Callable] = None,
856    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
857    custom_tracer: Any = None,
858) -> GraphModule:
859    """
860    Given a model with a graph with M ops such as
861
862
863      args_kwargs_m -> op_m -> output_m
864
865
866    And a set of N qconfigs for each op, creates a new model, with
867    each of the subgraph of `op_m` transformed into
868
869    .. code::
870
871           |---------> op_m_n -> log_m_n
872           |                     /
873      args_kwargs_m ---------> op_m -> log_m_0
874
875    Where op_m_n is op_m wrapped in a submodule and transformed with
876    qconfig_n, and its inner graph looks like
877
878    .. code::
879
880      args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
881                  /
882      kwargs_m ---
883
884    This is useful for testing different quantization of multiple layers in
885    a single pass through the model.
886
887    High level TODOs for future PRs:
888    * figure out a better way to name the output structure
889    * return a results data structure instead of printing it out
890    * add examples to docblocks
891    """
892
893    if custom_tracer is None:
894        tracer = quantize_fx.QuantizationTracer([], [])
895    else:
896        tracer = custom_tracer
897    mt = torch.fx.GraphModule(model, tracer.trace(model))
898    # this is necessary to ensure logger FQNs get populated
899    mt._node_name_to_scope = tracer.node_name_to_scope  # type: ignore[assignment]
900
901    # run example input propagation, we need this to call prepare_fx on
902    # individual subgraphs
903    output_prop = OutputProp(mt)
904    output_prop.propagate(*example_inputs)
905
906    # Find the set of subgraphs in the original graph which we need to
907    # consider.
908    modules = dict(mt.named_modules(remove_duplicate=False))
909    patterns = _get_pattern_to_quantize_handlers(backend_config)
910    root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
911    standalone_module_names: List[str] = []
912    standalone_module_classes: List[Type] = []
913    custom_module_classes: List[Type] = []
914    matches = _find_matches(
915        mt.graph,
916        modules,
917        patterns,
918        root_node_getter_mapping,
919        standalone_module_names,
920        standalone_module_classes,
921        custom_module_classes,
922    )
923    subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
924
925    # generate node to qconfig for each subgraph
926    # TODO(future PR): deduplicate repeating entries
927    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
928    for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
929        node_name_to_qconfig = _generate_node_name_to_qconfig(
930            mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
931        )
932        list_of_node_name_to_qconfig.append(node_name_to_qconfig)
933
934    # For each region in the model, do the following:
935    #   For each qconfig for that region, do the following:
936    #     1. create a copy of the region wrapped in a module
937    #     2. pass original args, original kwargs, and expected output to module
938    #     3. add an output comparison logger and hook it up to compare
939    #        actual output to expected output
940    #     4. run `prepare_fx` on the module
941    for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate(
942        subgraphs_dedup.items()
943    ):
944        create_n_transformed_and_logged_copies_of_subgraph(
945            mt,
946            subgraph_idx,
947            match_name,
948            nodes_in_this_subgraph,
949            qconfig_multi_mapping.qconfig_mappings_list,
950            list_of_node_name_to_qconfig,
951            custom_prepare_fn,
952            custom_prepare_kwargs,  # type: ignore[arg-type]
953        )
954
955    return mt
956
957
958# TODO(future PR): we should rethink the names of all the PNP APIs
959def _prepare_n_shadows_add_loggers_model(
960    model: torch.nn.Module,
961    example_inputs: Any,
962    qconfig_mapping: QConfigMapping,
963    backend_config: BackendConfig,
964) -> torch.nn.Module:
965    r"""
966    Note: this API is not recommended for wide usage, it is only
967    provided for customers who need to migrate from the `add_loggers`
968    API.
969
970    This creates a model which provides logging for the following
971    problem: if we quantize `model` with `qconfig_mapping` and feed
972    the same input through both models, log the comparisons of
973    corresponding intermediate layers.
974
975    The problem is solved with a single model.  Specifically, we
976    partition `model` into N subgraphs, create a copy of each relevant
977    subgraph, wrap it in a module, apply the quantization API to that
978    module, and hook up loggers to measure the comparisons.
979
980    Example starting graph:
981
982      x0 -> op0 -> x1 -> op1 -> x2
983
984    Example config: quantize op0 to int8, do nothing to op1.
985    The following graph will be created:
986
987    .. code::
988
989      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
990       \                        \                           \       # noqa: W605
991         ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
992
993    Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
994    to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
995    and clog is a comparison logger.
996    """
997
998    tracer = quantize_fx.QuantizationTracer([], [])
999    mt = torch.fx.GraphModule(model, tracer.trace(model))
1000    # this is necessary to ensure logger FQNs get populated
1001    mt._node_name_to_scope = tracer.node_name_to_scope  # type: ignore[assignment]
1002
1003    # run example input propagation, we need this to call prepare_fx on
1004    # individual subgraphs
1005    output_prop = OutputProp(mt)
1006    output_prop.propagate(*example_inputs)
1007
1008    # Find the set of subgraphs in the original graph which we need to
1009    # consider.
1010    modules = dict(mt.named_modules(remove_duplicate=False))
1011    patterns = _get_pattern_to_quantize_handlers(backend_config)
1012    root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
1013    standalone_module_names: List[str] = []
1014    standalone_module_classes: List[Type] = []
1015    custom_module_classes: List[Type] = []
1016    matches = _find_matches(
1017        mt.graph,
1018        modules,
1019        patterns,
1020        root_node_getter_mapping,
1021        standalone_module_names,
1022        standalone_module_classes,
1023        custom_module_classes,
1024    )
1025    subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches)
1026
1027    # generate node to qconfig for each subgraph
1028    node_name_to_qconfig = _generate_node_name_to_qconfig(
1029        mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope
1030    )
1031
1032    # Now, mutate the graph to be the add_loggers graph with propagation
1033    # error.
1034    create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
1035
1036    return mt
1037
1038
1039# TODO(future PR): we should rethink the names of all the PNP APIs
1040def _n_shadows_compare_weights(
1041    model: torch.nn.Module,
1042    example_inputs: Any,
1043    qconfig_mapping: QConfigMapping,
1044    backend_config: BackendConfig,
1045) -> NSResultsType:
1046    """
1047    Note: this API is not recommended for wide usage, it is only
1048    provided for customers who need to migrate from the `add_loggers`
1049    API.
1050    """
1051    qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping(
1052        [qconfig_mapping]
1053    )
1054    mp = prepare_n_shadows_model(
1055        model, example_inputs, qconfig_multi_mapping, backend_config
1056    )
1057    # passing inputs through the model is necessary to populate
1058    # observers which observe weights with real values
1059    mp(*example_inputs)
1060    mq = convert_n_shadows_model(mp)
1061    weight_comparison = extract_weight_comparison(mq)
1062    return weight_comparison
1063
1064
1065# TODO(future PR): consider aligning API signature with other similar quantization
1066# functions (enable_fake_quant, etc)
1067def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
1068    """
1069    Sets the `enabled` setting on a `model`'s loggers
1070    """
1071    for name, child in model.named_modules():
1072        if isinstance(child, OutputLogger):
1073            child.enabled = enabled
1074
1075
1076# TODO(future PR): consider aligning API signature with other similar quantization
1077# functions (enable_fake_quant, etc)
1078def loggers_set_save_activations(
1079    model: torch.nn.Module,
1080    save_activations: bool,
1081) -> None:
1082    """
1083    Sets the `save_activations` setting on a `model`'s loggers
1084    """
1085    for name, child in model.named_modules():
1086        if isinstance(child, OutputLogger):
1087            child.save_activations = save_activations
1088
1089
1090def convert_n_shadows_model(
1091    model: GraphModule,
1092    custom_convert_fn: Optional[Callable] = None,
1093    custom_convert_kwargs: Optional[Dict[str, Any]] = None,
1094) -> GraphModule:
1095    """
1096    Given a model from `prepare_n_shadows_model`, runs `convert_fx`
1097    on each shadow submodule.
1098    """
1099    for node in model.graph.nodes:
1100        # TODO(future PR): consider matching in a safer way than
1101        # node name string match
1102        if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
1103            orig_mod = getattr(model, node.name)
1104            if custom_convert_fn is None:
1105                converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod)
1106            else:
1107                if custom_convert_kwargs is None:
1108                    custom_convert_kwargs = {}
1109                converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
1110            setattr(model, node.name, converted_mod)
1111
1112    return model
1113
1114
1115def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
1116    """
1117    Extracts logger results from `model`.
1118    """
1119    results: NSResultsType = {}
1120    _extract_logger_info_one_model(model, results, OutputLogger)
1121    return results
1122
1123
1124def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
1125    """
1126    Prints a summary of extracted `results`.
1127    """
1128    results_grouped = group_results_by_subgraph(results)
1129    results_comparison = create_results_comparison(results_grouped)
1130    print_n_shadows_summary(results_comparison)
1131