xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2/3 runner that includes preprocessing and post processing
10 // logic. The module takes in a string as input and emits a string as output.
11 
12 #if defined(QAIHUB_LLAMA3_RUNNER)
13 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
14 #else
15 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16 #endif
17 #include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.h>
18 #include <executorch/extension/evalue_util/print_evalue.h>
19 #include <executorch/extension/llm/runner/util.h>
20 #include <executorch/runtime/core/exec_aten/exec_aten.h>
21 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
22 #include <executorch/runtime/platform/log.h>
23 
24 #include <ctime>
25 #include <memory>
26 #include <sstream>
27 
28 #if defined(__aarch64__)
29 #include "arm_neon.h"
30 #endif
31 
32 using executorch::aten::Tensor;
33 using executorch::extension::Module;
34 using executorch::extension::llm::Sampler;
35 using executorch::extension::llm::time_in_ms;
36 using executorch::runtime::Error;
37 using executorch::runtime::EValue;
38 using executorch::runtime::MethodMeta;
39 using executorch::runtime::Result;
40 
41 namespace example {
42 
43 namespace {
44 static constexpr auto kTopp = 0.9f;
45 void printReport(const Runner::Stats& stats);
46 std::string statsToJsonString(const Runner::Stats& stats);
47 } // namespace
48 
Runner(const std::vector<std::string> & models_path,const std::vector<std::string> & pos_embs_path,const std::vector<int> & shard_layers,const std::string & tokenizer_path,const int eval_mode,const float temperature,const float logits_scale,const int logits_offset)49 Runner::Runner(
50     const std::vector<std::string>& models_path,
51     const std::vector<std::string>& pos_embs_path,
52     const std::vector<int>& shard_layers,
53     const std::string& tokenizer_path,
54     const int eval_mode,
55     const float temperature,
56     const float logits_scale,
57     const int logits_offset)
58     : tokenizer_path_(tokenizer_path),
59       temperature_(temperature),
60       n_bos_(1),
61       n_eos_(1),
62       vocab_size_(QAIHUB_LLAMA_LOGITS),
63       max_seq_len_(1024),
64       eval_mode_(eval_mode),
65       stats_({}),
66       logits_scale_(logits_scale),
67       logits_offset_(logits_offset) {
68   for (size_t i = 0; i < models_path.size(); ++i) {
69     modules_.push_back(std::make_shared<Module>(
70         models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
71     ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str());
72   }
73   ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
74 
75 // load tokenizer
76 #if defined(QAIHUB_LLAMA3_RUNNER)
77   tokenizer_ = example::get_tiktoken_for_llama();
78   tokenizer_->load(tokenizer_path_);
79   eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
80   version_ = LlamaVersion::kLlama3;
81 #else
82   tokenizer_ = std::make_unique<executorch::extension::llm::BPETokenizer>();
83   tokenizer_->load(tokenizer_path_);
84   version_ = LlamaVersion::kLlama2;
85 #endif
86 
87   bos_id_ = tokenizer_->bos_tok();
88   eos_id_.insert(tokenizer_->eos_tok());
89 
90   switch (eval_mode_) {
91     case EvalMode::kBert:
92       io_mem_ =
93           std::make_unique<BertMemory>(pos_embs_path, modules_, shard_layers);
94       break;
95     case EvalMode::kKVCached:
96       io_mem_ = std::make_unique<KVCachedMemory>(
97           pos_embs_path, modules_, shard_layers);
98       break;
99     default:
100       ET_CHECK_MSG(false, "unsupported evaluation mode");
101   }
102   ET_LOG(Info, "creating io_memory");
103 }
104 
is_loaded() const105 bool Runner::is_loaded() const {
106   bool loaded = true;
107   for (const std::shared_ptr<Module>& module : modules_) {
108     loaded &= module->is_loaded();
109   }
110   return loaded && tokenizer_ && sampler_;
111 }
112 
load()113 Error Runner::load() {
114   if (is_loaded()) {
115     return Error::Ok;
116   }
117   for (std::shared_ptr<Module>& module : modules_) {
118     method_names_.emplace_back(*module->method_names()->begin());
119     ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(method_names_.back()));
120   }
121 
122   // create sampler
123   sampler_ = std::make_unique<Sampler>(
124       vocab_size_,
125       temperature_,
126       kTopp,
127       static_cast<unsigned long long>(std::time(nullptr)));
128 
129   // prepare io
130   auto methods_meta = get_methods_meta();
131   io_mem_->prepare_io(methods_meta);
132   return Error::Ok;
133 }
134 
logitsToToken(const Tensor & logits_tensor)135 int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
136   static std::vector<float> logits_f(vocab_size_);
137   const uint16_t* logits = logits_tensor.data_ptr<uint16_t>();
138 
139 #if defined(__aarch64__)
140   static int32x4_t offset = vmovq_n_s32(logits_offset_);
141   static float32x4_t scale = vmovq_n_f32(logits_scale_);
142   // dequantize
143   for (int i = 0; i < vocab_size_; i += 4) {
144     const uint16_t* in = logits + i;
145     float* out = logits_f.data() + i;
146     int32_t data[4] = {in[0], in[1], in[2], in[3]};
147     int32x4_t quantized = vld1q_s32(data);
148     int32x4_t shifted = vsubq_s32(quantized, offset);
149     float32x4_t shifted_f = vcvtq_f32_s32(shifted);
150     vst1q_f32(out, vmulq_f32(shifted_f, scale));
151   }
152 #else
153   // dequantize
154   for (int i = 0; i < vocab_size_; i++) {
155     logits_f[i] = (logits[i] - logits_offset_) * logits_scale_;
156   }
157 #endif
158 
159   return sampler_->sample(logits_f.data());
160 }
161 
run_model_step(std::vector<std::vector<EValue>> & inputs)162 void Runner::run_model_step(std::vector<std::vector<EValue>>& inputs) {
163   for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) {
164     Result<std::vector<EValue>> outputs_res =
165         modules_[i]->execute(method_names_[i], inputs[i]);
166     ET_CHECK_MSG(
167         outputs_res.error() == Error::Ok, "shard %zu inference failed", i);
168   }
169 }
170 
171 // TODO: add overloaded method for on-device tokenize
generate(const std::string & prompt,const std::string & system_prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback)172 Error Runner::generate(
173     const std::string& prompt,
174     const std::string& system_prompt,
175     int32_t seq_len,
176     std::function<void(const std::string&)> token_callback,
177     std::function<void(const Stats&)> stats_callback) {
178   ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
179 
180   std::vector<std::vector<Tensor>> input_tensors, output_tensors;
181   std::vector<std::vector<EValue>> inputs;
182   if (!is_loaded()) {
183     stats_.model_load_start_ms = time_in_ms();
184     ET_CHECK_OK_OR_RETURN_ERROR(load());
185     for (int i = 0; i < modules_.size(); ++i) {
186       input_tensors.emplace_back(io_mem_->get_input_tensors(i));
187       output_tensors.emplace_back(io_mem_->get_output_tensors(i));
188       for (size_t j = 0; j < output_tensors[i].size(); ++j) {
189         ET_CHECK_MSG(
190             modules_[i]->set_output(
191                 method_names_[i], output_tensors[i][j], j) == Error::Ok,
192             "failed to set output tensor for module %d's %zu'th output",
193             i,
194             j);
195       }
196       inputs.emplace_back(
197           std::vector<EValue>(begin(input_tensors[i]), end(input_tensors[i])));
198     }
199     stats_.model_load_end_ms = time_in_ms();
200   }
201 
202   stats_.inference_start_ms = time_in_ms();
203   seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
204 
205   std::string post_process_prompt;
206   switch (version_) {
207     case LlamaVersion::kLlama2:
208       post_process_prompt.append(prompt);
209       break;
210     case LlamaVersion::kLlama3:
211       if (!system_prompt.empty()) {
212         post_process_prompt.append(
213             "<|start_header_id|>system<|end_header_id|>\n\n");
214         post_process_prompt.append(system_prompt);
215         post_process_prompt.append("<|eot_id|>\n");
216       }
217       post_process_prompt.append(
218           "<|start_header_id|>user<|end_header_id|>\n\n");
219       post_process_prompt.append(prompt);
220       post_process_prompt.append(
221           "<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
222       // tokenizer_->encode will add <|begin_of_text|> token for us.
223       // For now, do token call back so the output format looks the same as
224       // llama3 model card.
225       if (token_callback && eval_mode_ == EvalMode::kKVCached) {
226         token_callback("<|begin_of_text|>");
227       }
228       break;
229     default:
230       ET_CHECK_MSG(false, "unsupported llama version");
231       break;
232   }
233 
234   Result<std::vector<uint64_t>> encode_res =
235       tokenizer_->encode(post_process_prompt, n_bos_, 0);
236   ET_CHECK_OK_OR_RETURN_ERROR(
237       encode_res.error(),
238       "failed to encode prompt %s",
239       post_process_prompt.c_str());
240 
241   std::vector<uint64_t> prompt_tokens = encode_res.get();
242   int num_prompt_tokens = prompt_tokens.size();
243   ET_CHECK_MSG(num_prompt_tokens < max_seq_len_, "max seq length exceeded");
244   ET_CHECK_MSG(
245       num_prompt_tokens < seq_len,
246       "sequence length exceeded - please increase the seq_len value");
247 
248   int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
249   if (eval_mode_ == EvalMode::kBert) {
250     BertMemory::IO* ptr =
251         static_cast<BertMemory::IO*>(io_mem_->get_mutable_ptr());
252 
253     int start_index = max_seq_len_ - num_prompt_tokens;
254     // indices are filled from behind, take 3 tokens as an example:
255     // > tokens : [...tok_pad, tok_bos, tok1, tok2]
256     // > indices: [0.....1020, 1021,    1022, 1023]
257     for (int i = 0; i < num_prompt_tokens; i++) {
258       ptr->input_ids[start_index + i] = static_cast<int32_t>(prompt_tokens[i]);
259     }
260     // causal attention mask is filled as following:
261     // 0, 65535 maps to -100.0, 0.0 after dequantizing
262     // 0      : [0,...................0,     0,     0,     0]
263     // 1-1019 : ...
264     // 1020   : [0,...............65535,     0,     0,     0]
265     // 1021   : [0,...............65535, 65535,     0,     0]
266     // 1022   : [0,...............65535, 65535, 65535,     0]
267     // 1023   : [0,...............65535, 65535, 65535, 65535]
268     for (int i = max_seq_len_ - 1, len = num_prompt_tokens; len >= 0;
269          --i, --len) {
270       for (int j = 0; j <= len; ++j) {
271         ptr->attention_mask[i * max_seq_len_ + start_index - 1 + j] = 65535;
272       }
273     }
274     pos = num_prompt_tokens - 1;
275     cur_token = prompt_tokens[pos];
276   } else if (eval_mode_ == EvalMode::kKVCached) {
277     KVCachedMemory::IO* ptr =
278         static_cast<KVCachedMemory::IO*>(io_mem_->get_mutable_ptr());
279     ptr->input_ids = static_cast<int32_t>(cur_token);
280     ptr->attention_mask[max_seq_len_ - 1] = 65535;
281   }
282 
283   while (pos < seq_len - 1) {
284     // inference
285     run_model_step(inputs);
286     Tensor& logits_tensor = output_tensors.back().back();
287 
288     if (pos == num_prompt_tokens) {
289       stats_.first_token_ms = time_in_ms();
290     } else if (pos == num_prompt_tokens - 1) {
291       stats_.prompt_eval_end_ms = time_in_ms();
292     }
293 
294     long sample_start_time_ms = time_in_ms();
295     prev_token = cur_token;
296     cur_token = logitsToToken(logits_tensor);
297     stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
298 
299     if (pos < num_prompt_tokens - 1) {
300       cur_token = prompt_tokens[pos + 1];
301     }
302     io_mem_->update_io(cur_token, ++pos, output_tensors);
303 
304     auto piece_res = tokenizer_->decode(prev_token, cur_token);
305     ET_CHECK(piece_res.ok());
306 
307     if (token_callback) {
308       token_callback(piece_res.get().c_str());
309     }
310 
311     if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
312       ET_LOG(Info, "\nReached to the end of generation");
313       break;
314     }
315   }
316   stats_.inference_end_ms = time_in_ms();
317 
318   if (pos == seq_len) {
319     ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len);
320   }
321 
322   stats_.num_prompt_tokens = num_prompt_tokens;
323   stats_.num_generated_tokens = pos - num_prompt_tokens;
324   printReport(stats_);
325   if (stats_callback) {
326     stats_callback(stats_);
327   }
328 
329   return Error::Ok;
330 }
331 
332 namespace {
printReport(const Runner::Stats & stats)333 void printReport(const Runner::Stats& stats) {
334   printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
335 
336   ET_LOG(
337       Info,
338       "\tPrompt Tokens: %" PRIu64 "    Generated Tokens: %" PRIu64,
339       stats.num_prompt_tokens,
340       stats.num_generated_tokens);
341 
342   ET_LOG(
343       Info,
344       "\tModel Load Time:\t\t%f (seconds)",
345       ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
346        stats.SCALING_FACTOR_UNITS_PER_SECOND));
347   double inference_time_ms =
348       (double)(stats.inference_end_ms - stats.inference_start_ms);
349   ET_LOG(
350       Info,
351       "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
352       inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
353 
354       (stats.num_generated_tokens) /
355           (double)(stats.inference_end_ms - stats.inference_start_ms) *
356           stats.SCALING_FACTOR_UNITS_PER_SECOND);
357   double prompt_eval_time =
358       (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
359   ET_LOG(
360       Info,
361       "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
362       prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
363       (stats.num_prompt_tokens) / prompt_eval_time *
364           stats.SCALING_FACTOR_UNITS_PER_SECOND);
365 
366   double eval_time =
367       (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
368   ET_LOG(
369       Info,
370       "\t\tGenerated %" PRIu64
371       " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
372       stats.num_generated_tokens,
373       eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
374       stats.num_generated_tokens / eval_time *
375           stats.SCALING_FACTOR_UNITS_PER_SECOND);
376 
377   // Time to first token is measured from the start of inference, excluding
378   // model load time.
379   ET_LOG(
380       Info,
381       "\tTime to first generated token:\t%f (seconds)",
382       ((double)(stats.first_token_ms - stats.inference_start_ms) /
383        stats.SCALING_FACTOR_UNITS_PER_SECOND));
384 
385   ET_LOG(
386       Info,
387       "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
388       stats.num_prompt_tokens + stats.num_generated_tokens,
389       (double)stats.aggregate_sampling_time_ms /
390           stats.SCALING_FACTOR_UNITS_PER_SECOND);
391 }
392 
statsToJsonString(const Runner::Stats & stats)393 std::string statsToJsonString(const Runner::Stats& stats) {
394   std::stringstream ss;
395   ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
396      << "\"generated_tokens\":" << stats.num_generated_tokens << ","
397      << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
398      << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
399      << "\"inference_start_ms\":" << stats.inference_start_ms << ","
400      << "\"inference_end_ms\":" << stats.inference_end_ms << ","
401      << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
402      << "\"first_token_ms\":" << stats.first_token_ms << ","
403      << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
404      << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
405      << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
406   return ss.str();
407 }
408 } // namespace
409 
get_methods_meta()410 std::vector<Result<MethodMeta>> Runner::get_methods_meta() {
411   std::vector<Result<MethodMeta>> methods_meta;
412   methods_meta.reserve(modules_.size());
413   for (size_t i = 0; i < modules_.size(); ++i) {
414     methods_meta.emplace_back(modules_[i]->method_meta(method_names_[i]));
415   }
416   return methods_meta;
417 }
418 } // namespace example
419