1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import warnings 10from typing import Any, Optional 11 12import torch 13import torch.fx as fx 14from executorch.exir.tensor import TensorSpec 15 16 17class ExportGraph: 18 """ 19 ExportGraph serves as a layer between EXIR and FX Graph API. 20 It enforces EXIR-specific invariants (ex. having nodes contain specs) 21 """ 22 23 owning_module: fx.GraphModule 24 _graph: fx.Graph 25 26 def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None: 27 self.owning_module = owning_module 28 self._graph = graph 29 30 @property 31 def nodes(self) -> fx.graph._node_list: 32 """ 33 Get the list of Nodes that constitute this Graph. 34 """ 35 return self._graph.nodes 36 37 def erase_node(self, to_erase: fx.Node) -> None: 38 """ 39 Erases a ``Node`` from the ``Graph``. Throws an exception if 40 there are still users of that node in the ``Graph``. 41 """ 42 return self._graph.erase_node(to_erase) 43 44 def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint: 45 """ 46 Sets the point at which we will insert the graph. 47 """ 48 return self._graph.inserting_before(n) 49 50 # pyre-ignore 51 def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node: 52 """ 53 Inserts a ``get_attr`` node into the Graph. 54 """ 55 node = self._graph.get_attr(qualified_name, type_expr) 56 57 # Gets the actual value of the attribute if it exists so that we can use 58 # it to set the 'spec' metadata 59 def _maybe_get_attr_value( 60 mod: torch.nn.Module, qualified_name: str 61 ) -> Optional[torch.Tensor]: 62 module_path, _, name = qualified_name.rpartition(".") 63 64 try: 65 submod: torch.nn.Module = mod.get_submodule(module_path) 66 except AttributeError: 67 warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1) 68 return None 69 70 # See if the value is a buffer 71 if name in submod._buffers: 72 return submod._buffers[name] 73 74 # See if the value is a parameter 75 if hasattr(submod, name): 76 attr = getattr(submod, name) 77 if isinstance(attr, torch.nn.Parameter): 78 return attr 79 80 return None 81 82 buffer = _maybe_get_attr_value(self.owning_module, qualified_name) 83 if buffer is not None: 84 node.meta["spec"] = TensorSpec.from_tensor(buffer, True) 85 86 return node 87