xref: /aosp_15_r20/external/executorch/backends/mediatek/preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file
4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root
5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport contextlib
8*523fa7a6SAndroid Build Coastguard Workerimport struct
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import final, List
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport mtk_converter
13*523fa7a6SAndroid Build Coastguard Workerimport mtk_neuron
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import (
16*523fa7a6SAndroid Build Coastguard Worker    BackendDetails,
17*523fa7a6SAndroid Build Coastguard Worker    ExportedProgram,
18*523fa7a6SAndroid Build Coastguard Worker    PreprocessResult,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard WorkerSKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker@final
26*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotBackend(BackendDetails):
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker    @classmethod
29*523fa7a6SAndroid Build Coastguard Worker    def preprocess(
30*523fa7a6SAndroid Build Coastguard Worker        cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
31*523fa7a6SAndroid Build Coastguard Worker    ) -> PreprocessResult:
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker        name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
34*523fa7a6SAndroid Build Coastguard Worker        input_names = edge_program.graph_signature.user_inputs
35*523fa7a6SAndroid Build Coastguard Worker        output_names = edge_program.graph_signature.user_outputs
36*523fa7a6SAndroid Build Coastguard Worker        fp_input_indices = [
37*523fa7a6SAndroid Build Coastguard Worker            idx
38*523fa7a6SAndroid Build Coastguard Worker            for idx, name in enumerate(input_names)
39*523fa7a6SAndroid Build Coastguard Worker            if name_to_node_mappings[name].meta["val"].dtype == torch.float32
40*523fa7a6SAndroid Build Coastguard Worker        ]
41*523fa7a6SAndroid Build Coastguard Worker        fp_output_indices = [
42*523fa7a6SAndroid Build Coastguard Worker            idx
43*523fa7a6SAndroid Build Coastguard Worker            for idx, name in enumerate(output_names)
44*523fa7a6SAndroid Build Coastguard Worker            if name_to_node_mappings[name].meta["val"].dtype == torch.float32
45*523fa7a6SAndroid Build Coastguard Worker        ]
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        # This default compile options are only for mt6989 SOC
48*523fa7a6SAndroid Build Coastguard Worker        compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"]
49*523fa7a6SAndroid Build Coastguard Worker        for spec in module_compile_spec:
50*523fa7a6SAndroid Build Coastguard Worker            if spec.key in SKIP_COMPILE_SPEC_KEYS:
51*523fa7a6SAndroid Build Coastguard Worker                continue
52*523fa7a6SAndroid Build Coastguard Worker            if spec.value == b"":
53*523fa7a6SAndroid Build Coastguard Worker                compile_options.append(f"--{spec.key}")
54*523fa7a6SAndroid Build Coastguard Worker            else:
55*523fa7a6SAndroid Build Coastguard Worker                value = spec.value.decode("utf-8")
56*523fa7a6SAndroid Build Coastguard Worker                compile_options.append(f"--{spec.key}={value}")
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker        converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program)
59*523fa7a6SAndroid Build Coastguard Worker        converter.quantize = True
60*523fa7a6SAndroid Build Coastguard Worker        converter.input_quantization_bitwidths = None
61*523fa7a6SAndroid Build Coastguard Worker        converter.allow_missing_quantization_ranges = True
62*523fa7a6SAndroid Build Coastguard Worker        converter.prepend_input_quantize_ops = True
63*523fa7a6SAndroid Build Coastguard Worker        converter.prepend_input_quantize_ops_indices = fp_input_indices
64*523fa7a6SAndroid Build Coastguard Worker        converter.append_output_dequantize_ops = True
65*523fa7a6SAndroid Build Coastguard Worker        converter.append_output_dequantize_ops_indices = fp_output_indices
66*523fa7a6SAndroid Build Coastguard Worker        with contextlib.redirect_stdout(None):
67*523fa7a6SAndroid Build Coastguard Worker            mlir_str = converter.convert_to_mlir()
68*523fa7a6SAndroid Build Coastguard Worker            model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options))
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker        num_inputs = len(input_names)
71*523fa7a6SAndroid Build Coastguard Worker        num_outputs = len(output_names)
72*523fa7a6SAndroid Build Coastguard Worker        header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes))
73*523fa7a6SAndroid Build Coastguard Worker        return PreprocessResult(processed_bytes=bytes(header + model_bytes))
74