xref: /aosp_15_r20/external/executorch/exir/graph.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import warnings
10from typing import Any, Optional
11
12import torch
13import torch.fx as fx
14from executorch.exir.tensor import TensorSpec
15
16
17class ExportGraph:
18    """
19    ExportGraph serves as a layer between EXIR and FX Graph API.
20    It enforces EXIR-specific invariants (ex. having nodes contain specs)
21    """
22
23    owning_module: fx.GraphModule
24    _graph: fx.Graph
25
26    def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None:
27        self.owning_module = owning_module
28        self._graph = graph
29
30    @property
31    def nodes(self) -> fx.graph._node_list:
32        """
33        Get the list of Nodes that constitute this Graph.
34        """
35        return self._graph.nodes
36
37    def erase_node(self, to_erase: fx.Node) -> None:
38        """
39        Erases a ``Node`` from the ``Graph``. Throws an exception if
40        there are still users of that node in the ``Graph``.
41        """
42        return self._graph.erase_node(to_erase)
43
44    def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint:
45        """
46        Sets the point at which we will insert the graph.
47        """
48        return self._graph.inserting_before(n)
49
50    # pyre-ignore
51    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node:
52        """
53        Inserts a ``get_attr`` node into the Graph.
54        """
55        node = self._graph.get_attr(qualified_name, type_expr)
56
57        # Gets the actual value of the attribute if it exists so that we can use
58        # it to set the 'spec' metadata
59        def _maybe_get_attr_value(
60            mod: torch.nn.Module, qualified_name: str
61        ) -> Optional[torch.Tensor]:
62            module_path, _, name = qualified_name.rpartition(".")
63
64            try:
65                submod: torch.nn.Module = mod.get_submodule(module_path)
66            except AttributeError:
67                warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1)
68                return None
69
70            # See if the value is a buffer
71            if name in submod._buffers:
72                return submod._buffers[name]
73
74            # See if the value is a parameter
75            if hasattr(submod, name):
76                attr = getattr(submod, name)
77                if isinstance(attr, torch.nn.Parameter):
78                    return attr
79
80            return None
81
82        buffer = _maybe_get_attr_value(self.owning_module, qualified_name)
83        if buffer is not None:
84            node.meta["spec"] = TensorSpec.from_tensor(buffer, True)
85
86        return node
87