1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5# 6 7# pyre-unsafe 8from typing import cast, Dict 9 10import numpy as np 11import serializer.tosa_serializer as ts 12import torch 13import torch.fx 14from executorch.backends.arm.operators.node_visitor import NodeVisitor 15from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg 16from executorch.backends.arm.tosa_quant_utils import ( 17 get_quant_arg_upstream, 18 get_quantized_node_output_dtype, 19 is_node_quantized, 20) 21from executorch.backends.arm.tosa_specification import TosaSpecification 22from executorch.backends.arm.tosa_utils import ( 23 getNodeArgs, 24 is_bias_node_for_quantized_conv, 25 tosa_shape, 26) 27from torch.export.exported_program import ExportedProgram 28 29 30def process_call_function( 31 node: torch.fx.Node, 32 tosa_graph: ts.TosaSerializer, 33 node_visitors: Dict[str, NodeVisitor], 34 tosa_spec: TosaSpecification, 35): 36 # Unpack arguments and convert 37 inputs = getNodeArgs(node) 38 39 # Convert output (this node itself) 40 output = TosaArg(node) 41 42 is_quant_node = is_node_quantized(node) 43 if is_quant_node: 44 output_dtype = map_dtype(get_quantized_node_output_dtype(node)) 45 else: 46 output_dtype = output.dtype 47 tosa_graph.currRegion.currBasicBlock.addTensor( 48 output.name, 49 tosa_shape(output.shape, output.dim_order), 50 output_dtype, 51 ) 52 53 # Visiting each Node 54 # pyre-ignore[16]: Undefined attribute. 55 if node.target.__name__ in node_visitors: 56 # pyre-ignore[16]: Undefined attribute. 57 node_visitors[node.target.__name__].define_node( 58 node, 59 tosa_graph, 60 inputs, 61 output, 62 is_quant_node, 63 ) 64 else: 65 raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") 66 67 68def process_inputs( 69 node: torch.fx.Node, 70 tosa_graph: ts.TosaSerializer, 71 tosa_spec: TosaSpecification, 72): 73 """Serialize an input node""" 74 # inputs need to be in default dim_order (contiguous memory format) 75 meta = node.meta["val"] 76 if meta.dim_order() != tuple(range(meta.dim())): 77 raise RuntimeError( 78 f"Arm backend only supports contiguous memory format for inputs. " 79 f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" 80 ) 81 inputs = [TosaArg(node)] 82 input_shape = inputs[0].shape 83 input_dim_order = inputs[0].dim_order 84 tensor = ts.TosaSerializerTensor( 85 inputs[0].name, 86 tosa_shape(input_shape, input_dim_order), 87 ( 88 map_dtype(get_quantized_node_output_dtype(node)) 89 if is_node_quantized(node) 90 else inputs[0].dtype 91 ), 92 data=None, 93 placeholderFilename=inputs[0].name + ".npy", 94 ) 95 tosa_graph.addInputTensor(tensor) 96 97 98def process_quantized_bias( 99 node: torch.fx.Node, 100 tosa_graph: ts.TosaSerializer, 101 parameter_values, 102): 103 """ 104 Serialize bias node that needs to be quantized. 105 """ 106 consumer_node = list(node.users)[0] 107 ( 108 input_node, 109 weight_node, 110 _, 111 ) = consumer_node.all_input_nodes 112 113 input_node_scale = get_quant_arg_upstream(input_node).scale 114 weight_node_scale = get_quant_arg_upstream(weight_node).scale 115 bias_values_quantized = ( 116 (parameter_values / (input_node_scale * weight_node_scale)) 117 .round() 118 .astype(np.int32) 119 ) 120 121 tosa_graph.addConst( 122 bias_values_quantized.shape, 123 ts.DType.INT32, 124 bias_values_quantized, 125 name=node.name, 126 ) 127 128 129def process_inputs_to_parameters( 130 node: torch.fx.Node, 131 tosa_graph: ts.TosaSerializer, 132 edge_program: ExportedProgram, 133 tosa_spec: TosaSpecification, 134): 135 """Serialize bias and non-quantized weights""" 136 inputs = [TosaArg(node)] 137 parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name] 138 parameter_data = edge_program.state_dict[parameter_name] 139 140 assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" 141 parameter_values = parameter_data.detach().numpy() 142 143 if is_bias_node_for_quantized_conv(node): 144 # BI bias 145 assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" 146 process_quantized_bias(node, tosa_graph, parameter_values) 147 else: 148 # MI weights or bias 149 if inputs[0].dtype == torch.float32: 150 assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" 151 152 parameter_values = np.transpose(parameter_values, inputs[0].dim_order) 153 154 tosa_graph.addConst( 155 parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name 156 ) 157 158 159def process_inputs_to_buffers( 160 node: torch.fx.Node, 161 tosa_graph: ts.TosaSerializer, 162 edge_program: ExportedProgram, 163): 164 """Serialize quantized weights""" 165 inputs = [TosaArg(node)] 166 buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] 167 buffer_data = edge_program.state_dict[buffer_name] 168 169 assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor" 170 buffer_values = buffer_data.detach().numpy() 171 172 # TODO: fragile code for temporary fix 173 # the mean and var tensors are also stored here but they have shape (1, ) 174 # we only transpose weights here 175 buffer_values = np.transpose(buffer_values, inputs[0].dim_order) 176 177 tosa_graph.addConst( 178 buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name 179 ) 180 181 182def process_inputs_to_lifted_tensor_constants( 183 node: torch.fx.Node, 184 tosa_graph: ts.TosaSerializer, 185 edge_program: ExportedProgram, 186): 187 arg = TosaArg(node) 188 tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ 189 arg.name 190 ] 191 tensor = edge_program.tensor_constants[tensor_name] 192 tensor_data = tensor.detach().numpy() 193 194 tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name) 195 196 197def process_placeholder( 198 node: torch.fx.Node, 199 tosa_graph: ts.TosaSerializer, 200 edge_program: ExportedProgram, 201 tosa_spec: TosaSpecification, 202): 203 """Wrapper for processing and serializing all types of placeholders""" 204 assert node.name == node.target, "Expect placeholder name and target to match" 205 assert 0 == len(node.args), "Can't handle default input values" 206 207 if node.name in edge_program.graph_signature.user_inputs: 208 process_inputs(node, tosa_graph, tosa_spec) 209 elif node.name in edge_program.graph_signature.inputs_to_parameters: 210 process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) 211 elif node.name in edge_program.graph_signature.inputs_to_buffers: 212 process_inputs_to_buffers(node, tosa_graph, edge_program) 213 elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: 214 process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program) 215 elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs: 216 raise NotImplementedError( 217 "Placeholder is of type 'lifted custom object' which is not supported." 218 ) 219 else: 220 raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") 221 222 223def process_output( 224 node: torch.fx.Node, 225 tosa_graph: ts.TosaSerializer, 226): 227 for output in cast(tuple[torch.fx.Node, ...], node.args[0]): 228 tosa_graph.addOutputTensor( 229 tosa_graph.currRegion.currBasicBlock.tensors[output.name] 230 ) 231