1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10from typing import Optional 11 12import torch 13import torch.fx 14 15from executorch.exir import ExportedProgram 16from executorch.exir.dialects._ops import ops as exir_ops 17 18from torch._export.utils import ( 19 get_buffer, 20 get_lifted_tensor_constant, 21 get_param, 22 is_buffer, 23 is_lifted_tensor_constant, 24 is_param, 25) 26from torch._ops import OpOverload 27from torch._subclasses.fake_tensor import FakeTensor 28 29 30def is_get_attr_node(node: torch.fx.Node) -> bool: 31 """ 32 Returns true if the given node is a get attr node for a tensor of the model 33 """ 34 return isinstance(node, torch.fx.Node) and node.op == "get_attr" 35 36 37def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: 38 return ( 39 is_get_attr_node(node) 40 or is_param(exp_prog, node) 41 or is_buffer(exp_prog, node) 42 or is_lifted_tensor_constant(exp_prog, node) 43 ) 44 45 46def get_param_tensor( 47 exp_prog: ExportedProgram, node: torch.fx.Node 48) -> Optional[torch.Tensor]: 49 if node is None: 50 return None 51 elif is_param(exp_prog, node): 52 return get_param(exp_prog, node) 53 elif is_buffer(exp_prog, node): 54 return get_buffer(exp_prog, node) 55 elif is_lifted_tensor_constant(exp_prog, node): 56 return get_lifted_tensor_constant(exp_prog, node) 57 elif is_get_attr_node(node): 58 # This is a hack to support both lifted and unlifted graph 59 try: 60 return getattr(node.graph.owning_module, node.target) 61 except AttributeError: 62 return getattr(exp_prog.graph_module, node.target) 63 raise RuntimeError(f"unsupported param type, {node.op}.") 64 65 66def create_node( 67 graph: torch.fx.Graph, 68 op_target: OpOverload, 69 args: tuple = (), 70 kwargs: Optional[dict] = None, 71 quantize: bool = False, 72 q_params: Optional[tuple] = None, 73): 74 """ 75 Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. 76 If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node. 77 """ 78 79 node = graph.create_node( 80 "call_function", 81 op_target, 82 args=args, 83 kwargs=kwargs or {}, 84 ) 85 if quantize and q_params: 86 return insert_q_dq_pair(graph, node, q_params) 87 return node 88 89 90def insert_q_dq_pair( 91 graph: torch.fx.Graph, 92 anchor: torch.fx.Node, 93 q_params: tuple, 94): 95 """ 96 Inserts a q dq node pair after the node 'anchor'. 97 """ 98 99 with graph.inserting_after(anchor): 100 q = create_node( 101 graph=graph, 102 op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 103 args=(), # We add the argument last 104 ) 105 q.meta = anchor.meta 106 with graph.inserting_after(q): 107 dq = create_node( 108 graph=graph, 109 op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 110 args=(q,) + q_params, 111 ) 112 dq.meta = q.meta 113 anchor.replace_all_uses_with(dq) 114 # We add this last so the replace all uses above does not replace the quantized 115 # node's first use 116 q.args = (anchor,) + q_params 117 return dq 118 119 120def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: 121 """ 122 Returns a FakeTensor from the meta field of 'node'. 123 If the node contains many fake tensors, return the first one. 124 """ 125 if isinstance( 126 node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) 127 ): 128 fake_tensor = node.meta["val"][0] 129 else: 130 fake_tensor = node.meta["val"] 131 132 assert isinstance( 133 fake_tensor, FakeTensor 134 ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.' 135 return fake_tensor 136