1# mypy: allow-untyped-defs 2import collections 3import copy 4import operator 5from typing import Any, Callable, Dict, List, Optional, Set, Tuple 6 7import torch 8import torch.fx 9from torch.ao.ns.fx.graph_passes import _maybe_get_fqn 10from torch.ao.ns.fx.ns_types import NSResultsType, NSSingleResultValuesType 11from torch.ao.ns.fx.utils import ( # TODO(future PR): make this work correctly for methods 12 get_normalized_nth_input, 13 get_target_type_str, 14) 15from torch.ao.quantization import QConfigMapping 16from torch.ao.quantization.fx.match_utils import _MatchResult 17from torch.ao.quantization.qconfig import QConfigAny 18from torch.ao.quantization.utils import getattr_from_fqn 19from torch.fx import Graph, GraphModule, Node 20from torch.utils._pytree import tree_map 21 22 23SHADOW_NODE_NAME_PREFIX = "shadow" 24SHADOW_WRAPPER_NODE_NAME_PREFIX = "shadow_wrapper" 25 26# TODO(future PR): reuse existing mapping instead of creating a new one 27BINARY_FUNCTIONS = { 28 torch.add, 29 torch.Tensor.add, 30 operator.add, 31 torch.mul, 32 torch.Tensor.mul, 33 operator.mul, 34} 35 36 37def _get_attr_name(subgraph_idx, subgraph_candidate_idx): 38 return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" 39 40 41def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx): 42 return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" 43 44 45class OutputProp: 46 """ 47 Output propagation (modeled from shape propagation). 48 49 Given a GraphModule and an example input, saves the output flowing 50 through each node on `node.traced_result`. 51 52 Code based on the example from 53 https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern 54 """ 55 56 def __init__(self, mod): 57 self.mod = mod 58 self.graph = mod.graph 59 self.modules = dict(self.mod.named_modules()) 60 61 def propagate(self, *args): 62 args_iter = iter(args) 63 env: Dict[str, Node] = {} 64 65 def load_arg(a): 66 return torch.fx.graph.map_arg(a, lambda n: env[n.name]) 67 68 def fetch_attr(target: str): 69 target_atoms = target.split(".") 70 attr_itr = self.mod 71 for i, atom in enumerate(target_atoms): 72 if not hasattr(attr_itr, atom): 73 raise RuntimeError( 74 f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" 75 ) 76 attr_itr = getattr(attr_itr, atom) 77 return attr_itr 78 79 for node in self.graph.nodes: 80 if node.op == "placeholder": 81 result = next(args_iter) 82 elif node.op == "get_attr": 83 result = fetch_attr(node.target) 84 elif node.op == "call_function": 85 result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 86 elif node.op == "call_method": 87 self_obj, *args = load_arg(node.args) 88 kwargs = load_arg(node.kwargs) 89 result = getattr(self_obj, node.target)(*args, **kwargs) 90 elif node.op == "call_module": 91 result = self.modules[node.target]( 92 *load_arg(node.args), **load_arg(node.kwargs) 93 ) 94 95 if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] 96 node.traced_result = result 97 98 env[node.name] = result 99 100 return None 101 102 103def _get_dedup_subgraphs(matches: Dict[str, _MatchResult]) -> Dict[str, List[Node]]: 104 # the original matches variable is unique by node, make it unique by subgraph 105 # instead 106 seen_nodes = set() 107 subgraphs_dedup = {} 108 109 # Dict items are not reversible until Python 3.8, so we hack it 110 # to be compatible with previous Python versions 111 # TODO(future PR): try reversed(list(matches.items())) 112 matches_items_reversed: List[Tuple[str, _MatchResult]] = [] 113 for name, cur_match in matches.items(): 114 matches_items_reversed.insert(0, (name, cur_match)) 115 116 # Note: the order is important. `matches` currently provides the matches 117 # in reverse order. We would like to process the matches in non-reverse 118 # order, so that we can create an intuitive naming scheme, such as 119 # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)` 120 for name, cur_match in matches_items_reversed: # type: ignore[call-overload] 121 was_seen = False 122 for node_or_tuple in cur_match[1]: 123 # Cur_match[1] has an unusual type. It says that it's a `List[Node]`, 124 # but it is really not. Furthermore, the contents of this field 125 # can change from match results of multiple nodes of the same pattern 126 # 127 # For example, for conv -> bn -> relu, we see 128 # match_results = { 129 # 'conv': (relu, [(bn, conv), relu], ...), 130 # 'bn': (relu, [(bn, conv), relu], ...), 131 # 'relu': (relu, [(bn, conv), relu], ...), 132 # } 133 # 134 # Ideally we should clean up the `find_matches` function to make 135 # this more intuitive. For the purposes of this prototype, we hack 136 # around it. 137 138 if isinstance(node_or_tuple, Node): 139 if node_or_tuple in seen_nodes: 140 was_seen = True 141 seen_nodes.add(node_or_tuple) 142 143 else: 144 assert isinstance(node_or_tuple, tuple) 145 for node in node_or_tuple: 146 assert isinstance(node, Node) 147 if node in seen_nodes: 148 was_seen = True 149 seen_nodes.add(node) 150 151 if was_seen: 152 continue 153 154 # Start with the unusual type, convert it to [op_0, ..., op_n] 155 list_of_nodes = [] 156 157 if len(cur_match[1]) == 1: 158 list_of_nodes = cur_match[1] 159 else: 160 assert len(cur_match[1]) == 2 161 # either (a, b), or ((a, b), c) or (c, (a, b)) 162 # cannot make any assumptions on order, not clear what the 163 # _find_matches function is doing to populate this 164 # TODO(future PR): make this code less confusing, see discussion 165 # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 166 167 def _order_nodes(node_a, node_b, node_c) -> List[Node]: 168 nodes = [node_a, node_b, node_c] 169 first_node = None 170 mid_node = None 171 last_node = None 172 for n in nodes: 173 prev_n = n.args[0] 174 next_n = next(iter(n.users)) 175 if prev_n not in nodes: 176 first_node = n 177 elif next_n not in nodes: 178 last_node = n 179 else: 180 mid_node = n 181 assert ( 182 first_node is not None 183 and mid_node is not None 184 and last_node is not None 185 ) 186 assert mid_node.args[0] is first_node 187 assert last_node.args[0] is mid_node 188 return [last_node, mid_node, first_node] 189 190 if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node): 191 # (a, b) 192 list_of_nodes = cur_match[1] 193 elif isinstance(cur_match[1][0], tuple): 194 # ((a, b), c) 195 node_a, node_b = cur_match[1][0] 196 node_c = cur_match[1][1] 197 list_of_nodes = _order_nodes(node_a, node_b, node_c) 198 elif isinstance(cur_match[1][1], tuple): 199 # (a, (b, c)) 200 node_a, node_b = cur_match[1][1] 201 node_c = cur_match[1][0] 202 list_of_nodes = _order_nodes(node_a, node_b, node_c) 203 204 # [node_n, ..., node_0], note that the order is reversed 205 # to make it chronological for simple subgraphs 206 list_of_nodes.reverse() 207 subgraphs_dedup[name] = list_of_nodes 208 209 return subgraphs_dedup 210 211 212def _get_logger_for_subgraph( 213 model: GraphModule, 214 first_node: Node, 215 last_node: Node, 216 subgraph_idx: int, 217 subgraph_candidate_idx: int, 218 qconfig_str: str, 219 logger_cls: Callable, 220 fqn: Optional[str], 221) -> torch.nn.Module: 222 """ 223 Given a model and a linear subgraph starting from `first_node` and 224 ending with `last_node`, creates a logger for the end of this 225 subgraph. 226 """ 227 if fqn is None: 228 fqn = "" 229 logger_mod_orig = logger_cls( 230 first_node.name, # ref_node_name 231 last_node.name, # prev_node_name 232 f"subgraph_{subgraph_idx}_{subgraph_candidate_idx}", # model_name 233 "model", # ref_name 234 get_target_type_str(last_node, model), # prev_node_target_type 235 get_target_type_str(first_node, model), # ref_node_target_type 236 NSSingleResultValuesType.NODE_OUTPUT.value, # results_type 237 0, # index_within_arg 238 0, # index_of_arg 239 fqn, # fqn 240 qconfig_str, 241 ) 242 # Usually we expect the user to add loggers, then calibrate, then convert, 243 # and then populate loggers. This is why the loggers start disabled. 244 # TODO(future PR): reconsider the design to make this more intuitive. 245 logger_mod_orig.enabled = False 246 return logger_mod_orig 247 248 249def create_submodule_from_subgraph( 250 model: torch.nn.Module, 251 first_node: Node, 252 last_node: Node, 253) -> GraphModule: 254 """ 255 Input: a model, and a linear subgraph within the model from first_node to 256 last_node. 257 258 Output: a new submodule containing a copy of the subgraph, with the inputs 259 to the first node becoming the inputs to the submodule, and all other 260 nodes in the subgraph being copied. 261 262 Example inputs: 263 264 `model`: a module with graph 265 266 x0 -> op1 -> x1 -> op2 -> x2 267 | 268 arg1 269 270 `first_node`: op1 271 `last_node`: op2 272 273 Example output: a new module with graph 274 275 input1 -> op1_copy -> x1 -> op2_copy -> output1 276 | 277 arg1 278 """ 279 280 # 281 # create a blank GraphModule with an empty graph 282 # 283 284 class M(torch.nn.Module): 285 def forward(self, x): 286 pass 287 288 m = M() 289 gm = torch.fx.symbolic_trace(m) 290 g = gm.graph 291 for node in reversed(gm.graph.nodes): 292 g.erase_node(node) 293 294 # 295 # modify the graph to have a copy of our subgraph 296 # 297 298 cur_node_orig = first_node 299 cur_args_orig = cur_node_orig.args 300 cur_kwargs_orig = cur_node_orig.kwargs 301 302 cur_name_idx = 0 303 304 iteration_limit = 100 305 cur_iteration = 0 306 307 while True: 308 if cur_node_orig is first_node: 309 # we are at the first node, we need to set up graph inputs 310 # TODO(future): some graphs could have placeholders which are unrelated 311 # to the first node, need to handle this 312 cur_args_copy = [] 313 cur_kwargs_copy = {} 314 seen_names: Set[str] = set() 315 old_name_to_new_node: Dict[str, Node] = {} 316 317 def _add_placeholder( 318 g: Graph, node: Node, seen_names, old_name_to_new_node 319 ): 320 # note: for graphs starting with patterns such as `y = x + x`, we 321 # need to ensure we do not add multiple placeholders with the 322 # same name 323 counter = 0 324 while node.name + "_" + str(counter) in seen_names: 325 counter += 1 326 cur_name = node.name + "_" + str(counter) 327 seen_names.add(cur_name) 328 placeholder = g.placeholder(cur_name) 329 old_name_to_new_node[node.name] = placeholder 330 return placeholder 331 332 for arg in cur_node_orig.args: 333 if isinstance(arg, Node): 334 p = _add_placeholder(g, arg, seen_names, old_name_to_new_node) 335 cur_args_copy.append(p) 336 elif isinstance(arg, (list, tuple)): 337 new_arg = [] 338 for inner_arg in arg: 339 if isinstance(inner_arg, Node): 340 new_arg.append( 341 _add_placeholder( 342 g, inner_arg, seen_names, old_name_to_new_node 343 ) 344 ) 345 else: 346 new_arg.append(inner_arg) 347 cur_args_copy.append(new_arg) 348 else: 349 cur_args_copy.append(arg) 350 351 # TODO(future PR): handle non-normalized kwargs 352 for kwarg_name, kwarg in cur_node_orig.kwargs.items(): 353 if isinstance(kwarg, Node): 354 cur_kwargs_copy[kwarg_name] = _add_placeholder( 355 g, kwarg, seen_names, old_name_to_new_node 356 ) 357 elif isinstance(kwarg, (list, tuple)): 358 new_kwarg = [] 359 for inner_kwarg in kwarg: 360 p = _add_placeholder( 361 g, inner_kwarg, seen_names, old_name_to_new_node 362 ) 363 new_kwarg.append(p) 364 cur_kwargs_copy[kwarg_name] = new_kwarg 365 else: 366 cur_kwargs_copy[kwarg_name] = kwarg 367 368 cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] 369 else: 370 # we are not at first node, first arg is from the previous node, 371 # and all other args are copied 372 373 # the current implementation is simplistic and cannot handle 374 # ops with two or more arguments which need to be passed from 375 # the previous op, so we assert them out 376 assert cur_node_orig.target not in BINARY_FUNCTIONS 377 378 # at this point in the code, cur_node_copy is pointing to the copy 379 # of the previous node 380 # TODO(future PR): this is not handling complicated graphs correctly, need to 381 # look at actual relationships instead of assuming sequential graph 382 # TODO(future PR): this is ignoring kwargs, will need to support kwargs 383 # for any fusion pattern which has them for a node that is not the 384 # first node. 385 cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821 386 387 if len(cur_node_orig.args) > 1: 388 for arg in cur_node_orig.args[1:]: 389 if isinstance(arg, torch.nn.Parameter): 390 new_arg = arg.clone().detach() # type: ignore[assignment] 391 mod_name = f"mod_{cur_name_idx}" 392 cur_name_idx += 1 393 setattr(gm, mod_name, new_arg) 394 new_arg_placeholder = gm.placeholder(mod_name) 395 cur_args_copy.append(new_arg_placeholder) 396 elif isinstance(arg, (float, int, torch.dtype)): 397 cur_args_copy.append(arg) 398 else: 399 raise AssertionError(f"arg of type {type(arg)} not handled yet") 400 cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] 401 402 # copy the node 403 if cur_node_orig.op == "call_module": 404 orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type] 405 orig_mod_copy = copy.deepcopy(orig_mod) 406 mod_name = f"mod_{cur_name_idx}" 407 setattr(gm, mod_name, orig_mod_copy) 408 cur_name_idx += 1 409 cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined,arg-type] 410 411 elif cur_node_orig.op == "call_function": 412 cur_node_copy = g.call_function( 413 cur_node_orig.target, # type: ignore[arg-type] 414 cur_args_copy, # type: ignore[arg-type] 415 cur_kwargs_copy, # type: ignore[possibly-undefined] 416 ) 417 418 elif cur_node_orig.op == "call_method": 419 cur_node_copy = g.call_method( 420 cur_node_orig.target, # type: ignore[arg-type] 421 cur_args_copy, # type: ignore[arg-type] 422 cur_kwargs_copy, # type: ignore[possibly-undefined] 423 ) 424 425 else: 426 raise AssertionError(f"{cur_node_orig.op} not supported yet") 427 428 if cur_node_orig is last_node: 429 break 430 431 # go to next node 432 assert ( 433 len(cur_node_orig.users.keys()) == 1 434 ), f"{cur_node_orig} has more than 1 users, not supported yet" 435 cur_node_orig = next(iter(cur_node_orig.users.keys())) 436 cur_args_orig = cur_node_orig.args 437 cur_kwargs_orig = cur_node_orig.kwargs 438 439 cur_iteration += 1 440 if cur_iteration > iteration_limit: 441 raise AssertionError("iteration limit exceeded") 442 443 # set up outputs 444 g.output(cur_node_copy) 445 446 gm.recompile() 447 return gm 448 449 450def create_one_transformed_and_logged_copy_of_subgraph( 451 mt: GraphModule, 452 subgraph_idx: int, 453 subgraph_candidate_idx: int, 454 first_node: Node, 455 last_node: Node, 456 fqn: Optional[str], 457 list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], 458 example_inputs: Any, 459 last_added_shadow_node_list: List[Optional[Node]], 460 custom_prepare_fn: Optional[Callable] = None, 461 custom_prepare_kwargs: Optional[Dict[str, Any]] = None, 462) -> None: 463 """ 464 Given a subgraph in `mt` and a subgraph candidate idx, inserts the 465 subgraph candidate copy and instruments it with loggers. 466 467 If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just 468 add a logger to the end. 469 470 If subgraph_candidate_idx is not 0, we create a copy of the subgraph and 471 prepare it with `prepare_fx`. 472 """ 473 474 # TODO(future PR): move logger classes to utils to remove circular dependency 475 from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger 476 477 if subgraph_candidate_idx == 0: 478 # idx = 0 is the floating point (original) version of the subgraph 479 # We keep the subgraph as is, and add a logger at the end 480 481 qconfig_str = "" 482 logger_mod_orig = _get_logger_for_subgraph( 483 mt, 484 first_node, 485 last_node, 486 subgraph_idx, 487 subgraph_candidate_idx, 488 qconfig_str, 489 OutputLogger, 490 fqn, 491 ) 492 493 attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) 494 assert not hasattr(mt, attr_name) 495 setattr(mt, attr_name, logger_mod_orig) 496 with mt.graph.inserting_after(last_node): 497 new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={}) 498 last_added_shadow_node_list[0] = new_node 499 500 else: 501 # idx > 0 means we have a candidate qconfig to try, so we need 502 # to make a copy of the subgraph, feed it with the right inputs, 503 # and add a logger at the end 504 505 # get the qconfig 506 # subtract one because the first candidate is the floating point 507 # version of the subgraph 508 node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] 509 qconfig = node_name_to_qconfig[first_node.name] 510 511 # if no quantization is requested, skip 512 # TODO(future PR): deduplicate equivalent qconfigs that come from 513 # different qconfig mapping objects 514 if qconfig is None: 515 return 516 517 qconfig_mapping = QConfigMapping().set_global(qconfig) 518 519 # create a copy of the submodule, wrapped in a separate module 520 orig_mod_copy_wrapped = create_submodule_from_subgraph( 521 mt, first_node, last_node 522 ) 523 524 # add a call to prepare_fx on the wrapper module 525 if custom_prepare_fn is None: 526 orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( 527 orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs 528 ) 529 else: 530 if custom_prepare_kwargs is None: 531 custom_prepare_kwargs = {} 532 for kwarg_name in [ 533 "example_inputs", 534 "prepare_custom_config", 535 "qconfig_mapping", 536 ]: 537 assert ( 538 kwarg_name not in custom_prepare_kwargs 539 ), f"cannot specify {kwarg_name} in custom_prepare_kwargs" 540 prepare_kwargs: Dict[str, Any] = { 541 "example_inputs": example_inputs, 542 "qconfig_mapping": qconfig_mapping, 543 } 544 prepare_kwargs.update(custom_prepare_kwargs) 545 orig_mod_copy_wrapped = custom_prepare_fn( 546 orig_mod_copy_wrapped, **prepare_kwargs 547 ) 548 549 # attach the wrapper to the model 550 attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) 551 assert not hasattr(mt, attr_name) 552 setattr(mt, attr_name, orig_mod_copy_wrapped) 553 554 # add a call to the wrapper module from the parent graph 555 insert_after_node = last_added_shadow_node_list[0] 556 with mt.graph.inserting_after(insert_after_node): 557 # TODO(future PR): handle fusion patterns where non-first nodes 558 # need inputs 559 560 # pass in all node args and kwargs 561 562 new_args = [] 563 for arg in first_node.args: 564 if isinstance(arg, Node): 565 new_args.append(arg) 566 elif ( 567 isinstance(arg, (list, tuple)) 568 and len(arg) 569 and isinstance(arg[0], Node) 570 ): 571 for inner_arg in arg: 572 if isinstance(inner_arg, Node): 573 new_args.append(inner_arg) 574 575 new_kwargs = {} 576 for name, old_kwarg in first_node.kwargs.items(): 577 if isinstance(old_kwarg, Node): 578 new_kwargs[name] = old_kwarg 579 elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg): 580 # TODO(future PR): clarify why we are adding kwargs to args 581 new_args.extend(old_kwarg) 582 583 new_args = tuple(new_args) # type: ignore[assignment] 584 585 new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type] 586 587 # add a logger to parent graph to observe the shadow wrapper 588 logger_mod_orig = _get_logger_for_subgraph( 589 mt, 590 first_node, 591 last_node, 592 subgraph_idx, 593 subgraph_candidate_idx, 594 str(qconfig), 595 OutputComparisonLogger, 596 fqn, 597 ) 598 599 attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) 600 assert not hasattr(mt, attr_name) 601 setattr(mt, attr_name, logger_mod_orig) 602 with mt.graph.inserting_after(new_node): 603 logger = mt.graph.call_module( 604 attr_name, args=(new_node, last_node), kwargs={} 605 ) 606 last_added_shadow_node_list[0] = logger 607 608 mt.recompile() 609 610 611def create_n_transformed_and_logged_copies_of_subgraph( 612 mt: GraphModule, 613 subgraph_idx: int, 614 match_name: str, 615 nodes_in_this_subgraph: List[Any], 616 qconfig_mappings: List[QConfigMapping], 617 list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], 618 custom_prepare_fn: Optional[Callable] = None, 619 custom_prepare_kwargs: Optional[Dict[str, Any]] = None, 620) -> None: 621 """ 622 Given a model `mt` and a subgraph_idx, creates the needed copies 623 of the subgraph for all qconfigs, and instruments them with loggers. 624 """ 625 # for now, assume that 626 # 1. the first node has one input 627 # 2. the last node has one output 628 629 # for now, ignore all subgraphs that contain non-nodes (tuples, etc) 630 # TODO(future PR): implement this 631 if any(not isinstance(node, Node) for node in nodes_in_this_subgraph): 632 return 633 634 first_node = nodes_in_this_subgraph[0] 635 last_node = nodes_in_this_subgraph[-1] 636 # We used output propagation to populate example values on each 637 # node. Use the example values from the previous node as the input 638 # to the current node. 639 prev_node = get_normalized_nth_input(first_node, mt, 0) 640 if isinstance(prev_node, list): 641 example_inputs = [x.traced_result for x in prev_node] 642 elif isinstance(prev_node, tuple): 643 example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment] 644 else: 645 # currently some customer models do not have a traced_result in 646 # every node, so we have to guard for this case since we cannot 647 # quantize without an example input 648 # TODO(future PR): add a test case for this once we have an easy 649 # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489 650 # for additional context 651 if hasattr(prev_node, "traced_result"): 652 example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment] 653 else: 654 print( 655 "unable to get example input for node " 656 + f"{first_node.format_node()}, skipping" 657 ) 658 return 659 660 # If there are no quantization configs for this subgraph, skip adding 661 # loggers. This reduces memory usage for models where not all layers are 662 # quantized. 663 # TODO(future): consider making this configurable 664 found_at_least_one_qconfig = False 665 for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): 666 if subgraph_candidate_idx == 0: 667 # fp32 baseline does not need a qconfig 668 continue 669 670 # a. we have N shadows, so len(qconfig_mappings) is N 671 # b. we will have the fp32 layer + N shadows, so overall number of 672 # (original_op) + (*shadows) will be N+1 673 # c. since `subgraph_candidate_idx` represents (b), we need 674 # to subtract 1 to query from (a) 675 node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] 676 qconfig = node_name_to_qconfig[first_node.name] 677 if qconfig is not None: 678 found_at_least_one_qconfig = True 679 break 680 if not found_at_least_one_qconfig: 681 print( 682 "unable to find at least one qconfig for node " 683 + f"{first_node.format_node()}, skipping" 684 ) 685 return 686 687 fqn = _maybe_get_fqn(first_node, mt) 688 689 # We want the results to contain the subgraphs in natural order, 690 # and the graph to also contain shadow wrappers and shadow loggers 691 # in natural order. 692 # If we just iterate in reverse, the graph will be in natural 693 # order but the eventual results will be in reverse order. 694 # So, we keep track of the last shadow logger we added and 695 # always insert after it. 696 last_added_shadow_node_list: List[Optional[Node]] = [None] 697 for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): 698 create_one_transformed_and_logged_copy_of_subgraph( 699 mt, 700 subgraph_idx, 701 subgraph_candidate_idx, 702 first_node, 703 last_node, 704 fqn, 705 list_of_node_name_to_qconfig, 706 example_inputs, 707 last_added_shadow_node_list, 708 custom_prepare_fn, 709 custom_prepare_kwargs, 710 ) 711 712 713def create_add_loggers_graph( 714 model: GraphModule, 715 subgraphs_dedup: Dict[str, List[Node]], 716 qconfig_mapping: QConfigMapping, 717 node_name_to_qconfig: Dict[str, QConfigAny], 718) -> None: 719 r""" 720 Given a model, a model graph partition (currently a set of matched 721 subgraphs) and instructions how to transform each subgraph 722 (currently quantizing it according to qconfig_mapping), modifies 723 the model graph to create an alternate path through the original graph, 724 with each of the subgraphs quantized. This is useful to compare 725 propagation error of a transformation such as quantization. 726 727 For example, given layer op0 and op1, there are four cases when handling op1: 728 1. op0 and op1 quantized 729 2. op0 and op1 unquantized 730 3. op0 quantized, op1 unquantized 731 4. op0 unquantized, op1 quantized 732 733 Example input, case 1: 734 735 .. code:: 736 737 x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log 738 \ \ \ \ # noqa: W605 739 ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog 740 741 Example output, case 1: 742 743 .. code:: 744 745 x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log 746 \ \ \ # noqa: W605 747 ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog 748 749 """ 750 # TODO(future PR): move logger classes to utils to remove circular dependency 751 from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger 752 753 def _get_subgraph_containing_node(node, subgraphs_dedup): 754 for subgraph in subgraphs_dedup.values(): 755 if node in subgraph: 756 return subgraph 757 return None 758 759 # First, we need to create shadow branches, going from 760 # 761 # x0 -> op0 -> x1 -> ... 762 # 763 # 764 # to 765 # 766 # x0 -> op0_0 -> x1_0 -> log -> ... 767 # \ \ 768 # -> op0_1 -> x1_1 -> clog 769 # 770 # Later, the outputs of each shadow will be rerouted to calculate 771 # propagation error. 772 773 # Note: we cannot iterate over matched subgraphs because some nodes 774 # may not be matched. So, we iterate over nodes in the graph, and 775 # associate them to matched subgraphs if possible. 776 777 nodes_to_skip = set() 778 # for each subgraph, save a mapping from first node of subgraph 779 # to first and last node of the shadow of this subgraph 780 orig_first_node_to_shadow_in_node = {} 781 orig_first_node_to_shadow_out_node = {} 782 # need to record original list because we will mutate the graph as we go 783 orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type] 784 cur_subgraph_idx = 0 785 for n in orig_nodes: 786 if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: 787 continue 788 789 maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) 790 insert_submodule_copy = False 791 if maybe_subgraph is not None: 792 first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] 793 nodes_to_skip.update(maybe_subgraph) 794 qconfig = node_name_to_qconfig[first_node.name] 795 if qconfig is not None: 796 insert_submodule_copy = True 797 else: 798 first_node, last_node = n, n 799 800 if insert_submodule_copy: 801 match_name = first_node.name 802 create_n_transformed_and_logged_copies_of_subgraph( 803 model, 804 cur_subgraph_idx, 805 match_name, 806 maybe_subgraph, 807 [qconfig_mapping], 808 [node_name_to_qconfig], 809 None, 810 None, # type: ignore[arg-type] 811 ) 812 # find the created shadow module and record it so we 813 # can find it easily in step 2 814 expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1" 815 new_shadow_mod = None 816 for maybe_shadow_mod in model.graph.nodes: 817 if ( 818 maybe_shadow_mod.op == "call_module" 819 and maybe_shadow_mod.target == expected_shadow_target 820 ): 821 new_shadow_mod = maybe_shadow_mod 822 break 823 assert new_shadow_mod is not None 824 orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod 825 orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod 826 827 else: 828 # create a copy of the subgraph by only copying FX nodes 829 # but not copying any parameters, to minimize memory usage 830 subgraph_to_use = ( 831 maybe_subgraph if maybe_subgraph is not None else [first_node] 832 ) 833 834 # add a regular logger after last_node 835 qconfig_str = "" 836 subgraph_candidate_idx = 0 837 fqn = _maybe_get_fqn(first_node, model) 838 logger_mod_orig = _get_logger_for_subgraph( 839 model, 840 first_node, 841 last_node, 842 cur_subgraph_idx, 843 subgraph_candidate_idx, 844 qconfig_str, 845 OutputLogger, 846 fqn, 847 ) 848 attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) 849 assert not hasattr(model, attr_name) 850 setattr(model, attr_name, logger_mod_orig) 851 insertion_point = last_node 852 with model.graph.inserting_after(insertion_point): 853 logger = model.graph.call_module( 854 attr_name, args=(last_node,), kwargs={} 855 ) 856 insertion_point = logger 857 858 # create a copy of the subgraph 859 cur_node_orig = first_node 860 cur_node_copy = None 861 first_node_copy = None 862 while cur_node_orig in subgraph_to_use: 863 # TODO(future PR): make this support all possible args/kwargs 864 if cur_node_orig is first_node: 865 new_args = cur_node_orig.args 866 new_kwargs = cur_node_orig.kwargs 867 else: 868 first_arg_for_copy = cur_node_copy 869 new_args = (first_arg_for_copy, *cur_node_orig.args[1:]) 870 new_kwargs = cur_node_orig.kwargs 871 # make a copy of cur_node_orig 872 with model.graph.inserting_after(insertion_point): 873 cur_node_copy = model.graph.create_node( 874 cur_node_orig.op, 875 cur_node_orig.target, 876 new_args, 877 new_kwargs, 878 # cur_node_orig.name, # TODO(future PR): set name explicitly 879 ) 880 if first_node_copy is None: 881 first_node_copy = cur_node_copy 882 # since now only linear subgraphs are supported, all nodes 883 # except the last one must have only one user 884 if cur_node_orig != last_node: 885 assert len(cur_node_orig.users.keys()) == 1 886 cur_node_orig = next(iter(cur_node_orig.users.keys())) 887 assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX) 888 insertion_point = cur_node_copy 889 890 # add a comparison logger after last_node's copy 891 subgraph_candidate_idx = 1 892 logger_mod_orig = _get_logger_for_subgraph( 893 model, 894 first_node, 895 last_node, 896 cur_subgraph_idx, 897 subgraph_candidate_idx, 898 qconfig_str, 899 OutputComparisonLogger, 900 fqn, 901 ) 902 attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) 903 assert not hasattr(model, attr_name) 904 setattr(model, attr_name, logger_mod_orig) 905 with model.graph.inserting_after(insertion_point): 906 logger = model.graph.call_module( 907 attr_name, args=(cur_node_copy, last_node), kwargs={} 908 ) 909 910 # save the final node so we can use it in step 2 911 orig_first_node_to_shadow_in_node[first_node] = first_node_copy 912 orig_first_node_to_shadow_out_node[first_node] = cur_node_copy 913 914 cur_subgraph_idx += 1 915 916 model.recompile() 917 918 # Now, we go from 919 # 920 # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ... 921 # \ \ \ 922 # -> op0_1 -> x1_1 -> clog -> op1_1 -> ... 923 # 924 # to 925 # 926 # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ... 927 # \ \ 928 # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ... 929 # 930 # sample values of key internal variables for the example above: 931 # 932 # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1} 933 # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1} 934 # 935 # note: for subgraphs with more than one node, in_node will be different 936 # compared to out_node 937 938 nodes_to_skip = set() 939 for n in orig_nodes: 940 if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip: 941 continue 942 943 maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) 944 if maybe_subgraph is not None: 945 first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] 946 nodes_to_skip.update(maybe_subgraph) 947 else: 948 first_node, last_node = n, n 949 950 def maybe_remap_node_to_shadow(node): 951 """ 952 If unshadowed `node` has a shadow version, return that. If not, 953 return `node`. 954 """ 955 if not isinstance(node, Node): 956 # handle scalars 957 return node 958 959 if node.op in ("placeholder", "get_attr"): 960 return node 961 962 # Find the shadowed version of this arg from the previous 963 # subgraph. For this, we need to: 964 # 1. navigate to the first node of the previous subgraph 965 # 2. get the output of the shadow wrapper which has (1) as an input 966 967 # For now, assume the arg is in matched subgraphs. In the 968 # future we may have to handle the case where this is not true. 969 prev_subgraph = _get_subgraph_containing_node(node, subgraphs_dedup) 970 if prev_subgraph is None: 971 prev_subgraph = [node] 972 prev_first_node = prev_subgraph[0] 973 prev_shadow_output = orig_first_node_to_shadow_out_node[prev_first_node] 974 return prev_shadow_output 975 976 cur_shadow_input = orig_first_node_to_shadow_in_node[first_node] 977 assert cur_shadow_input is not None 978 cur_shadow_input.args = tree_map( 979 maybe_remap_node_to_shadow, cur_shadow_input.args 980 ) 981 cur_shadow_input.kwargs = tree_map( 982 maybe_remap_node_to_shadow, cur_shadow_input.kwargs 983 ) 984 985 model.recompile() 986 987 988def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): 989 # input: shadow wrapper module 990 # output if shadow wrapper module has a weighted op: 991 # (quantize_fn, (quantize_fn_args)) 992 # output if shadow wrapper module doesn't have a weighted op: 993 # None 994 995 # For now, assume that the weight is the second input 996 # to the shadow module. If that changes, we can fix it later. 997 placeholders_seen = 0 998 for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr] 999 if shadow_n.op != "placeholder": 1000 continue 1001 1002 placeholders_seen += 1 1003 if placeholders_seen != 2: 1004 continue 1005 1006 # the subgraph looks like 1007 # 1008 # _input_scale_1 = self._input_scale_1 1009 # _input_zero_point_1 = self._input_zero_point_1 1010 # quantize_per_channel = torch.quantize_per_channel( 1011 # w2_0, _input_scale_1, _input_zero_point_1, 1012 # 0, torch.qint8) 1013 # 1014 # we have `w2_0`, and are navigating this subgraph 1015 # to get `_input_scale_1` and `_input_zero_point_1` 1016 1017 assert len(shadow_n.users) == 1 1018 quant_node = next(iter(shadow_n.users.keys())) 1019 new_args: Any = None 1020 if quant_node.target == torch.quantize_per_channel: 1021 _weight, scale_node, zp_node, axis, dtype = quant_node.args 1022 scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) 1023 zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) 1024 new_args = (scale_val, zp_val, axis, dtype) 1025 else: 1026 assert quant_node.target == torch.quantize_per_tensor 1027 _weight, scale_node, zp_node, dtype = quant_node.args 1028 scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) 1029 zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) 1030 new_args = (scale_val, zp_val, dtype) 1031 return (quant_node.target, new_args) 1032 1033 return None 1034 1035 1036def extract_weight_comparison(m: GraphModule) -> NSResultsType: 1037 # example graph: 1038 # 1039 # w1 = self.w1 1040 # b1 = self.b1 1041 # linear = torch._C._nn.linear(x, w1, b1) 1042 # shadow_0_0 = self.shadow_0_0(linear) 1043 # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1) 1044 # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear) 1045 # 1046 # algorithm: 1047 # 1. for each call_function node matching our allowlist: 1048 # 2. if corresponding shadow wrapper exists, extract the weight pair 1049 # 1050 # Note: this is not super robust, but that's ok because this is 1051 # just for legacy customers who depend on the previous two-model version 1052 # of this API. TBD if we need to make this robust. 1053 # Note: modules are not supported, since existing customers only 1054 # use functions. 1055 1056 # TODO(future PR): move this to config 1057 weighted_ops = { 1058 torch.nn.functional.linear, 1059 } 1060 1061 results: NSResultsType = {"model": {NSSingleResultValuesType.WEIGHT.value: {}}} 1062 1063 for n in m.graph.nodes: # type: ignore[union-attr] 1064 if not (n.op == "call_function" and n.target in weighted_ops): 1065 continue 1066 1067 # Check if we have a corresponding shadow wrapper 1068 # TODO(future PR, if needed): support kwargs 1069 # TODO(future PR, if needed): support multiple shadow users 1070 first_arg = n.args[0] 1071 shadow_wrapper_node = None 1072 for user in first_arg.users: 1073 # TODO(before land): fix string match 1074 if user.op == "call_module" and user.target.startswith("shadow_wrapper"): 1075 shadow_wrapper_node = user 1076 break 1077 1078 if shadow_wrapper_node is None: 1079 continue 1080 1081 shadow_wrapper = getattr_from_fqn( 1082 m, shadow_wrapper_node.target 1083 ) # type: ignore[arg-type] 1084 weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper) 1085 if weight_info is None: 1086 continue 1087 1088 # get weight 1089 w_node = n.args[1] 1090 w_obj = getattr_from_fqn(m, w_node.target).detach() 1091 1092 # get a quantized version of weight 1093 quant_fn, quant_fn_args_except_first = weight_info 1094 new_args = (w_obj, *quant_fn_args_except_first) 1095 w_obj_q = quant_fn(*new_args) 1096 1097 # add a comparison 1098 ref_node_name = n.name 1099 prev_node_name = n.name 1100 ref_node_type = get_target_type_str(n, m) 1101 prev_node_type = ref_node_type 1102 fqn = None 1103 if hasattr(m, "_node_name_to_scope"): 1104 fqn = m._node_name_to_scope[n.name][0] # type: ignore[index] 1105 comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q) 1106 result_fp32 = { 1107 "res_type": NSSingleResultValuesType.WEIGHT.value, 1108 "values": [w_obj], 1109 "prev_node_name": prev_node_name, 1110 "prev_node_target_type": prev_node_type, 1111 "ref_node_name": ref_node_name, 1112 "ref_node_target_type": ref_node_type, 1113 "index_within_arg": 0, 1114 "index_of_arg": 0, 1115 "fqn": fqn, 1116 "qconfig_str": "", 1117 "comparisons": [comparison], 1118 "comparison_fn_name": "sqnr", 1119 } 1120 result_q = { 1121 "res_type": NSSingleResultValuesType.WEIGHT.value, 1122 "values": [w_obj_q], 1123 "prev_node_name": prev_node_name, 1124 "prev_node_target_type": prev_node_type, 1125 "ref_node_name": ref_node_name, 1126 "ref_node_target_type": ref_node_type, 1127 "index_within_arg": 0, 1128 "index_of_arg": 0, 1129 "fqn": fqn, 1130 "qconfig_str": "", 1131 "comparisons": [comparison], 1132 "comparison_fn_name": "sqnr", 1133 } 1134 1135 # go from subgraph_n_1 to subgraph_n_0 1136 _1, _2, node_idx, _3 = shadow_wrapper_node.target.split("_") 1137 name_fp32 = f"subgraph_{node_idx}_0" 1138 name_q = f"subgraph_{node_idx}_1" 1139 1140 results["model"][NSSingleResultValuesType.WEIGHT.value][name_fp32] = [ 1141 result_fp32 1142 ] 1143 results["model"][NSSingleResultValuesType.WEIGHT.value][name_q] = [result_q] 1144 1145 return results 1146 1147 1148# TODO(future PR): redesign this to make it easier to consume outputs 1149def group_results_by_subgraph(results: NSResultsType) -> Any: 1150 """ 1151 Creates a comparison of results 1152 1153 Input: 1154 1155 { 1156 'model': { 1157 'node_output': { 1158 'subgraph_0_0': [ 1159 'values': [torch.tensor(...), ...], ... 1160 'ref_node_name': ..., 1161 'ref_node_target_type': ..., 1162 'qconfig_str': ..., 1163 'comparisons': [], ... 1164 'comparison_fn_name': '', 1165 'fqn': '...', 1166 ], 1167 'subgraph_0_1': [ 1168 'values': [torch.tensor(...), ...], ... 1169 'ref_node_name': ..., 1170 'ref_node_target_type': ..., 1171 'qconfig_str': ..., 1172 'comparisons': [torch.tensor(...), ...], ... 1173 'comparison_fn_name': '...', 1174 'fqn': '...', 1175 ], 1176 ... 1177 }, 1178 }, 1179 } 1180 1181 Output: 1182 { 1183 'subgraph_0': { 1184 '0': { 1185 'ref_node_name': '...', 1186 'ref_node_target_type': ..., 1187 'values': [torch.tensor(...), ...], 1188 'qconfig_str': None, 1189 'comparisons': [torch.tensor(...), ...], ... 1190 'comparison_fn_name': '...', 1191 'fqn': '...', 1192 }, 1193 '1': { 1194 'ref_node_name': '...', 1195 'ref_node_target_type': ..., 1196 'values': [torch.tensor(...), ...], 1197 'qconfig_str': '...', 1198 'comparisons': [torch.tensor(...), ...], ... 1199 'comparison_fn_name': '...', 1200 'fqn': '...', 1201 }, 1202 }, 1203 } 1204 1205 """ 1206 subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict) 1207 1208 # node_output or weight 1209 key_to_use = next(iter(results["model"].keys())) 1210 1211 for subgraph_name_with_idx, subgraph_candidate_results in results["model"][ 1212 key_to_use 1213 ].items(): 1214 # convert from `subgraph_m_n` to `subgraph_m` and `n` 1215 ( 1216 subgraph_str, 1217 subgraph_idx, 1218 subgraph_candidate_idx, 1219 ) = subgraph_name_with_idx.split("_") 1220 subgraph_name = f"{subgraph_str}_{subgraph_idx}" 1221 1222 subgraph_results = { 1223 "ref_node_name": subgraph_candidate_results[0]["ref_node_name"], 1224 "ref_node_target_type": subgraph_candidate_results[0][ 1225 "ref_node_target_type" 1226 ], 1227 "fqn": subgraph_candidate_results[0]["fqn"], 1228 "values": subgraph_candidate_results[0]["values"], 1229 "qconfig_str": subgraph_candidate_results[0]["qconfig_str"], 1230 "comparisons": subgraph_candidate_results[0]["comparisons"], 1231 "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"], 1232 } 1233 1234 subgraph_name_to_subgraph_results[subgraph_name][ 1235 subgraph_candidate_idx 1236 ] = subgraph_results 1237 1238 return dict(subgraph_name_to_subgraph_results) 1239 1240 1241# TODO(future PR): redesign this to make it easier to consume outputs 1242def create_results_comparison( 1243 results_grouped, 1244) -> Any: 1245 """ 1246 Input: 1247 1248 { 1249 'subgraph_0': { 1250 '0': { 1251 'ref_node_name': '...', 1252 'ref_node_target_type': ..., 1253 'values': [torch.tensor(...), ...], 1254 'qconfig_str': '', 1255 'comparisons': [], 1256 'comparison_fn_name': '', 1257 'fqn': '...', 1258 }, 1259 '1': { 1260 'ref_node_name': '...', 1261 'ref_node_target_type': ..., 1262 'values': [torch.tensor(...), ...], 1263 'qconfig_str': '...', 1264 'comparisons': [torch.tensor(...), ...], 1265 'comparison_fn_name': 'sqnr', 1266 'fqn': '...', 1267 }, 1268 }, 1269 } 1270 1271 Output: 1272 { 1273 'subgraph_0': { 1274 'ref_node_name': '...', 1275 'ref_node_target_type': '...', 1276 'fqn': '...', 1277 'candidates': { 1278 '1': { 1279 'qconfig_str': ..., 1280 'comparison_fn_name': 'sqnr', 1281 'cmp_raw': [..., ...], 1282 'cmp_mean': ..., 1283 }, 1284 ..., 1285 }, 1286 }, 1287 } 1288 """ 1289 1290 results_comparison = {} 1291 1292 for subgraph_name, subgraph_results in results_grouped.items(): 1293 candidates = {} 1294 for subgraph_inner_name, subgraph_inner_result in subgraph_results.items(): 1295 # skip comparing baseline to baseline 1296 if subgraph_inner_name == "0": 1297 continue 1298 1299 # we expect the comparisons to be precalculated from 1300 # calibration, so we just fetch them here 1301 cmp_raw = subgraph_inner_result["comparisons"] 1302 cmp_raw_tensor = torch.stack(cmp_raw) 1303 1304 candidates[subgraph_inner_name] = { 1305 "qconfig_str": subgraph_inner_result["qconfig_str"], 1306 "comparison_fn_name": subgraph_inner_result["comparison_fn_name"], 1307 "cmp_raw": cmp_raw_tensor, 1308 "cmp_mean": torch.mean(cmp_raw_tensor), 1309 } 1310 1311 results_comparison[subgraph_name] = { 1312 "ref_node_name": subgraph_results["0"]["ref_node_name"], 1313 "ref_node_target_type": subgraph_results["0"]["ref_node_target_type"], 1314 "fqn": subgraph_results["0"]["fqn"], 1315 "candidates": candidates, 1316 } 1317 1318 return results_comparison 1319 1320 1321# TODO(future PR): redesign this to make it easier to consume outputs 1322def print_n_shadows_summary( 1323 results_comparison, 1324) -> None: 1325 """ 1326 Input: 1327 1328 { 1329 'subgraph_0': { 1330 'ref_node_name': 'linear1', 1331 'ref_node_target_type': '...', 1332 'fqn': '...', 1333 'candidates': { 1334 '1': { 1335 'qconfig_str': ..., 1336 'comparison_fn_name': ..., 1337 'cmp_raw': [45.0, 55.0], 1338 'cmp_mean': 50.0, 1339 }, 1340 ..., 1341 }, 1342 }, 1343 } 1344 1345 Prints: 1346 1347 node_name | node_type | fqn | 0 | 1 | ... 1348 linear1 | ... | ... | 45.0 | 50.0 | ... 1349 """ 1350 1351 try: 1352 from tabulate import tabulate 1353 except ImportError: 1354 print( 1355 "`print_tabular` relies on the library `tabulate`, " 1356 "which could not be found on this machine. Run `pip " 1357 "install tabulate` to install the library." 1358 ) 1359 return 1360 1361 results = [] 1362 for subgraph_data in results_comparison.values(): 1363 mean_all_candidates = [ 1364 candidate["cmp_mean"] 1365 for candidate_name, candidate in subgraph_data["candidates"].items() 1366 ] 1367 1368 data_row = [ 1369 subgraph_data["ref_node_name"], 1370 subgraph_data["ref_node_target_type"], 1371 subgraph_data["fqn"], 1372 *mean_all_candidates, 1373 ] 1374 results.append(data_row) 1375 1376 max_candidate_idx_len = -1 1377 for data_row in results: 1378 max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1])) 1379 candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)] 1380 1381 headers = ["node_name", "node_type", "fqn", *candidate_idx_headers] 1382 print(tabulate(results, headers=headers)) 1383