xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/n_shadows_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import copy
4import operator
5from typing import Any, Callable, Dict, List, Optional, Set, Tuple
6
7import torch
8import torch.fx
9from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
10from torch.ao.ns.fx.ns_types import NSResultsType, NSSingleResultValuesType
11from torch.ao.ns.fx.utils import (  # TODO(future PR): make this work correctly for methods
12    get_normalized_nth_input,
13    get_target_type_str,
14)
15from torch.ao.quantization import QConfigMapping
16from torch.ao.quantization.fx.match_utils import _MatchResult
17from torch.ao.quantization.qconfig import QConfigAny
18from torch.ao.quantization.utils import getattr_from_fqn
19from torch.fx import Graph, GraphModule, Node
20from torch.utils._pytree import tree_map
21
22
23SHADOW_NODE_NAME_PREFIX = "shadow"
24SHADOW_WRAPPER_NODE_NAME_PREFIX = "shadow_wrapper"
25
26# TODO(future PR): reuse existing mapping instead of creating a new one
27BINARY_FUNCTIONS = {
28    torch.add,
29    torch.Tensor.add,
30    operator.add,
31    torch.mul,
32    torch.Tensor.mul,
33    operator.mul,
34}
35
36
37def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
38    return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
39
40
41def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
42    return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
43
44
45class OutputProp:
46    """
47    Output propagation (modeled from shape propagation).
48
49    Given a GraphModule and an example input, saves the output flowing
50    through each node on `node.traced_result`.
51
52    Code based on the example from
53    https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
54    """
55
56    def __init__(self, mod):
57        self.mod = mod
58        self.graph = mod.graph
59        self.modules = dict(self.mod.named_modules())
60
61    def propagate(self, *args):
62        args_iter = iter(args)
63        env: Dict[str, Node] = {}
64
65        def load_arg(a):
66            return torch.fx.graph.map_arg(a, lambda n: env[n.name])
67
68        def fetch_attr(target: str):
69            target_atoms = target.split(".")
70            attr_itr = self.mod
71            for i, atom in enumerate(target_atoms):
72                if not hasattr(attr_itr, atom):
73                    raise RuntimeError(
74                        f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
75                    )
76                attr_itr = getattr(attr_itr, atom)
77            return attr_itr
78
79        for node in self.graph.nodes:
80            if node.op == "placeholder":
81                result = next(args_iter)
82            elif node.op == "get_attr":
83                result = fetch_attr(node.target)
84            elif node.op == "call_function":
85                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
86            elif node.op == "call_method":
87                self_obj, *args = load_arg(node.args)
88                kwargs = load_arg(node.kwargs)
89                result = getattr(self_obj, node.target)(*args, **kwargs)
90            elif node.op == "call_module":
91                result = self.modules[node.target](
92                    *load_arg(node.args), **load_arg(node.kwargs)
93                )
94
95            if isinstance(result, torch.Tensor):  # type: ignore[possibly-undefined]
96                node.traced_result = result
97
98            env[node.name] = result
99
100        return None
101
102
103def _get_dedup_subgraphs(matches: Dict[str, _MatchResult]) -> Dict[str, List[Node]]:
104    # the original matches variable is unique by node, make it unique by subgraph
105    # instead
106    seen_nodes = set()
107    subgraphs_dedup = {}
108
109    # Dict items are not reversible until Python 3.8, so we hack it
110    # to be compatible with previous Python versions
111    # TODO(future PR): try reversed(list(matches.items()))
112    matches_items_reversed: List[Tuple[str, _MatchResult]] = []
113    for name, cur_match in matches.items():
114        matches_items_reversed.insert(0, (name, cur_match))
115
116    # Note: the order is important.  `matches` currently provides the matches
117    # in reverse order.  We would like to process the matches in non-reverse
118    # order, so that we can create an intuitive naming scheme, such as
119    # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
120    for name, cur_match in matches_items_reversed:  # type: ignore[call-overload]
121        was_seen = False
122        for node_or_tuple in cur_match[1]:
123            # Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
124            # but it is really not. Furthermore, the contents of this field
125            # can change from match results of multiple nodes of the same pattern
126            #
127            # For example, for conv -> bn -> relu, we see
128            # match_results = {
129            #   'conv': (relu, [(bn, conv), relu], ...),
130            #   'bn': (relu, [(bn, conv), relu], ...),
131            #   'relu': (relu, [(bn, conv), relu], ...),
132            # }
133            #
134            # Ideally we should clean up the `find_matches` function to make
135            # this more intuitive. For the purposes of this prototype, we hack
136            # around it.
137
138            if isinstance(node_or_tuple, Node):
139                if node_or_tuple in seen_nodes:
140                    was_seen = True
141                seen_nodes.add(node_or_tuple)
142
143            else:
144                assert isinstance(node_or_tuple, tuple)
145                for node in node_or_tuple:
146                    assert isinstance(node, Node)
147                    if node in seen_nodes:
148                        was_seen = True
149                    seen_nodes.add(node)
150
151        if was_seen:
152            continue
153
154        # Start with the unusual type, convert it to [op_0, ..., op_n]
155        list_of_nodes = []
156
157        if len(cur_match[1]) == 1:
158            list_of_nodes = cur_match[1]
159        else:
160            assert len(cur_match[1]) == 2
161            # either (a, b), or ((a, b), c) or (c, (a, b))
162            # cannot make any assumptions on order, not clear what the
163            # _find_matches function is doing to populate this
164            # TODO(future PR): make this code less confusing,  see discussion
165            # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
166
167            def _order_nodes(node_a, node_b, node_c) -> List[Node]:
168                nodes = [node_a, node_b, node_c]
169                first_node = None
170                mid_node = None
171                last_node = None
172                for n in nodes:
173                    prev_n = n.args[0]
174                    next_n = next(iter(n.users))
175                    if prev_n not in nodes:
176                        first_node = n
177                    elif next_n not in nodes:
178                        last_node = n
179                    else:
180                        mid_node = n
181                assert (
182                    first_node is not None
183                    and mid_node is not None
184                    and last_node is not None
185                )
186                assert mid_node.args[0] is first_node
187                assert last_node.args[0] is mid_node
188                return [last_node, mid_node, first_node]
189
190            if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
191                # (a, b)
192                list_of_nodes = cur_match[1]
193            elif isinstance(cur_match[1][0], tuple):
194                # ((a, b), c)
195                node_a, node_b = cur_match[1][0]
196                node_c = cur_match[1][1]
197                list_of_nodes = _order_nodes(node_a, node_b, node_c)
198            elif isinstance(cur_match[1][1], tuple):
199                # (a, (b, c))
200                node_a, node_b = cur_match[1][1]
201                node_c = cur_match[1][0]
202                list_of_nodes = _order_nodes(node_a, node_b, node_c)
203
204        # [node_n, ..., node_0], note that the order is reversed
205        # to make it chronological for simple subgraphs
206        list_of_nodes.reverse()
207        subgraphs_dedup[name] = list_of_nodes
208
209    return subgraphs_dedup
210
211
212def _get_logger_for_subgraph(
213    model: GraphModule,
214    first_node: Node,
215    last_node: Node,
216    subgraph_idx: int,
217    subgraph_candidate_idx: int,
218    qconfig_str: str,
219    logger_cls: Callable,
220    fqn: Optional[str],
221) -> torch.nn.Module:
222    """
223    Given a model and a linear subgraph starting from `first_node` and
224    ending with `last_node`, creates a logger for the end of this
225    subgraph.
226    """
227    if fqn is None:
228        fqn = ""
229    logger_mod_orig = logger_cls(
230        first_node.name,  # ref_node_name
231        last_node.name,  # prev_node_name
232        f"subgraph_{subgraph_idx}_{subgraph_candidate_idx}",  # model_name
233        "model",  # ref_name
234        get_target_type_str(last_node, model),  # prev_node_target_type
235        get_target_type_str(first_node, model),  # ref_node_target_type
236        NSSingleResultValuesType.NODE_OUTPUT.value,  # results_type
237        0,  # index_within_arg
238        0,  # index_of_arg
239        fqn,  # fqn
240        qconfig_str,
241    )
242    # Usually we expect the user to add loggers, then calibrate, then convert,
243    # and then populate loggers.  This is why the loggers start disabled.
244    # TODO(future PR): reconsider the design to make this more intuitive.
245    logger_mod_orig.enabled = False
246    return logger_mod_orig
247
248
249def create_submodule_from_subgraph(
250    model: torch.nn.Module,
251    first_node: Node,
252    last_node: Node,
253) -> GraphModule:
254    """
255    Input: a model, and a linear subgraph within the model from first_node to
256      last_node.
257
258    Output: a new submodule containing a copy of the subgraph, with the inputs
259      to the first node becoming the inputs to the submodule, and all other
260      nodes in the subgraph being copied.
261
262    Example inputs:
263
264    `model`: a module with graph
265
266      x0 -> op1 -> x1 -> op2 -> x2
267             |
268            arg1
269
270    `first_node`: op1
271    `last_node`: op2
272
273    Example output: a new module with graph
274
275      input1 -> op1_copy -> x1 -> op2_copy -> output1
276                   |
277                  arg1
278    """
279
280    #
281    # create a blank GraphModule with an empty graph
282    #
283
284    class M(torch.nn.Module):
285        def forward(self, x):
286            pass
287
288    m = M()
289    gm = torch.fx.symbolic_trace(m)
290    g = gm.graph
291    for node in reversed(gm.graph.nodes):
292        g.erase_node(node)
293
294    #
295    # modify the graph to have a copy of our subgraph
296    #
297
298    cur_node_orig = first_node
299    cur_args_orig = cur_node_orig.args
300    cur_kwargs_orig = cur_node_orig.kwargs
301
302    cur_name_idx = 0
303
304    iteration_limit = 100
305    cur_iteration = 0
306
307    while True:
308        if cur_node_orig is first_node:
309            # we are at the first node, we need to set up graph inputs
310            # TODO(future): some graphs could have placeholders which are unrelated
311            # to the first node, need to handle this
312            cur_args_copy = []
313            cur_kwargs_copy = {}
314            seen_names: Set[str] = set()
315            old_name_to_new_node: Dict[str, Node] = {}
316
317            def _add_placeholder(
318                g: Graph, node: Node, seen_names, old_name_to_new_node
319            ):
320                # note: for graphs starting with patterns such as `y = x + x`, we
321                # need to ensure we do not add multiple placeholders with the
322                # same name
323                counter = 0
324                while node.name + "_" + str(counter) in seen_names:
325                    counter += 1
326                cur_name = node.name + "_" + str(counter)
327                seen_names.add(cur_name)
328                placeholder = g.placeholder(cur_name)
329                old_name_to_new_node[node.name] = placeholder
330                return placeholder
331
332            for arg in cur_node_orig.args:
333                if isinstance(arg, Node):
334                    p = _add_placeholder(g, arg, seen_names, old_name_to_new_node)
335                    cur_args_copy.append(p)
336                elif isinstance(arg, (list, tuple)):
337                    new_arg = []
338                    for inner_arg in arg:
339                        if isinstance(inner_arg, Node):
340                            new_arg.append(
341                                _add_placeholder(
342                                    g, inner_arg, seen_names, old_name_to_new_node
343                                )
344                            )
345                        else:
346                            new_arg.append(inner_arg)
347                    cur_args_copy.append(new_arg)
348                else:
349                    cur_args_copy.append(arg)
350
351            # TODO(future PR): handle non-normalized kwargs
352            for kwarg_name, kwarg in cur_node_orig.kwargs.items():
353                if isinstance(kwarg, Node):
354                    cur_kwargs_copy[kwarg_name] = _add_placeholder(
355                        g, kwarg, seen_names, old_name_to_new_node
356                    )
357                elif isinstance(kwarg, (list, tuple)):
358                    new_kwarg = []
359                    for inner_kwarg in kwarg:
360                        p = _add_placeholder(
361                            g, inner_kwarg, seen_names, old_name_to_new_node
362                        )
363                        new_kwarg.append(p)
364                    cur_kwargs_copy[kwarg_name] = new_kwarg
365                else:
366                    cur_kwargs_copy[kwarg_name] = kwarg
367
368            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
369        else:
370            # we are not at first node, first arg is from the previous node,
371            # and all other args are copied
372
373            # the current implementation is simplistic and cannot handle
374            # ops with two or more arguments which need to be passed from
375            # the previous op, so we assert them out
376            assert cur_node_orig.target not in BINARY_FUNCTIONS
377
378            # at this point in the code, cur_node_copy is pointing to the copy
379            # of the previous node
380            # TODO(future PR): this is not handling complicated graphs correctly, need to
381            # look at actual relationships instead of assuming sequential graph
382            # TODO(future PR): this is ignoring kwargs, will need to support kwargs
383            # for any fusion pattern which has them for a node that is not the
384            # first node.
385            cur_args_copy = [cur_node_copy]  # type: ignore[has-type, possibly-undefined]  # noqa: F821
386
387            if len(cur_node_orig.args) > 1:
388                for arg in cur_node_orig.args[1:]:
389                    if isinstance(arg, torch.nn.Parameter):
390                        new_arg = arg.clone().detach()  # type: ignore[assignment]
391                        mod_name = f"mod_{cur_name_idx}"
392                        cur_name_idx += 1
393                        setattr(gm, mod_name, new_arg)
394                        new_arg_placeholder = gm.placeholder(mod_name)
395                        cur_args_copy.append(new_arg_placeholder)
396                    elif isinstance(arg, (float, int, torch.dtype)):
397                        cur_args_copy.append(arg)
398                    else:
399                        raise AssertionError(f"arg of type {type(arg)} not handled yet")
400            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
401
402        # copy the node
403        if cur_node_orig.op == "call_module":
404            orig_mod = getattr_from_fqn(model, cur_node_orig.target)  # type: ignore[arg-type]
405            orig_mod_copy = copy.deepcopy(orig_mod)
406            mod_name = f"mod_{cur_name_idx}"
407            setattr(gm, mod_name, orig_mod_copy)
408            cur_name_idx += 1
409            cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)  # type: ignore[possibly-undefined,arg-type]
410
411        elif cur_node_orig.op == "call_function":
412            cur_node_copy = g.call_function(
413                cur_node_orig.target,  # type: ignore[arg-type]
414                cur_args_copy,  # type: ignore[arg-type]
415                cur_kwargs_copy,  # type: ignore[possibly-undefined]
416            )
417
418        elif cur_node_orig.op == "call_method":
419            cur_node_copy = g.call_method(
420                cur_node_orig.target,  # type: ignore[arg-type]
421                cur_args_copy,  # type: ignore[arg-type]
422                cur_kwargs_copy,  # type: ignore[possibly-undefined]
423            )
424
425        else:
426            raise AssertionError(f"{cur_node_orig.op} not supported yet")
427
428        if cur_node_orig is last_node:
429            break
430
431        # go to next node
432        assert (
433            len(cur_node_orig.users.keys()) == 1
434        ), f"{cur_node_orig} has more than 1 users, not supported yet"
435        cur_node_orig = next(iter(cur_node_orig.users.keys()))
436        cur_args_orig = cur_node_orig.args
437        cur_kwargs_orig = cur_node_orig.kwargs
438
439        cur_iteration += 1
440        if cur_iteration > iteration_limit:
441            raise AssertionError("iteration limit exceeded")
442
443    # set up outputs
444    g.output(cur_node_copy)
445
446    gm.recompile()
447    return gm
448
449
450def create_one_transformed_and_logged_copy_of_subgraph(
451    mt: GraphModule,
452    subgraph_idx: int,
453    subgraph_candidate_idx: int,
454    first_node: Node,
455    last_node: Node,
456    fqn: Optional[str],
457    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
458    example_inputs: Any,
459    last_added_shadow_node_list: List[Optional[Node]],
460    custom_prepare_fn: Optional[Callable] = None,
461    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
462) -> None:
463    """
464    Given a subgraph in `mt` and a subgraph candidate idx, inserts the
465    subgraph candidate copy and instruments it with loggers.
466
467    If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
468    add a logger to the end.
469
470    If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
471    prepare it with `prepare_fx`.
472    """
473
474    # TODO(future PR): move logger classes to utils to remove circular dependency
475    from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger
476
477    if subgraph_candidate_idx == 0:
478        # idx = 0 is the floating point (original) version of the subgraph
479        # We keep the subgraph as is, and add a logger at the end
480
481        qconfig_str = ""
482        logger_mod_orig = _get_logger_for_subgraph(
483            mt,
484            first_node,
485            last_node,
486            subgraph_idx,
487            subgraph_candidate_idx,
488            qconfig_str,
489            OutputLogger,
490            fqn,
491        )
492
493        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
494        assert not hasattr(mt, attr_name)
495        setattr(mt, attr_name, logger_mod_orig)
496        with mt.graph.inserting_after(last_node):
497            new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
498            last_added_shadow_node_list[0] = new_node
499
500    else:
501        # idx > 0 means we have a candidate qconfig to try, so we need
502        # to make a copy of the subgraph, feed it with the right inputs,
503        # and add a logger at the end
504
505        # get the qconfig
506        # subtract one because the first candidate is the floating point
507        # version of the subgraph
508        node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
509        qconfig = node_name_to_qconfig[first_node.name]
510
511        # if no quantization is requested, skip
512        # TODO(future PR): deduplicate equivalent qconfigs that come from
513        #   different qconfig mapping objects
514        if qconfig is None:
515            return
516
517        qconfig_mapping = QConfigMapping().set_global(qconfig)
518
519        # create a copy of the submodule, wrapped in a separate module
520        orig_mod_copy_wrapped = create_submodule_from_subgraph(
521            mt, first_node, last_node
522        )
523
524        # add a call to prepare_fx on the wrapper module
525        if custom_prepare_fn is None:
526            orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
527                orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs
528            )
529        else:
530            if custom_prepare_kwargs is None:
531                custom_prepare_kwargs = {}
532            for kwarg_name in [
533                "example_inputs",
534                "prepare_custom_config",
535                "qconfig_mapping",
536            ]:
537                assert (
538                    kwarg_name not in custom_prepare_kwargs
539                ), f"cannot specify {kwarg_name} in custom_prepare_kwargs"
540            prepare_kwargs: Dict[str, Any] = {
541                "example_inputs": example_inputs,
542                "qconfig_mapping": qconfig_mapping,
543            }
544            prepare_kwargs.update(custom_prepare_kwargs)
545            orig_mod_copy_wrapped = custom_prepare_fn(
546                orig_mod_copy_wrapped, **prepare_kwargs
547            )
548
549        # attach the wrapper to the model
550        attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
551        assert not hasattr(mt, attr_name)
552        setattr(mt, attr_name, orig_mod_copy_wrapped)
553
554        # add a call to the wrapper module from the parent graph
555        insert_after_node = last_added_shadow_node_list[0]
556        with mt.graph.inserting_after(insert_after_node):
557            # TODO(future PR): handle fusion patterns where non-first nodes
558            # need inputs
559
560            # pass in all node args and kwargs
561
562            new_args = []
563            for arg in first_node.args:
564                if isinstance(arg, Node):
565                    new_args.append(arg)
566                elif (
567                    isinstance(arg, (list, tuple))
568                    and len(arg)
569                    and isinstance(arg[0], Node)
570                ):
571                    for inner_arg in arg:
572                        if isinstance(inner_arg, Node):
573                            new_args.append(inner_arg)
574
575            new_kwargs = {}
576            for name, old_kwarg in first_node.kwargs.items():
577                if isinstance(old_kwarg, Node):
578                    new_kwargs[name] = old_kwarg
579                elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
580                    # TODO(future PR): clarify why we are adding kwargs to args
581                    new_args.extend(old_kwarg)
582
583            new_args = tuple(new_args)  # type: ignore[assignment]
584
585            new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)  # type: ignore[arg-type]
586
587        # add a logger to parent graph to observe the shadow wrapper
588        logger_mod_orig = _get_logger_for_subgraph(
589            mt,
590            first_node,
591            last_node,
592            subgraph_idx,
593            subgraph_candidate_idx,
594            str(qconfig),
595            OutputComparisonLogger,
596            fqn,
597        )
598
599        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
600        assert not hasattr(mt, attr_name)
601        setattr(mt, attr_name, logger_mod_orig)
602        with mt.graph.inserting_after(new_node):
603            logger = mt.graph.call_module(
604                attr_name, args=(new_node, last_node), kwargs={}
605            )
606            last_added_shadow_node_list[0] = logger
607
608    mt.recompile()
609
610
611def create_n_transformed_and_logged_copies_of_subgraph(
612    mt: GraphModule,
613    subgraph_idx: int,
614    match_name: str,
615    nodes_in_this_subgraph: List[Any],
616    qconfig_mappings: List[QConfigMapping],
617    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
618    custom_prepare_fn: Optional[Callable] = None,
619    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
620) -> None:
621    """
622    Given a model `mt` and a subgraph_idx, creates the needed copies
623    of the subgraph for all qconfigs, and instruments them with loggers.
624    """
625    # for now, assume that
626    # 1. the first node has one input
627    # 2. the last node has one output
628
629    # for now, ignore all subgraphs that contain non-nodes (tuples, etc)
630    # TODO(future PR): implement this
631    if any(not isinstance(node, Node) for node in nodes_in_this_subgraph):
632        return
633
634    first_node = nodes_in_this_subgraph[0]
635    last_node = nodes_in_this_subgraph[-1]
636    # We used output propagation to populate example values on each
637    # node. Use the example values from the previous node as the input
638    # to the current node.
639    prev_node = get_normalized_nth_input(first_node, mt, 0)
640    if isinstance(prev_node, list):
641        example_inputs = [x.traced_result for x in prev_node]
642    elif isinstance(prev_node, tuple):
643        example_inputs = (x.traced_result for x in prev_node)  # type: ignore[assignment]
644    else:
645        # currently some customer models do not have a traced_result in
646        # every node, so we have to guard for this case since we cannot
647        # quantize without an example input
648        # TODO(future PR): add a test case for this once we have an easy
649        # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
650        # for additional context
651        if hasattr(prev_node, "traced_result"):
652            example_inputs = (prev_node.traced_result,)  # type: ignore[attr-defined, assignment]
653        else:
654            print(
655                "unable to get example input for node "
656                + f"{first_node.format_node()}, skipping"
657            )
658            return
659
660    # If there are no quantization configs for this subgraph, skip adding
661    # loggers. This reduces memory usage for models where not all layers are
662    # quantized.
663    # TODO(future): consider making this configurable
664    found_at_least_one_qconfig = False
665    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
666        if subgraph_candidate_idx == 0:
667            # fp32 baseline does not need a qconfig
668            continue
669
670        # a. we have N shadows, so len(qconfig_mappings) is N
671        # b. we will have the fp32 layer + N shadows, so overall number of
672        #    (original_op) + (*shadows) will be N+1
673        # c. since `subgraph_candidate_idx` represents (b), we need
674        #    to subtract 1 to query from (a)
675        node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
676        qconfig = node_name_to_qconfig[first_node.name]
677        if qconfig is not None:
678            found_at_least_one_qconfig = True
679            break
680    if not found_at_least_one_qconfig:
681        print(
682            "unable to find at least one qconfig for node "
683            + f"{first_node.format_node()}, skipping"
684        )
685        return
686
687    fqn = _maybe_get_fqn(first_node, mt)
688
689    # We want the results to contain the subgraphs in natural order,
690    # and the graph to also contain shadow wrappers and shadow loggers
691    # in natural order.
692    # If we just iterate in reverse, the graph will be in natural
693    # order but the eventual results will be in reverse order.
694    # So, we keep track of the last shadow logger we added and
695    # always insert after it.
696    last_added_shadow_node_list: List[Optional[Node]] = [None]
697    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
698        create_one_transformed_and_logged_copy_of_subgraph(
699            mt,
700            subgraph_idx,
701            subgraph_candidate_idx,
702            first_node,
703            last_node,
704            fqn,
705            list_of_node_name_to_qconfig,
706            example_inputs,
707            last_added_shadow_node_list,
708            custom_prepare_fn,
709            custom_prepare_kwargs,
710        )
711
712
713def create_add_loggers_graph(
714    model: GraphModule,
715    subgraphs_dedup: Dict[str, List[Node]],
716    qconfig_mapping: QConfigMapping,
717    node_name_to_qconfig: Dict[str, QConfigAny],
718) -> None:
719    r"""
720    Given a model, a model graph partition (currently a set of matched
721    subgraphs) and instructions how to transform each subgraph
722    (currently quantizing it according to qconfig_mapping), modifies
723    the model graph to create an alternate path through the original graph,
724    with each of the subgraphs quantized.  This is useful to compare
725    propagation error of a transformation such as quantization.
726
727    For example, given layer op0 and op1, there are four cases when handling op1:
728    1. op0 and op1 quantized
729    2. op0 and op1 unquantized
730    3. op0 quantized, op1 unquantized
731    4. op0 unquantized, op1 quantized
732
733    Example input, case 1:
734
735    .. code::
736
737      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
738       \                        \          \                 \       # noqa: W605
739         ---> op0_1 -> x1_1 ----> clog    op1_1 -> x2_1 ----> clog
740
741    Example output, case 1:
742
743    .. code::
744
745      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
746       \                        \                           \        # noqa: W605
747         ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
748
749    """
750    # TODO(future PR): move logger classes to utils to remove circular dependency
751    from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger
752
753    def _get_subgraph_containing_node(node, subgraphs_dedup):
754        for subgraph in subgraphs_dedup.values():
755            if node in subgraph:
756                return subgraph
757        return None
758
759    # First, we need to create shadow branches, going from
760    #
761    #   x0 -> op0 -> x1 -> ...
762    #
763    #
764    # to
765    #
766    #   x0 -> op0_0 -> x1_0 -> log -> ...
767    #    \                     \
768    #      -> op0_1 -> x1_1 -> clog
769    #
770    # Later, the outputs of each shadow will be rerouted to calculate
771    # propagation error.
772
773    # Note: we cannot iterate over matched subgraphs because some nodes
774    # may not be matched. So, we iterate over nodes in the graph, and
775    # associate them to matched subgraphs if possible.
776
777    nodes_to_skip = set()
778    # for each subgraph, save a mapping from first node of subgraph
779    # to first and last node of the shadow of this subgraph
780    orig_first_node_to_shadow_in_node = {}
781    orig_first_node_to_shadow_out_node = {}
782    # need to record original list because we will mutate the graph as we go
783    orig_nodes = list(model.graph.nodes)  # type: ignore[union-attr, arg-type]
784    cur_subgraph_idx = 0
785    for n in orig_nodes:
786        if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip:
787            continue
788
789        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
790        insert_submodule_copy = False
791        if maybe_subgraph is not None:
792            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
793            nodes_to_skip.update(maybe_subgraph)
794            qconfig = node_name_to_qconfig[first_node.name]
795            if qconfig is not None:
796                insert_submodule_copy = True
797        else:
798            first_node, last_node = n, n
799
800        if insert_submodule_copy:
801            match_name = first_node.name
802            create_n_transformed_and_logged_copies_of_subgraph(
803                model,
804                cur_subgraph_idx,
805                match_name,
806                maybe_subgraph,
807                [qconfig_mapping],
808                [node_name_to_qconfig],
809                None,
810                None,  # type: ignore[arg-type]
811            )
812            # find the created shadow module and record it so we
813            # can find it easily in step 2
814            expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
815            new_shadow_mod = None
816            for maybe_shadow_mod in model.graph.nodes:
817                if (
818                    maybe_shadow_mod.op == "call_module"
819                    and maybe_shadow_mod.target == expected_shadow_target
820                ):
821                    new_shadow_mod = maybe_shadow_mod
822                    break
823            assert new_shadow_mod is not None
824            orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
825            orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
826
827        else:
828            # create a copy of the subgraph by only copying FX nodes
829            # but not copying any parameters, to minimize memory usage
830            subgraph_to_use = (
831                maybe_subgraph if maybe_subgraph is not None else [first_node]
832            )
833
834            # add a regular logger after last_node
835            qconfig_str = ""
836            subgraph_candidate_idx = 0
837            fqn = _maybe_get_fqn(first_node, model)
838            logger_mod_orig = _get_logger_for_subgraph(
839                model,
840                first_node,
841                last_node,
842                cur_subgraph_idx,
843                subgraph_candidate_idx,
844                qconfig_str,
845                OutputLogger,
846                fqn,
847            )
848            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
849            assert not hasattr(model, attr_name)
850            setattr(model, attr_name, logger_mod_orig)
851            insertion_point = last_node
852            with model.graph.inserting_after(insertion_point):
853                logger = model.graph.call_module(
854                    attr_name, args=(last_node,), kwargs={}
855                )
856                insertion_point = logger
857
858            # create a copy of the subgraph
859            cur_node_orig = first_node
860            cur_node_copy = None
861            first_node_copy = None
862            while cur_node_orig in subgraph_to_use:
863                # TODO(future PR): make this support all possible args/kwargs
864                if cur_node_orig is first_node:
865                    new_args = cur_node_orig.args
866                    new_kwargs = cur_node_orig.kwargs
867                else:
868                    first_arg_for_copy = cur_node_copy
869                    new_args = (first_arg_for_copy, *cur_node_orig.args[1:])
870                    new_kwargs = cur_node_orig.kwargs
871                # make a copy of cur_node_orig
872                with model.graph.inserting_after(insertion_point):
873                    cur_node_copy = model.graph.create_node(
874                        cur_node_orig.op,
875                        cur_node_orig.target,
876                        new_args,
877                        new_kwargs,
878                        # cur_node_orig.name,  # TODO(future PR): set name explicitly
879                    )
880                    if first_node_copy is None:
881                        first_node_copy = cur_node_copy
882                # since now only linear subgraphs are supported, all nodes
883                # except the last one must have only one user
884                if cur_node_orig != last_node:
885                    assert len(cur_node_orig.users.keys()) == 1
886                cur_node_orig = next(iter(cur_node_orig.users.keys()))
887                assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
888                insertion_point = cur_node_copy
889
890            # add a comparison logger after last_node's copy
891            subgraph_candidate_idx = 1
892            logger_mod_orig = _get_logger_for_subgraph(
893                model,
894                first_node,
895                last_node,
896                cur_subgraph_idx,
897                subgraph_candidate_idx,
898                qconfig_str,
899                OutputComparisonLogger,
900                fqn,
901            )
902            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
903            assert not hasattr(model, attr_name)
904            setattr(model, attr_name, logger_mod_orig)
905            with model.graph.inserting_after(insertion_point):
906                logger = model.graph.call_module(
907                    attr_name, args=(cur_node_copy, last_node), kwargs={}
908                )
909
910            # save the final node so we can use it in step 2
911            orig_first_node_to_shadow_in_node[first_node] = first_node_copy
912            orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
913
914        cur_subgraph_idx += 1
915
916    model.recompile()
917
918    # Now, we go from
919    #
920    #   x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
921    #    \                     \       \
922    #      -> op0_1 -> x1_1 -> clog      -> op1_1 -> ...
923    #
924    # to
925    #
926    #   x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
927    #    \                     \
928    #      -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
929    #
930    # sample values of key internal variables for the example above:
931    #
932    #   orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
933    #   orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
934    #
935    # note: for subgraphs with more than one node, in_node will be different
936    # compared to out_node
937
938    nodes_to_skip = set()
939    for n in orig_nodes:
940        if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip:
941            continue
942
943        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
944        if maybe_subgraph is not None:
945            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
946            nodes_to_skip.update(maybe_subgraph)
947        else:
948            first_node, last_node = n, n
949
950        def maybe_remap_node_to_shadow(node):
951            """
952            If unshadowed `node` has a shadow version, return that. If not,
953            return `node`.
954            """
955            if not isinstance(node, Node):
956                # handle scalars
957                return node
958
959            if node.op in ("placeholder", "get_attr"):
960                return node
961
962            # Find the shadowed version of this arg from the previous
963            # subgraph. For this, we need to:
964            # 1. navigate to the first node of the previous subgraph
965            # 2. get the output of the shadow wrapper which has (1) as an input
966
967            # For now, assume the arg is in matched subgraphs. In the
968            # future we may have to handle the case where this is not true.
969            prev_subgraph = _get_subgraph_containing_node(node, subgraphs_dedup)
970            if prev_subgraph is None:
971                prev_subgraph = [node]
972            prev_first_node = prev_subgraph[0]
973            prev_shadow_output = orig_first_node_to_shadow_out_node[prev_first_node]
974            return prev_shadow_output
975
976        cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
977        assert cur_shadow_input is not None
978        cur_shadow_input.args = tree_map(
979            maybe_remap_node_to_shadow, cur_shadow_input.args
980        )
981        cur_shadow_input.kwargs = tree_map(
982            maybe_remap_node_to_shadow, cur_shadow_input.kwargs
983        )
984
985        model.recompile()
986
987
988def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
989    # input: shadow wrapper module
990    # output if shadow wrapper module has a weighted op:
991    #   (quantize_fn, (quantize_fn_args))
992    # output if shadow wrapper module doesn't have a weighted op:
993    #   None
994
995    # For now, assume that the weight is the second input
996    # to the shadow module. If that changes, we can fix it later.
997    placeholders_seen = 0
998    for shadow_n in shadow_wrapper.graph.nodes:  # type: ignore[union-attr]
999        if shadow_n.op != "placeholder":
1000            continue
1001
1002        placeholders_seen += 1
1003        if placeholders_seen != 2:
1004            continue
1005
1006        # the subgraph looks like
1007        #
1008        #   _input_scale_1 = self._input_scale_1
1009        #   _input_zero_point_1 = self._input_zero_point_1
1010        #   quantize_per_channel = torch.quantize_per_channel(
1011        #       w2_0, _input_scale_1, _input_zero_point_1,
1012        #       0, torch.qint8)
1013        #
1014        #  we have `w2_0`, and are navigating this subgraph
1015        #  to get `_input_scale_1` and `_input_zero_point_1`
1016
1017        assert len(shadow_n.users) == 1
1018        quant_node = next(iter(shadow_n.users.keys()))
1019        new_args: Any = None
1020        if quant_node.target == torch.quantize_per_channel:
1021            _weight, scale_node, zp_node, axis, dtype = quant_node.args
1022            scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
1023            zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
1024            new_args = (scale_val, zp_val, axis, dtype)
1025        else:
1026            assert quant_node.target == torch.quantize_per_tensor
1027            _weight, scale_node, zp_node, dtype = quant_node.args
1028            scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
1029            zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
1030            new_args = (scale_val, zp_val, dtype)
1031        return (quant_node.target, new_args)
1032
1033    return None
1034
1035
1036def extract_weight_comparison(m: GraphModule) -> NSResultsType:
1037    # example graph:
1038    #
1039    #   w1 = self.w1
1040    #   b1 = self.b1
1041    #   linear = torch._C._nn.linear(x, w1, b1)
1042    #   shadow_0_0 = self.shadow_0_0(linear)
1043    #   shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
1044    #   shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
1045    #
1046    # algorithm:
1047    # 1. for each call_function node matching our allowlist:
1048    # 2.   if corresponding shadow wrapper exists, extract the weight pair
1049    #
1050    # Note: this is not super robust, but that's ok because this is
1051    # just for legacy customers who depend on the previous two-model version
1052    # of this API. TBD if we need to make this robust.
1053    # Note: modules are not supported, since existing customers only
1054    # use functions.
1055
1056    # TODO(future PR): move this to config
1057    weighted_ops = {
1058        torch.nn.functional.linear,
1059    }
1060
1061    results: NSResultsType = {"model": {NSSingleResultValuesType.WEIGHT.value: {}}}
1062
1063    for n in m.graph.nodes:  # type: ignore[union-attr]
1064        if not (n.op == "call_function" and n.target in weighted_ops):
1065            continue
1066
1067        # Check if we have a corresponding shadow wrapper
1068        # TODO(future PR, if needed): support kwargs
1069        # TODO(future PR, if needed): support multiple shadow users
1070        first_arg = n.args[0]
1071        shadow_wrapper_node = None
1072        for user in first_arg.users:
1073            # TODO(before land): fix string match
1074            if user.op == "call_module" and user.target.startswith("shadow_wrapper"):
1075                shadow_wrapper_node = user
1076                break
1077
1078        if shadow_wrapper_node is None:
1079            continue
1080
1081        shadow_wrapper = getattr_from_fqn(
1082            m, shadow_wrapper_node.target
1083        )  # type: ignore[arg-type]
1084        weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper)
1085        if weight_info is None:
1086            continue
1087
1088        # get weight
1089        w_node = n.args[1]
1090        w_obj = getattr_from_fqn(m, w_node.target).detach()
1091
1092        # get a quantized version of weight
1093        quant_fn, quant_fn_args_except_first = weight_info
1094        new_args = (w_obj, *quant_fn_args_except_first)
1095        w_obj_q = quant_fn(*new_args)
1096
1097        # add a comparison
1098        ref_node_name = n.name
1099        prev_node_name = n.name
1100        ref_node_type = get_target_type_str(n, m)
1101        prev_node_type = ref_node_type
1102        fqn = None
1103        if hasattr(m, "_node_name_to_scope"):
1104            fqn = m._node_name_to_scope[n.name][0]  # type: ignore[index]
1105        comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
1106        result_fp32 = {
1107            "res_type": NSSingleResultValuesType.WEIGHT.value,
1108            "values": [w_obj],
1109            "prev_node_name": prev_node_name,
1110            "prev_node_target_type": prev_node_type,
1111            "ref_node_name": ref_node_name,
1112            "ref_node_target_type": ref_node_type,
1113            "index_within_arg": 0,
1114            "index_of_arg": 0,
1115            "fqn": fqn,
1116            "qconfig_str": "",
1117            "comparisons": [comparison],
1118            "comparison_fn_name": "sqnr",
1119        }
1120        result_q = {
1121            "res_type": NSSingleResultValuesType.WEIGHT.value,
1122            "values": [w_obj_q],
1123            "prev_node_name": prev_node_name,
1124            "prev_node_target_type": prev_node_type,
1125            "ref_node_name": ref_node_name,
1126            "ref_node_target_type": ref_node_type,
1127            "index_within_arg": 0,
1128            "index_of_arg": 0,
1129            "fqn": fqn,
1130            "qconfig_str": "",
1131            "comparisons": [comparison],
1132            "comparison_fn_name": "sqnr",
1133        }
1134
1135        # go from subgraph_n_1 to subgraph_n_0
1136        _1, _2, node_idx, _3 = shadow_wrapper_node.target.split("_")
1137        name_fp32 = f"subgraph_{node_idx}_0"
1138        name_q = f"subgraph_{node_idx}_1"
1139
1140        results["model"][NSSingleResultValuesType.WEIGHT.value][name_fp32] = [
1141            result_fp32
1142        ]
1143        results["model"][NSSingleResultValuesType.WEIGHT.value][name_q] = [result_q]
1144
1145    return results
1146
1147
1148# TODO(future PR): redesign this to make it easier to consume outputs
1149def group_results_by_subgraph(results: NSResultsType) -> Any:
1150    """
1151    Creates a comparison of results
1152
1153    Input:
1154
1155    {
1156      'model': {
1157        'node_output': {
1158          'subgraph_0_0': [
1159            'values': [torch.tensor(...), ...], ...
1160            'ref_node_name': ...,
1161            'ref_node_target_type': ...,
1162            'qconfig_str': ...,
1163            'comparisons': [], ...
1164            'comparison_fn_name': '',
1165            'fqn': '...',
1166          ],
1167          'subgraph_0_1': [
1168            'values': [torch.tensor(...), ...], ...
1169            'ref_node_name': ...,
1170            'ref_node_target_type': ...,
1171            'qconfig_str': ...,
1172            'comparisons': [torch.tensor(...), ...], ...
1173            'comparison_fn_name': '...',
1174            'fqn': '...',
1175          ],
1176          ...
1177        },
1178      },
1179    }
1180
1181    Output:
1182    {
1183      'subgraph_0': {
1184        '0': {
1185          'ref_node_name': '...',
1186          'ref_node_target_type': ...,
1187          'values': [torch.tensor(...), ...],
1188          'qconfig_str': None,
1189          'comparisons': [torch.tensor(...), ...], ...
1190          'comparison_fn_name': '...',
1191          'fqn': '...',
1192        },
1193        '1': {
1194          'ref_node_name': '...',
1195          'ref_node_target_type': ...,
1196          'values': [torch.tensor(...), ...],
1197          'qconfig_str': '...',
1198          'comparisons': [torch.tensor(...), ...], ...
1199          'comparison_fn_name': '...',
1200          'fqn': '...',
1201        },
1202      },
1203    }
1204
1205    """
1206    subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
1207
1208    # node_output or weight
1209    key_to_use = next(iter(results["model"].keys()))
1210
1211    for subgraph_name_with_idx, subgraph_candidate_results in results["model"][
1212        key_to_use
1213    ].items():
1214        # convert from `subgraph_m_n` to `subgraph_m` and `n`
1215        (
1216            subgraph_str,
1217            subgraph_idx,
1218            subgraph_candidate_idx,
1219        ) = subgraph_name_with_idx.split("_")
1220        subgraph_name = f"{subgraph_str}_{subgraph_idx}"
1221
1222        subgraph_results = {
1223            "ref_node_name": subgraph_candidate_results[0]["ref_node_name"],
1224            "ref_node_target_type": subgraph_candidate_results[0][
1225                "ref_node_target_type"
1226            ],
1227            "fqn": subgraph_candidate_results[0]["fqn"],
1228            "values": subgraph_candidate_results[0]["values"],
1229            "qconfig_str": subgraph_candidate_results[0]["qconfig_str"],
1230            "comparisons": subgraph_candidate_results[0]["comparisons"],
1231            "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"],
1232        }
1233
1234        subgraph_name_to_subgraph_results[subgraph_name][
1235            subgraph_candidate_idx
1236        ] = subgraph_results
1237
1238    return dict(subgraph_name_to_subgraph_results)
1239
1240
1241# TODO(future PR): redesign this to make it easier to consume outputs
1242def create_results_comparison(
1243    results_grouped,
1244) -> Any:
1245    """
1246    Input:
1247
1248    {
1249      'subgraph_0': {
1250        '0': {
1251          'ref_node_name': '...',
1252          'ref_node_target_type': ...,
1253          'values': [torch.tensor(...), ...],
1254          'qconfig_str': '',
1255          'comparisons': [],
1256          'comparison_fn_name': '',
1257          'fqn': '...',
1258        },
1259        '1': {
1260          'ref_node_name': '...',
1261          'ref_node_target_type': ...,
1262          'values': [torch.tensor(...), ...],
1263          'qconfig_str': '...',
1264          'comparisons': [torch.tensor(...), ...],
1265          'comparison_fn_name': 'sqnr',
1266          'fqn': '...',
1267        },
1268      },
1269    }
1270
1271    Output:
1272    {
1273      'subgraph_0': {
1274        'ref_node_name': '...',
1275        'ref_node_target_type': '...',
1276        'fqn': '...',
1277        'candidates': {
1278          '1': {
1279            'qconfig_str': ...,
1280            'comparison_fn_name': 'sqnr',
1281            'cmp_raw': [..., ...],
1282            'cmp_mean': ...,
1283          },
1284          ...,
1285        },
1286      },
1287    }
1288    """
1289
1290    results_comparison = {}
1291
1292    for subgraph_name, subgraph_results in results_grouped.items():
1293        candidates = {}
1294        for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
1295            # skip comparing baseline to baseline
1296            if subgraph_inner_name == "0":
1297                continue
1298
1299            # we expect the comparisons to be precalculated from
1300            # calibration, so we just fetch them here
1301            cmp_raw = subgraph_inner_result["comparisons"]
1302            cmp_raw_tensor = torch.stack(cmp_raw)
1303
1304            candidates[subgraph_inner_name] = {
1305                "qconfig_str": subgraph_inner_result["qconfig_str"],
1306                "comparison_fn_name": subgraph_inner_result["comparison_fn_name"],
1307                "cmp_raw": cmp_raw_tensor,
1308                "cmp_mean": torch.mean(cmp_raw_tensor),
1309            }
1310
1311        results_comparison[subgraph_name] = {
1312            "ref_node_name": subgraph_results["0"]["ref_node_name"],
1313            "ref_node_target_type": subgraph_results["0"]["ref_node_target_type"],
1314            "fqn": subgraph_results["0"]["fqn"],
1315            "candidates": candidates,
1316        }
1317
1318    return results_comparison
1319
1320
1321# TODO(future PR): redesign this to make it easier to consume outputs
1322def print_n_shadows_summary(
1323    results_comparison,
1324) -> None:
1325    """
1326    Input:
1327
1328    {
1329      'subgraph_0': {
1330        'ref_node_name': 'linear1',
1331        'ref_node_target_type': '...',
1332        'fqn': '...',
1333        'candidates': {
1334          '1': {
1335            'qconfig_str': ...,
1336            'comparison_fn_name': ...,
1337            'cmp_raw': [45.0, 55.0],
1338            'cmp_mean': 50.0,
1339          },
1340          ...,
1341        },
1342      },
1343    }
1344
1345    Prints:
1346
1347    node_name | node_type | fqn | 0    | 1    | ...
1348    linear1   | ...       | ... | 45.0 | 50.0 | ...
1349    """
1350
1351    try:
1352        from tabulate import tabulate
1353    except ImportError:
1354        print(
1355            "`print_tabular` relies on the library `tabulate`, "
1356            "which could not be found on this machine. Run `pip "
1357            "install tabulate` to install the library."
1358        )
1359        return
1360
1361    results = []
1362    for subgraph_data in results_comparison.values():
1363        mean_all_candidates = [
1364            candidate["cmp_mean"]
1365            for candidate_name, candidate in subgraph_data["candidates"].items()
1366        ]
1367
1368        data_row = [
1369            subgraph_data["ref_node_name"],
1370            subgraph_data["ref_node_target_type"],
1371            subgraph_data["fqn"],
1372            *mean_all_candidates,
1373        ]
1374        results.append(data_row)
1375
1376    max_candidate_idx_len = -1
1377    for data_row in results:
1378        max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
1379    candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
1380
1381    headers = ["node_name", "node_type", "fqn", *candidate_idx_headers]
1382    print(tabulate(results, headers=headers))
1383