1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport warnings 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Optional 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx as fx 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerclass ExportGraph: 18*523fa7a6SAndroid Build Coastguard Worker """ 19*523fa7a6SAndroid Build Coastguard Worker ExportGraph serves as a layer between EXIR and FX Graph API. 20*523fa7a6SAndroid Build Coastguard Worker It enforces EXIR-specific invariants (ex. having nodes contain specs) 21*523fa7a6SAndroid Build Coastguard Worker """ 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker owning_module: fx.GraphModule 24*523fa7a6SAndroid Build Coastguard Worker _graph: fx.Graph 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None: 27*523fa7a6SAndroid Build Coastguard Worker self.owning_module = owning_module 28*523fa7a6SAndroid Build Coastguard Worker self._graph = graph 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker @property 31*523fa7a6SAndroid Build Coastguard Worker def nodes(self) -> fx.graph._node_list: 32*523fa7a6SAndroid Build Coastguard Worker """ 33*523fa7a6SAndroid Build Coastguard Worker Get the list of Nodes that constitute this Graph. 34*523fa7a6SAndroid Build Coastguard Worker """ 35*523fa7a6SAndroid Build Coastguard Worker return self._graph.nodes 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker def erase_node(self, to_erase: fx.Node) -> None: 38*523fa7a6SAndroid Build Coastguard Worker """ 39*523fa7a6SAndroid Build Coastguard Worker Erases a ``Node`` from the ``Graph``. Throws an exception if 40*523fa7a6SAndroid Build Coastguard Worker there are still users of that node in the ``Graph``. 41*523fa7a6SAndroid Build Coastguard Worker """ 42*523fa7a6SAndroid Build Coastguard Worker return self._graph.erase_node(to_erase) 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint: 45*523fa7a6SAndroid Build Coastguard Worker """ 46*523fa7a6SAndroid Build Coastguard Worker Sets the point at which we will insert the graph. 47*523fa7a6SAndroid Build Coastguard Worker """ 48*523fa7a6SAndroid Build Coastguard Worker return self._graph.inserting_before(n) 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 51*523fa7a6SAndroid Build Coastguard Worker def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node: 52*523fa7a6SAndroid Build Coastguard Worker """ 53*523fa7a6SAndroid Build Coastguard Worker Inserts a ``get_attr`` node into the Graph. 54*523fa7a6SAndroid Build Coastguard Worker """ 55*523fa7a6SAndroid Build Coastguard Worker node = self._graph.get_attr(qualified_name, type_expr) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker # Gets the actual value of the attribute if it exists so that we can use 58*523fa7a6SAndroid Build Coastguard Worker # it to set the 'spec' metadata 59*523fa7a6SAndroid Build Coastguard Worker def _maybe_get_attr_value( 60*523fa7a6SAndroid Build Coastguard Worker mod: torch.nn.Module, qualified_name: str 61*523fa7a6SAndroid Build Coastguard Worker ) -> Optional[torch.Tensor]: 62*523fa7a6SAndroid Build Coastguard Worker module_path, _, name = qualified_name.rpartition(".") 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker try: 65*523fa7a6SAndroid Build Coastguard Worker submod: torch.nn.Module = mod.get_submodule(module_path) 66*523fa7a6SAndroid Build Coastguard Worker except AttributeError: 67*523fa7a6SAndroid Build Coastguard Worker warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1) 68*523fa7a6SAndroid Build Coastguard Worker return None 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker # See if the value is a buffer 71*523fa7a6SAndroid Build Coastguard Worker if name in submod._buffers: 72*523fa7a6SAndroid Build Coastguard Worker return submod._buffers[name] 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker # See if the value is a parameter 75*523fa7a6SAndroid Build Coastguard Worker if hasattr(submod, name): 76*523fa7a6SAndroid Build Coastguard Worker attr = getattr(submod, name) 77*523fa7a6SAndroid Build Coastguard Worker if isinstance(attr, torch.nn.Parameter): 78*523fa7a6SAndroid Build Coastguard Worker return attr 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker return None 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker buffer = _maybe_get_attr_value(self.owning_module, qualified_name) 83*523fa7a6SAndroid Build Coastguard Worker if buffer is not None: 84*523fa7a6SAndroid Build Coastguard Worker node.meta["spec"] = TensorSpec.from_tensor(buffer, True) 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker return node 87