xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import torch
12from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
13
14from executorch.backends.qualcomm.utils.utils import (
15    ExecutorchBackendConfig,
16    from_context_binary,
17    generate_htp_compiler_spec,
18    generate_qnn_executorch_compiler_spec,
19    get_soc_to_chipset_map,
20)
21from executorch.examples.qualcomm.qaihub_scripts.utils.utils import (
22    gen_pte_from_ctx_bin,
23    get_encoding,
24)
25from executorch.examples.qualcomm.utils import (
26    setup_common_args_and_variables,
27    SimpleADB,
28)
29from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
30
31
32def main(args):
33    os.makedirs(args.artifact, exist_ok=True)
34
35    target_names = (
36        [
37            f"llama_v3_8b_chat_quantized_PromptProcessor_{i}_Quantized.bin"
38            for i in range(1, 6)
39        ]
40        if args.use_prompt_processor
41        else [
42            f"llama_v3_8b_chat_quantized_TokenGenerator_{i}_Quantized.bin"
43            for i in range(1, 6)
44        ]
45    )
46
47    # common part for compile & inference
48    backend_options = generate_htp_compiler_spec(
49        use_fp16=False,
50        use_multi_contexts=True,
51    )
52    compiler_specs = generate_qnn_executorch_compiler_spec(
53        soc_model=getattr(QcomChipset, args.model),
54        backend_options=backend_options,
55        is_from_context_binary=True,
56    )
57
58    if args.use_prompt_processor:
59        pte_name = "qaihub_llama3_8b_prompt"
60        last_shard_num_inputs = 4
61        last_shard_num_outputs = 65
62    else:
63        pte_name = "qaihub_llama3_8b_token"
64        last_shard_num_inputs = 68
65        last_shard_num_outputs = 65
66
67    if args.pre_gen_pte is None:
68        # create custom operators as context loader
69        soc_model = get_soc_to_chipset_map()[args.model]
70        bundle_programs = [
71            from_context_binary(
72                ctx_path=f"{args.context_binaries}/{target}",
73                op_name=f"ctx_loader_{i}",
74                soc_model=soc_model,
75            )
76            for i, target in enumerate(target_names)
77        ]
78        pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))]
79        memory_planning_pass = MemoryPlanningPass(
80            alloc_graph_input=False,
81            alloc_graph_output=False,
82        )
83        pte_files = gen_pte_from_ctx_bin(
84            artifact=args.artifact,
85            pte_names=pte_names,
86            bundle_programs=bundle_programs,
87            backend_config=ExecutorchBackendConfig(
88                memory_planning_pass=memory_planning_pass
89            ),
90        )
91    else:
92        pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(5)]
93
94    if args.compile_only:
95        return
96
97    adb = SimpleADB(
98        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
99        build_path=args.build_folder,
100        pte_path=pte_files,
101        workspace=f"/data/local/tmp/executorch/{pte_name}",
102        device_id=args.device,
103        host_id=args.host,
104        soc_model=args.model,
105        runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama3_8b_runner",
106    )
107    output_file = "result.txt"
108    pos_embs_file = ["freq_cos", "freq_sin"]
109
110    encoding = get_encoding(
111        path_to_shard=f"{args.context_binaries}/{target_names[-1]}",
112        compiler_specs=compiler_specs,
113        get_input=False,
114        get_output=True,
115        num_input=last_shard_num_inputs,
116        num_output=last_shard_num_outputs,
117    )[0]
118    scale = encoding["scale"][-1]
119    offset = encoding["offset"][-1]
120    outputs = []
121    runner_args = [
122        *[
123            f"--sharded_{i+1}_path {os.path.basename(pte_file)}"
124            for i, pte_file in enumerate(pte_files)
125        ],
126        *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file],
127        f"--output_path {adb.output_folder}/{output_file}",
128        f"--tokenizer_path {os.path.basename(args.tokenizer_model)}",
129        f"--prompt '{args.prompt}'",
130        f"--temperature {args.temperature}",
131        f"--seq_len {args.seq_len}",
132        f"--eval_mode {0 if args.use_prompt_processor else 1}",
133        f"--logits_scale {scale}",
134        f"--logits_offset {-offset}",
135        f"--system_prompt '{args.system_prompt}'",
136    ]
137    runner_cmds = " ".join(
138        [
139            f"cd {adb.workspace} &&",
140            f"./qaihub_llama3_8b_runner {' '.join(runner_args)}",
141        ]
142    )
143
144    def compute_pos_embedding():
145        head_dim, max_seq_len, theta = 128, 1024, 10000.0
146        base = torch.arange(0, head_dim, 2)
147        freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim))
148        t = torch.arange(max_seq_len * 2)
149        freqs = torch.outer(t, freqs).float()
150        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
151        freqs_cis = freqs_cis[0:max_seq_len]
152        freqs_real = torch.view_as_real(freqs_cis)
153        return freqs_real[:, :, 0], freqs_real[:, :, 1]
154
155    def post_process():
156        with open(f"{args.artifact}/outputs/{output_file}", "r") as f:
157            outputs.append(f.read())
158
159    custom_files = [args.tokenizer_model]
160    for var_name, freq in zip(pos_embs_file, compute_pos_embedding()):
161        custom_files.append(f"{adb.working_dir}/{var_name}.raw")
162        scale, offset = (freq.max() - freq.min()) / 65535, 32768
163        freq = (freq / scale + offset).clip(min=0, max=65535).detach()
164        freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1])
165
166    if not args.skip_push:
167        adb.push(files=custom_files)
168    adb.execute(custom_runner_cmd=runner_cmds)
169    adb.pull(args.artifact, callback=post_process)
170    if args.ip and args.port != -1:
171        with Client((args.ip, args.port)) as conn:
172            conn.send(
173                json.dumps(
174                    {
175                        "result": outputs[0],
176                    }
177                )
178            )
179    else:
180        print(outputs[0])
181
182
183if __name__ == "__main__":
184    parser = setup_common_args_and_variables()
185
186    parser.add_argument(
187        "-a",
188        "--artifact",
189        help="path for storing generated artifacts by this example. Default ./llama3_qai_hub",
190        default="./llama3_qai_hub",
191        type=str,
192    )
193
194    parser.add_argument(
195        "--context_binaries",
196        help="path to context binaries generated from qai_hub",
197        required=True,
198    )
199
200    parser.add_argument(
201        "--use_prompt_processor",
202        help="tokens will be evaluated all at once",
203        default=False,
204        action="store_true",
205    )
206
207    parser.add_argument(
208        "--tokenizer_model",
209        help="llama3 tokenizer model",
210        required=True,
211        type=str,
212    )
213
214    parser.add_argument(
215        "--seq_len",
216        help="ouput sequence length for llama3",
217        default=128,
218        type=int,
219    )
220
221    parser.add_argument(
222        "--temperature",
223        help="sampling temperature for llama3",
224        default=0.0,
225        type=float,
226    )
227
228    parser.add_argument(
229        "--prompt",
230        help="user prompts for llama3",
231        required=True,
232        type=str,
233    )
234
235    parser.add_argument(
236        "--system_prompt",
237        help="Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None",
238        default="",
239        type=str,
240    )
241
242    parser.add_argument(
243        "--pre_gen_pte",
244        help="folder path to pre-compiled ptes",
245        default=None,
246        type=str,
247    )
248
249    args = parser.parse_args()
250
251    try:
252        main(args)
253    except Exception as e:
254        if args.ip and args.port != -1:
255            with Client((args.ip, args.port)) as conn:
256                conn.send(json.dumps({"Error": str(e)}))
257        else:
258            raise Exception(e)
259