xref: /aosp_15_r20/external/executorch/backends/mediatek/preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) 2024 MediaTek Inc.
2#
3# Licensed under the BSD License (the "License"); you may not use this file
4# except in compliance with the License. See the license file in the root
5# directory of this source tree for more details.
6
7import contextlib
8import struct
9
10from typing import final, List
11
12import mtk_converter
13import mtk_neuron
14import torch
15from executorch.exir.backend.backend_details import (
16    BackendDetails,
17    ExportedProgram,
18    PreprocessResult,
19)
20from executorch.exir.backend.compile_spec_schema import CompileSpec
21
22SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
23
24
25@final
26class NeuropilotBackend(BackendDetails):
27
28    @classmethod
29    def preprocess(
30        cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
31    ) -> PreprocessResult:
32
33        name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
34        input_names = edge_program.graph_signature.user_inputs
35        output_names = edge_program.graph_signature.user_outputs
36        fp_input_indices = [
37            idx
38            for idx, name in enumerate(input_names)
39            if name_to_node_mappings[name].meta["val"].dtype == torch.float32
40        ]
41        fp_output_indices = [
42            idx
43            for idx, name in enumerate(output_names)
44            if name_to_node_mappings[name].meta["val"].dtype == torch.float32
45        ]
46
47        # This default compile options are only for mt6989 SOC
48        compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"]
49        for spec in module_compile_spec:
50            if spec.key in SKIP_COMPILE_SPEC_KEYS:
51                continue
52            if spec.value == b"":
53                compile_options.append(f"--{spec.key}")
54            else:
55                value = spec.value.decode("utf-8")
56                compile_options.append(f"--{spec.key}={value}")
57
58        converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program)
59        converter.quantize = True
60        converter.input_quantization_bitwidths = None
61        converter.allow_missing_quantization_ranges = True
62        converter.prepend_input_quantize_ops = True
63        converter.prepend_input_quantize_ops_indices = fp_input_indices
64        converter.append_output_dequantize_ops = True
65        converter.append_output_dequantize_ops_indices = fp_output_indices
66        with contextlib.redirect_stdout(None):
67            mlir_str = converter.convert_to_mlir()
68            model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options))
69
70        num_inputs = len(input_names)
71        num_outputs = len(output_names)
72        header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes))
73        return PreprocessResult(processed_bytes=bytes(header + model_bytes))
74