xref: /aosp_15_r20/external/executorch/backends/arm/process_node.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5#
6
7# pyre-unsafe
8from typing import cast, Dict
9
10import numpy as np
11import serializer.tosa_serializer as ts
12import torch
13import torch.fx
14from executorch.backends.arm.operators.node_visitor import NodeVisitor
15from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16from executorch.backends.arm.tosa_quant_utils import (
17    get_quant_arg_upstream,
18    get_quantized_node_output_dtype,
19    is_node_quantized,
20)
21from executorch.backends.arm.tosa_specification import TosaSpecification
22from executorch.backends.arm.tosa_utils import (
23    getNodeArgs,
24    is_bias_node_for_quantized_conv,
25    tosa_shape,
26)
27from torch.export.exported_program import ExportedProgram
28
29
30def process_call_function(
31    node: torch.fx.Node,
32    tosa_graph: ts.TosaSerializer,
33    node_visitors: Dict[str, NodeVisitor],
34    tosa_spec: TosaSpecification,
35):
36    # Unpack arguments and convert
37    inputs = getNodeArgs(node)
38
39    # Convert output (this node itself)
40    output = TosaArg(node)
41
42    is_quant_node = is_node_quantized(node)
43    if is_quant_node:
44        output_dtype = map_dtype(get_quantized_node_output_dtype(node))
45    else:
46        output_dtype = output.dtype
47    tosa_graph.currRegion.currBasicBlock.addTensor(
48        output.name,
49        tosa_shape(output.shape, output.dim_order),
50        output_dtype,
51    )
52
53    # Visiting each Node
54    # pyre-ignore[16]: Undefined attribute.
55    if node.target.__name__ in node_visitors:
56        # pyre-ignore[16]: Undefined attribute.
57        node_visitors[node.target.__name__].define_node(
58            node,
59            tosa_graph,
60            inputs,
61            output,
62            is_quant_node,
63        )
64    else:
65        raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
66
67
68def process_inputs(
69    node: torch.fx.Node,
70    tosa_graph: ts.TosaSerializer,
71    tosa_spec: TosaSpecification,
72):
73    """Serialize an input node"""
74    # inputs need to be in default dim_order (contiguous memory format)
75    meta = node.meta["val"]
76    if meta.dim_order() != tuple(range(meta.dim())):
77        raise RuntimeError(
78            f"Arm backend only supports contiguous memory format for inputs. "
79            f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
80        )
81    inputs = [TosaArg(node)]
82    input_shape = inputs[0].shape
83    input_dim_order = inputs[0].dim_order
84    tensor = ts.TosaSerializerTensor(
85        inputs[0].name,
86        tosa_shape(input_shape, input_dim_order),
87        (
88            map_dtype(get_quantized_node_output_dtype(node))
89            if is_node_quantized(node)
90            else inputs[0].dtype
91        ),
92        data=None,
93        placeholderFilename=inputs[0].name + ".npy",
94    )
95    tosa_graph.addInputTensor(tensor)
96
97
98def process_quantized_bias(
99    node: torch.fx.Node,
100    tosa_graph: ts.TosaSerializer,
101    parameter_values,
102):
103    """
104    Serialize bias node that needs to be quantized.
105    """
106    consumer_node = list(node.users)[0]
107    (
108        input_node,
109        weight_node,
110        _,
111    ) = consumer_node.all_input_nodes
112
113    input_node_scale = get_quant_arg_upstream(input_node).scale
114    weight_node_scale = get_quant_arg_upstream(weight_node).scale
115    bias_values_quantized = (
116        (parameter_values / (input_node_scale * weight_node_scale))
117        .round()
118        .astype(np.int32)
119    )
120
121    tosa_graph.addConst(
122        bias_values_quantized.shape,
123        ts.DType.INT32,
124        bias_values_quantized,
125        name=node.name,
126    )
127
128
129def process_inputs_to_parameters(
130    node: torch.fx.Node,
131    tosa_graph: ts.TosaSerializer,
132    edge_program: ExportedProgram,
133    tosa_spec: TosaSpecification,
134):
135    """Serialize bias and non-quantized weights"""
136    inputs = [TosaArg(node)]
137    parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
138    parameter_data = edge_program.state_dict[parameter_name]
139
140    assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
141    parameter_values = parameter_data.detach().numpy()
142
143    if is_bias_node_for_quantized_conv(node):
144        # BI bias
145        assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
146        process_quantized_bias(node, tosa_graph, parameter_values)
147    else:
148        # MI weights or bias
149        if inputs[0].dtype == torch.float32:
150            assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
151
152        parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
153
154        tosa_graph.addConst(
155            parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
156        )
157
158
159def process_inputs_to_buffers(
160    node: torch.fx.Node,
161    tosa_graph: ts.TosaSerializer,
162    edge_program: ExportedProgram,
163):
164    """Serialize quantized weights"""
165    inputs = [TosaArg(node)]
166    buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
167    buffer_data = edge_program.state_dict[buffer_name]
168
169    assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
170    buffer_values = buffer_data.detach().numpy()
171
172    # TODO: fragile code for temporary fix
173    # the mean and var tensors are also stored here but they have shape (1, )
174    # we only transpose weights here
175    buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
176
177    tosa_graph.addConst(
178        buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
179    )
180
181
182def process_inputs_to_lifted_tensor_constants(
183    node: torch.fx.Node,
184    tosa_graph: ts.TosaSerializer,
185    edge_program: ExportedProgram,
186):
187    arg = TosaArg(node)
188    tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
189        arg.name
190    ]
191    tensor = edge_program.tensor_constants[tensor_name]
192    tensor_data = tensor.detach().numpy()
193
194    tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
195
196
197def process_placeholder(
198    node: torch.fx.Node,
199    tosa_graph: ts.TosaSerializer,
200    edge_program: ExportedProgram,
201    tosa_spec: TosaSpecification,
202):
203    """Wrapper for processing and serializing all types of placeholders"""
204    assert node.name == node.target, "Expect placeholder name and target to match"
205    assert 0 == len(node.args), "Can't handle default input values"
206
207    if node.name in edge_program.graph_signature.user_inputs:
208        process_inputs(node, tosa_graph, tosa_spec)
209    elif node.name in edge_program.graph_signature.inputs_to_parameters:
210        process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
211    elif node.name in edge_program.graph_signature.inputs_to_buffers:
212        process_inputs_to_buffers(node, tosa_graph, edge_program)
213    elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
214        process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
215    elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
216        raise NotImplementedError(
217            "Placeholder is of type 'lifted custom object' which is not supported."
218        )
219    else:
220        raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
221
222
223def process_output(
224    node: torch.fx.Node,
225    tosa_graph: ts.TosaSerializer,
226):
227    for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
228        tosa_graph.addOutputTensor(
229            tosa_graph.currRegion.currBasicBlock.tensors[output.name]
230        )
231