1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // A simple llama2/3 runner that includes preprocessing and post processing
10 // logic. The module takes in a string as input and emits a string as output.
11
12 #if defined(QAIHUB_LLAMA3_RUNNER)
13 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
14 #else
15 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16 #endif
17 #include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.h>
18 #include <executorch/extension/evalue_util/print_evalue.h>
19 #include <executorch/extension/llm/runner/util.h>
20 #include <executorch/runtime/core/exec_aten/exec_aten.h>
21 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
22 #include <executorch/runtime/platform/log.h>
23
24 #include <ctime>
25 #include <memory>
26 #include <sstream>
27
28 #if defined(__aarch64__)
29 #include "arm_neon.h"
30 #endif
31
32 using executorch::aten::Tensor;
33 using executorch::extension::Module;
34 using executorch::extension::llm::Sampler;
35 using executorch::extension::llm::time_in_ms;
36 using executorch::runtime::Error;
37 using executorch::runtime::EValue;
38 using executorch::runtime::MethodMeta;
39 using executorch::runtime::Result;
40
41 namespace example {
42
43 namespace {
44 static constexpr auto kTopp = 0.9f;
45 void printReport(const Runner::Stats& stats);
46 std::string statsToJsonString(const Runner::Stats& stats);
47 } // namespace
48
Runner(const std::vector<std::string> & models_path,const std::vector<std::string> & pos_embs_path,const std::vector<int> & shard_layers,const std::string & tokenizer_path,const int eval_mode,const float temperature,const float logits_scale,const int logits_offset)49 Runner::Runner(
50 const std::vector<std::string>& models_path,
51 const std::vector<std::string>& pos_embs_path,
52 const std::vector<int>& shard_layers,
53 const std::string& tokenizer_path,
54 const int eval_mode,
55 const float temperature,
56 const float logits_scale,
57 const int logits_offset)
58 : tokenizer_path_(tokenizer_path),
59 temperature_(temperature),
60 n_bos_(1),
61 n_eos_(1),
62 vocab_size_(QAIHUB_LLAMA_LOGITS),
63 max_seq_len_(1024),
64 eval_mode_(eval_mode),
65 stats_({}),
66 logits_scale_(logits_scale),
67 logits_offset_(logits_offset) {
68 for (size_t i = 0; i < models_path.size(); ++i) {
69 modules_.push_back(std::make_shared<Module>(
70 models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
71 ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str());
72 }
73 ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
74
75 // load tokenizer
76 #if defined(QAIHUB_LLAMA3_RUNNER)
77 tokenizer_ = example::get_tiktoken_for_llama();
78 tokenizer_->load(tokenizer_path_);
79 eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
80 version_ = LlamaVersion::kLlama3;
81 #else
82 tokenizer_ = std::make_unique<executorch::extension::llm::BPETokenizer>();
83 tokenizer_->load(tokenizer_path_);
84 version_ = LlamaVersion::kLlama2;
85 #endif
86
87 bos_id_ = tokenizer_->bos_tok();
88 eos_id_.insert(tokenizer_->eos_tok());
89
90 switch (eval_mode_) {
91 case EvalMode::kBert:
92 io_mem_ =
93 std::make_unique<BertMemory>(pos_embs_path, modules_, shard_layers);
94 break;
95 case EvalMode::kKVCached:
96 io_mem_ = std::make_unique<KVCachedMemory>(
97 pos_embs_path, modules_, shard_layers);
98 break;
99 default:
100 ET_CHECK_MSG(false, "unsupported evaluation mode");
101 }
102 ET_LOG(Info, "creating io_memory");
103 }
104
is_loaded() const105 bool Runner::is_loaded() const {
106 bool loaded = true;
107 for (const std::shared_ptr<Module>& module : modules_) {
108 loaded &= module->is_loaded();
109 }
110 return loaded && tokenizer_ && sampler_;
111 }
112
load()113 Error Runner::load() {
114 if (is_loaded()) {
115 return Error::Ok;
116 }
117 for (std::shared_ptr<Module>& module : modules_) {
118 method_names_.emplace_back(*module->method_names()->begin());
119 ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(method_names_.back()));
120 }
121
122 // create sampler
123 sampler_ = std::make_unique<Sampler>(
124 vocab_size_,
125 temperature_,
126 kTopp,
127 static_cast<unsigned long long>(std::time(nullptr)));
128
129 // prepare io
130 auto methods_meta = get_methods_meta();
131 io_mem_->prepare_io(methods_meta);
132 return Error::Ok;
133 }
134
logitsToToken(const Tensor & logits_tensor)135 int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
136 static std::vector<float> logits_f(vocab_size_);
137 const uint16_t* logits = logits_tensor.data_ptr<uint16_t>();
138
139 #if defined(__aarch64__)
140 static int32x4_t offset = vmovq_n_s32(logits_offset_);
141 static float32x4_t scale = vmovq_n_f32(logits_scale_);
142 // dequantize
143 for (int i = 0; i < vocab_size_; i += 4) {
144 const uint16_t* in = logits + i;
145 float* out = logits_f.data() + i;
146 int32_t data[4] = {in[0], in[1], in[2], in[3]};
147 int32x4_t quantized = vld1q_s32(data);
148 int32x4_t shifted = vsubq_s32(quantized, offset);
149 float32x4_t shifted_f = vcvtq_f32_s32(shifted);
150 vst1q_f32(out, vmulq_f32(shifted_f, scale));
151 }
152 #else
153 // dequantize
154 for (int i = 0; i < vocab_size_; i++) {
155 logits_f[i] = (logits[i] - logits_offset_) * logits_scale_;
156 }
157 #endif
158
159 return sampler_->sample(logits_f.data());
160 }
161
run_model_step(std::vector<std::vector<EValue>> & inputs)162 void Runner::run_model_step(std::vector<std::vector<EValue>>& inputs) {
163 for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) {
164 Result<std::vector<EValue>> outputs_res =
165 modules_[i]->execute(method_names_[i], inputs[i]);
166 ET_CHECK_MSG(
167 outputs_res.error() == Error::Ok, "shard %zu inference failed", i);
168 }
169 }
170
171 // TODO: add overloaded method for on-device tokenize
generate(const std::string & prompt,const std::string & system_prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback)172 Error Runner::generate(
173 const std::string& prompt,
174 const std::string& system_prompt,
175 int32_t seq_len,
176 std::function<void(const std::string&)> token_callback,
177 std::function<void(const Stats&)> stats_callback) {
178 ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
179
180 std::vector<std::vector<Tensor>> input_tensors, output_tensors;
181 std::vector<std::vector<EValue>> inputs;
182 if (!is_loaded()) {
183 stats_.model_load_start_ms = time_in_ms();
184 ET_CHECK_OK_OR_RETURN_ERROR(load());
185 for (int i = 0; i < modules_.size(); ++i) {
186 input_tensors.emplace_back(io_mem_->get_input_tensors(i));
187 output_tensors.emplace_back(io_mem_->get_output_tensors(i));
188 for (size_t j = 0; j < output_tensors[i].size(); ++j) {
189 ET_CHECK_MSG(
190 modules_[i]->set_output(
191 method_names_[i], output_tensors[i][j], j) == Error::Ok,
192 "failed to set output tensor for module %d's %zu'th output",
193 i,
194 j);
195 }
196 inputs.emplace_back(
197 std::vector<EValue>(begin(input_tensors[i]), end(input_tensors[i])));
198 }
199 stats_.model_load_end_ms = time_in_ms();
200 }
201
202 stats_.inference_start_ms = time_in_ms();
203 seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
204
205 std::string post_process_prompt;
206 switch (version_) {
207 case LlamaVersion::kLlama2:
208 post_process_prompt.append(prompt);
209 break;
210 case LlamaVersion::kLlama3:
211 if (!system_prompt.empty()) {
212 post_process_prompt.append(
213 "<|start_header_id|>system<|end_header_id|>\n\n");
214 post_process_prompt.append(system_prompt);
215 post_process_prompt.append("<|eot_id|>\n");
216 }
217 post_process_prompt.append(
218 "<|start_header_id|>user<|end_header_id|>\n\n");
219 post_process_prompt.append(prompt);
220 post_process_prompt.append(
221 "<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
222 // tokenizer_->encode will add <|begin_of_text|> token for us.
223 // For now, do token call back so the output format looks the same as
224 // llama3 model card.
225 if (token_callback && eval_mode_ == EvalMode::kKVCached) {
226 token_callback("<|begin_of_text|>");
227 }
228 break;
229 default:
230 ET_CHECK_MSG(false, "unsupported llama version");
231 break;
232 }
233
234 Result<std::vector<uint64_t>> encode_res =
235 tokenizer_->encode(post_process_prompt, n_bos_, 0);
236 ET_CHECK_OK_OR_RETURN_ERROR(
237 encode_res.error(),
238 "failed to encode prompt %s",
239 post_process_prompt.c_str());
240
241 std::vector<uint64_t> prompt_tokens = encode_res.get();
242 int num_prompt_tokens = prompt_tokens.size();
243 ET_CHECK_MSG(num_prompt_tokens < max_seq_len_, "max seq length exceeded");
244 ET_CHECK_MSG(
245 num_prompt_tokens < seq_len,
246 "sequence length exceeded - please increase the seq_len value");
247
248 int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
249 if (eval_mode_ == EvalMode::kBert) {
250 BertMemory::IO* ptr =
251 static_cast<BertMemory::IO*>(io_mem_->get_mutable_ptr());
252
253 int start_index = max_seq_len_ - num_prompt_tokens;
254 // indices are filled from behind, take 3 tokens as an example:
255 // > tokens : [...tok_pad, tok_bos, tok1, tok2]
256 // > indices: [0.....1020, 1021, 1022, 1023]
257 for (int i = 0; i < num_prompt_tokens; i++) {
258 ptr->input_ids[start_index + i] = static_cast<int32_t>(prompt_tokens[i]);
259 }
260 // causal attention mask is filled as following:
261 // 0, 65535 maps to -100.0, 0.0 after dequantizing
262 // 0 : [0,...................0, 0, 0, 0]
263 // 1-1019 : ...
264 // 1020 : [0,...............65535, 0, 0, 0]
265 // 1021 : [0,...............65535, 65535, 0, 0]
266 // 1022 : [0,...............65535, 65535, 65535, 0]
267 // 1023 : [0,...............65535, 65535, 65535, 65535]
268 for (int i = max_seq_len_ - 1, len = num_prompt_tokens; len >= 0;
269 --i, --len) {
270 for (int j = 0; j <= len; ++j) {
271 ptr->attention_mask[i * max_seq_len_ + start_index - 1 + j] = 65535;
272 }
273 }
274 pos = num_prompt_tokens - 1;
275 cur_token = prompt_tokens[pos];
276 } else if (eval_mode_ == EvalMode::kKVCached) {
277 KVCachedMemory::IO* ptr =
278 static_cast<KVCachedMemory::IO*>(io_mem_->get_mutable_ptr());
279 ptr->input_ids = static_cast<int32_t>(cur_token);
280 ptr->attention_mask[max_seq_len_ - 1] = 65535;
281 }
282
283 while (pos < seq_len - 1) {
284 // inference
285 run_model_step(inputs);
286 Tensor& logits_tensor = output_tensors.back().back();
287
288 if (pos == num_prompt_tokens) {
289 stats_.first_token_ms = time_in_ms();
290 } else if (pos == num_prompt_tokens - 1) {
291 stats_.prompt_eval_end_ms = time_in_ms();
292 }
293
294 long sample_start_time_ms = time_in_ms();
295 prev_token = cur_token;
296 cur_token = logitsToToken(logits_tensor);
297 stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
298
299 if (pos < num_prompt_tokens - 1) {
300 cur_token = prompt_tokens[pos + 1];
301 }
302 io_mem_->update_io(cur_token, ++pos, output_tensors);
303
304 auto piece_res = tokenizer_->decode(prev_token, cur_token);
305 ET_CHECK(piece_res.ok());
306
307 if (token_callback) {
308 token_callback(piece_res.get().c_str());
309 }
310
311 if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
312 ET_LOG(Info, "\nReached to the end of generation");
313 break;
314 }
315 }
316 stats_.inference_end_ms = time_in_ms();
317
318 if (pos == seq_len) {
319 ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len);
320 }
321
322 stats_.num_prompt_tokens = num_prompt_tokens;
323 stats_.num_generated_tokens = pos - num_prompt_tokens;
324 printReport(stats_);
325 if (stats_callback) {
326 stats_callback(stats_);
327 }
328
329 return Error::Ok;
330 }
331
332 namespace {
printReport(const Runner::Stats & stats)333 void printReport(const Runner::Stats& stats) {
334 printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
335
336 ET_LOG(
337 Info,
338 "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
339 stats.num_prompt_tokens,
340 stats.num_generated_tokens);
341
342 ET_LOG(
343 Info,
344 "\tModel Load Time:\t\t%f (seconds)",
345 ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
346 stats.SCALING_FACTOR_UNITS_PER_SECOND));
347 double inference_time_ms =
348 (double)(stats.inference_end_ms - stats.inference_start_ms);
349 ET_LOG(
350 Info,
351 "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
352 inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
353
354 (stats.num_generated_tokens) /
355 (double)(stats.inference_end_ms - stats.inference_start_ms) *
356 stats.SCALING_FACTOR_UNITS_PER_SECOND);
357 double prompt_eval_time =
358 (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
359 ET_LOG(
360 Info,
361 "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
362 prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
363 (stats.num_prompt_tokens) / prompt_eval_time *
364 stats.SCALING_FACTOR_UNITS_PER_SECOND);
365
366 double eval_time =
367 (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
368 ET_LOG(
369 Info,
370 "\t\tGenerated %" PRIu64
371 " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
372 stats.num_generated_tokens,
373 eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
374 stats.num_generated_tokens / eval_time *
375 stats.SCALING_FACTOR_UNITS_PER_SECOND);
376
377 // Time to first token is measured from the start of inference, excluding
378 // model load time.
379 ET_LOG(
380 Info,
381 "\tTime to first generated token:\t%f (seconds)",
382 ((double)(stats.first_token_ms - stats.inference_start_ms) /
383 stats.SCALING_FACTOR_UNITS_PER_SECOND));
384
385 ET_LOG(
386 Info,
387 "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
388 stats.num_prompt_tokens + stats.num_generated_tokens,
389 (double)stats.aggregate_sampling_time_ms /
390 stats.SCALING_FACTOR_UNITS_PER_SECOND);
391 }
392
statsToJsonString(const Runner::Stats & stats)393 std::string statsToJsonString(const Runner::Stats& stats) {
394 std::stringstream ss;
395 ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
396 << "\"generated_tokens\":" << stats.num_generated_tokens << ","
397 << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
398 << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
399 << "\"inference_start_ms\":" << stats.inference_start_ms << ","
400 << "\"inference_end_ms\":" << stats.inference_end_ms << ","
401 << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
402 << "\"first_token_ms\":" << stats.first_token_ms << ","
403 << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
404 << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
405 << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
406 return ss.str();
407 }
408 } // namespace
409
get_methods_meta()410 std::vector<Result<MethodMeta>> Runner::get_methods_meta() {
411 std::vector<Result<MethodMeta>> methods_meta;
412 methods_meta.reserve(modules_.size());
413 for (size_t i = 0; i < modules_.size(); ++i) {
414 methods_meta.emplace_back(modules_[i]->method_meta(method_names_[i]));
415 }
416 return methods_meta;
417 }
418 } // namespace example
419