xref: /aosp_15_r20/external/executorch/devtools/debug_format/et_schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9"""
10Intermediate Representation of ExecuTorch Concepts in Developer Tools
11"""
12
13from __future__ import annotations
14
15import operator
16import warnings
17
18from collections import defaultdict
19
20from enum import Enum
21from types import NoneType
22from typing import Any, Dict, List, Optional, Set, Tuple
23
24import torch
25from executorch import exir
26from executorch.devtools.debug_format.base_schema import (
27    Node,
28    OperatorGraph,
29    OperatorNode,
30    ValueNode,
31)
32from torch._subclasses import FakeTensor
33
34
35# Keywords used in debug_format Metadata
36class RESERVED_METADATA_ARG(Enum):
37    DEBUG_HANDLE = "debug_handle"
38    MODULE_STACK = "nn_module_stack"
39    SOURCE_FN_STACK = "source_fn_stack"
40    MODULE_TYPE = "module_type"
41    PROFILE_START_TIME = "profile_start_time"
42    PROFILE_END_TIME = "profile_end_time"
43    LOAD_START_TIME = "load_start_time"
44    LOAD_END_TIME = "load_end_time"
45    MEMORY_USAGE = "memory_usage"
46    DEBUG_ENTRY = "debug_entry"
47    STACK_TRACE = "stack_trace"
48    DEBUG_DATA = "debug_data"
49
50    METRICS_KEYWORD = "metrics"
51    PROFILE_SUMMARY_COLDSTART = "Coldstart"
52    PROFILE_SUMMARY_AVERAGE = "Average"
53    PROFILE_SUMMARY_P90 = "P90"
54    PROFILE_SUMMARY_P10 = "P10"
55    PROFILE_SUMMARY_MIN = "Min"
56    PROFILE_SUMMARY_MAX = "Max"
57
58    AGGREGATED_OP_TABLE = "Aggregated op stats"
59    RUN_SUMMARY_INDIVIDUAL_RUNS_TABLE = "Run summary individual stats"
60    RUN_SUMMARY_TABLE = "Aggregated run summary stats"
61    OP_INSTANCE_SUMMARY_TABLE = "Individual op stats"
62
63    TABLES_KEYWORD = "tables"
64    KV_KEYWORD = "kv"
65    MODEL_LOAD_TIME_KEY = "Model load time (ms)"
66
67
68# Representation of an FX GraphModule as an OperatorGraph
69class FXOperatorGraph(OperatorGraph):
70    @staticmethod
71    def _get_node_name(node: torch.fx.Node) -> str:
72        if node.target == operator.getitem:
73            # pyre-ignore[9]: Incompatible variable type
74            node = node.args[0]
75            assert isinstance(
76                node, torch.fx.Node
77            ), f"First argument of getitem must be a torch fx node. Got {node.args[0]}"
78
79        # Adding the "_" to the node name prevents TensorBoard from collapsing
80        # nodes with similar names that only differ by an integer at the end of
81        # their name.
82        return node.name + "_"
83
84    @staticmethod
85    def _get_op_name(node: torch.fx.Node) -> str:
86        # pyre-ignore[16]: has no attribute `__name__`.
87        return node.target.__name__
88
89    # Given a node and its metadata (if containing module stack), update the provided module mappings
90    @staticmethod
91    def _update_module_mapping(
92        node: Node,
93        module_mapping: Dict[Tuple[str, str], List[Node]],
94        metadata: Dict[str, Any],
95    ):
96        if (
97            source_fn_stack := metadata.get("source_fn_stack")
98        ) is not None and "nn_module_stack" in metadata:
99            # (module name, module type)
100            source_fn = source_fn_stack[-1]
101            module_type = (
102                source_fn[1] if isinstance(source_fn[1], str) else source_fn[1].__name__
103            )
104            module_mapping[(source_fn[0], module_type)].append(node)
105
106    @staticmethod
107    def _parse_args(  # noqa: C901
108        node: torch.fx.Node,
109        nodes: Dict[str, Node],
110        const_count: int,
111        module_mapping: Dict[Tuple[str, str], List[Node]],
112        enable_module_hierarchy: bool,
113    ) -> Tuple[List[Node], int]:
114        inputs = []
115        op = node.op
116        name = node.name
117        args = node.args
118        kwargs = node.kwargs
119        named_args = None
120        if node.op == "call_function" and hasattr(node.target, "_schema"):
121            # pyre-ignore
122            named_args = node.target._schema.arguments
123
124        for index, arg in enumerate(args):
125            if isinstance(arg, torch.fx.node.Node):
126                if arg.target == exir.memory.alloc:
127                    continue
128                arg_name = FXOperatorGraph._get_node_name(arg)
129            elif isinstance(arg, (int, float, torch.dtype, str)):
130                # e.g. The "0" from node.args of squeeze_copy (mm_default, 0)
131                if named_args and len(named_args) > index:
132                    arg_name = named_args[index].name + "_" + str(const_count)
133                else:
134                    arg_name = "CONST_" + str(const_count)
135                const_count += 1
136                const_node = ValueNode(arg_name, val=str(arg))
137                nodes[arg_name] = const_node
138                if enable_module_hierarchy:
139                    FXOperatorGraph._update_module_mapping(
140                        const_node, module_mapping, node.meta
141                    )
142            elif isinstance(arg, list):
143                arg_name: List[str] = []
144                for list_arg in arg:
145                    if isinstance(list_arg, (int, float)):
146                        # Consider the whole list of ints/floats as a single constant and
147                        # stringify that.
148                        if named_args and len(named_args) > index:
149                            arg_name = [named_args[index].name + "_" + str(const_count)]
150                        else:
151                            arg_name = ["CONST_" + str(const_count)]
152                        const_count += 1
153                        const_node = ValueNode(arg_name[0], val=arg)
154                        nodes[arg_name[0]] = const_node
155                        if enable_module_hierarchy:
156                            FXOperatorGraph._update_module_mapping(
157                                const_node, module_mapping, node.meta
158                            )
159                        break
160                    elif isinstance(list_arg, torch.fx.node.Node):
161                        arg_name += [FXOperatorGraph._get_node_name(list_arg)]
162                    elif list_arg is None:
163                        arg_name += ["CONST_NONE" + str(const_count)]
164                        const_count += 1
165                        const_node = ValueNode(arg_name[-1], val=str(arg))
166                        nodes[arg_name[-1]] = const_node
167                        if enable_module_hierarchy:
168                            FXOperatorGraph._update_module_mapping(
169                                const_node, module_mapping, node.meta
170                            )
171                    else:
172                        raise Exception(
173                            f"Unsupported argument encountered in list {arg}, {type(arg[0])}"
174                        )
175            elif isinstance(arg, NoneType):
176                continue
177            else:
178                raise Exception(
179                    f"Unsupported argument encountered {op}, {name}, {arg}, {type(arg)}"
180                )
181
182            if isinstance(arg_name, list):
183                for val in arg_name:
184                    inputs.append(nodes[val])
185            else:
186                inputs.append(nodes[arg_name])
187        for _, node in kwargs.items():
188            # We can ignore the out kwarg as that's mostly used to pass in the output tensor
189            # which has been memory planned. The same value is also returned by the operator
190            # which is then consumed by other nodes in the graph.
191            if (
192                isinstance(node, torch.fx.node.Node)
193                and node.target == exir.memory.alloc
194            ):
195                continue
196            else:
197                warnings.warn(
198                    f"Unsupported kwarg encountered: {name}, {kwargs}", stacklevel=1
199                )
200
201        return inputs, const_count
202
203    # Given an FX GraphModule, parse it into an OperatorGraph
204    @staticmethod
205    def gen_operator_graph(
206        model: torch.fx.GraphModule,
207        skip_stack_trace: Optional[bool] = False,
208        enable_module_hierarchy: bool = False,
209    ) -> FXOperatorGraph:
210        graph: torch.fx.Graph = model.graph
211
212        nodes = {}
213        input_nodes = {}
214        output_nodes = {}
215        out_variant_output_nodes = set()
216        module_mapping = defaultdict(list)
217
218        const_count = 0
219        for fx_node in graph.nodes:
220            if (
221                fx_node.target == exir.memory.alloc
222                or fx_node.target == operator.getitem
223            ):
224                continue
225            op = fx_node.op
226            name = FXOperatorGraph._get_node_name(fx_node)
227            target = fx_node.target
228            args = fx_node.args
229            kwargs = fx_node.kwargs
230            metadata = FXOperatorGraph._extract_metadata(fx_node.meta, skip_stack_trace)
231            output_shapes = FXOperatorGraph._extract_output_shapes(
232                fx_node.meta.get("val")
233            )
234            dtype = FXOperatorGraph._extract_output_dtype(fx_node.meta.get("val")) or ""
235
236            assert (
237                op != "call_module"
238            ), f"Module call not yet supported in edge model graph [.toEdge()]: {name}, {str(target)}"
239            assert (
240                op != "call_method"
241            ), f"Call Method not yet supported in edge model graph [.toEdge()]: {name}, {str(target)}"
242
243            # Input
244            if op == "placeholder":
245                node = ValueNode(
246                    name,
247                    output_shapes=output_shapes,
248                    metadata=metadata,
249                    dtype=str(dtype),
250                )  # val is default arg
251                input_nodes[name] = node
252            # Constants
253            elif op == "get_attr":
254                node = ValueNode(
255                    name,
256                    output_shapes=output_shapes,
257                    metadata=metadata,
258                    dtype=str(dtype),
259                )
260            # Output
261            elif op == "output":
262                assert len(args) == 1
263                # Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], )
264                in_nodes = [
265                    (
266                        nodes[FXOperatorGraph._get_node_name(ret)]
267                        if ret is not None
268                        else []
269                    )
270                    for ret in args[0]
271                ]
272                node = ValueNode(
273                    name,
274                    inputs=in_nodes,
275                    output_shapes=output_shapes,
276                    metadata=metadata,
277                    dtype=str(dtype),
278                )
279                output_nodes[name] = node
280            # Op Calls
281            elif op == "call_function":
282                inputs, const_count = FXOperatorGraph._parse_args(
283                    fx_node, nodes, const_count, module_mapping, enable_module_hierarchy
284                )
285                named_args = []
286                if fx_node.op == "call_function" and hasattr(fx_node.target, "_schema"):
287                    named_args = [arg.name for arg in fx_node.target._schema.arguments]
288                node = OperatorNode(
289                    name,
290                    inputs=inputs,
291                    output_shapes=output_shapes,
292                    metadata=metadata,
293                    op=FXOperatorGraph._get_op_name(fx_node),
294                    named_args=named_args,
295                )
296                if enable_module_hierarchy:
297                    FXOperatorGraph._update_module_mapping(
298                        node, module_mapping, fx_node.meta
299                    )
300
301                for kwarg_name, kwarg in kwargs.items():
302                    if (
303                        isinstance(kwarg, torch.fx.node.Node)
304                        and kwarg.target == exir.memory.alloc
305                        and kwarg_name == "out"
306                    ):
307                        nodes[FXOperatorGraph._get_node_name(kwarg)] = node
308                        out_variant_output_nodes.add(
309                            FXOperatorGraph._get_node_name(kwarg)
310                        )
311            else:
312                raise Exception(f"Unsupported op type encountered {op}, {name}")
313
314            nodes[name] = node
315        return FXOperatorGraph._compose_op_graph(
316            "base",
317            nodes,
318            input_nodes,
319            output_nodes,
320            out_variant_output_nodes,
321            module_mapping,
322        )
323
324    @staticmethod
325    def _compose_op_graph(
326        name: str,
327        nodes: Dict[str, Node],
328        input_nodes: Dict[
329            str, Node | OperatorGraph
330        ],  # Never OperatorGraph, annotated for Pyre
331        output_nodes: Dict[
332            str, Node | OperatorGraph
333        ],  # Never OperatorGraph, annotated for Pyre
334        out_variant_output_nodes: Set[str],
335        module_mapping: Dict[
336            Tuple[str, str], List[Any]
337        ],  # Any used here for Pyre, list of Nodes
338    ):
339        # Generate Module Graphs
340        module_graphs: List[OperatorGraph] = []
341        for module_key, module_nodes in module_mapping.items():
342            module_element = OperatorGraph(
343                graph_name=module_key[0],
344                elements=module_nodes,
345                metadata={"module_type": module_key[1]},
346            )
347            module_graphs.append(module_element)
348
349            # Remove module modes from main graph
350            for node in module_nodes:
351                nodes.pop(node.name)
352
353        main_nodes = [
354            node
355            for name, node in nodes.items()
356            if name not in input_nodes
357            and name not in output_nodes
358            and name not in out_variant_output_nodes
359        ]
360        main_graph = FXOperatorGraph(
361            graph_name="forward", elements=main_nodes + module_graphs
362        )
363        input_graph = FXOperatorGraph(
364            graph_name="inputs", elements=list(input_nodes.values())
365        )
366        output_graph = FXOperatorGraph(
367            graph_name="outputs", elements=list(output_nodes.values())
368        )
369
370        return FXOperatorGraph(
371            graph_name=name,
372            elements=[input_graph, main_graph, output_graph],
373        )
374
375    # Given a dict, extract only the utilized metadata
376    @staticmethod
377    def _extract_metadata(
378        metadata: Dict[str, Any], skip_stack_trace: Optional[bool] = False
379    ) -> Dict[str, Any]:
380        ret = {}
381        if RESERVED_METADATA_ARG.DEBUG_HANDLE.value in metadata:
382            ret[RESERVED_METADATA_ARG.DEBUG_HANDLE.value] = metadata[
383                RESERVED_METADATA_ARG.DEBUG_HANDLE.value
384            ]
385        if not skip_stack_trace and RESERVED_METADATA_ARG.STACK_TRACE.value in metadata:
386            ret[RESERVED_METADATA_ARG.STACK_TRACE.value] = metadata[
387                RESERVED_METADATA_ARG.STACK_TRACE.value
388            ]
389        if RESERVED_METADATA_ARG.MODULE_STACK.value in metadata:
390            ret[RESERVED_METADATA_ARG.MODULE_STACK.value] = metadata[
391                RESERVED_METADATA_ARG.MODULE_STACK.value
392            ]
393        return ret
394
395    @staticmethod
396    def _extract_output_shapes(val: Any) -> Optional[List[List[int]]]:
397        if isinstance(val, (FakeTensor, torch.Tensor)):
398            # If val is a single tensor
399            return [list(val.shape)]
400        elif isinstance(val, tuple) and all(
401            isinstance(tensor, (FakeTensor, torch.Tensor)) for tensor in val
402        ):
403            # If val is a tuple of tensors
404            shapes = [list(fake_tensor.shape) for fake_tensor in val]
405            return shapes
406        else:
407            return None
408
409    @staticmethod
410    def _extract_output_dtype(val: Any) -> Optional[List[torch.dtype]]:
411        if isinstance(val, (FakeTensor, torch.Tensor)):
412            # If val is a single tensor
413            return [val.dtype]
414        elif isinstance(val, tuple) and all(
415            isinstance(tensor, (FakeTensor, torch.Tensor)) for tensor in val
416        ):
417            # If val is a tuple of tensors
418            dtypes = [fake_tensor.dtype for fake_tensor in val]
419            return dtypes
420        else:
421            return None
422