xref: /aosp_15_r20/external/pytorch/torch/fx/passes/net_min_base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3from dataclasses import dataclass
4from typing import Any, Callable, Dict, List, Optional, Tuple
5
6import torch
7import torch.fx
8
9from torch.fx._compatibility import compatibility
10from torch.fx.node import map_arg
11
12from .shape_prop import ShapeProp
13from .split_utils import split_by_tags
14from .tools_common import (
15    CALLABLE_NODE_OPS,
16    FxNetAccFusionsFinder,
17    Names,
18    NodeList,
19    NodeSet,
20    TensorOrTensors,
21    Tensors,
22)
23
24__all__ = [
25    "FxNetMinimizerBadModuleError",
26    "FxNetMinimizerRunFuncError",
27    "FxNetMinimizerResultMismatchError",
28]
29
30_LOGGER = logging.getLogger(__name__)
31
32
33@compatibility(is_backward_compatible=False)
34class FxNetMinimizerBadModuleError(Exception):
35    """
36    Raised if failed to split out a minimize module
37    """
38
39
40
41@compatibility(is_backward_compatible=False)
42class FxNetMinimizerRunFuncError(Exception):
43    """
44    Raised if error occurs during run_a or run_b functions
45    """
46
47
48
49@compatibility(is_backward_compatible=False)
50class FxNetMinimizerResultMismatchError(Exception):
51    """
52    Raised if comparing function thinks the results are mismatching.
53    """
54
55
56
57@dataclass
58class _MinimizerSettingBase:
59    """
60    Args:
61    `accumulate_error`: Instead of using a's input for both converted module to verify
62    , use the previous outputs of each converted module as input to accumulate the
63    errors.
64
65    `traverse_method`: "sequential" or "binary" or "accumulate"
66    Determine the way of traverse the nodes in FX module.
67
68    `find_all`: Minimizer will go through the entire model and return all problematic nodes.
69
70    `return_intermediate`: If true, when using `run_nodes()` function to run the
71    model, intermediate results of all the ops will be returned as output.
72    """
73
74    accumulate_error: bool = False
75    traverse_method: str = "sequential"
76    find_all: bool = False
77    return_intermediate: bool = False
78
79    def __str__(self):
80        settings_str = "FX Minimizer Settings:\n"
81
82        for k, v in vars(self).items():
83            settings_str += f"\t{k}: {v}\n"
84
85        return settings_str
86
87
88class _MinimizerBase:
89    """
90    This class is used to automatically find problematic nodes in a model. It takes a FX
91    graphmodule and generate some submodules while traverse the graph. Then two functions
92    `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
93    will be used to compare the results.
94
95    Currently we provides two ways to traverse the graph and generate submodules.
96        1. Sequential traversal: this will traverse the graph node by node and generate
97           one submodule with one sigle node.
98        2. Binary searching: this will do a binary search style traversal on the graph.
99
100    For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
101    """
102
103    def __init__(
104        self,
105        module: torch.fx.GraphModule,
106        sample_input: Tensors,
107        compare_fn: Callable[
108            [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
109        ],
110        settings: _MinimizerSettingBase,
111        module_exporter: Optional[
112            Callable[
113                [Tensors, torch.fx.GraphModule, str],
114                None
115            ]
116        ] = None,
117        exclusion_fn: Optional[
118            Callable[[NodeList, int, int], None]
119        ] = None,
120    ):
121        assert isinstance(module, torch.fx.GraphModule)
122
123        self.module = module
124        self.sample_input = sample_input
125        self.compare_fn = compare_fn
126        self.module_exporter = module_exporter
127        self.settings = settings
128        self.exclusion_fn = exclusion_fn
129
130        # Stores outputs of run_a function
131        self.a_outputs: Dict[str, Any] = {}
132
133        # Stores outputs of run_b function
134        self.b_outputs: Dict[str, Any] = {}
135
136        # Stores the results of compare_fn
137        self.results: Dict[Any, Any] = {}
138
139        # Stores the report for the runs
140        self.reports: List[List[str]] = []
141
142        # Current iteration
143        self.iteration: int = 0
144
145        callable_nodes = {
146            node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
147        }
148        ShapeProp(self.module).propagate(*self.sample_input)
149        self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
150
151        # Check if number of input in sample_input matches the number of placeholders
152        placeholders = [
153            node.name for node in self.module.graph.nodes if node.op == "placeholder"
154        ]
155        assert len(placeholders) == len(self.sample_input)
156
157        # Store sample_input
158        for i, name in enumerate(placeholders):
159            self.a_outputs[name] = sample_input[i]
160            self.b_outputs[name] = sample_input[i]
161
162    def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
163        """
164        Run `mod` with `inputs` and generate output. The output will be compared with
165        output of run_b().
166        """
167        raise RuntimeError("run_a() is not implemented.")
168
169    def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
170        """
171        Run `mod` with `inputs` and generate output. The output will be compared with
172        output of run_a().
173        """
174        raise RuntimeError("run_b() is not implemented.")
175
176    def _store_outputs(
177        self,
178        a_result: TensorOrTensors,
179        b_result: TensorOrTensors,
180        submodule: torch.fx.GraphModule,
181    ):
182        """
183        Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
184        self.b_outputs, so that we can use them when execute preceding nodes that
185        use those outputs as inputs.
186
187        Args:
188            a_result: Output of self.run_a(). Could be a tensor or tensors.
189            b_result: Output of self.run_b(). Could be a tensor or tensors.
190            submodule: The module that generates a_result and b_result.
191        """
192        output_node = next(
193            node for node in submodule.graph.nodes if node.op == "output"
194        )
195
196        # Only one output
197        if isinstance(output_node.args[0], torch.fx.Node):
198            self.a_outputs[output_node.args[0].name] = a_result
199            self.b_outputs[output_node.args[0].name] = b_result
200        # Multiple outputs
201        else:
202            for i, arg in enumerate(output_node.args[0]):
203                self.a_outputs[arg.name] = a_result[i]
204                self.b_outputs[arg.name] = b_result[i]
205
206    def _get_submod_inputs(
207        self, main_module: torch.fx.GraphModule, submod_path: str
208    ) -> Tuple[Tensors, Tensors]:
209        """
210        Try get submodule inputs from stored outputs. If not found then use
211        torch_glow.get_submod_inputs to get the inputs.
212
213        If accumulate_error is False, use a_input for run_a() and run_b()
214        otherwise use a_input for run_a and b_input for run_b.
215
216        Args:
217            main_module: Top-levlel fx module.
218            submod_path: Path to the submodule we want to run and compare results.
219
220        Returns:
221            a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
222            b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
223        """
224        a_input = []
225        b_input = []
226        submodule = getattr(main_module, submod_path)
227        placeholders = [
228            node.name for node in submodule.graph.nodes if node.op == "placeholder"
229        ]
230
231        # If all placeholder can be found in stored outputs, use stored
232        # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
233        # to get the inputs.
234        if set(placeholders) <= self.a_outputs.keys():
235            for name in placeholders:
236                a_input.append(self.a_outputs[name])
237                b_input.append(self.b_outputs[name])
238        else:
239            if self.settings.accumulate_error:
240                print(f"Can't find previous stored outputs named {placeholders}!")
241
242            def get_inputs(self: torch.nn.Module, inputs: Any):
243                nonlocal a_input
244                a_input = inputs
245
246            # Use forward hook to get the inputs to the submodule
247            handle = submodule.register_forward_pre_hook(get_inputs)
248            main_module(*self.sample_input)
249            handle.remove()
250
251            b_input = a_input
252
253        if not self.settings.accumulate_error:
254            return a_input, a_input
255
256        return a_input, b_input
257
258    def _tag_nodes(self, selected_nodes: NodeSet):
259        """
260        Tag selected nodes with tag "minimize". Nodes with the same tags will
261        be split to the same submodule afterwards.
262
263        Args:
264            selected_nodes: Nodes that we want to minimize. We will tag those nodes
265                with "minimize", all preceding nodes with "main_0" and all following
266                nodes with "main_1".
267        """
268        for node in self.module.graph.nodes:
269            if node.op not in CALLABLE_NODE_OPS:
270                continue
271
272            if node in selected_nodes:
273                node.tag = "minimize"
274            elif any(
275                n.tag in {"minimize", "main_1"}
276                for n in node.all_input_nodes
277                if n.op in CALLABLE_NODE_OPS
278            ):
279                node.tag = "main_1"
280            else:
281                node.tag = "main_0"
282
283    def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
284        """
285        Split self.module so that one submodule consists of `nodes` and only `nodes`.
286
287        Args:
288            nodes: Nodes that we want to include in the minimize submodule.
289
290        Returns:
291            split_module (torch.fx.GraphModule): the module after split.
292            submodule_name (str): the name of the submodule that consists of `nodes`.
293        """
294        # Color provided nodes
295        self._tag_nodes(nodes)
296
297        # Split module based on coloring
298        split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
299
300        # Find submodule containing colored nodes
301        submodule_name: str = ""
302        for child_name, _ in split_module.named_children():  # type: ignore[union-attr]
303            # Skip submodules we're not interested in at the moment
304            if "minimize" not in child_name:
305                continue
306
307            if submodule_name == "":
308                submodule_name = child_name
309            else:
310                raise FxNetMinimizerBadModuleError(
311                    f"Expected only one minimize submodule with nodes {nodes}"
312                )
313
314        if submodule_name == "":
315            raise FxNetMinimizerBadModuleError(
316                f"Minimize submodule was not found with nodes {nodes}"
317            )
318
319        return split_module, submodule_name  # type: ignore[return-value]
320
321    def _run_and_compare(
322        self,
323        split_module: torch.fx.GraphModule,
324        submod_name: str,
325        output_names: Names,
326        report_idx: int = -1
327    ):
328        """
329        Run the submodule in `split_module` that has name `submod_name`
330        using `self.run_a` and `self.run_b` and compare their results.
331
332        Args:
333            split_module: Main module that contains the minimize submodule.
334            submod_name: Name of the minimize submodule.
335            output_names: Names of the node we want to output. If None, we
336                will use the original output.
337        """
338        submodule = getattr(split_module, submod_name)
339        a_input, b_input = self._get_submod_inputs(split_module, submod_name)
340
341        if len(self.reports) == 0:
342            self.reports.append([])
343            self.iteration = 1
344
345        report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1]
346        report.append("Run and compare ...")
347
348        if output_names:
349            output_nodes: NodeList = []
350            for node in submodule.graph.nodes:
351                if node.op == "output":
352                    submodule.graph.erase_node(node)
353
354                if node.name in output_names:
355                    output_nodes.append(node)
356
357            submodule.graph.output(
358                output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
359            )
360            submodule.graph.lint()
361            submodule.recompile()
362
363        # Use name of args in output node as key to store comparison result
364        for node in submodule.graph.nodes:
365            if node.op == "output":
366                result_key = map_arg(node.args, lambda x: x.name)
367
368        try:
369            a_result = self.run_a(submodule, a_input, report_idx)
370            b_result = self.run_b(submodule, b_input, report_idx)
371            self._store_outputs(a_result, b_result, submodule)
372        except Exception as e:
373            report.append(f"Exception raised when running {submod_name}: {e}")
374            raise FxNetMinimizerRunFuncError(  # noqa: B904
375                f"Exception raised when running {submod_name}: {e}"
376            )
377
378        # Compare results
379        names: Names = output_names
380        if output_names is None:
381            names = [str(v) for v in result_key]  # type: ignore[possibly-undefined]
382
383        numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
384
385        self.results[result_key] = numeric_result  # type: ignore[possibly-undefined]
386        report.append(f"Numerical accuracy = {numeric_result}")
387        if not bool_result:
388            report.append(f"Result mismatch for {result_key}")
389            if self.module_exporter:
390                self.module_exporter(
391                    a_input, submodule, str(result_key[0]) + "_cpu",  # type: ignore[index]
392                )
393                self.module_exporter(
394                    b_input, submodule, str(result_key[0]) + "_acc",  # type: ignore[index]
395                )
396            raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
397
398    def _binary_search_impl(
399        self, all_nodes: NodeList, start_idx: int, end_idx: int
400    ) -> NodeSet:
401        """
402        Recursive binary search implementation.
403        """
404        culprits: NodeSet = set()
405        nodes: NodeList = all_nodes[start_idx:end_idx]
406
407        report: List[str] = []
408        if self.exclusion_fn is not None:
409            self.exclusion_fn(nodes, start_idx, end_idx)
410            if len(nodes) == 0:
411                report = ["All nodes are excluded by user"]
412                self.reports.append(report)
413                return culprits
414
415        first_node_name = nodes[0].name
416        output_node_name = nodes[-1].name
417        self.iteration += 1
418        self.reports.append(report)
419        report.append(f"Binary search iteration {self.iteration}")
420        report.append(
421            f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. "
422            f"Size of the interested node list is {len(nodes)}"
423        )
424        cur_nodes: NodeSet = set(nodes)
425
426        try:
427            split_module, submod_name = self._build_submodule(cur_nodes)
428            self._run_and_compare(split_module, submod_name, [output_node_name])
429
430        except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
431
432            if len(nodes) == 1:
433                report.append(
434                    f"This is the last node in the sub-module. "
435                    f"Search in the current branch is successful with culprit = {cur_nodes}."
436                )
437                self.print_report(report)
438                return cur_nodes
439
440            report.append(
441                "Proceed to split and lower the halves of the current "
442                "sub-module individually."
443            )
444            self.print_report(report)
445
446            mid = len(nodes) // 2
447            culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
448
449            if len(culprits) != 0 and not self.settings.find_all:
450                return culprits
451
452            culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
453
454            if len(culprits) == 0:
455                report.append(
456                    f"Further split and lowering found no errors. "
457                    f"Unable to minimize the submodule with list of nodes: {nodes}"
458                )
459                self.print_report(report)
460
461            return culprits
462        else:
463            report.append("No discrepancy found.")
464            self.print_report(report)
465            return set()
466
467    def _binary_traverse(self, nodes: NodeList) -> NodeSet:
468        """
469        Binary search on `nodes` for culprit.
470        """
471        return self._binary_search_impl(nodes, 0, len(nodes))
472
473    def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
474        """
475        Traverse `nodes` one by one and determine if any of them is a culprit.
476        """
477        culprits: NodeSet = set()
478
479        for node in nodes:
480            report: List[str] = []
481            self.reports.append(report)
482            self.iteration += 1
483            report.append(f"Sequential traverse iteration {self.iteration}.")
484            report.append(f"Visit node: {node.name}")
485
486            _LOGGER.info("Visit node: %s", node.name)
487            node_list: NodeList = [node]
488            if self.exclusion_fn is not None:
489                self.exclusion_fn(node_list, -1, -1)
490                if len(node_list) == 0:
491                    report.append(f"User exclusion : {node.name}")
492                    self.print_report(report)
493                    if not self.settings.find_all:
494                        return culprits
495                    else:
496                        continue
497
498            cur_nodes: NodeSet = {node}
499
500            if node in self.fusions:
501                cur_nodes = self.fusions[node]
502
503            try:
504                split_module, submod_name = self._build_submodule(cur_nodes)
505                self._run_and_compare(split_module, submod_name, [node.name])
506                self.print_report(report)
507            except (FxNetMinimizerResultMismatchError):
508                culprits.add(node)
509                report.append(f"Found culprit from numeric error: {node}")
510                self.print_report(report)
511                if not self.settings.find_all:
512                    return culprits
513            except (FxNetMinimizerRunFuncError):
514                culprits.update(cur_nodes)
515                report.append(f"Found culprit from run error: {node}")
516                self.print_report(report)
517                if not self.settings.find_all:
518                    return culprits
519
520        return culprits
521
522
523    def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int:
524        """
525        Recursive block search implementation.
526        find_last_node: If True, search for the last node which result in numerics difference
527        if False: find first node in sorted node list
528        """
529        report: List[str] = []
530
531        mid = (start_idx + end_idx) // 2
532        cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:]
533
534        if self.exclusion_fn:
535            self.exclusion_fn(cur_nodes_list, -1, -1)
536
537        cur_nodes = set(cur_nodes_list)
538
539        first_node_name = cur_nodes_list[0].name
540        last_node_name = cur_nodes_list[-1].name
541        target_node_name = last_node_name if find_last_node else first_node_name
542
543        self.iteration += 1
544        self.reports.append(report)
545        report.extend(
546            [
547                "=" * 30,
548                f"Block search iteration {self.iteration}",
549            ]
550        )
551        report.extend(
552            [
553                f"Search for {'last' if find_last_node else 'first'} node in culprits",
554                f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ",
555                f"Subgraph constructed by {first_node_name} to {last_node_name}",
556                f"Targeting node: {target_node_name}",
557                f"Size of the interested node list is {end_idx - start_idx + 1}",
558            ]
559        )
560        report_idx = len(self.reports) - 1
561
562        try:
563            split_module, submod_name = self._build_submodule(cur_nodes)
564            self._run_and_compare(split_module, submod_name, [last_node_name], report_idx)
565        except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
566            report.append(f"Culprits found from node {first_node_name} to {last_node_name}.")
567
568            if start_idx == mid:
569                report.extend(
570                    [
571                        "This is the last node in the sub-module. ",
572                        "Search in the current branch is successful with node :",
573                        f"{start_idx}, node name: {nodes[start_idx].name}."
574                    ]
575                )
576                self.print_report(report)
577                return start_idx
578
579            report.append(
580                "Proceed to split and lower the halves of the current "
581                "sub-module individually."
582            )
583            self.print_report(report)
584
585            if find_last_node:
586                return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
587            else:
588                return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
589        else:
590            report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.")
591
592            if start_idx == mid:
593                report.extend(
594                    [
595                        "This is the last node in the sub-module. ",
596                        "Search in the current branch is successful with node",
597                        f"{start_idx}, node name: {nodes[start_idx].name}.",
598                    ]
599                )
600                self.print_report(report)
601                return start_idx + 1 if find_last_node else start_idx - 1
602
603            report.append(
604                "Proceed to split and lower the halves of the current "
605                "sub-module individually."
606            )
607            self.print_report(report)
608
609            if find_last_node:
610                return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
611            else:
612                return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
613
614
615    def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet:
616        """
617        Traverse topologically sorted node list
618        Find minimium block (start_idx, end_idx) which contains the culprit
619        1st pass: search for end_idx by finding the last node in culprit block
620        where Numerical accuracy (0, end_idx) > threshold
621        2nd pass: search for start_idx by finding the first node in culprit block
622        where Numerical accuracy (start_idx, end_idx) < threshold
623        Form minimum block by (start_idx - 1, end_idx)
624        """
625        culprits: NodeSet = set()
626        first_node_name = nodes[0].name
627        last_node_name = nodes[-1].name
628        last_node_report = [f"Block search from {first_node_name} to {last_node_name}"]
629        last_node_report.append("*" * 50)
630        self.reports.append(last_node_report)
631
632        start_idx = 0
633        end_idx = len(nodes) - 1
634        run_both = True if find_last_node is None else False
635
636        # step 1: find (0, end_idx) of culprit block
637        if run_both or find_last_node:
638            last_node_report.append("Start searching for last node in culprit")
639            self.print_report(last_node_report)
640            end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
641            last_node_report.extend(
642                [
643                    "Finish Pass 1",
644                    f"Find end_idx = {end_idx}:{nodes[end_idx].name}"
645                ]
646            )
647            self.print_report(last_node_report)
648
649        # step 2: reduce culprit block to (start_idx, end_idx)
650        if run_both or not find_last_node:
651            first_node_report = ["Start searching for first node in culprit"]
652            self.print_report(first_node_report)
653            start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False)
654            first_node_report.append("*" * 50)
655            self.reports.append(first_node_report)
656            first_node_report.extend(
657                [
658                    "Finish Pass 2",
659                    f"Find start_idx = {start_idx}:{nodes[start_idx].name}"
660                ]
661            )
662            self.print_report(first_node_report)
663
664        # step 3: form module with minimum culprits
665        culprits.update(nodes[start_idx:end_idx + 1])
666        result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"]
667        self.reports.append(result_report)
668        self.print_report(result_report)
669        return culprits
670
671
672    def _defined_traverse(self, nodes: NodeList) -> NodeSet:
673        """
674        run user defined `nodes` and determine if it is a culprit.
675        """
676        culprits: NodeSet = set()
677        if self.exclusion_fn is not None:
678            self.exclusion_fn(nodes, -1, -1)
679        if len(nodes) == 0:
680            report = ["All nodes are excluded by user"]
681            self.reports.append(report)
682            return culprits
683
684        first_node_name = nodes[0].name
685        output_node_name = nodes[-1].name
686        report = [f"Defined graph from {first_node_name} to {output_node_name}"]
687        cur_nodes: NodeSet = set(nodes)
688        try:
689            split_module, submod_name = self._build_submodule(cur_nodes)
690            self._run_and_compare(split_module, submod_name, [output_node_name])
691            self.print_report(report)
692        except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
693            report.append(f"Found culprit {cur_nodes}")
694            self.print_report(report)
695            return culprits
696
697        return culprits
698
699    def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
700        culprits: NodeSet = set()
701        nodes_to_run: NodeSet = set()
702
703        # find_all is not supported for accumulate traversal because all the
704        # ops run on NNPI. So we return after the first op that raises error.
705        if self.settings.find_all:
706            print("'Find All' mode is not supported in accumulate traversal.")
707            return culprits
708
709        for node in nodes:
710            report: List[str] = []
711            self.reports.append(report)
712            self.iteration += 1
713            report.append(f"Accumulate traverse iteration {self.iteration}.")
714
715            nodes_to_run.add(node)
716
717            node_name = node.name
718            if node_name is not None and isinstance(node_name, tuple):
719                node_name = node_name[0]
720            assert node_name is not None and isinstance(
721                node_name, str
722            ), f"minimize: node_name: {node_name}"
723
724            report.append(f"Add node: {node_name}")
725
726            try:
727                split_module, submod_name = self._build_submodule(nodes_to_run)
728                self._run_and_compare(split_module, submod_name, [node_name])
729                self.print_report(report)
730            except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
731                culprits.add(node)
732                report.append(f"Found culprit {node}")
733                self.print_report(report)
734                return culprits
735
736        return culprits
737
738    def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
739        """
740        Skip certain nodes in graph based on settings
741        """
742        culprits: NodeSet = set()
743        nodes: NodeList = all_nodes[start_idx:end_idx]
744        cur_nodes: NodeSet = set(nodes)
745        if self.exclusion_fn is not None:
746            self.exclusion_fn(nodes, start_idx, end_idx)
747            cur_nodes = set(nodes)
748        else:
749            for node in nodes:
750                if node in self.fusions:
751                    cur_nodes.update(self.fusions[node])
752        report: List[str] = []
753        self.reports.append(report)
754        self.iteration += 1
755        report.append(f" Nodes block {self.iteration}.")
756        report.append(
757            f"From node index {start_idx} to {end_idx-1}. "
758            f"Size of the interested node list is {len(nodes)}"
759        )
760
761        try:
762            split_module, submod_name = self._build_submodule(cur_nodes)
763            self._run_and_compare(split_module, submod_name, [])
764        except (FxNetMinimizerResultMismatchError):
765            culprits.update(cur_nodes)
766            report.append(f"Found culprit from numeric error: {cur_nodes}")
767            self.print_report(report)
768            return culprits
769        except (FxNetMinimizerRunFuncError):
770            culprits.update(cur_nodes)
771            report.append(f"Found culprit from run error: {cur_nodes}")
772            self.print_report(report)
773            return culprits
774        else:
775            report.append("No discrepancy found.")
776            self.print_report(report)
777            return set()
778
779
780    def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
781        """
782        Skip certain nodes in graph based on settings
783        """
784        start_idx = 0
785        num_nodes = len(all_nodes)
786        idx = 0
787        culprits = set()
788        while idx < num_nodes:
789            node = all_nodes[idx]
790            if (node.name in skip_nodes):  # skip the node
791                if idx > start_idx:
792                    culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
793                start_idx = idx + 1
794            elif idx == num_nodes - 1 and start_idx <= idx:  # last node
795                culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
796            idx += 1
797
798        return culprits
799
800
801
802    def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
803        """
804        Collect nodes in the model that between nodes with name of `start` and `end`.
805        These two nodes are also included.
806        """
807        nodes: NodeList = []
808        add_node = start is None
809
810        for node in self.module.graph.nodes:
811            if node.op not in CALLABLE_NODE_OPS:
812                continue
813
814            if node.name == start:
815                add_node = True
816
817            if add_node:
818                nodes.append(node)
819
820            if node.name == end:
821                break
822
823        return nodes
824
825    def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
826        """
827        Run part of the model from `start` node to `end` node. If `start` is None
828        then we start from the beginning of the model. If `end` is None then we
829        stop at the end of the model.
830
831        Args:
832            start: The name of the node which is the first node of the submodule
833                we want to run. If set to None, then we'll start with the first
834                node of the model.
835            end: The name of the node which is the last node of the submodule we
836                want to run. If set to None, we'll end with the last node of the
837                model.
838        """
839        nodes = self._collect_nodes(start, end)
840        cur_nodes = set(nodes)
841
842        for node in nodes:
843            if node in self.fusions:
844                cur_nodes.update(self.fusions[node])
845
846        output_names = []
847        if self.settings.return_intermediate:
848            output_names = [node.name for node in nodes]
849
850        try:
851            split_module, submod_name = self._build_submodule(cur_nodes)
852            self._run_and_compare(split_module, submod_name, output_names)
853        except (
854            FxNetMinimizerRunFuncError,
855            FxNetMinimizerResultMismatchError,
856        ) as e:
857            print(e)
858
859    def print_report(self, report: List[str]):
860        for i in range(len(report)):
861            if i > 0:
862                print(" . " + report[i])
863            else:
864                print(report[i])
865
866    def print_reports(self):
867        for report in self.reports:
868            self.print_report(report)
869
870    def minimize(
871        self,
872        start: Optional[str] = None,
873        end: Optional[str] = None,
874        skip_nodes: Optional[List] = None,
875        find_last_node: Optional[bool] = None,
876    ) -> NodeSet:
877        """
878        Minimizing the model from node with name `start` to node with name `end` base
879        on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
880        FxNetMinimizerResultMismatchError errors.
881
882        Args:
883            start: The name of the node where we want to start minimizing. If set
884                to None, then we'll start with the first node of the model.
885            end: The name of the node where we want to terminate minimizing. If
886                set to None, we'll end with the last node of the model.
887            skip_nodes: The names of nodes where we want to skip during minimizing.
888                It'll create subgraphs without these skip nodes under the hood.
889                Only applicable in mode "skip".
890            find_last_node: True if only last_node of a culprits is needed in mode "block".
891                False if only the first_node of a culprits is needed.
892                Only applicable in mode "block".
893
894        Returns:
895            nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
896                FxNetMinimizerResultMismatchError errors during minimizing.
897        """
898
899        print(self.settings)
900        print(self.module.graph)
901
902        nodes = self._collect_nodes(start, end)
903
904        if self.settings.traverse_method == "sequential":
905            return self._sequential_traverse(nodes)
906
907        if self.settings.traverse_method == "binary":
908            return self._binary_traverse(nodes)
909
910        if self.settings.traverse_method == "accumulate":
911            return self._accumulate_traverse(nodes)
912
913        if self.settings.traverse_method == "skip":
914            if (skip_nodes is None):
915                raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
916            return self._skip_traverse(nodes, skip_nodes)
917
918        if self.settings.traverse_method == "defined":
919            return self._defined_traverse(nodes)
920
921        if self.settings.traverse_method == "block":
922            return self._block_traverse(nodes, find_last_node)
923
924        raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")
925