1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import torch 12from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 13from executorch.backends.qualcomm.utils.utils import ( 14 ExecutorchBackendConfig, 15 from_context_binary, 16 generate_htp_compiler_spec, 17 generate_qnn_executorch_compiler_spec, 18 get_soc_to_chipset_map, 19) 20from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( 21 gen_pte_from_ctx_bin, 22 get_encoding, 23) 24from executorch.examples.qualcomm.utils import ( 25 setup_common_args_and_variables, 26 SimpleADB, 27) 28from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 29 30 31def main(args): 32 os.makedirs(args.artifact, exist_ok=True) 33 34 target_names = ( 35 [ 36 f"llama_v2_7b_chat_quantized_PromptProcessor_{i}_Quantized.bin" 37 for i in range(1, 5) 38 ] 39 if args.use_prompt_processor 40 else [ 41 f"llama_v2_7b_chat_quantized_TokenGenerator_{i}_Quantized.bin" 42 for i in range(1, 5) 43 ] 44 ) 45 46 # common part for compile & inference 47 backend_options = generate_htp_compiler_spec( 48 use_fp16=False, 49 use_multi_contexts=True, 50 ) 51 compiler_specs = generate_qnn_executorch_compiler_spec( 52 soc_model=getattr(QcomChipset, args.model), 53 backend_options=backend_options, 54 is_from_context_binary=True, 55 ) 56 57 if args.use_prompt_processor: 58 pte_name = "qaihub_llama2_7b_prompt" 59 last_shard_num_inputs = 4 60 last_shard_num_outputs = 513 61 else: 62 pte_name = "qaihub_llama2_7b_token" 63 last_shard_num_inputs = 516 64 last_shard_num_outputs = 513 65 66 if args.pre_gen_pte is None: 67 # create custom operators as context loader 68 soc_model = get_soc_to_chipset_map()[args.model] 69 bundle_programs = [ 70 from_context_binary( 71 ctx_path=f"{args.context_binaries}/{target}", 72 op_name=f"ctx_loader_{i}", 73 soc_model=soc_model, 74 ) 75 for i, target in enumerate(target_names) 76 ] 77 pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))] 78 memory_planning_pass = MemoryPlanningPass( 79 alloc_graph_input=False, 80 alloc_graph_output=False, 81 ) 82 pte_files = gen_pte_from_ctx_bin( 83 artifact=args.artifact, 84 pte_names=pte_names, 85 bundle_programs=bundle_programs, 86 backend_config=ExecutorchBackendConfig( 87 memory_planning_pass=memory_planning_pass 88 ), 89 ) 90 else: 91 pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)] 92 93 if args.compile_only: 94 return 95 96 adb = SimpleADB( 97 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 98 build_path=args.build_folder, 99 pte_path=pte_files, 100 workspace=f"/data/local/tmp/executorch/{pte_name}", 101 device_id=args.device, 102 host_id=args.host, 103 soc_model=args.model, 104 runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama2_7b_runner", 105 ) 106 output_file = "result.txt" 107 pos_embs_file = ["freq_cos", "freq_sin"] 108 encoding = get_encoding( 109 path_to_shard=f"{args.context_binaries}/{target_names[-1]}", 110 compiler_specs=compiler_specs, 111 get_input=False, 112 get_output=True, 113 num_input=last_shard_num_inputs, 114 num_output=last_shard_num_outputs, 115 )[0] 116 scale = encoding["scale"][-1] 117 offset = encoding["offset"][-1] 118 outputs = [] 119 runner_args = [ 120 *[ 121 f"--sharded_{i+1}_path {os.path.basename(pte_file)}" 122 for i, pte_file in enumerate(pte_files) 123 ], 124 *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file], 125 f"--output_path {adb.output_folder}/{output_file}", 126 f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", 127 f"--prompt '{args.prompt}'", 128 f"--temperature {args.temperature}", 129 f"--seq_len {args.seq_len}", 130 f"--eval_mode {0 if args.use_prompt_processor else 1}", 131 f"--logits_scale {scale}", 132 f"--logits_offset {-offset}", 133 ] 134 runner_cmds = " ".join( 135 [ 136 f"cd {adb.workspace} &&", 137 f"./qaihub_llama2_7b_runner {' '.join(runner_args)}", 138 ] 139 ) 140 141 def compute_pos_embedding(): 142 head_dim, max_seq_len, theta = 128, 1024, 10000.0 143 base = torch.arange(0, head_dim, 2) 144 freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim)) 145 t = torch.arange(max_seq_len * 2) 146 freqs = torch.outer(t, freqs).float() 147 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 148 freqs_cis = freqs_cis[0:max_seq_len] 149 freqs_real = torch.view_as_real(freqs_cis) 150 return freqs_real[:, :, 0], freqs_real[:, :, 1] 151 152 def post_process(): 153 with open(f"{args.artifact}/outputs/{output_file}", "r") as f: 154 outputs.append(f.read()) 155 156 custom_files = [args.tokenizer_bin] 157 for var_name, freq in zip(pos_embs_file, compute_pos_embedding()): 158 custom_files.append(f"{adb.working_dir}/{var_name}.raw") 159 scale, offset = (freq.max() - freq.min()) / 65535, 32768 160 freq = (freq / scale + offset).clip(min=0, max=65535).detach() 161 freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) 162 163 if not args.skip_push: 164 adb.push(files=custom_files) 165 adb.execute(custom_runner_cmd=runner_cmds) 166 adb.pull(args.artifact, callback=post_process) 167 if args.ip and args.port != -1: 168 with Client((args.ip, args.port)) as conn: 169 conn.send( 170 json.dumps( 171 { 172 "result": outputs[0], 173 } 174 ) 175 ) 176 else: 177 print(outputs[0]) 178 179 180if __name__ == "__main__": 181 parser = setup_common_args_and_variables() 182 183 parser.add_argument( 184 "-a", 185 "--artifact", 186 help="path for storing generated artifacts by this example. Default ./llama2_qai_hub", 187 default="./llama2_qai_hub", 188 type=str, 189 ) 190 191 parser.add_argument( 192 "--context_binaries", 193 help="path to context binaries generated from qai_hub", 194 required=True, 195 ) 196 197 parser.add_argument( 198 "--use_prompt_processor", 199 help="tokens will be evaluated all at once", 200 default=False, 201 action="store_true", 202 ) 203 204 parser.add_argument( 205 "--tokenizer_bin", 206 help="llama2 tokenizer binary", 207 required=True, 208 type=str, 209 ) 210 211 parser.add_argument( 212 "--seq_len", 213 help="ouput sequence length for llama2", 214 default=128, 215 type=int, 216 ) 217 218 parser.add_argument( 219 "--temperature", 220 help="sampling temperature for llama2", 221 default=0.0, 222 type=float, 223 ) 224 225 parser.add_argument( 226 "--prompt", 227 help="user prompts for llama2", 228 required=True, 229 type=str, 230 ) 231 232 parser.add_argument( 233 "--pre_gen_pte", 234 help="folder path to pre-compiled ptes", 235 default=None, 236 type=str, 237 ) 238 239 args = parser.parse_args() 240 241 try: 242 main(args) 243 except Exception as e: 244 if args.ip and args.port != -1: 245 with Client((args.ip, args.port)) as conn: 246 conn.send(json.dumps({"Error": str(e)})) 247 else: 248 raise Exception(e) 249