1# mypy: allow-untyped-defs 2""" 3This module contains tooling to compare weights and activations 4across models. Example usage:: 5 6 import copy 7 import torch 8 import torch.ao.quantization.quantize_fx as quantize_fx 9 import torch.ao.ns._numeric_suite_fx as ns 10 11 m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() 12 mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) 13 # We convert a copy because we need the original prepared model 14 # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. 15 mq = quantize_fx.convert_fx(copy.deepcopy(mp)) 16 17 # 18 # Comparing weights 19 # 20 21 # extract weight pairs 22 weight_comparison = ns.extract_weights('a', mp, 'b', mq) 23 24 # add SQNR for each comparison, inplace 25 ns.extend_logger_results_with_comparison( 26 weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, 27 'sqnr') 28 29 # weight_comparison contains the weights from `mp` and `mq` stored 30 # in pairs, and can be used for further analysis. 31 32 33 # 34 # Comparing activations, with error propagation 35 # 36 37 # add loggers 38 mp_ns, mq_ns = ns.add_loggers( 39 'a', copy.deepcopy(mp), 40 'b', copy.deepcopy(mq), 41 ns.OutputLogger) 42 43 # send an example datum to capture intermediate activations 44 datum = torch.randn(1, 1, 1, 1) 45 mp_ns(datum) 46 mq_ns(datum) 47 48 # extract intermediate activations 49 act_comparison = ns.extract_logger_info( 50 mp_ns, mq_ns, ns.OutputLogger, 'b') 51 52 # add SQNR for each comparison, inplace 53 ns.extend_logger_results_with_comparison( 54 act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, 55 'sqnr') 56 57 # act_comparison contains the activations from `mp_ns` and `mq_ns` stored 58 # in pairs, and can be used for further analysis. 59 60 # 61 # Comparing activations, without error propagation 62 # 63 64 # create shadow model 65 mp_shadows_mq = ns.add_shadow_loggers( 66 'a', copy.deepcopy(mp), 67 'b', copy.deepcopy(mq), 68 ns.OutputLogger) 69 70 # send an example datum to capture intermediate activations 71 datum = torch.randn(1, 1, 1, 1) 72 mp_shadows_mq(datum) 73 74 # extract intermediate activations 75 shadow_act_comparison = ns.extract_shadow_logger_info( 76 mp_shadows_mq, ns.OutputLogger, 'b') 77 78 # add SQNR for each comparison, inplace 79 ns.extend_logger_results_with_comparison( 80 shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, 81 'sqnr') 82 83 # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored 84 # in pairs, and can be used for further analysis. 85 86""" 87 88import collections 89from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING 90 91import torch 92import torch.ao.quantization.quantize_fx as quantize_fx 93import torch.nn as nn 94from torch.ao.ns.fx.graph_matcher import ( 95 get_matching_subgraph_pairs, 96 get_type_a_related_to_b, 97) 98from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops 99from torch.ao.ns.fx.n_shadows_utils import ( 100 _get_dedup_subgraphs, 101 create_add_loggers_graph, 102 create_n_transformed_and_logged_copies_of_subgraph, 103 create_results_comparison, 104 extract_weight_comparison, 105 group_results_by_subgraph, 106 OutputProp, 107 print_n_shadows_summary, 108 SHADOW_WRAPPER_NODE_NAME_PREFIX, 109) 110from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping 111from torch.ao.quantization import QConfigMapping 112from torch.ao.quantization.backend_config import BackendConfig 113from torch.ao.quantization.backend_config.utils import ( 114 get_fusion_pattern_to_root_node_getter, 115) 116from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr 117from torch.ao.quantization.fx.match_utils import _find_matches 118from torch.ao.quantization.fx.qconfig_mapping_utils import ( 119 _generate_node_name_to_qconfig, 120) 121from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers 122from torch.fx import GraphModule 123from torch.fx.graph import Node 124 125from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b 126from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType 127from .fx.utils import ( 128 get_target_type_str, 129 maybe_add_missing_fqns, 130 rekey_logger_info_on_node_name_of_model, 131) 132from .fx.weight_utils import extract_weight_from_node 133 134 135if TYPE_CHECKING: 136 from torch.ao.quantization.qconfig import QConfigAny 137 138RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] 139 140 141class OutputLogger(nn.Module): 142 """ 143 Base class for capturing intermediate values. 144 """ 145 146 stats: List[torch.Tensor] 147 stats_rnn: List[RNNReturnType] 148 149 # Mark as impure so that calls to it will not be removed during DCE. 150 _is_impure = True 151 152 def __init__( 153 self, 154 ref_node_name: str, 155 prev_node_name: str, 156 model_name: str, 157 ref_name: str, 158 prev_node_target_type: str, 159 ref_node_target_type: str, 160 results_type: str, 161 index_within_arg: int, 162 index_of_arg: int, 163 fqn: Optional[str], 164 qconfig_str: Optional[str] = "", 165 ): 166 super().__init__() 167 self.stats: List[torch.Tensor] = [] 168 self.stats_rnn: List[RNNReturnType] = [] 169 170 # name of the node which was responsible for adding this logger 171 # Note: 172 # - if we are logging node outputs, this is the same as prev_node_name 173 # - if we are logging node inputs, this is the name of the node 174 # whose input this logger is logging. 175 # 176 # example, where logger1 is logging input of op1 and logger2 is logging 177 # the output of op1: 178 # 179 # x1 -> logger1 -> op1 -> logger2 -> x2 180 # 181 # in this example, 182 # - logger1's prev_node_name is x1 and ref_node_name is op1 183 # - logger2's prev_node_name is op1 and ref_node_name is op1 184 self.ref_node_name = ref_node_name 185 # name of the node whose output this Logger is capturing 186 self.prev_node_name = prev_node_name 187 188 # name of the model from which the node originated from 189 self.model_name = model_name 190 # reference name, used to match loggers from separate models 191 # to each other 192 self.ref_name = ref_name 193 # type of the target of the node whose output this logger is logging 194 self.prev_node_target_type = prev_node_target_type 195 # type of the target of the node which was responsible for adding this 196 # logger 197 self.ref_node_target_type = ref_node_target_type 198 # what kind of values are inside of stats 199 self.results_type = results_type 200 # index of this node within the arg of the input/output node 201 # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 202 self.index_within_arg = index_within_arg 203 # index of this node within the args of the input/output node 204 # for example, in add(x1, x2), x2 would have index_of_arg == 1 205 self.index_of_arg = index_of_arg 206 # fully qualified name 207 self.fqn = fqn 208 # if loggers are added before prepare_fx, but we do not want 209 # collect results of calibration, only results after convert_fx 210 # so, we add a flag to control whether this logger collects data 211 self.enabled = True 212 # string representation of qconfig 213 self.qconfig_str = qconfig_str 214 # this can be turned off to reduce memory usage during calibration 215 self.save_activations = True 216 217 # Note: cannot annotate the type of x because TorchScript does not support 218 # the Union type. 219 def forward(self, x): 220 # fmt: off 221 """ 222 """ # blank docblock to make autodoc happy 223 # fmt: on 224 # TODO(future PR): consider designing this better, as the difference 225 # between these two flags is subtle and not obvious. 226 if not self.enabled: 227 return x 228 if not self.save_activations: 229 return x 230 # TODO(future PR): consider refactoring this to better reuse the parent 231 # class 232 if isinstance(x, torch.Tensor): 233 self.stats.append(x.detach()) 234 elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: 235 new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) 236 self.stats_rnn.append(new_res) 237 return x 238 239 def __repr__(self): 240 clean_dict = { 241 k: v 242 for k, v in self.__dict__.items() 243 # skip nn.Module keys 244 if (k != "training") and not k.startswith("_") 245 } 246 return f"OutputLogger({clean_dict})" 247 248 249class OutputComparisonLogger(OutputLogger): 250 """ 251 Same as OutputLogger, but also requires the original activation 252 in order to calculate the comparison at calibration time 253 """ 254 255 def __init__(self, *args, **kwargs): 256 super().__init__(*args, **kwargs) 257 # TODO(future PR): make the comparison function configurable 258 self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr 259 self.comparison_fn_name = "sqnr" 260 # precalculated comparisons of logger output versus reference 261 self.comparisons = [] 262 # precalculated comparisons function 263 264 def forward(self, x, x_ref): 265 # fmt: off 266 """ 267 """ # blank docblock to make autodoc happy 268 # fmt: on 269 if not self.enabled: 270 return x 271 assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported" 272 if self.save_activations: 273 # save the activation, for debugging 274 self.stats.append(x.detach()) 275 # save the comparison 276 self.comparisons.append(self.comparison_fn(x, x_ref)) 277 return x 278 279 def __repr__(self): 280 clean_dict = { 281 k: v 282 for k, v in self.__dict__.items() 283 # skip nn.Module keys 284 if (k != "training") and not k.startswith("_") 285 } 286 return f"OutputComparisonLogger({clean_dict})" 287 288 289class NSTracer(quantize_fx.QuantizationTracer): 290 """ 291 Just like a regular FX quantization tracer, but treats observers and fake_quantize 292 modules as leaf modules. 293 """ 294 295 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: 296 # fmt: off 297 """ 298 """ # blank docblock to make autodoc happy 299 # fmt: on 300 if isinstance(m, torch.ao.quantization.ObserverBase): 301 return True 302 elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): 303 return True 304 return super().is_leaf_module(m, module_qualified_name) 305 306 307def _extract_weights_one_model( 308 model_name: str, 309 model: GraphModule, 310 nodes_and_names_to_instrument: List[Tuple[Node, str]], 311 results: NSResultsType, 312 op_to_type_to_weight_extraction_fn: Optional[ 313 Dict[str, Dict[Callable, Callable]] 314 ] = None, 315) -> None: 316 torch._C._log_api_usage_once( 317 "quantization_api._numeric_suite_fx._extract_weights_one_model" 318 ) 319 for node, ref_name in nodes_and_names_to_instrument: 320 res_type = NSSingleResultValuesType.WEIGHT.value 321 extracted_weight = extract_weight_from_node( 322 node, model, op_to_type_to_weight_extraction_fn 323 ) 324 if extracted_weight: 325 if ref_name not in results: 326 results[ref_name] = {res_type: {}} 327 results[ref_name][res_type][model_name] = [extracted_weight] 328 329 330def _extract_weights_impl( 331 model_name_a: str, 332 gm_a: GraphModule, 333 model_name_b: str, 334 gm_b: GraphModule, 335 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 336 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 337 op_to_type_to_weight_extraction_fn: Optional[ 338 Dict[str, Dict[Callable, Callable]] 339 ] = None, 340) -> NSResultsType: 341 torch._C._log_api_usage_once( 342 "quantization_api._numeric_suite_fx._extract_weights_impl" 343 ) 344 matched_subgraph_pairs = get_matching_subgraph_pairs( 345 gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map 346 ) 347 348 # split the subgraph pairs into one data structure for each model 349 nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = [] 350 nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = [] 351 for match_name, match in matched_subgraph_pairs.items(): 352 subgraph_a, subgraph_b = match 353 nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) 354 nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) 355 356 # populate the results, one model at a time 357 results: NSResultsType = {} 358 _extract_weights_one_model( 359 model_name_a, 360 gm_a, 361 nodes_and_names_to_instrument_a, 362 results, 363 op_to_type_to_weight_extraction_fn, 364 ) 365 _extract_weights_one_model( 366 model_name_b, 367 gm_b, 368 nodes_and_names_to_instrument_b, 369 results, 370 op_to_type_to_weight_extraction_fn, 371 ) 372 373 # fill in missing fqn entries 374 maybe_add_missing_fqns(results) 375 376 # rekey on names of nodes in gm_b 377 results = rekey_logger_info_on_node_name_of_model(results, model_name_b) 378 379 return results 380 381 382def extract_weights( 383 model_name_a: str, 384 model_a: nn.Module, 385 model_name_b: str, 386 model_b: nn.Module, 387 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 388 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 389 op_to_type_to_weight_extraction_fn: Optional[ 390 Dict[str, Dict[Callable, Callable]] 391 ] = None, 392) -> NSResultsType: 393 """ 394 Extract weights from model A and model B, and return a comparison. 395 396 Args: 397 model_name_a: string name of model A to use in results 398 model_a: model A 399 model_name_b: string name of model B to use in results 400 model_b: model B 401 base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change 402 unmatchable_types_map: optional override of unmatchable types, subject to change 403 op_to_type_to_weight_extraction_fn: optional override of function which extracts weight 404 from a type, subject to change 405 406 Return: 407 NSResultsType, containing the weight comparisons 408 """ 409 410 torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") 411 if base_name_to_sets_of_related_ops is None: 412 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 413 type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops) 414 415 # TODO(future PR): expose these 416 skipped_module_names: List[str] = [] 417 skipped_module_classes: List[Callable] = [] 418 tracer_a = NSTracer(skipped_module_names, skipped_module_classes) 419 tracer_b = NSTracer(skipped_module_names, skipped_module_classes) 420 gm_a = GraphModule(model_a, tracer_a.trace(model_a)) 421 maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( 422 model_a, "node_name_to_scope" 423 ) 424 if maybe_model_a_node_name_to_scope is not None: 425 gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope 426 gm_b = GraphModule(model_b, tracer_b.trace(model_b)) 427 maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( 428 model_b, "node_name_to_scope" 429 ) 430 if maybe_model_b_node_name_to_scope is not None: 431 gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope 432 return _extract_weights_impl( 433 model_name_a, 434 gm_a, 435 model_name_b, 436 gm_b, 437 base_name_to_sets_of_related_ops, 438 unmatchable_types_map, 439 op_to_type_to_weight_extraction_fn, 440 ) 441 442 443def _add_loggers_one_model( 444 model_name: str, 445 model: GraphModule, 446 nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]], 447 nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]], 448 logger_cls: Callable, 449) -> nn.Module: 450 torch._C._log_api_usage_once( 451 "quantization_api._numeric_suite_fx._add_loggers_one_model" 452 ) 453 454 # TODO(future PR): do not observe nodes we do not care 455 # about (both fp32, denylist, etc) 456 node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} 457 node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} 458 for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: 459 node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) 460 for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: 461 node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) 462 463 model = add_loggers_to_model( 464 model, 465 node_to_instrument_inputs_to_ref_name, 466 node_to_instrument_outputs_to_ref_name, 467 logger_cls, 468 model_name, 469 ) 470 return model 471 472 473def _add_loggers_impl( 474 name_a: str, 475 gm_a: GraphModule, 476 name_b: str, 477 gm_b: GraphModule, 478 logger_cls: Callable, 479 should_log_inputs: bool, 480 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 481 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 482) -> Tuple[nn.Module, nn.Module]: 483 torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") 484 matched_subgraph_pairs = get_matching_subgraph_pairs( 485 gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map 486 ) 487 nodes_and_names_to_instrument_inputs_a = [] 488 nodes_and_names_to_instrument_inputs_b = [] 489 nodes_and_names_to_instrument_outputs_a = [] 490 nodes_and_names_to_instrument_outputs_b = [] 491 for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): 492 ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) 493 ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) 494 # Note: for matching inputs we use start_node, such as observing 495 # the input of linear in linear-relu 496 if should_log_inputs: 497 nodes_and_names_to_instrument_inputs_a.append( 498 (subgraph_a.start_node, match_name, ref_node_type_a) 499 ) 500 nodes_and_names_to_instrument_inputs_b.append( 501 (subgraph_b.start_node, match_name, ref_node_type_b) 502 ) 503 # Note: for matching activations we always use end_node, 504 # such as observing the output of relu in linear-relu 505 nodes_and_names_to_instrument_outputs_a.append( 506 (subgraph_a.end_node, match_name, ref_node_type_a) 507 ) 508 nodes_and_names_to_instrument_outputs_b.append( 509 (subgraph_b.end_node, match_name, ref_node_type_b) 510 ) 511 512 new_model_a = _add_loggers_one_model( 513 name_a, 514 gm_a, 515 nodes_and_names_to_instrument_inputs_a, 516 nodes_and_names_to_instrument_outputs_a, 517 logger_cls, 518 ) 519 new_model_b = _add_loggers_one_model( 520 name_b, 521 gm_b, 522 nodes_and_names_to_instrument_inputs_b, 523 nodes_and_names_to_instrument_outputs_b, 524 logger_cls, 525 ) 526 return (new_model_a, new_model_b) 527 528 529def add_loggers( 530 name_a: str, 531 model_a: nn.Module, 532 name_b: str, 533 model_b: nn.Module, 534 logger_cls: Callable, 535 should_log_inputs: bool = False, 536 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 537 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 538) -> Tuple[nn.Module, nn.Module]: 539 """ 540 Instrument model A and model B with loggers. 541 542 Args: 543 name_a: string name of model A to use in results 544 model_a: model A 545 name_b: string name of model B to use in results 546 model_b: model B 547 logger_cls: class of Logger to use 548 base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change 549 unmatchable_types_map: optional override of unmatchable types, subject to change 550 551 Return: 552 Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. 553 """ 554 555 torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") 556 # TODO(future PR): expose these 557 skipped_module_names: List[str] = [] 558 skipped_module_classes: List[Callable] = [] 559 tracer_a = NSTracer(skipped_module_names, skipped_module_classes) 560 tracer_b = NSTracer(skipped_module_names, skipped_module_classes) 561 gm_a = GraphModule(model_a, tracer_a.trace(model_a)) 562 maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( 563 model_a, "node_name_to_scope" 564 ) 565 if maybe_model_a_node_name_to_scope is not None: 566 gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope 567 gm_b = GraphModule(model_b, tracer_b.trace(model_b)) 568 maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( 569 model_b, "node_name_to_scope" 570 ) 571 if maybe_model_b_node_name_to_scope is not None: 572 gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope 573 return _add_loggers_impl( 574 name_a, 575 gm_a, 576 name_b, 577 gm_b, 578 logger_cls, 579 should_log_inputs=should_log_inputs, 580 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, 581 unmatchable_types_map=unmatchable_types_map, 582 ) 583 584 585def _extract_logger_info_one_model( 586 model: nn.Module, 587 results: NSResultsType, 588 logger_cls: Callable, 589) -> None: 590 torch._C._log_api_usage_once( 591 "quantization_api._numeric_suite_fx._extract_logger_info_one_model" 592 ) 593 for gm_name, mod in model.named_modules(): 594 # TODO(future PR): better check when scripted 595 is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type] 596 isinstance(mod, torch.jit.RecursiveScriptModule) 597 and mod.original_name == "OutputLogger" 598 ) 599 if is_logger: 600 key = mod.ref_name 601 if key not in results: 602 results[key] = {} 603 assert ( 604 mod.model_name not in results[key] 605 ), f"{mod.model_name} is already present in results" 606 if mod.results_type not in results[key]: 607 results[key][mod.results_type] = {} 608 if mod.model_name not in results[key][mod.results_type]: 609 results[key][mod.results_type][mod.model_name] = [] 610 stats_to_use = mod.stats 611 if len(mod.stats_rnn) > 0: 612 stats_to_use = mod.stats_rnn 613 data = { 614 "type": mod.results_type, 615 "values": stats_to_use, 616 "ref_node_name": mod.ref_node_name, 617 "ref_node_target_type": mod.ref_node_target_type, 618 "prev_node_name": mod.prev_node_name, 619 "prev_node_target_type": mod.prev_node_target_type, 620 "index_within_arg": mod.index_within_arg, 621 "index_of_arg": mod.index_of_arg, 622 "fqn": mod.fqn, 623 "qconfig_str": mod.qconfig_str, 624 } 625 if hasattr(mod, "comparisons"): 626 data["comparisons"] = mod.comparisons 627 data["comparison_fn_name"] = mod.comparison_fn_name 628 else: 629 data["comparisons"] = [] 630 data["comparison_fn_name"] = "" 631 results[key][mod.results_type][mod.model_name].append(data) 632 # ensure the list stays sorted 633 results[key][mod.results_type][mod.model_name].sort( 634 key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}" 635 ) 636 637 638# TODO(future PR): align on naming 639# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` 640def extract_logger_info( 641 model_a: nn.Module, 642 model_b: nn.Module, 643 logger_cls: Callable, 644 model_name_to_use_for_layer_names: str, 645) -> NSResultsType: 646 """ 647 Traverse all loggers in `model_a` and `model_b`, and extract the logged 648 information. 649 650 Args: 651 model_a: model A 652 model_b: model B 653 logger_cls: class of Logger to use 654 model_name_to_use_for_layer_names: string name of model to use for 655 layer names in the output 656 657 Return: 658 NSResultsType, containing the logged comparisons 659 """ 660 torch._C._log_api_usage_once( 661 "quantization_api._numeric_suite_fx.extract_logger_info" 662 ) 663 results: NSResultsType = {} 664 for model in (model_a, model_b): 665 _extract_logger_info_one_model(model, results, logger_cls) 666 # fill in missing fqn entries 667 maybe_add_missing_fqns(results) 668 # rekey on the name of model b 669 results = rekey_logger_info_on_node_name_of_model( 670 results, model_name_to_use_for_layer_names 671 ) 672 return results 673 674 675def _add_shadow_loggers_impl( 676 name_a: str, 677 gm_a: GraphModule, 678 name_b: str, 679 gm_b: GraphModule, 680 logger_cls: Callable, 681 should_log_inputs: bool, 682 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 683 node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 684 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 685) -> nn.Module: 686 torch._C._log_api_usage_once( 687 "quantization_api._numeric_suite_fx._add_shadow_loggers_impl" 688 ) 689 matched_subgraph_pairs = get_matching_subgraph_pairs( 690 gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map 691 ) 692 gm_a_shadows_b = create_a_shadows_b( 693 name_a, 694 gm_a, 695 name_b, 696 gm_b, 697 matched_subgraph_pairs, 698 logger_cls, 699 should_log_inputs=should_log_inputs, 700 node_type_to_io_type_map=node_type_to_io_type_map, 701 ) 702 return gm_a_shadows_b 703 704 705def add_shadow_loggers( 706 name_a: str, 707 model_a: nn.Module, 708 name_b: str, 709 model_b: nn.Module, 710 logger_cls: Callable, 711 should_log_inputs: bool = False, 712 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 713 node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 714 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 715) -> nn.Module: 716 """ 717 Instrument model A and model B with shadow loggers. 718 719 Args: 720 name_a: string name of model A to use in results 721 model_a: model A 722 name_b: string name of model B to use in results 723 model_b: model B 724 logger_cls: class of Logger to use 725 should_log_inputs: whether to log inputs 726 base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change 727 unmatchable_types_map: optional override of unmatchable types, subject to change 728 """ 729 torch._C._log_api_usage_once( 730 "quantization_api._numeric_suite_fx.add_shadow_loggers" 731 ) 732 # TODO(future PR): expose these 733 skipped_module_names: List[str] = [] 734 skipped_module_classes: List[Callable] = [] 735 tracer_a = NSTracer(skipped_module_names, skipped_module_classes) 736 tracer_b = NSTracer(skipped_module_names, skipped_module_classes) 737 gm_a = GraphModule(model_a, tracer_a.trace(model_a)) 738 maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( 739 model_a, "node_name_to_scope" 740 ) 741 if maybe_model_a_node_name_to_scope is not None: 742 gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope 743 gm_b = GraphModule(model_b, tracer_b.trace(model_b)) 744 maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( 745 model_b, "node_name_to_scope" 746 ) 747 if maybe_model_b_node_name_to_scope is not None: 748 gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope 749 return _add_shadow_loggers_impl( 750 name_a, 751 gm_a, 752 name_b, 753 gm_b, 754 logger_cls, 755 should_log_inputs=should_log_inputs, 756 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, 757 node_type_to_io_type_map=node_type_to_io_type_map, 758 unmatchable_types_map=unmatchable_types_map, 759 ) 760 761 762def extract_shadow_logger_info( 763 model_a_shadows_b: nn.Module, 764 logger_cls: Callable, 765 model_name_to_use_for_layer_names: str, 766) -> NSResultsType: 767 """ 768 Traverse all loggers in a shadow model, and extract the logged 769 information. 770 771 Args: 772 model_a_shadows_b: shadow model 773 logger_cls: class of Logger to use 774 model_name_to_use_for_layer_names: string name of model to use for 775 layer names in the output 776 777 Return: 778 NSResultsType, containing the logged comparisons 779 """ 780 torch._C._log_api_usage_once( 781 "quantization_api._numeric_suite_fx.extract_shadow_logger_info" 782 ) 783 results: NSResultsType = collections.defaultdict(dict) 784 _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) 785 # fill in missing fqn entries 786 maybe_add_missing_fqns(results) 787 # rekey on the name of model b 788 results = rekey_logger_info_on_node_name_of_model( 789 results, model_name_to_use_for_layer_names 790 ) 791 return dict(results) 792 793 794def extend_logger_results_with_comparison( 795 results: NSResultsType, 796 model_name_1: str, 797 model_name_2: str, 798 comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 799 comparison_name: str, 800) -> None: 801 """ 802 Compares the logged values from `model_name_2` against the corresponding 803 values in `model_name_1`, using `comparison_fn`. Records the result 804 in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. 805 806 Args: 807 results: the result data structure from `extract_logger_info` or 808 `extract_shadow_logger_info`. 809 model_name_1: string name of model 1 810 model_name_2: string name of model 2 811 comparison_fn: function to compare two Tensors 812 comparison_name: string name of model to use for 813 layer names in the output 814 """ 815 for results_type_to_results in results.values(): 816 for model_name_to_results in results_type_to_results.values(): 817 assert ( 818 model_name_1 in model_name_to_results 819 ), f"{model_name_1} not found in results" 820 assert ( 821 model_name_2 in model_name_to_results 822 ), f"{model_name_2} not found in results" 823 824 results_1 = model_name_to_results[model_name_1] 825 results_2 = model_name_to_results[model_name_2] 826 827 for result_2 in results_2: 828 index_within_arg_2 = result_2["index_within_arg"] 829 index_of_arg_2 = result_2["index_of_arg"] 830 # find corresponding result_1 831 result_1 = None 832 for cur_result_1 in results_1: 833 index_within_arg_1 = cur_result_1["index_within_arg"] 834 index_of_arg_1 = cur_result_1["index_of_arg"] 835 if (index_within_arg_1 == index_within_arg_2) and ( 836 index_of_arg_1 == index_of_arg_2 837 ): 838 result_1 = cur_result_1 839 break 840 assert result_1 is not None 841 842 values_1 = result_1["values"] 843 values_2 = result_2["values"] 844 result_2[comparison_name] = [] 845 for value_1, value_2 in zip(values_1, values_2): 846 comparison_result = comparison_fn(value_1, value_2) 847 result_2[comparison_name].append(comparison_result) 848 849 850def prepare_n_shadows_model( 851 model: torch.nn.Module, 852 example_inputs: Any, 853 qconfig_multi_mapping: QConfigMultiMapping, 854 backend_config: BackendConfig, 855 custom_prepare_fn: Optional[Callable] = None, 856 custom_prepare_kwargs: Optional[Dict[str, Any]] = None, 857 custom_tracer: Any = None, 858) -> GraphModule: 859 """ 860 Given a model with a graph with M ops such as 861 862 863 args_kwargs_m -> op_m -> output_m 864 865 866 And a set of N qconfigs for each op, creates a new model, with 867 each of the subgraph of `op_m` transformed into 868 869 .. code:: 870 871 |---------> op_m_n -> log_m_n 872 | / 873 args_kwargs_m ---------> op_m -> log_m_0 874 875 Where op_m_n is op_m wrapped in a submodule and transformed with 876 qconfig_n, and its inner graph looks like 877 878 .. code:: 879 880 args_m -------- op_m_prepared_with_qconfig_n -> out_m_n 881 / 882 kwargs_m --- 883 884 This is useful for testing different quantization of multiple layers in 885 a single pass through the model. 886 887 High level TODOs for future PRs: 888 * figure out a better way to name the output structure 889 * return a results data structure instead of printing it out 890 * add examples to docblocks 891 """ 892 893 if custom_tracer is None: 894 tracer = quantize_fx.QuantizationTracer([], []) 895 else: 896 tracer = custom_tracer 897 mt = torch.fx.GraphModule(model, tracer.trace(model)) 898 # this is necessary to ensure logger FQNs get populated 899 mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] 900 901 # run example input propagation, we need this to call prepare_fx on 902 # individual subgraphs 903 output_prop = OutputProp(mt) 904 output_prop.propagate(*example_inputs) 905 906 # Find the set of subgraphs in the original graph which we need to 907 # consider. 908 modules = dict(mt.named_modules(remove_duplicate=False)) 909 patterns = _get_pattern_to_quantize_handlers(backend_config) 910 root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) 911 standalone_module_names: List[str] = [] 912 standalone_module_classes: List[Type] = [] 913 custom_module_classes: List[Type] = [] 914 matches = _find_matches( 915 mt.graph, 916 modules, 917 patterns, 918 root_node_getter_mapping, 919 standalone_module_names, 920 standalone_module_classes, 921 custom_module_classes, 922 ) 923 subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches) 924 925 # generate node to qconfig for each subgraph 926 # TODO(future PR): deduplicate repeating entries 927 list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = [] 928 for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: 929 node_name_to_qconfig = _generate_node_name_to_qconfig( 930 mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope 931 ) 932 list_of_node_name_to_qconfig.append(node_name_to_qconfig) 933 934 # For each region in the model, do the following: 935 # For each qconfig for that region, do the following: 936 # 1. create a copy of the region wrapped in a module 937 # 2. pass original args, original kwargs, and expected output to module 938 # 3. add an output comparison logger and hook it up to compare 939 # actual output to expected output 940 # 4. run `prepare_fx` on the module 941 for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate( 942 subgraphs_dedup.items() 943 ): 944 create_n_transformed_and_logged_copies_of_subgraph( 945 mt, 946 subgraph_idx, 947 match_name, 948 nodes_in_this_subgraph, 949 qconfig_multi_mapping.qconfig_mappings_list, 950 list_of_node_name_to_qconfig, 951 custom_prepare_fn, 952 custom_prepare_kwargs, # type: ignore[arg-type] 953 ) 954 955 return mt 956 957 958# TODO(future PR): we should rethink the names of all the PNP APIs 959def _prepare_n_shadows_add_loggers_model( 960 model: torch.nn.Module, 961 example_inputs: Any, 962 qconfig_mapping: QConfigMapping, 963 backend_config: BackendConfig, 964) -> torch.nn.Module: 965 r""" 966 Note: this API is not recommended for wide usage, it is only 967 provided for customers who need to migrate from the `add_loggers` 968 API. 969 970 This creates a model which provides logging for the following 971 problem: if we quantize `model` with `qconfig_mapping` and feed 972 the same input through both models, log the comparisons of 973 corresponding intermediate layers. 974 975 The problem is solved with a single model. Specifically, we 976 partition `model` into N subgraphs, create a copy of each relevant 977 subgraph, wrap it in a module, apply the quantization API to that 978 module, and hook up loggers to measure the comparisons. 979 980 Example starting graph: 981 982 x0 -> op0 -> x1 -> op1 -> x2 983 984 Example config: quantize op0 to int8, do nothing to op1. 985 The following graph will be created: 986 987 .. code:: 988 989 x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log 990 \ \ \ # noqa: W605 991 ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog 992 993 Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized 994 to int8, op1_0 is op1 (appearing in the graph twice), log is a logger, 995 and clog is a comparison logger. 996 """ 997 998 tracer = quantize_fx.QuantizationTracer([], []) 999 mt = torch.fx.GraphModule(model, tracer.trace(model)) 1000 # this is necessary to ensure logger FQNs get populated 1001 mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] 1002 1003 # run example input propagation, we need this to call prepare_fx on 1004 # individual subgraphs 1005 output_prop = OutputProp(mt) 1006 output_prop.propagate(*example_inputs) 1007 1008 # Find the set of subgraphs in the original graph which we need to 1009 # consider. 1010 modules = dict(mt.named_modules(remove_duplicate=False)) 1011 patterns = _get_pattern_to_quantize_handlers(backend_config) 1012 root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) 1013 standalone_module_names: List[str] = [] 1014 standalone_module_classes: List[Type] = [] 1015 custom_module_classes: List[Type] = [] 1016 matches = _find_matches( 1017 mt.graph, 1018 modules, 1019 patterns, 1020 root_node_getter_mapping, 1021 standalone_module_names, 1022 standalone_module_classes, 1023 custom_module_classes, 1024 ) 1025 subgraphs_dedup: Dict[str, List[Node]] = _get_dedup_subgraphs(matches) 1026 1027 # generate node to qconfig for each subgraph 1028 node_name_to_qconfig = _generate_node_name_to_qconfig( 1029 mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope 1030 ) 1031 1032 # Now, mutate the graph to be the add_loggers graph with propagation 1033 # error. 1034 create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig) 1035 1036 return mt 1037 1038 1039# TODO(future PR): we should rethink the names of all the PNP APIs 1040def _n_shadows_compare_weights( 1041 model: torch.nn.Module, 1042 example_inputs: Any, 1043 qconfig_mapping: QConfigMapping, 1044 backend_config: BackendConfig, 1045) -> NSResultsType: 1046 """ 1047 Note: this API is not recommended for wide usage, it is only 1048 provided for customers who need to migrate from the `add_loggers` 1049 API. 1050 """ 1051 qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping( 1052 [qconfig_mapping] 1053 ) 1054 mp = prepare_n_shadows_model( 1055 model, example_inputs, qconfig_multi_mapping, backend_config 1056 ) 1057 # passing inputs through the model is necessary to populate 1058 # observers which observe weights with real values 1059 mp(*example_inputs) 1060 mq = convert_n_shadows_model(mp) 1061 weight_comparison = extract_weight_comparison(mq) 1062 return weight_comparison 1063 1064 1065# TODO(future PR): consider aligning API signature with other similar quantization 1066# functions (enable_fake_quant, etc) 1067def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None: 1068 """ 1069 Sets the `enabled` setting on a `model`'s loggers 1070 """ 1071 for name, child in model.named_modules(): 1072 if isinstance(child, OutputLogger): 1073 child.enabled = enabled 1074 1075 1076# TODO(future PR): consider aligning API signature with other similar quantization 1077# functions (enable_fake_quant, etc) 1078def loggers_set_save_activations( 1079 model: torch.nn.Module, 1080 save_activations: bool, 1081) -> None: 1082 """ 1083 Sets the `save_activations` setting on a `model`'s loggers 1084 """ 1085 for name, child in model.named_modules(): 1086 if isinstance(child, OutputLogger): 1087 child.save_activations = save_activations 1088 1089 1090def convert_n_shadows_model( 1091 model: GraphModule, 1092 custom_convert_fn: Optional[Callable] = None, 1093 custom_convert_kwargs: Optional[Dict[str, Any]] = None, 1094) -> GraphModule: 1095 """ 1096 Given a model from `prepare_n_shadows_model`, runs `convert_fx` 1097 on each shadow submodule. 1098 """ 1099 for node in model.graph.nodes: 1100 # TODO(future PR): consider matching in a safer way than 1101 # node name string match 1102 if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): 1103 orig_mod = getattr(model, node.name) 1104 if custom_convert_fn is None: 1105 converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod) 1106 else: 1107 if custom_convert_kwargs is None: 1108 custom_convert_kwargs = {} 1109 converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) 1110 setattr(model, node.name, converted_mod) 1111 1112 return model 1113 1114 1115def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType: 1116 """ 1117 Extracts logger results from `model`. 1118 """ 1119 results: NSResultsType = {} 1120 _extract_logger_info_one_model(model, results, OutputLogger) 1121 return results 1122 1123 1124def print_comparisons_n_shadows_model(results: NSResultsType) -> None: 1125 """ 1126 Prints a summary of extracted `results`. 1127 """ 1128 results_grouped = group_results_by_subgraph(results) 1129 results_comparison = create_results_comparison(results_grouped) 1130 print_n_shadows_summary(results_comparison) 1131