xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/examples/models/phi-3-mini/runner.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <ctime>
12*523fa7a6SAndroid Build Coastguard Worker #include <iostream>
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/tensor/tensor.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/log.h>
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
19*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::Module;
20*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::llm::BPETokenizer;
21*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::llm::Sampler;
22*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker namespace example {
25*523fa7a6SAndroid Build Coastguard Worker 
26*523fa7a6SAndroid Build Coastguard Worker #define SAMPLER_TOP 0.9f
27*523fa7a6SAndroid Build Coastguard Worker #define ENDOFTEXT_TOKEN 32000
28*523fa7a6SAndroid Build Coastguard Worker #define VOCABULARY_SIZE 32064
29*523fa7a6SAndroid Build Coastguard Worker 
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)30*523fa7a6SAndroid Build Coastguard Worker Runner::Runner(
31*523fa7a6SAndroid Build Coastguard Worker     const std::string& model_path,
32*523fa7a6SAndroid Build Coastguard Worker     const std::string& tokenizer_path,
33*523fa7a6SAndroid Build Coastguard Worker     const float temperature)
34*523fa7a6SAndroid Build Coastguard Worker     : module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
35*523fa7a6SAndroid Build Coastguard Worker       tokenizer_(std::make_unique<BPETokenizer>()),
36*523fa7a6SAndroid Build Coastguard Worker       sampler_(std::make_unique<Sampler>(
37*523fa7a6SAndroid Build Coastguard Worker           VOCABULARY_SIZE,
38*523fa7a6SAndroid Build Coastguard Worker           temperature,
39*523fa7a6SAndroid Build Coastguard Worker           SAMPLER_TOP,
40*523fa7a6SAndroid Build Coastguard Worker           static_cast<unsigned long long>(std::time(nullptr)))) {
41*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
42*523fa7a6SAndroid Build Coastguard Worker       tokenizer_->load(tokenizer_path) == Error::Ok,
43*523fa7a6SAndroid Build Coastguard Worker       "Failed to load tokenizer at %s",
44*523fa7a6SAndroid Build Coastguard Worker       tokenizer_path.c_str());
45*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
46*523fa7a6SAndroid Build Coastguard Worker       Info,
47*523fa7a6SAndroid Build Coastguard Worker       "Created Phi-3-mini runner: model_path=%s, tokenizer_path=%s",
48*523fa7a6SAndroid Build Coastguard Worker       model_path.c_str(),
49*523fa7a6SAndroid Build Coastguard Worker       tokenizer_path.c_str());
50*523fa7a6SAndroid Build Coastguard Worker }
51*523fa7a6SAndroid Build Coastguard Worker 
generate(const std::string & prompt,std::size_t max_seq_len)52*523fa7a6SAndroid Build Coastguard Worker void Runner::generate(const std::string& prompt, std::size_t max_seq_len) {
53*523fa7a6SAndroid Build Coastguard Worker   auto encode_res = tokenizer_->encode(prompt, 0, 0);
54*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
55*523fa7a6SAndroid Build Coastguard Worker       encode_res.error() == Error::Ok, "Failed to encode %s", prompt.c_str());
56*523fa7a6SAndroid Build Coastguard Worker   auto input_tokens = encode_res.get();
57*523fa7a6SAndroid Build Coastguard Worker   auto prev_token = input_tokens.back();
58*523fa7a6SAndroid Build Coastguard Worker   auto current_token = prefill(input_tokens);
59*523fa7a6SAndroid Build Coastguard Worker   std::cout << tokenizer_->decode(prev_token, current_token).get();
60*523fa7a6SAndroid Build Coastguard Worker   std::cout.flush();
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker   std::size_t seq_len = input_tokens.size() + 1;
63*523fa7a6SAndroid Build Coastguard Worker 
64*523fa7a6SAndroid Build Coastguard Worker   while (current_token != ENDOFTEXT_TOKEN && seq_len < max_seq_len) {
65*523fa7a6SAndroid Build Coastguard Worker     prev_token = current_token;
66*523fa7a6SAndroid Build Coastguard Worker     current_token = run_model_step(current_token);
67*523fa7a6SAndroid Build Coastguard Worker     std::cout << tokenizer_->decode(prev_token, current_token).get();
68*523fa7a6SAndroid Build Coastguard Worker     std::cout.flush();
69*523fa7a6SAndroid Build Coastguard Worker 
70*523fa7a6SAndroid Build Coastguard Worker     ++seq_len;
71*523fa7a6SAndroid Build Coastguard Worker   }
72*523fa7a6SAndroid Build Coastguard Worker 
73*523fa7a6SAndroid Build Coastguard Worker   std::cout << std::endl;
74*523fa7a6SAndroid Build Coastguard Worker }
75*523fa7a6SAndroid Build Coastguard Worker 
logits_to_token(const exec_aten::Tensor & logits_tensor)76*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) {
77*523fa7a6SAndroid Build Coastguard Worker   return sampler_->sample(logits_tensor.data_ptr<float>());
78*523fa7a6SAndroid Build Coastguard Worker }
79*523fa7a6SAndroid Build Coastguard Worker 
prefill(std::vector<uint64_t> & tokens)80*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::prefill(std::vector<uint64_t>& tokens) {
81*523fa7a6SAndroid Build Coastguard Worker   auto result = module_->forward(executorch::extension::from_blob(
82*523fa7a6SAndroid Build Coastguard Worker       tokens.data(),
83*523fa7a6SAndroid Build Coastguard Worker       {1, static_cast<exec_aten::SizesType>(tokens.size())},
84*523fa7a6SAndroid Build Coastguard Worker       ScalarType::Long));
85*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens");
86*523fa7a6SAndroid Build Coastguard Worker 
87*523fa7a6SAndroid Build Coastguard Worker   return logits_to_token(result.get()[0].toTensor());
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker 
run_model_step(uint64_t token)90*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::run_model_step(uint64_t token) {
91*523fa7a6SAndroid Build Coastguard Worker   auto result = module_->forward(
92*523fa7a6SAndroid Build Coastguard Worker       executorch::extension::from_blob(&token, {1, 1}, ScalarType::Long));
93*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
94*523fa7a6SAndroid Build Coastguard Worker       result.error() == Error::Ok,
95*523fa7a6SAndroid Build Coastguard Worker       "Failed to run forward() for token %" PRIu64,
96*523fa7a6SAndroid Build Coastguard Worker       token);
97*523fa7a6SAndroid Build Coastguard Worker 
98*523fa7a6SAndroid Build Coastguard Worker   return logits_to_token(result.get()[0].toTensor());
99*523fa7a6SAndroid Build Coastguard Worker }
100*523fa7a6SAndroid Build Coastguard Worker 
101*523fa7a6SAndroid Build Coastguard Worker } // namespace example
102