# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict from types import FunctionType as function from typing import Dict, List, Tuple, Union import torch LeafValue = Union[ torch.Tensor, str, int, float, bool, complex, torch.dtype, torch.device, torch.memory_format, torch.layout, None, ] # We maintain a global cache of op lookups as this significantly speeds up # deserialization because hasattr(torch.ops, name) is an expensive call. _cache_ops_dict: Dict[ Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket] ] = {} _cache_fake_ops_dict: Dict[Tuple[str, str], function] = {} def _get_submodule( graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int ) -> Tuple[str, torch.nn.Module, torch.fx.Node]: submod_node = node.args[arg_index] assert isinstance(submod_node, torch.fx.Node) assert submod_node.op == "get_attr" assert isinstance(submod_node.target, str) submodule = graph_module.get_submodule(submod_node.target) # pyre-ignore return submod_node.target, submodule, node def get_control_flow_submodules( graph_module: torch.fx.GraphModule, ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: """ Returns a list of submodules used for control flow operations (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look into submodules). Specifically, the returned value is a list containing a tuple of (name of the submodule that's stored in the graph module, the submodule itself, and the fx node that uses this submodule). """ control_flow_submodules = [] for node in graph_module.graph.nodes: if node.op != "call_function": continue if node.target is torch.ops.higher_order.cond: control_flow_submodules.append(_get_submodule(graph_module, node, 1)) control_flow_submodules.append(_get_submodule(graph_module, node, 2)) if node.target is torch.ops.higher_order.map_impl: control_flow_submodules.append(_get_submodule(graph_module, node, 0)) return control_flow_submodules