1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import torch 12from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 13 14from executorch.backends.qualcomm.utils.utils import ( 15 ExecutorchBackendConfig, 16 from_context_binary, 17 generate_htp_compiler_spec, 18 generate_qnn_executorch_compiler_spec, 19 get_soc_to_chipset_map, 20) 21from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( 22 gen_pte_from_ctx_bin, 23 get_encoding, 24) 25from executorch.examples.qualcomm.utils import ( 26 setup_common_args_and_variables, 27 SimpleADB, 28) 29from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 30 31 32def main(args): 33 os.makedirs(args.artifact, exist_ok=True) 34 35 target_names = ( 36 [ 37 f"llama_v3_8b_chat_quantized_PromptProcessor_{i}_Quantized.bin" 38 for i in range(1, 6) 39 ] 40 if args.use_prompt_processor 41 else [ 42 f"llama_v3_8b_chat_quantized_TokenGenerator_{i}_Quantized.bin" 43 for i in range(1, 6) 44 ] 45 ) 46 47 # common part for compile & inference 48 backend_options = generate_htp_compiler_spec( 49 use_fp16=False, 50 use_multi_contexts=True, 51 ) 52 compiler_specs = generate_qnn_executorch_compiler_spec( 53 soc_model=getattr(QcomChipset, args.model), 54 backend_options=backend_options, 55 is_from_context_binary=True, 56 ) 57 58 if args.use_prompt_processor: 59 pte_name = "qaihub_llama3_8b_prompt" 60 last_shard_num_inputs = 4 61 last_shard_num_outputs = 65 62 else: 63 pte_name = "qaihub_llama3_8b_token" 64 last_shard_num_inputs = 68 65 last_shard_num_outputs = 65 66 67 if args.pre_gen_pte is None: 68 # create custom operators as context loader 69 soc_model = get_soc_to_chipset_map()[args.model] 70 bundle_programs = [ 71 from_context_binary( 72 ctx_path=f"{args.context_binaries}/{target}", 73 op_name=f"ctx_loader_{i}", 74 soc_model=soc_model, 75 ) 76 for i, target in enumerate(target_names) 77 ] 78 pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))] 79 memory_planning_pass = MemoryPlanningPass( 80 alloc_graph_input=False, 81 alloc_graph_output=False, 82 ) 83 pte_files = gen_pte_from_ctx_bin( 84 artifact=args.artifact, 85 pte_names=pte_names, 86 bundle_programs=bundle_programs, 87 backend_config=ExecutorchBackendConfig( 88 memory_planning_pass=memory_planning_pass 89 ), 90 ) 91 else: 92 pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(5)] 93 94 if args.compile_only: 95 return 96 97 adb = SimpleADB( 98 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 99 build_path=args.build_folder, 100 pte_path=pte_files, 101 workspace=f"/data/local/tmp/executorch/{pte_name}", 102 device_id=args.device, 103 host_id=args.host, 104 soc_model=args.model, 105 runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama3_8b_runner", 106 ) 107 output_file = "result.txt" 108 pos_embs_file = ["freq_cos", "freq_sin"] 109 110 encoding = get_encoding( 111 path_to_shard=f"{args.context_binaries}/{target_names[-1]}", 112 compiler_specs=compiler_specs, 113 get_input=False, 114 get_output=True, 115 num_input=last_shard_num_inputs, 116 num_output=last_shard_num_outputs, 117 )[0] 118 scale = encoding["scale"][-1] 119 offset = encoding["offset"][-1] 120 outputs = [] 121 runner_args = [ 122 *[ 123 f"--sharded_{i+1}_path {os.path.basename(pte_file)}" 124 for i, pte_file in enumerate(pte_files) 125 ], 126 *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file], 127 f"--output_path {adb.output_folder}/{output_file}", 128 f"--tokenizer_path {os.path.basename(args.tokenizer_model)}", 129 f"--prompt '{args.prompt}'", 130 f"--temperature {args.temperature}", 131 f"--seq_len {args.seq_len}", 132 f"--eval_mode {0 if args.use_prompt_processor else 1}", 133 f"--logits_scale {scale}", 134 f"--logits_offset {-offset}", 135 f"--system_prompt '{args.system_prompt}'", 136 ] 137 runner_cmds = " ".join( 138 [ 139 f"cd {adb.workspace} &&", 140 f"./qaihub_llama3_8b_runner {' '.join(runner_args)}", 141 ] 142 ) 143 144 def compute_pos_embedding(): 145 head_dim, max_seq_len, theta = 128, 1024, 10000.0 146 base = torch.arange(0, head_dim, 2) 147 freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim)) 148 t = torch.arange(max_seq_len * 2) 149 freqs = torch.outer(t, freqs).float() 150 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 151 freqs_cis = freqs_cis[0:max_seq_len] 152 freqs_real = torch.view_as_real(freqs_cis) 153 return freqs_real[:, :, 0], freqs_real[:, :, 1] 154 155 def post_process(): 156 with open(f"{args.artifact}/outputs/{output_file}", "r") as f: 157 outputs.append(f.read()) 158 159 custom_files = [args.tokenizer_model] 160 for var_name, freq in zip(pos_embs_file, compute_pos_embedding()): 161 custom_files.append(f"{adb.working_dir}/{var_name}.raw") 162 scale, offset = (freq.max() - freq.min()) / 65535, 32768 163 freq = (freq / scale + offset).clip(min=0, max=65535).detach() 164 freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) 165 166 if not args.skip_push: 167 adb.push(files=custom_files) 168 adb.execute(custom_runner_cmd=runner_cmds) 169 adb.pull(args.artifact, callback=post_process) 170 if args.ip and args.port != -1: 171 with Client((args.ip, args.port)) as conn: 172 conn.send( 173 json.dumps( 174 { 175 "result": outputs[0], 176 } 177 ) 178 ) 179 else: 180 print(outputs[0]) 181 182 183if __name__ == "__main__": 184 parser = setup_common_args_and_variables() 185 186 parser.add_argument( 187 "-a", 188 "--artifact", 189 help="path for storing generated artifacts by this example. Default ./llama3_qai_hub", 190 default="./llama3_qai_hub", 191 type=str, 192 ) 193 194 parser.add_argument( 195 "--context_binaries", 196 help="path to context binaries generated from qai_hub", 197 required=True, 198 ) 199 200 parser.add_argument( 201 "--use_prompt_processor", 202 help="tokens will be evaluated all at once", 203 default=False, 204 action="store_true", 205 ) 206 207 parser.add_argument( 208 "--tokenizer_model", 209 help="llama3 tokenizer model", 210 required=True, 211 type=str, 212 ) 213 214 parser.add_argument( 215 "--seq_len", 216 help="ouput sequence length for llama3", 217 default=128, 218 type=int, 219 ) 220 221 parser.add_argument( 222 "--temperature", 223 help="sampling temperature for llama3", 224 default=0.0, 225 type=float, 226 ) 227 228 parser.add_argument( 229 "--prompt", 230 help="user prompts for llama3", 231 required=True, 232 type=str, 233 ) 234 235 parser.add_argument( 236 "--system_prompt", 237 help="Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", 238 default="", 239 type=str, 240 ) 241 242 parser.add_argument( 243 "--pre_gen_pte", 244 help="folder path to pre-compiled ptes", 245 default=None, 246 type=str, 247 ) 248 249 args = parser.parse_args() 250 251 try: 252 main(args) 253 except Exception as e: 254 if args.ip and args.port != -1: 255 with Client((args.ip, args.port)) as conn: 256 conn.send(json.dumps({"Error": str(e)})) 257 else: 258 raise Exception(e) 259