# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import warnings from typing import Any, Optional import torch import torch.fx as fx from executorch.exir.tensor import TensorSpec class ExportGraph: """ ExportGraph serves as a layer between EXIR and FX Graph API. It enforces EXIR-specific invariants (ex. having nodes contain specs) """ owning_module: fx.GraphModule _graph: fx.Graph def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None: self.owning_module = owning_module self._graph = graph @property def nodes(self) -> fx.graph._node_list: """ Get the list of Nodes that constitute this Graph. """ return self._graph.nodes def erase_node(self, to_erase: fx.Node) -> None: """ Erases a ``Node`` from the ``Graph``. Throws an exception if there are still users of that node in the ``Graph``. """ return self._graph.erase_node(to_erase) def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint: """ Sets the point at which we will insert the graph. """ return self._graph.inserting_before(n) # pyre-ignore def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node: """ Inserts a ``get_attr`` node into the Graph. """ node = self._graph.get_attr(qualified_name, type_expr) # Gets the actual value of the attribute if it exists so that we can use # it to set the 'spec' metadata def _maybe_get_attr_value( mod: torch.nn.Module, qualified_name: str ) -> Optional[torch.Tensor]: module_path, _, name = qualified_name.rpartition(".") try: submod: torch.nn.Module = mod.get_submodule(module_path) except AttributeError: warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1) return None # See if the value is a buffer if name in submod._buffers: return submod._buffers[name] # See if the value is a parameter if hasattr(submod, name): attr = getattr(submod, name) if isinstance(attr, torch.nn.Parameter): return attr return None buffer = _maybe_get_attr_value(self.owning_module, qualified_name) if buffer is not None: node.meta["spec"] = TensorSpec.from_tensor(buffer, True) return node