1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file 4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root 5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport contextlib 8*523fa7a6SAndroid Build Coastguard Workerimport struct 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import final, List 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport mtk_converter 13*523fa7a6SAndroid Build Coastguard Workerimport mtk_neuron 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import ( 16*523fa7a6SAndroid Build Coastguard Worker BackendDetails, 17*523fa7a6SAndroid Build Coastguard Worker ExportedProgram, 18*523fa7a6SAndroid Build Coastguard Worker PreprocessResult, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard WorkerSKIP_COMPILE_SPEC_KEYS = {"ImportForever"} 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker@final 26*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotBackend(BackendDetails): 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker @classmethod 29*523fa7a6SAndroid Build Coastguard Worker def preprocess( 30*523fa7a6SAndroid Build Coastguard Worker cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec] 31*523fa7a6SAndroid Build Coastguard Worker ) -> PreprocessResult: 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes} 34*523fa7a6SAndroid Build Coastguard Worker input_names = edge_program.graph_signature.user_inputs 35*523fa7a6SAndroid Build Coastguard Worker output_names = edge_program.graph_signature.user_outputs 36*523fa7a6SAndroid Build Coastguard Worker fp_input_indices = [ 37*523fa7a6SAndroid Build Coastguard Worker idx 38*523fa7a6SAndroid Build Coastguard Worker for idx, name in enumerate(input_names) 39*523fa7a6SAndroid Build Coastguard Worker if name_to_node_mappings[name].meta["val"].dtype == torch.float32 40*523fa7a6SAndroid Build Coastguard Worker ] 41*523fa7a6SAndroid Build Coastguard Worker fp_output_indices = [ 42*523fa7a6SAndroid Build Coastguard Worker idx 43*523fa7a6SAndroid Build Coastguard Worker for idx, name in enumerate(output_names) 44*523fa7a6SAndroid Build Coastguard Worker if name_to_node_mappings[name].meta["val"].dtype == torch.float32 45*523fa7a6SAndroid Build Coastguard Worker ] 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker # This default compile options are only for mt6989 SOC 48*523fa7a6SAndroid Build Coastguard Worker compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"] 49*523fa7a6SAndroid Build Coastguard Worker for spec in module_compile_spec: 50*523fa7a6SAndroid Build Coastguard Worker if spec.key in SKIP_COMPILE_SPEC_KEYS: 51*523fa7a6SAndroid Build Coastguard Worker continue 52*523fa7a6SAndroid Build Coastguard Worker if spec.value == b"": 53*523fa7a6SAndroid Build Coastguard Worker compile_options.append(f"--{spec.key}") 54*523fa7a6SAndroid Build Coastguard Worker else: 55*523fa7a6SAndroid Build Coastguard Worker value = spec.value.decode("utf-8") 56*523fa7a6SAndroid Build Coastguard Worker compile_options.append(f"--{spec.key}={value}") 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program) 59*523fa7a6SAndroid Build Coastguard Worker converter.quantize = True 60*523fa7a6SAndroid Build Coastguard Worker converter.input_quantization_bitwidths = None 61*523fa7a6SAndroid Build Coastguard Worker converter.allow_missing_quantization_ranges = True 62*523fa7a6SAndroid Build Coastguard Worker converter.prepend_input_quantize_ops = True 63*523fa7a6SAndroid Build Coastguard Worker converter.prepend_input_quantize_ops_indices = fp_input_indices 64*523fa7a6SAndroid Build Coastguard Worker converter.append_output_dequantize_ops = True 65*523fa7a6SAndroid Build Coastguard Worker converter.append_output_dequantize_ops_indices = fp_output_indices 66*523fa7a6SAndroid Build Coastguard Worker with contextlib.redirect_stdout(None): 67*523fa7a6SAndroid Build Coastguard Worker mlir_str = converter.convert_to_mlir() 68*523fa7a6SAndroid Build Coastguard Worker model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options)) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker num_inputs = len(input_names) 71*523fa7a6SAndroid Build Coastguard Worker num_outputs = len(output_names) 72*523fa7a6SAndroid Build Coastguard Worker header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes)) 73*523fa7a6SAndroid Build Coastguard Worker return PreprocessResult(processed_bytes=bytes(header + model_bytes)) 74