1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include "basic_sampler.h"
12*523fa7a6SAndroid Build Coastguard Worker #include "basic_tokenizer.h"
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/tensor/tensor.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/evalue.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/result.h>
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::ScalarType;
21*523fa7a6SAndroid Build Coastguard Worker using executorch::aten::Tensor;
22*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::from_blob;
23*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::Module;
24*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::EValue;
25*523fa7a6SAndroid Build Coastguard Worker using executorch::runtime::Result;
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker // The value of the gpt2 `<|endoftext|>` token.
28*523fa7a6SAndroid Build Coastguard Worker #define ENDOFTEXT_TOKEN 50256
29*523fa7a6SAndroid Build Coastguard Worker
generate(Module & llm_model,std::string & prompt,BasicTokenizer & tokenizer,BasicSampler & sampler,size_t max_input_length,size_t max_output_length)30*523fa7a6SAndroid Build Coastguard Worker std::string generate(
31*523fa7a6SAndroid Build Coastguard Worker Module& llm_model,
32*523fa7a6SAndroid Build Coastguard Worker std::string& prompt,
33*523fa7a6SAndroid Build Coastguard Worker BasicTokenizer& tokenizer,
34*523fa7a6SAndroid Build Coastguard Worker BasicSampler& sampler,
35*523fa7a6SAndroid Build Coastguard Worker size_t max_input_length,
36*523fa7a6SAndroid Build Coastguard Worker size_t max_output_length) {
37*523fa7a6SAndroid Build Coastguard Worker // Convert the input text into a list of integers (tokens) that represents it,
38*523fa7a6SAndroid Build Coastguard Worker // using the string-to-token mapping that the model was trained on. Each token
39*523fa7a6SAndroid Build Coastguard Worker // is an integer that represents a word or part of a word.
40*523fa7a6SAndroid Build Coastguard Worker std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
41*523fa7a6SAndroid Build Coastguard Worker std::vector<int64_t> output_tokens;
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0u; i < max_output_length; i++) {
44*523fa7a6SAndroid Build Coastguard Worker // Convert the input_tokens from a vector of int64_t to EValue. EValue is a
45*523fa7a6SAndroid Build Coastguard Worker // unified data type in the ExecuTorch runtime.
46*523fa7a6SAndroid Build Coastguard Worker auto inputs = from_blob(
47*523fa7a6SAndroid Build Coastguard Worker input_tokens.data(),
48*523fa7a6SAndroid Build Coastguard Worker {1, static_cast<int>(input_tokens.size())},
49*523fa7a6SAndroid Build Coastguard Worker ScalarType::Long);
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker // Run the model. It will return a tensor of logits (log-probabilities).
52*523fa7a6SAndroid Build Coastguard Worker auto logits_evalue = llm_model.forward(inputs);
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker // Convert the output logits from EValue to std::vector, which is what the
55*523fa7a6SAndroid Build Coastguard Worker // sampler expects.
56*523fa7a6SAndroid Build Coastguard Worker Tensor logits_tensor = logits_evalue.get()[0].toTensor();
57*523fa7a6SAndroid Build Coastguard Worker std::vector<float> logits(
58*523fa7a6SAndroid Build Coastguard Worker logits_tensor.data_ptr<float>(),
59*523fa7a6SAndroid Build Coastguard Worker logits_tensor.data_ptr<float>() + logits_tensor.numel());
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker // Sample the next token from the logits.
62*523fa7a6SAndroid Build Coastguard Worker int64_t next_token = sampler.sample(logits);
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker // Break if we reached the end of the text.
65*523fa7a6SAndroid Build Coastguard Worker if (next_token == ENDOFTEXT_TOKEN) {
66*523fa7a6SAndroid Build Coastguard Worker break;
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker // Add the next token to the output.
70*523fa7a6SAndroid Build Coastguard Worker output_tokens.push_back(next_token);
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker std::cout << tokenizer.decode({next_token});
73*523fa7a6SAndroid Build Coastguard Worker std::cout.flush();
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker // Update next input.
76*523fa7a6SAndroid Build Coastguard Worker input_tokens.push_back(next_token);
77*523fa7a6SAndroid Build Coastguard Worker if (input_tokens.size() > max_input_length) {
78*523fa7a6SAndroid Build Coastguard Worker input_tokens.erase(input_tokens.begin());
79*523fa7a6SAndroid Build Coastguard Worker }
80*523fa7a6SAndroid Build Coastguard Worker }
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker std::cout << std::endl;
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker // Convert the output tokens into a human-readable string.
85*523fa7a6SAndroid Build Coastguard Worker std::string output_string = tokenizer.decode(output_tokens);
86*523fa7a6SAndroid Build Coastguard Worker return output_string;
87*523fa7a6SAndroid Build Coastguard Worker }
88*523fa7a6SAndroid Build Coastguard Worker
main()89*523fa7a6SAndroid Build Coastguard Worker int main() {
90*523fa7a6SAndroid Build Coastguard Worker // Set up the prompt. This provides the seed text for the model to elaborate.
91*523fa7a6SAndroid Build Coastguard Worker std::cout << "Enter model prompt: ";
92*523fa7a6SAndroid Build Coastguard Worker std::string prompt;
93*523fa7a6SAndroid Build Coastguard Worker std::getline(std::cin, prompt);
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker // The tokenizer is used to convert between tokens (used by the model) and
96*523fa7a6SAndroid Build Coastguard Worker // human-readable strings.
97*523fa7a6SAndroid Build Coastguard Worker BasicTokenizer tokenizer("vocab.json");
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker // The sampler is used to sample the next token from the logits.
100*523fa7a6SAndroid Build Coastguard Worker BasicSampler sampler = BasicSampler();
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker // Load the exported nanoGPT program, which was generated via the previous
103*523fa7a6SAndroid Build Coastguard Worker // steps.
104*523fa7a6SAndroid Build Coastguard Worker Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker const auto max_input_tokens = 1024;
107*523fa7a6SAndroid Build Coastguard Worker const auto max_output_tokens = 30;
108*523fa7a6SAndroid Build Coastguard Worker std::cout << prompt;
109*523fa7a6SAndroid Build Coastguard Worker generate(
110*523fa7a6SAndroid Build Coastguard Worker model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
111*523fa7a6SAndroid Build Coastguard Worker }
112