1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Any, cast, Optional, Tuple 8 9import executorch.exir as exir 10import torch 11 12from executorch.backends.xnnpack.utils.configs import ( 13 get_transform_passes, 14 get_xnnpack_capture_config, 15 get_xnnpack_edge_compile_config, 16) 17from executorch.exir import ExportedProgram 18from executorch.exir.dialects._ops import ops as exir_ops 19 20from torch._export.utils import ( 21 get_buffer, 22 get_lifted_tensor_constant, 23 get_param, 24 is_buffer, 25 is_lifted_tensor_constant, 26 is_param, 27) 28 29 30### XNNPACK Capture ### 31def capture_graph_for_xnnpack( 32 module: torch.nn.Module, 33 inputs: Tuple[torch.Tensor], 34 enable_aot: Optional[bool] = None, 35 unlift: Optional[bool] = None, 36) -> exir.ExirExportedProgram: 37 return ( 38 exir.capture( 39 module, 40 inputs, 41 get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift), 42 ) 43 .to_edge(get_xnnpack_edge_compile_config()) 44 .transform(*get_transform_passes()) 45 ) 46 47 48### XNNPACK Utils ### 49PERM_NCHW_TO_NHWC = [0, 2, 3, 1] 50PERM_NHWC_TO_NCHW = [0, 3, 1, 2] 51 52 53def check_or_raise(condition: bool, err: str) -> None: 54 """ 55 Raises runtime error if condition is false, with the given error message 56 57 Args: 58 condition: boolean condition to check 59 err: error message to raise if condition is not true 60 """ 61 if not condition: 62 raise RuntimeError(err) 63 64 65def is_node(node: Any) -> bool: 66 """ 67 returns true if node is a torch.fx.Node, otherwise false 68 """ 69 return isinstance(node, torch.fx.Node) 70 71 72def is_getitem(node: torch.fx.Node) -> bool: 73 if node.op != "call_function": 74 return False 75 76 return node.target.__name__ == "getitem" # pyre-ignore 77 78 79def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: 80 return cast(torch.fx.Node, node.args[input_index]) 81 82 83def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: 84 """ 85 Checks if the current node is only consumed by a relu node and can be fused, 86 if so, we return the relu node that can be fused, otherwise return None 87 """ 88 if ( 89 len(node.users) == 1 90 and list(node.users.keys())[0].target == exir_ops.edge.aten.relu.default 91 ): 92 relu_node = list(node.users.keys())[0] 93 return relu_node 94 95 return None 96 97 98def is_get_attr_node(node: torch.fx.Node) -> bool: 99 """ 100 Returns true if the given node is a get attr node for a tensor of the model 101 """ 102 return isinstance(node, torch.fx.Node) and node.op == "get_attr" 103 104 105def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: 106 return ( 107 is_get_attr_node(node) 108 or is_param(exp_prog, node) 109 or is_buffer(exp_prog, node) 110 or is_lifted_tensor_constant(exp_prog, node) 111 ) 112 113 114def get_param_tensor( 115 exp_prog: ExportedProgram, node: torch.fx.Node 116) -> Optional[torch.Tensor]: 117 if node is None: 118 return None 119 elif is_param(exp_prog, node): 120 return get_param(exp_prog, node) 121 elif is_buffer(exp_prog, node): 122 return get_buffer(exp_prog, node) 123 elif is_lifted_tensor_constant(exp_prog, node): 124 return get_lifted_tensor_constant(exp_prog, node) 125 elif is_get_attr_node(node): 126 # This is a hack to support both lifted and unlifted graph 127 try: 128 return getattr(node.graph.owning_module, node.target) 129 except AttributeError: 130 return getattr(exp_prog.graph_module, node.target) 131 raise RuntimeError(f"unsupported param type, {node.op}.") 132 133 134def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: 135 """ 136 Returns the source fn of the given node, return None if something goes wrong 137 """ 138 if ( 139 node.op != "call_function" 140 or (source_fn_st := node.meta.get("source_fn_stack", None)) is None 141 ): 142 return None 143 source_fn = source_fn_st[-1] 144 return source_fn[1] 145