1# mypy: allow-untyped-defs 2import logging 3from dataclasses import dataclass 4from typing import Any, Callable, Dict, List, Optional, Tuple 5 6import torch 7import torch.fx 8 9from torch.fx._compatibility import compatibility 10from torch.fx.node import map_arg 11 12from .shape_prop import ShapeProp 13from .split_utils import split_by_tags 14from .tools_common import ( 15 CALLABLE_NODE_OPS, 16 FxNetAccFusionsFinder, 17 Names, 18 NodeList, 19 NodeSet, 20 TensorOrTensors, 21 Tensors, 22) 23 24__all__ = [ 25 "FxNetMinimizerBadModuleError", 26 "FxNetMinimizerRunFuncError", 27 "FxNetMinimizerResultMismatchError", 28] 29 30_LOGGER = logging.getLogger(__name__) 31 32 33@compatibility(is_backward_compatible=False) 34class FxNetMinimizerBadModuleError(Exception): 35 """ 36 Raised if failed to split out a minimize module 37 """ 38 39 40 41@compatibility(is_backward_compatible=False) 42class FxNetMinimizerRunFuncError(Exception): 43 """ 44 Raised if error occurs during run_a or run_b functions 45 """ 46 47 48 49@compatibility(is_backward_compatible=False) 50class FxNetMinimizerResultMismatchError(Exception): 51 """ 52 Raised if comparing function thinks the results are mismatching. 53 """ 54 55 56 57@dataclass 58class _MinimizerSettingBase: 59 """ 60 Args: 61 `accumulate_error`: Instead of using a's input for both converted module to verify 62 , use the previous outputs of each converted module as input to accumulate the 63 errors. 64 65 `traverse_method`: "sequential" or "binary" or "accumulate" 66 Determine the way of traverse the nodes in FX module. 67 68 `find_all`: Minimizer will go through the entire model and return all problematic nodes. 69 70 `return_intermediate`: If true, when using `run_nodes()` function to run the 71 model, intermediate results of all the ops will be returned as output. 72 """ 73 74 accumulate_error: bool = False 75 traverse_method: str = "sequential" 76 find_all: bool = False 77 return_intermediate: bool = False 78 79 def __str__(self): 80 settings_str = "FX Minimizer Settings:\n" 81 82 for k, v in vars(self).items(): 83 settings_str += f"\t{k}: {v}\n" 84 85 return settings_str 86 87 88class _MinimizerBase: 89 """ 90 This class is used to automatically find problematic nodes in a model. It takes a FX 91 graphmodule and generate some submodules while traverse the graph. Then two functions 92 `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` 93 will be used to compare the results. 94 95 Currently we provides two ways to traverse the graph and generate submodules. 96 1. Sequential traversal: this will traverse the graph node by node and generate 97 one submodule with one sigle node. 98 2. Binary searching: this will do a binary search style traversal on the graph. 99 100 For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. 101 """ 102 103 def __init__( 104 self, 105 module: torch.fx.GraphModule, 106 sample_input: Tensors, 107 compare_fn: Callable[ 108 [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] 109 ], 110 settings: _MinimizerSettingBase, 111 module_exporter: Optional[ 112 Callable[ 113 [Tensors, torch.fx.GraphModule, str], 114 None 115 ] 116 ] = None, 117 exclusion_fn: Optional[ 118 Callable[[NodeList, int, int], None] 119 ] = None, 120 ): 121 assert isinstance(module, torch.fx.GraphModule) 122 123 self.module = module 124 self.sample_input = sample_input 125 self.compare_fn = compare_fn 126 self.module_exporter = module_exporter 127 self.settings = settings 128 self.exclusion_fn = exclusion_fn 129 130 # Stores outputs of run_a function 131 self.a_outputs: Dict[str, Any] = {} 132 133 # Stores outputs of run_b function 134 self.b_outputs: Dict[str, Any] = {} 135 136 # Stores the results of compare_fn 137 self.results: Dict[Any, Any] = {} 138 139 # Stores the report for the runs 140 self.reports: List[List[str]] = [] 141 142 # Current iteration 143 self.iteration: int = 0 144 145 callable_nodes = { 146 node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS 147 } 148 ShapeProp(self.module).propagate(*self.sample_input) 149 self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() 150 151 # Check if number of input in sample_input matches the number of placeholders 152 placeholders = [ 153 node.name for node in self.module.graph.nodes if node.op == "placeholder" 154 ] 155 assert len(placeholders) == len(self.sample_input) 156 157 # Store sample_input 158 for i, name in enumerate(placeholders): 159 self.a_outputs[name] = sample_input[i] 160 self.b_outputs[name] = sample_input[i] 161 162 def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: 163 """ 164 Run `mod` with `inputs` and generate output. The output will be compared with 165 output of run_b(). 166 """ 167 raise RuntimeError("run_a() is not implemented.") 168 169 def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: 170 """ 171 Run `mod` with `inputs` and generate output. The output will be compared with 172 output of run_a(). 173 """ 174 raise RuntimeError("run_b() is not implemented.") 175 176 def _store_outputs( 177 self, 178 a_result: TensorOrTensors, 179 b_result: TensorOrTensors, 180 submodule: torch.fx.GraphModule, 181 ): 182 """ 183 Store the outputs of self.run_a() and self.run_b() into self.a_outputs and 184 self.b_outputs, so that we can use them when execute preceding nodes that 185 use those outputs as inputs. 186 187 Args: 188 a_result: Output of self.run_a(). Could be a tensor or tensors. 189 b_result: Output of self.run_b(). Could be a tensor or tensors. 190 submodule: The module that generates a_result and b_result. 191 """ 192 output_node = next( 193 node for node in submodule.graph.nodes if node.op == "output" 194 ) 195 196 # Only one output 197 if isinstance(output_node.args[0], torch.fx.Node): 198 self.a_outputs[output_node.args[0].name] = a_result 199 self.b_outputs[output_node.args[0].name] = b_result 200 # Multiple outputs 201 else: 202 for i, arg in enumerate(output_node.args[0]): 203 self.a_outputs[arg.name] = a_result[i] 204 self.b_outputs[arg.name] = b_result[i] 205 206 def _get_submod_inputs( 207 self, main_module: torch.fx.GraphModule, submod_path: str 208 ) -> Tuple[Tensors, Tensors]: 209 """ 210 Try get submodule inputs from stored outputs. If not found then use 211 torch_glow.get_submod_inputs to get the inputs. 212 213 If accumulate_error is False, use a_input for run_a() and run_b() 214 otherwise use a_input for run_a and b_input for run_b. 215 216 Args: 217 main_module: Top-levlel fx module. 218 submod_path: Path to the submodule we want to run and compare results. 219 220 Returns: 221 a_input: List of tensor(s) that will be used by run_a() as submodule inputs. 222 b_input: List of tensor(s) that will be used by run_b() as submodule inputs. 223 """ 224 a_input = [] 225 b_input = [] 226 submodule = getattr(main_module, submod_path) 227 placeholders = [ 228 node.name for node in submodule.graph.nodes if node.op == "placeholder" 229 ] 230 231 # If all placeholder can be found in stored outputs, use stored 232 # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` 233 # to get the inputs. 234 if set(placeholders) <= self.a_outputs.keys(): 235 for name in placeholders: 236 a_input.append(self.a_outputs[name]) 237 b_input.append(self.b_outputs[name]) 238 else: 239 if self.settings.accumulate_error: 240 print(f"Can't find previous stored outputs named {placeholders}!") 241 242 def get_inputs(self: torch.nn.Module, inputs: Any): 243 nonlocal a_input 244 a_input = inputs 245 246 # Use forward hook to get the inputs to the submodule 247 handle = submodule.register_forward_pre_hook(get_inputs) 248 main_module(*self.sample_input) 249 handle.remove() 250 251 b_input = a_input 252 253 if not self.settings.accumulate_error: 254 return a_input, a_input 255 256 return a_input, b_input 257 258 def _tag_nodes(self, selected_nodes: NodeSet): 259 """ 260 Tag selected nodes with tag "minimize". Nodes with the same tags will 261 be split to the same submodule afterwards. 262 263 Args: 264 selected_nodes: Nodes that we want to minimize. We will tag those nodes 265 with "minimize", all preceding nodes with "main_0" and all following 266 nodes with "main_1". 267 """ 268 for node in self.module.graph.nodes: 269 if node.op not in CALLABLE_NODE_OPS: 270 continue 271 272 if node in selected_nodes: 273 node.tag = "minimize" 274 elif any( 275 n.tag in {"minimize", "main_1"} 276 for n in node.all_input_nodes 277 if n.op in CALLABLE_NODE_OPS 278 ): 279 node.tag = "main_1" 280 else: 281 node.tag = "main_0" 282 283 def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]: 284 """ 285 Split self.module so that one submodule consists of `nodes` and only `nodes`. 286 287 Args: 288 nodes: Nodes that we want to include in the minimize submodule. 289 290 Returns: 291 split_module (torch.fx.GraphModule): the module after split. 292 submodule_name (str): the name of the submodule that consists of `nodes`. 293 """ 294 # Color provided nodes 295 self._tag_nodes(nodes) 296 297 # Split module based on coloring 298 split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) 299 300 # Find submodule containing colored nodes 301 submodule_name: str = "" 302 for child_name, _ in split_module.named_children(): # type: ignore[union-attr] 303 # Skip submodules we're not interested in at the moment 304 if "minimize" not in child_name: 305 continue 306 307 if submodule_name == "": 308 submodule_name = child_name 309 else: 310 raise FxNetMinimizerBadModuleError( 311 f"Expected only one minimize submodule with nodes {nodes}" 312 ) 313 314 if submodule_name == "": 315 raise FxNetMinimizerBadModuleError( 316 f"Minimize submodule was not found with nodes {nodes}" 317 ) 318 319 return split_module, submodule_name # type: ignore[return-value] 320 321 def _run_and_compare( 322 self, 323 split_module: torch.fx.GraphModule, 324 submod_name: str, 325 output_names: Names, 326 report_idx: int = -1 327 ): 328 """ 329 Run the submodule in `split_module` that has name `submod_name` 330 using `self.run_a` and `self.run_b` and compare their results. 331 332 Args: 333 split_module: Main module that contains the minimize submodule. 334 submod_name: Name of the minimize submodule. 335 output_names: Names of the node we want to output. If None, we 336 will use the original output. 337 """ 338 submodule = getattr(split_module, submod_name) 339 a_input, b_input = self._get_submod_inputs(split_module, submod_name) 340 341 if len(self.reports) == 0: 342 self.reports.append([]) 343 self.iteration = 1 344 345 report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] 346 report.append("Run and compare ...") 347 348 if output_names: 349 output_nodes: NodeList = [] 350 for node in submodule.graph.nodes: 351 if node.op == "output": 352 submodule.graph.erase_node(node) 353 354 if node.name in output_names: 355 output_nodes.append(node) 356 357 submodule.graph.output( 358 output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) 359 ) 360 submodule.graph.lint() 361 submodule.recompile() 362 363 # Use name of args in output node as key to store comparison result 364 for node in submodule.graph.nodes: 365 if node.op == "output": 366 result_key = map_arg(node.args, lambda x: x.name) 367 368 try: 369 a_result = self.run_a(submodule, a_input, report_idx) 370 b_result = self.run_b(submodule, b_input, report_idx) 371 self._store_outputs(a_result, b_result, submodule) 372 except Exception as e: 373 report.append(f"Exception raised when running {submod_name}: {e}") 374 raise FxNetMinimizerRunFuncError( # noqa: B904 375 f"Exception raised when running {submod_name}: {e}" 376 ) 377 378 # Compare results 379 names: Names = output_names 380 if output_names is None: 381 names = [str(v) for v in result_key] # type: ignore[possibly-undefined] 382 383 numeric_result, bool_result = self.compare_fn(a_result, b_result, names) 384 385 self.results[result_key] = numeric_result # type: ignore[possibly-undefined] 386 report.append(f"Numerical accuracy = {numeric_result}") 387 if not bool_result: 388 report.append(f"Result mismatch for {result_key}") 389 if self.module_exporter: 390 self.module_exporter( 391 a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] 392 ) 393 self.module_exporter( 394 b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] 395 ) 396 raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") 397 398 def _binary_search_impl( 399 self, all_nodes: NodeList, start_idx: int, end_idx: int 400 ) -> NodeSet: 401 """ 402 Recursive binary search implementation. 403 """ 404 culprits: NodeSet = set() 405 nodes: NodeList = all_nodes[start_idx:end_idx] 406 407 report: List[str] = [] 408 if self.exclusion_fn is not None: 409 self.exclusion_fn(nodes, start_idx, end_idx) 410 if len(nodes) == 0: 411 report = ["All nodes are excluded by user"] 412 self.reports.append(report) 413 return culprits 414 415 first_node_name = nodes[0].name 416 output_node_name = nodes[-1].name 417 self.iteration += 1 418 self.reports.append(report) 419 report.append(f"Binary search iteration {self.iteration}") 420 report.append( 421 f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. " 422 f"Size of the interested node list is {len(nodes)}" 423 ) 424 cur_nodes: NodeSet = set(nodes) 425 426 try: 427 split_module, submod_name = self._build_submodule(cur_nodes) 428 self._run_and_compare(split_module, submod_name, [output_node_name]) 429 430 except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): 431 432 if len(nodes) == 1: 433 report.append( 434 f"This is the last node in the sub-module. " 435 f"Search in the current branch is successful with culprit = {cur_nodes}." 436 ) 437 self.print_report(report) 438 return cur_nodes 439 440 report.append( 441 "Proceed to split and lower the halves of the current " 442 "sub-module individually." 443 ) 444 self.print_report(report) 445 446 mid = len(nodes) // 2 447 culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) 448 449 if len(culprits) != 0 and not self.settings.find_all: 450 return culprits 451 452 culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) 453 454 if len(culprits) == 0: 455 report.append( 456 f"Further split and lowering found no errors. " 457 f"Unable to minimize the submodule with list of nodes: {nodes}" 458 ) 459 self.print_report(report) 460 461 return culprits 462 else: 463 report.append("No discrepancy found.") 464 self.print_report(report) 465 return set() 466 467 def _binary_traverse(self, nodes: NodeList) -> NodeSet: 468 """ 469 Binary search on `nodes` for culprit. 470 """ 471 return self._binary_search_impl(nodes, 0, len(nodes)) 472 473 def _sequential_traverse(self, nodes: NodeList) -> NodeSet: 474 """ 475 Traverse `nodes` one by one and determine if any of them is a culprit. 476 """ 477 culprits: NodeSet = set() 478 479 for node in nodes: 480 report: List[str] = [] 481 self.reports.append(report) 482 self.iteration += 1 483 report.append(f"Sequential traverse iteration {self.iteration}.") 484 report.append(f"Visit node: {node.name}") 485 486 _LOGGER.info("Visit node: %s", node.name) 487 node_list: NodeList = [node] 488 if self.exclusion_fn is not None: 489 self.exclusion_fn(node_list, -1, -1) 490 if len(node_list) == 0: 491 report.append(f"User exclusion : {node.name}") 492 self.print_report(report) 493 if not self.settings.find_all: 494 return culprits 495 else: 496 continue 497 498 cur_nodes: NodeSet = {node} 499 500 if node in self.fusions: 501 cur_nodes = self.fusions[node] 502 503 try: 504 split_module, submod_name = self._build_submodule(cur_nodes) 505 self._run_and_compare(split_module, submod_name, [node.name]) 506 self.print_report(report) 507 except (FxNetMinimizerResultMismatchError): 508 culprits.add(node) 509 report.append(f"Found culprit from numeric error: {node}") 510 self.print_report(report) 511 if not self.settings.find_all: 512 return culprits 513 except (FxNetMinimizerRunFuncError): 514 culprits.update(cur_nodes) 515 report.append(f"Found culprit from run error: {node}") 516 self.print_report(report) 517 if not self.settings.find_all: 518 return culprits 519 520 return culprits 521 522 523 def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int: 524 """ 525 Recursive block search implementation. 526 find_last_node: If True, search for the last node which result in numerics difference 527 if False: find first node in sorted node list 528 """ 529 report: List[str] = [] 530 531 mid = (start_idx + end_idx) // 2 532 cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:] 533 534 if self.exclusion_fn: 535 self.exclusion_fn(cur_nodes_list, -1, -1) 536 537 cur_nodes = set(cur_nodes_list) 538 539 first_node_name = cur_nodes_list[0].name 540 last_node_name = cur_nodes_list[-1].name 541 target_node_name = last_node_name if find_last_node else first_node_name 542 543 self.iteration += 1 544 self.reports.append(report) 545 report.extend( 546 [ 547 "=" * 30, 548 f"Block search iteration {self.iteration}", 549 ] 550 ) 551 report.extend( 552 [ 553 f"Search for {'last' if find_last_node else 'first'} node in culprits", 554 f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ", 555 f"Subgraph constructed by {first_node_name} to {last_node_name}", 556 f"Targeting node: {target_node_name}", 557 f"Size of the interested node list is {end_idx - start_idx + 1}", 558 ] 559 ) 560 report_idx = len(self.reports) - 1 561 562 try: 563 split_module, submod_name = self._build_submodule(cur_nodes) 564 self._run_and_compare(split_module, submod_name, [last_node_name], report_idx) 565 except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): 566 report.append(f"Culprits found from node {first_node_name} to {last_node_name}.") 567 568 if start_idx == mid: 569 report.extend( 570 [ 571 "This is the last node in the sub-module. ", 572 "Search in the current branch is successful with node :", 573 f"{start_idx}, node name: {nodes[start_idx].name}." 574 ] 575 ) 576 self.print_report(report) 577 return start_idx 578 579 report.append( 580 "Proceed to split and lower the halves of the current " 581 "sub-module individually." 582 ) 583 self.print_report(report) 584 585 if find_last_node: 586 return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) 587 else: 588 return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) 589 else: 590 report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.") 591 592 if start_idx == mid: 593 report.extend( 594 [ 595 "This is the last node in the sub-module. ", 596 "Search in the current branch is successful with node", 597 f"{start_idx}, node name: {nodes[start_idx].name}.", 598 ] 599 ) 600 self.print_report(report) 601 return start_idx + 1 if find_last_node else start_idx - 1 602 603 report.append( 604 "Proceed to split and lower the halves of the current " 605 "sub-module individually." 606 ) 607 self.print_report(report) 608 609 if find_last_node: 610 return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) 611 else: 612 return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) 613 614 615 def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet: 616 """ 617 Traverse topologically sorted node list 618 Find minimium block (start_idx, end_idx) which contains the culprit 619 1st pass: search for end_idx by finding the last node in culprit block 620 where Numerical accuracy (0, end_idx) > threshold 621 2nd pass: search for start_idx by finding the first node in culprit block 622 where Numerical accuracy (start_idx, end_idx) < threshold 623 Form minimum block by (start_idx - 1, end_idx) 624 """ 625 culprits: NodeSet = set() 626 first_node_name = nodes[0].name 627 last_node_name = nodes[-1].name 628 last_node_report = [f"Block search from {first_node_name} to {last_node_name}"] 629 last_node_report.append("*" * 50) 630 self.reports.append(last_node_report) 631 632 start_idx = 0 633 end_idx = len(nodes) - 1 634 run_both = True if find_last_node is None else False 635 636 # step 1: find (0, end_idx) of culprit block 637 if run_both or find_last_node: 638 last_node_report.append("Start searching for last node in culprit") 639 self.print_report(last_node_report) 640 end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) 641 last_node_report.extend( 642 [ 643 "Finish Pass 1", 644 f"Find end_idx = {end_idx}:{nodes[end_idx].name}" 645 ] 646 ) 647 self.print_report(last_node_report) 648 649 # step 2: reduce culprit block to (start_idx, end_idx) 650 if run_both or not find_last_node: 651 first_node_report = ["Start searching for first node in culprit"] 652 self.print_report(first_node_report) 653 start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False) 654 first_node_report.append("*" * 50) 655 self.reports.append(first_node_report) 656 first_node_report.extend( 657 [ 658 "Finish Pass 2", 659 f"Find start_idx = {start_idx}:{nodes[start_idx].name}" 660 ] 661 ) 662 self.print_report(first_node_report) 663 664 # step 3: form module with minimum culprits 665 culprits.update(nodes[start_idx:end_idx + 1]) 666 result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"] 667 self.reports.append(result_report) 668 self.print_report(result_report) 669 return culprits 670 671 672 def _defined_traverse(self, nodes: NodeList) -> NodeSet: 673 """ 674 run user defined `nodes` and determine if it is a culprit. 675 """ 676 culprits: NodeSet = set() 677 if self.exclusion_fn is not None: 678 self.exclusion_fn(nodes, -1, -1) 679 if len(nodes) == 0: 680 report = ["All nodes are excluded by user"] 681 self.reports.append(report) 682 return culprits 683 684 first_node_name = nodes[0].name 685 output_node_name = nodes[-1].name 686 report = [f"Defined graph from {first_node_name} to {output_node_name}"] 687 cur_nodes: NodeSet = set(nodes) 688 try: 689 split_module, submod_name = self._build_submodule(cur_nodes) 690 self._run_and_compare(split_module, submod_name, [output_node_name]) 691 self.print_report(report) 692 except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): 693 report.append(f"Found culprit {cur_nodes}") 694 self.print_report(report) 695 return culprits 696 697 return culprits 698 699 def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: 700 culprits: NodeSet = set() 701 nodes_to_run: NodeSet = set() 702 703 # find_all is not supported for accumulate traversal because all the 704 # ops run on NNPI. So we return after the first op that raises error. 705 if self.settings.find_all: 706 print("'Find All' mode is not supported in accumulate traversal.") 707 return culprits 708 709 for node in nodes: 710 report: List[str] = [] 711 self.reports.append(report) 712 self.iteration += 1 713 report.append(f"Accumulate traverse iteration {self.iteration}.") 714 715 nodes_to_run.add(node) 716 717 node_name = node.name 718 if node_name is not None and isinstance(node_name, tuple): 719 node_name = node_name[0] 720 assert node_name is not None and isinstance( 721 node_name, str 722 ), f"minimize: node_name: {node_name}" 723 724 report.append(f"Add node: {node_name}") 725 726 try: 727 split_module, submod_name = self._build_submodule(nodes_to_run) 728 self._run_and_compare(split_module, submod_name, [node_name]) 729 self.print_report(report) 730 except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): 731 culprits.add(node) 732 report.append(f"Found culprit {node}") 733 self.print_report(report) 734 return culprits 735 736 return culprits 737 738 def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet: 739 """ 740 Skip certain nodes in graph based on settings 741 """ 742 culprits: NodeSet = set() 743 nodes: NodeList = all_nodes[start_idx:end_idx] 744 cur_nodes: NodeSet = set(nodes) 745 if self.exclusion_fn is not None: 746 self.exclusion_fn(nodes, start_idx, end_idx) 747 cur_nodes = set(nodes) 748 else: 749 for node in nodes: 750 if node in self.fusions: 751 cur_nodes.update(self.fusions[node]) 752 report: List[str] = [] 753 self.reports.append(report) 754 self.iteration += 1 755 report.append(f" Nodes block {self.iteration}.") 756 report.append( 757 f"From node index {start_idx} to {end_idx-1}. " 758 f"Size of the interested node list is {len(nodes)}" 759 ) 760 761 try: 762 split_module, submod_name = self._build_submodule(cur_nodes) 763 self._run_and_compare(split_module, submod_name, []) 764 except (FxNetMinimizerResultMismatchError): 765 culprits.update(cur_nodes) 766 report.append(f"Found culprit from numeric error: {cur_nodes}") 767 self.print_report(report) 768 return culprits 769 except (FxNetMinimizerRunFuncError): 770 culprits.update(cur_nodes) 771 report.append(f"Found culprit from run error: {cur_nodes}") 772 self.print_report(report) 773 return culprits 774 else: 775 report.append("No discrepancy found.") 776 self.print_report(report) 777 return set() 778 779 780 def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: 781 """ 782 Skip certain nodes in graph based on settings 783 """ 784 start_idx = 0 785 num_nodes = len(all_nodes) 786 idx = 0 787 culprits = set() 788 while idx < num_nodes: 789 node = all_nodes[idx] 790 if (node.name in skip_nodes): # skip the node 791 if idx > start_idx: 792 culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) 793 start_idx = idx + 1 794 elif idx == num_nodes - 1 and start_idx <= idx: # last node 795 culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1) 796 idx += 1 797 798 return culprits 799 800 801 802 def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: 803 """ 804 Collect nodes in the model that between nodes with name of `start` and `end`. 805 These two nodes are also included. 806 """ 807 nodes: NodeList = [] 808 add_node = start is None 809 810 for node in self.module.graph.nodes: 811 if node.op not in CALLABLE_NODE_OPS: 812 continue 813 814 if node.name == start: 815 add_node = True 816 817 if add_node: 818 nodes.append(node) 819 820 if node.name == end: 821 break 822 823 return nodes 824 825 def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): 826 """ 827 Run part of the model from `start` node to `end` node. If `start` is None 828 then we start from the beginning of the model. If `end` is None then we 829 stop at the end of the model. 830 831 Args: 832 start: The name of the node which is the first node of the submodule 833 we want to run. If set to None, then we'll start with the first 834 node of the model. 835 end: The name of the node which is the last node of the submodule we 836 want to run. If set to None, we'll end with the last node of the 837 model. 838 """ 839 nodes = self._collect_nodes(start, end) 840 cur_nodes = set(nodes) 841 842 for node in nodes: 843 if node in self.fusions: 844 cur_nodes.update(self.fusions[node]) 845 846 output_names = [] 847 if self.settings.return_intermediate: 848 output_names = [node.name for node in nodes] 849 850 try: 851 split_module, submod_name = self._build_submodule(cur_nodes) 852 self._run_and_compare(split_module, submod_name, output_names) 853 except ( 854 FxNetMinimizerRunFuncError, 855 FxNetMinimizerResultMismatchError, 856 ) as e: 857 print(e) 858 859 def print_report(self, report: List[str]): 860 for i in range(len(report)): 861 if i > 0: 862 print(" . " + report[i]) 863 else: 864 print(report[i]) 865 866 def print_reports(self): 867 for report in self.reports: 868 self.print_report(report) 869 870 def minimize( 871 self, 872 start: Optional[str] = None, 873 end: Optional[str] = None, 874 skip_nodes: Optional[List] = None, 875 find_last_node: Optional[bool] = None, 876 ) -> NodeSet: 877 """ 878 Minimizing the model from node with name `start` to node with name `end` base 879 on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or 880 FxNetMinimizerResultMismatchError errors. 881 882 Args: 883 start: The name of the node where we want to start minimizing. If set 884 to None, then we'll start with the first node of the model. 885 end: The name of the node where we want to terminate minimizing. If 886 set to None, we'll end with the last node of the model. 887 skip_nodes: The names of nodes where we want to skip during minimizing. 888 It'll create subgraphs without these skip nodes under the hood. 889 Only applicable in mode "skip". 890 find_last_node: True if only last_node of a culprits is needed in mode "block". 891 False if only the first_node of a culprits is needed. 892 Only applicable in mode "block". 893 894 Returns: 895 nodes: A list of nodes that causes FxNetMinimizerRunFuncError or 896 FxNetMinimizerResultMismatchError errors during minimizing. 897 """ 898 899 print(self.settings) 900 print(self.module.graph) 901 902 nodes = self._collect_nodes(start, end) 903 904 if self.settings.traverse_method == "sequential": 905 return self._sequential_traverse(nodes) 906 907 if self.settings.traverse_method == "binary": 908 return self._binary_traverse(nodes) 909 910 if self.settings.traverse_method == "accumulate": 911 return self._accumulate_traverse(nodes) 912 913 if self.settings.traverse_method == "skip": 914 if (skip_nodes is None): 915 raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") 916 return self._skip_traverse(nodes, skip_nodes) 917 918 if self.settings.traverse_method == "defined": 919 return self._defined_traverse(nodes) 920 921 if self.settings.traverse_method == "block": 922 return self._block_traverse(nodes, find_last_node) 923 924 raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!") 925