xref: /aosp_15_r20/external/executorch/devtools/debug_format/base_schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9"""
10Base Intermediate Representation for Developer Tools consumers
11(e.g. TensorBoard, Terminal Debugger)
12"""
13
14from __future__ import annotations
15
16from dataclasses import dataclass
17from typing import Any, Dict, List, Optional
18
19
20# Base Representation of a generic node within a ModelGraph
21@dataclass
22class Node:
23    name: str
24    # Nodes that this Node consumes/in-edges
25    inputs: Optional[List[Node]] = None
26    # List of output shapes
27    output_shapes: Optional[List[List[int]]] = None
28    # Generic Node level metadata
29    metadata: Optional[Dict[str, Any]] = None
30    # Names of the arguments derived from the op schema:
31    named_args: Optional[List[str]] = None
32
33
34# Base Representation of an operator subgraph with metadata
35@dataclass
36class OperatorGraph:
37    # Identifier used for grouping nodes (e.g. expand/minimize Module)
38    graph_name: str
39    # Nodes and Sub-Graphs
40    elements: List[Node | OperatorGraph]
41    # Graph Level Metadata
42    metadata: Optional[Dict[str, Any]] = None
43
44
45"""
46Node SubClasses Types
47"""
48
49
50# Representation of a "Value" node within a ModelGraph
51# i.e. Non-Operator Nodes
52@dataclass
53class ValueNode(Node):
54    dtype: str = ""
55    val: Optional[Any] = None
56
57
58# Representation of an "OP" node within a ModelGraph
59@dataclass
60class OperatorNode(Node):
61    op: Optional[str] = None
62