1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, cast, Optional, Tuple 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.configs import ( 13*523fa7a6SAndroid Build Coastguard Worker get_transform_passes, 14*523fa7a6SAndroid Build Coastguard Worker get_xnnpack_capture_config, 15*523fa7a6SAndroid Build Coastguard Worker get_xnnpack_edge_compile_config, 16*523fa7a6SAndroid Build Coastguard Worker) 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExportedProgram 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import ( 21*523fa7a6SAndroid Build Coastguard Worker get_buffer, 22*523fa7a6SAndroid Build Coastguard Worker get_lifted_tensor_constant, 23*523fa7a6SAndroid Build Coastguard Worker get_param, 24*523fa7a6SAndroid Build Coastguard Worker is_buffer, 25*523fa7a6SAndroid Build Coastguard Worker is_lifted_tensor_constant, 26*523fa7a6SAndroid Build Coastguard Worker is_param, 27*523fa7a6SAndroid Build Coastguard Worker) 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker### XNNPACK Capture ### 31*523fa7a6SAndroid Build Coastguard Workerdef capture_graph_for_xnnpack( 32*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 33*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 34*523fa7a6SAndroid Build Coastguard Worker enable_aot: Optional[bool] = None, 35*523fa7a6SAndroid Build Coastguard Worker unlift: Optional[bool] = None, 36*523fa7a6SAndroid Build Coastguard Worker) -> exir.ExirExportedProgram: 37*523fa7a6SAndroid Build Coastguard Worker return ( 38*523fa7a6SAndroid Build Coastguard Worker exir.capture( 39*523fa7a6SAndroid Build Coastguard Worker module, 40*523fa7a6SAndroid Build Coastguard Worker inputs, 41*523fa7a6SAndroid Build Coastguard Worker get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift), 42*523fa7a6SAndroid Build Coastguard Worker ) 43*523fa7a6SAndroid Build Coastguard Worker .to_edge(get_xnnpack_edge_compile_config()) 44*523fa7a6SAndroid Build Coastguard Worker .transform(*get_transform_passes()) 45*523fa7a6SAndroid Build Coastguard Worker ) 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker### XNNPACK Utils ### 49*523fa7a6SAndroid Build Coastguard WorkerPERM_NCHW_TO_NHWC = [0, 2, 3, 1] 50*523fa7a6SAndroid Build Coastguard WorkerPERM_NHWC_TO_NCHW = [0, 3, 1, 2] 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerdef check_or_raise(condition: bool, err: str) -> None: 54*523fa7a6SAndroid Build Coastguard Worker """ 55*523fa7a6SAndroid Build Coastguard Worker Raises runtime error if condition is false, with the given error message 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker Args: 58*523fa7a6SAndroid Build Coastguard Worker condition: boolean condition to check 59*523fa7a6SAndroid Build Coastguard Worker err: error message to raise if condition is not true 60*523fa7a6SAndroid Build Coastguard Worker """ 61*523fa7a6SAndroid Build Coastguard Worker if not condition: 62*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(err) 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Workerdef is_node(node: Any) -> bool: 66*523fa7a6SAndroid Build Coastguard Worker """ 67*523fa7a6SAndroid Build Coastguard Worker returns true if node is a torch.fx.Node, otherwise false 68*523fa7a6SAndroid Build Coastguard Worker """ 69*523fa7a6SAndroid Build Coastguard Worker return isinstance(node, torch.fx.Node) 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Workerdef is_getitem(node: torch.fx.Node) -> bool: 73*523fa7a6SAndroid Build Coastguard Worker if node.op != "call_function": 74*523fa7a6SAndroid Build Coastguard Worker return False 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker return node.target.__name__ == "getitem" # pyre-ignore 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Workerdef get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: 80*523fa7a6SAndroid Build Coastguard Worker return cast(torch.fx.Node, node.args[input_index]) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Workerdef get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: 84*523fa7a6SAndroid Build Coastguard Worker """ 85*523fa7a6SAndroid Build Coastguard Worker Checks if the current node is only consumed by a relu node and can be fused, 86*523fa7a6SAndroid Build Coastguard Worker if so, we return the relu node that can be fused, otherwise return None 87*523fa7a6SAndroid Build Coastguard Worker """ 88*523fa7a6SAndroid Build Coastguard Worker if ( 89*523fa7a6SAndroid Build Coastguard Worker len(node.users) == 1 90*523fa7a6SAndroid Build Coastguard Worker and list(node.users.keys())[0].target == exir_ops.edge.aten.relu.default 91*523fa7a6SAndroid Build Coastguard Worker ): 92*523fa7a6SAndroid Build Coastguard Worker relu_node = list(node.users.keys())[0] 93*523fa7a6SAndroid Build Coastguard Worker return relu_node 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker return None 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool: 99*523fa7a6SAndroid Build Coastguard Worker """ 100*523fa7a6SAndroid Build Coastguard Worker Returns true if the given node is a get attr node for a tensor of the model 101*523fa7a6SAndroid Build Coastguard Worker """ 102*523fa7a6SAndroid Build Coastguard Worker return isinstance(node, torch.fx.Node) and node.op == "get_attr" 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: 106*523fa7a6SAndroid Build Coastguard Worker return ( 107*523fa7a6SAndroid Build Coastguard Worker is_get_attr_node(node) 108*523fa7a6SAndroid Build Coastguard Worker or is_param(exp_prog, node) 109*523fa7a6SAndroid Build Coastguard Worker or is_buffer(exp_prog, node) 110*523fa7a6SAndroid Build Coastguard Worker or is_lifted_tensor_constant(exp_prog, node) 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Workerdef get_param_tensor( 115*523fa7a6SAndroid Build Coastguard Worker exp_prog: ExportedProgram, node: torch.fx.Node 116*523fa7a6SAndroid Build Coastguard Worker) -> Optional[torch.Tensor]: 117*523fa7a6SAndroid Build Coastguard Worker if node is None: 118*523fa7a6SAndroid Build Coastguard Worker return None 119*523fa7a6SAndroid Build Coastguard Worker elif is_param(exp_prog, node): 120*523fa7a6SAndroid Build Coastguard Worker return get_param(exp_prog, node) 121*523fa7a6SAndroid Build Coastguard Worker elif is_buffer(exp_prog, node): 122*523fa7a6SAndroid Build Coastguard Worker return get_buffer(exp_prog, node) 123*523fa7a6SAndroid Build Coastguard Worker elif is_lifted_tensor_constant(exp_prog, node): 124*523fa7a6SAndroid Build Coastguard Worker return get_lifted_tensor_constant(exp_prog, node) 125*523fa7a6SAndroid Build Coastguard Worker elif is_get_attr_node(node): 126*523fa7a6SAndroid Build Coastguard Worker # This is a hack to support both lifted and unlifted graph 127*523fa7a6SAndroid Build Coastguard Worker try: 128*523fa7a6SAndroid Build Coastguard Worker return getattr(node.graph.owning_module, node.target) 129*523fa7a6SAndroid Build Coastguard Worker except AttributeError: 130*523fa7a6SAndroid Build Coastguard Worker return getattr(exp_prog.graph_module, node.target) 131*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"unsupported param type, {node.op}.") 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Workerdef get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: 135*523fa7a6SAndroid Build Coastguard Worker """ 136*523fa7a6SAndroid Build Coastguard Worker Returns the source fn of the given node, return None if something goes wrong 137*523fa7a6SAndroid Build Coastguard Worker """ 138*523fa7a6SAndroid Build Coastguard Worker if ( 139*523fa7a6SAndroid Build Coastguard Worker node.op != "call_function" 140*523fa7a6SAndroid Build Coastguard Worker or (source_fn_st := node.meta.get("source_fn_stack", None)) is None 141*523fa7a6SAndroid Build Coastguard Worker ): 142*523fa7a6SAndroid Build Coastguard Worker return None 143*523fa7a6SAndroid Build Coastguard Worker source_fn = source_fn_st[-1] 144*523fa7a6SAndroid Build Coastguard Worker return source_fn[1] 145