1# Copyright (c) 2024 MediaTek Inc. 2# 3# Licensed under the BSD License (the "License"); you may not use this file 4# except in compliance with the License. See the license file in the root 5# directory of this source tree for more details. 6 7import contextlib 8import struct 9 10from typing import final, List 11 12import mtk_converter 13import mtk_neuron 14import torch 15from executorch.exir.backend.backend_details import ( 16 BackendDetails, 17 ExportedProgram, 18 PreprocessResult, 19) 20from executorch.exir.backend.compile_spec_schema import CompileSpec 21 22SKIP_COMPILE_SPEC_KEYS = {"ImportForever"} 23 24 25@final 26class NeuropilotBackend(BackendDetails): 27 28 @classmethod 29 def preprocess( 30 cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec] 31 ) -> PreprocessResult: 32 33 name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes} 34 input_names = edge_program.graph_signature.user_inputs 35 output_names = edge_program.graph_signature.user_outputs 36 fp_input_indices = [ 37 idx 38 for idx, name in enumerate(input_names) 39 if name_to_node_mappings[name].meta["val"].dtype == torch.float32 40 ] 41 fp_output_indices = [ 42 idx 43 for idx, name in enumerate(output_names) 44 if name_to_node_mappings[name].meta["val"].dtype == torch.float32 45 ] 46 47 # This default compile options are only for mt6989 SOC 48 compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"] 49 for spec in module_compile_spec: 50 if spec.key in SKIP_COMPILE_SPEC_KEYS: 51 continue 52 if spec.value == b"": 53 compile_options.append(f"--{spec.key}") 54 else: 55 value = spec.value.decode("utf-8") 56 compile_options.append(f"--{spec.key}={value}") 57 58 converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program) 59 converter.quantize = True 60 converter.input_quantization_bitwidths = None 61 converter.allow_missing_quantization_ranges = True 62 converter.prepend_input_quantize_ops = True 63 converter.prepend_input_quantize_ops_indices = fp_input_indices 64 converter.append_output_dequantize_ops = True 65 converter.append_output_dequantize_ops_indices = fp_output_indices 66 with contextlib.redirect_stdout(None): 67 mlir_str = converter.convert_to_mlir() 68 model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options)) 69 70 num_inputs = len(input_names) 71 num_outputs = len(output_names) 72 header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes)) 73 return PreprocessResult(processed_bytes=bytes(header + model_bytes)) 74