xref: /aosp_15_r20/external/executorch/exir/graph.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport warnings
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Optional
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx as fx
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerclass ExportGraph:
18*523fa7a6SAndroid Build Coastguard Worker    """
19*523fa7a6SAndroid Build Coastguard Worker    ExportGraph serves as a layer between EXIR and FX Graph API.
20*523fa7a6SAndroid Build Coastguard Worker    It enforces EXIR-specific invariants (ex. having nodes contain specs)
21*523fa7a6SAndroid Build Coastguard Worker    """
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker    owning_module: fx.GraphModule
24*523fa7a6SAndroid Build Coastguard Worker    _graph: fx.Graph
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None:
27*523fa7a6SAndroid Build Coastguard Worker        self.owning_module = owning_module
28*523fa7a6SAndroid Build Coastguard Worker        self._graph = graph
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker    @property
31*523fa7a6SAndroid Build Coastguard Worker    def nodes(self) -> fx.graph._node_list:
32*523fa7a6SAndroid Build Coastguard Worker        """
33*523fa7a6SAndroid Build Coastguard Worker        Get the list of Nodes that constitute this Graph.
34*523fa7a6SAndroid Build Coastguard Worker        """
35*523fa7a6SAndroid Build Coastguard Worker        return self._graph.nodes
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker    def erase_node(self, to_erase: fx.Node) -> None:
38*523fa7a6SAndroid Build Coastguard Worker        """
39*523fa7a6SAndroid Build Coastguard Worker        Erases a ``Node`` from the ``Graph``. Throws an exception if
40*523fa7a6SAndroid Build Coastguard Worker        there are still users of that node in the ``Graph``.
41*523fa7a6SAndroid Build Coastguard Worker        """
42*523fa7a6SAndroid Build Coastguard Worker        return self._graph.erase_node(to_erase)
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker    def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint:
45*523fa7a6SAndroid Build Coastguard Worker        """
46*523fa7a6SAndroid Build Coastguard Worker        Sets the point at which we will insert the graph.
47*523fa7a6SAndroid Build Coastguard Worker        """
48*523fa7a6SAndroid Build Coastguard Worker        return self._graph.inserting_before(n)
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
51*523fa7a6SAndroid Build Coastguard Worker    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node:
52*523fa7a6SAndroid Build Coastguard Worker        """
53*523fa7a6SAndroid Build Coastguard Worker        Inserts a ``get_attr`` node into the Graph.
54*523fa7a6SAndroid Build Coastguard Worker        """
55*523fa7a6SAndroid Build Coastguard Worker        node = self._graph.get_attr(qualified_name, type_expr)
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker        # Gets the actual value of the attribute if it exists so that we can use
58*523fa7a6SAndroid Build Coastguard Worker        # it to set the 'spec' metadata
59*523fa7a6SAndroid Build Coastguard Worker        def _maybe_get_attr_value(
60*523fa7a6SAndroid Build Coastguard Worker            mod: torch.nn.Module, qualified_name: str
61*523fa7a6SAndroid Build Coastguard Worker        ) -> Optional[torch.Tensor]:
62*523fa7a6SAndroid Build Coastguard Worker            module_path, _, name = qualified_name.rpartition(".")
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker            try:
65*523fa7a6SAndroid Build Coastguard Worker                submod: torch.nn.Module = mod.get_submodule(module_path)
66*523fa7a6SAndroid Build Coastguard Worker            except AttributeError:
67*523fa7a6SAndroid Build Coastguard Worker                warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1)
68*523fa7a6SAndroid Build Coastguard Worker                return None
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker            # See if the value is a buffer
71*523fa7a6SAndroid Build Coastguard Worker            if name in submod._buffers:
72*523fa7a6SAndroid Build Coastguard Worker                return submod._buffers[name]
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker            # See if the value is a parameter
75*523fa7a6SAndroid Build Coastguard Worker            if hasattr(submod, name):
76*523fa7a6SAndroid Build Coastguard Worker                attr = getattr(submod, name)
77*523fa7a6SAndroid Build Coastguard Worker                if isinstance(attr, torch.nn.Parameter):
78*523fa7a6SAndroid Build Coastguard Worker                    return attr
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker            return None
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker        buffer = _maybe_get_attr_value(self.owning_module, qualified_name)
83*523fa7a6SAndroid Build Coastguard Worker        if buffer is not None:
84*523fa7a6SAndroid Build Coastguard Worker            node.meta["spec"] = TensorSpec.from_tensor(buffer, True)
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker        return node
87