1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/examples/models/phi-3-mini/runner.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <ctime>
12*523fa7a6SAndroid Build Coastguard Worker #include <iostream>
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/tensor/tensor.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/log.h>
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
19*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::Module;
20*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::llm::BPETokenizer;
21*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::llm::Sampler;
22*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Error;
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker namespace example {
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker #define SAMPLER_TOP 0.9f
27*523fa7a6SAndroid Build Coastguard Worker #define ENDOFTEXT_TOKEN 32000
28*523fa7a6SAndroid Build Coastguard Worker #define VOCABULARY_SIZE 32064
29*523fa7a6SAndroid Build Coastguard Worker
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)30*523fa7a6SAndroid Build Coastguard Worker Runner::Runner(
31*523fa7a6SAndroid Build Coastguard Worker const std::string& model_path,
32*523fa7a6SAndroid Build Coastguard Worker const std::string& tokenizer_path,
33*523fa7a6SAndroid Build Coastguard Worker const float temperature)
34*523fa7a6SAndroid Build Coastguard Worker : module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
35*523fa7a6SAndroid Build Coastguard Worker tokenizer_(std::make_unique<BPETokenizer>()),
36*523fa7a6SAndroid Build Coastguard Worker sampler_(std::make_unique<Sampler>(
37*523fa7a6SAndroid Build Coastguard Worker VOCABULARY_SIZE,
38*523fa7a6SAndroid Build Coastguard Worker temperature,
39*523fa7a6SAndroid Build Coastguard Worker SAMPLER_TOP,
40*523fa7a6SAndroid Build Coastguard Worker static_cast<unsigned long long>(std::time(nullptr)))) {
41*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
42*523fa7a6SAndroid Build Coastguard Worker tokenizer_->load(tokenizer_path) == Error::Ok,
43*523fa7a6SAndroid Build Coastguard Worker "Failed to load tokenizer at %s",
44*523fa7a6SAndroid Build Coastguard Worker tokenizer_path.c_str());
45*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
46*523fa7a6SAndroid Build Coastguard Worker Info,
47*523fa7a6SAndroid Build Coastguard Worker "Created Phi-3-mini runner: model_path=%s, tokenizer_path=%s",
48*523fa7a6SAndroid Build Coastguard Worker model_path.c_str(),
49*523fa7a6SAndroid Build Coastguard Worker tokenizer_path.c_str());
50*523fa7a6SAndroid Build Coastguard Worker }
51*523fa7a6SAndroid Build Coastguard Worker
generate(const std::string & prompt,std::size_t max_seq_len)52*523fa7a6SAndroid Build Coastguard Worker void Runner::generate(const std::string& prompt, std::size_t max_seq_len) {
53*523fa7a6SAndroid Build Coastguard Worker auto encode_res = tokenizer_->encode(prompt, 0, 0);
54*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
55*523fa7a6SAndroid Build Coastguard Worker encode_res.error() == Error::Ok, "Failed to encode %s", prompt.c_str());
56*523fa7a6SAndroid Build Coastguard Worker auto input_tokens = encode_res.get();
57*523fa7a6SAndroid Build Coastguard Worker auto prev_token = input_tokens.back();
58*523fa7a6SAndroid Build Coastguard Worker auto current_token = prefill(input_tokens);
59*523fa7a6SAndroid Build Coastguard Worker std::cout << tokenizer_->decode(prev_token, current_token).get();
60*523fa7a6SAndroid Build Coastguard Worker std::cout.flush();
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker std::size_t seq_len = input_tokens.size() + 1;
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker while (current_token != ENDOFTEXT_TOKEN && seq_len < max_seq_len) {
65*523fa7a6SAndroid Build Coastguard Worker prev_token = current_token;
66*523fa7a6SAndroid Build Coastguard Worker current_token = run_model_step(current_token);
67*523fa7a6SAndroid Build Coastguard Worker std::cout << tokenizer_->decode(prev_token, current_token).get();
68*523fa7a6SAndroid Build Coastguard Worker std::cout.flush();
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker ++seq_len;
71*523fa7a6SAndroid Build Coastguard Worker }
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker std::cout << std::endl;
74*523fa7a6SAndroid Build Coastguard Worker }
75*523fa7a6SAndroid Build Coastguard Worker
logits_to_token(const exec_aten::Tensor & logits_tensor)76*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) {
77*523fa7a6SAndroid Build Coastguard Worker return sampler_->sample(logits_tensor.data_ptr<float>());
78*523fa7a6SAndroid Build Coastguard Worker }
79*523fa7a6SAndroid Build Coastguard Worker
prefill(std::vector<uint64_t> & tokens)80*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::prefill(std::vector<uint64_t>& tokens) {
81*523fa7a6SAndroid Build Coastguard Worker auto result = module_->forward(executorch::extension::from_blob(
82*523fa7a6SAndroid Build Coastguard Worker tokens.data(),
83*523fa7a6SAndroid Build Coastguard Worker {1, static_cast<exec_aten::SizesType>(tokens.size())},
84*523fa7a6SAndroid Build Coastguard Worker ScalarType::Long));
85*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens");
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker return logits_to_token(result.get()[0].toTensor());
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker
run_model_step(uint64_t token)90*523fa7a6SAndroid Build Coastguard Worker uint64_t Runner::run_model_step(uint64_t token) {
91*523fa7a6SAndroid Build Coastguard Worker auto result = module_->forward(
92*523fa7a6SAndroid Build Coastguard Worker executorch::extension::from_blob(&token, {1, 1}, ScalarType::Long));
93*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
94*523fa7a6SAndroid Build Coastguard Worker result.error() == Error::Ok,
95*523fa7a6SAndroid Build Coastguard Worker "Failed to run forward() for token %" PRIu64,
96*523fa7a6SAndroid Build Coastguard Worker token);
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker return logits_to_token(result.get()[0].toTensor());
99*523fa7a6SAndroid Build Coastguard Worker }
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker } // namespace example
102