xref: /aosp_15_r20/external/executorch/examples/llm_manual/main.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include "basic_sampler.h"
12*523fa7a6SAndroid Build Coastguard Worker #include "basic_tokenizer.h"
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/tensor/tensor.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/evalue.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/result.h>
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
21*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::Tensor;
22*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::from_blob;
23*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::Module;
24*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::EValue;
25*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
26*523fa7a6SAndroid Build Coastguard Worker 
27*523fa7a6SAndroid Build Coastguard Worker // The value of the gpt2 `<|endoftext|>` token.
28*523fa7a6SAndroid Build Coastguard Worker #define ENDOFTEXT_TOKEN 50256
29*523fa7a6SAndroid Build Coastguard Worker 
generate(Module & llm_model,std::string & prompt,BasicTokenizer & tokenizer,BasicSampler & sampler,size_t max_input_length,size_t max_output_length)30*523fa7a6SAndroid Build Coastguard Worker std::string generate(
31*523fa7a6SAndroid Build Coastguard Worker     Module& llm_model,
32*523fa7a6SAndroid Build Coastguard Worker     std::string& prompt,
33*523fa7a6SAndroid Build Coastguard Worker     BasicTokenizer& tokenizer,
34*523fa7a6SAndroid Build Coastguard Worker     BasicSampler& sampler,
35*523fa7a6SAndroid Build Coastguard Worker     size_t max_input_length,
36*523fa7a6SAndroid Build Coastguard Worker     size_t max_output_length) {
37*523fa7a6SAndroid Build Coastguard Worker   // Convert the input text into a list of integers (tokens) that represents it,
38*523fa7a6SAndroid Build Coastguard Worker   // using the string-to-token mapping that the model was trained on. Each token
39*523fa7a6SAndroid Build Coastguard Worker   // is an integer that represents a word or part of a word.
40*523fa7a6SAndroid Build Coastguard Worker   std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
41*523fa7a6SAndroid Build Coastguard Worker   std::vector<int64_t> output_tokens;
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0u; i < max_output_length; i++) {
44*523fa7a6SAndroid Build Coastguard Worker     // Convert the input_tokens from a vector of int64_t to EValue. EValue is a
45*523fa7a6SAndroid Build Coastguard Worker     // unified data type in the ExecuTorch runtime.
46*523fa7a6SAndroid Build Coastguard Worker     auto inputs = from_blob(
47*523fa7a6SAndroid Build Coastguard Worker         input_tokens.data(),
48*523fa7a6SAndroid Build Coastguard Worker         {1, static_cast<int>(input_tokens.size())},
49*523fa7a6SAndroid Build Coastguard Worker         ScalarType::Long);
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker     // Run the model. It will return a tensor of logits (log-probabilities).
52*523fa7a6SAndroid Build Coastguard Worker     auto logits_evalue = llm_model.forward(inputs);
53*523fa7a6SAndroid Build Coastguard Worker 
54*523fa7a6SAndroid Build Coastguard Worker     // Convert the output logits from EValue to std::vector, which is what the
55*523fa7a6SAndroid Build Coastguard Worker     // sampler expects.
56*523fa7a6SAndroid Build Coastguard Worker     Tensor logits_tensor = logits_evalue.get()[0].toTensor();
57*523fa7a6SAndroid Build Coastguard Worker     std::vector<float> logits(
58*523fa7a6SAndroid Build Coastguard Worker         logits_tensor.data_ptr<float>(),
59*523fa7a6SAndroid Build Coastguard Worker         logits_tensor.data_ptr<float>() + logits_tensor.numel());
60*523fa7a6SAndroid Build Coastguard Worker 
61*523fa7a6SAndroid Build Coastguard Worker     // Sample the next token from the logits.
62*523fa7a6SAndroid Build Coastguard Worker     int64_t next_token = sampler.sample(logits);
63*523fa7a6SAndroid Build Coastguard Worker 
64*523fa7a6SAndroid Build Coastguard Worker     // Break if we reached the end of the text.
65*523fa7a6SAndroid Build Coastguard Worker     if (next_token == ENDOFTEXT_TOKEN) {
66*523fa7a6SAndroid Build Coastguard Worker       break;
67*523fa7a6SAndroid Build Coastguard Worker     }
68*523fa7a6SAndroid Build Coastguard Worker 
69*523fa7a6SAndroid Build Coastguard Worker     // Add the next token to the output.
70*523fa7a6SAndroid Build Coastguard Worker     output_tokens.push_back(next_token);
71*523fa7a6SAndroid Build Coastguard Worker 
72*523fa7a6SAndroid Build Coastguard Worker     std::cout << tokenizer.decode({next_token});
73*523fa7a6SAndroid Build Coastguard Worker     std::cout.flush();
74*523fa7a6SAndroid Build Coastguard Worker 
75*523fa7a6SAndroid Build Coastguard Worker     // Update next input.
76*523fa7a6SAndroid Build Coastguard Worker     input_tokens.push_back(next_token);
77*523fa7a6SAndroid Build Coastguard Worker     if (input_tokens.size() > max_input_length) {
78*523fa7a6SAndroid Build Coastguard Worker       input_tokens.erase(input_tokens.begin());
79*523fa7a6SAndroid Build Coastguard Worker     }
80*523fa7a6SAndroid Build Coastguard Worker   }
81*523fa7a6SAndroid Build Coastguard Worker 
82*523fa7a6SAndroid Build Coastguard Worker   std::cout << std::endl;
83*523fa7a6SAndroid Build Coastguard Worker 
84*523fa7a6SAndroid Build Coastguard Worker   // Convert the output tokens into a human-readable string.
85*523fa7a6SAndroid Build Coastguard Worker   std::string output_string = tokenizer.decode(output_tokens);
86*523fa7a6SAndroid Build Coastguard Worker   return output_string;
87*523fa7a6SAndroid Build Coastguard Worker }
88*523fa7a6SAndroid Build Coastguard Worker 
main()89*523fa7a6SAndroid Build Coastguard Worker int main() {
90*523fa7a6SAndroid Build Coastguard Worker   // Set up the prompt. This provides the seed text for the model to elaborate.
91*523fa7a6SAndroid Build Coastguard Worker   std::cout << "Enter model prompt: ";
92*523fa7a6SAndroid Build Coastguard Worker   std::string prompt;
93*523fa7a6SAndroid Build Coastguard Worker   std::getline(std::cin, prompt);
94*523fa7a6SAndroid Build Coastguard Worker 
95*523fa7a6SAndroid Build Coastguard Worker   // The tokenizer is used to convert between tokens (used by the model) and
96*523fa7a6SAndroid Build Coastguard Worker   // human-readable strings.
97*523fa7a6SAndroid Build Coastguard Worker   BasicTokenizer tokenizer("vocab.json");
98*523fa7a6SAndroid Build Coastguard Worker 
99*523fa7a6SAndroid Build Coastguard Worker   // The sampler is used to sample the next token from the logits.
100*523fa7a6SAndroid Build Coastguard Worker   BasicSampler sampler = BasicSampler();
101*523fa7a6SAndroid Build Coastguard Worker 
102*523fa7a6SAndroid Build Coastguard Worker   // Load the exported nanoGPT program, which was generated via the previous
103*523fa7a6SAndroid Build Coastguard Worker   // steps.
104*523fa7a6SAndroid Build Coastguard Worker   Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
105*523fa7a6SAndroid Build Coastguard Worker 
106*523fa7a6SAndroid Build Coastguard Worker   const auto max_input_tokens = 1024;
107*523fa7a6SAndroid Build Coastguard Worker   const auto max_output_tokens = 30;
108*523fa7a6SAndroid Build Coastguard Worker   std::cout << prompt;
109*523fa7a6SAndroid Build Coastguard Worker   generate(
110*523fa7a6SAndroid Build Coastguard Worker       model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
111*523fa7a6SAndroid Build Coastguard Worker }
112