xref: /aosp_15_r20/external/executorch/examples/models/llama/main.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gflags/gflags.h>
10 
11 #include <executorch/examples/models/llama/runner/runner.h>
12 
13 #if defined(ET_USE_THREADPOOL)
14 #include <executorch/extension/threadpool/cpuinfo_utils.h>
15 #include <executorch/extension/threadpool/threadpool.h>
16 #endif
17 
18 DEFINE_string(
19     model_path,
20     "llama2.pte",
21     "Model serialized in flatbuffer format.");
22 
23 DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
24 
25 DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
26 
27 DEFINE_double(
28     temperature,
29     0.8f,
30     "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
31 
32 DEFINE_int32(
33     seq_len,
34     128,
35     "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
36 
37 DEFINE_int32(
38     cpu_threads,
39     -1,
40     "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
41 
42 DEFINE_bool(warmup, false, "Whether to run a warmup run.");
43 
main(int32_t argc,char ** argv)44 int32_t main(int32_t argc, char** argv) {
45   gflags::ParseCommandLineFlags(&argc, &argv, true);
46 
47   // Create a loader to get the data of the program file. There are other
48   // DataLoaders that use mmap() or point32_t to data that's already in memory,
49   // and users can create their own DataLoaders to load from arbitrary sources.
50   const char* model_path = FLAGS_model_path.c_str();
51 
52   const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
53 
54   const char* prompt = FLAGS_prompt.c_str();
55 
56   double temperature = FLAGS_temperature;
57 
58   int32_t seq_len = FLAGS_seq_len;
59 
60   int32_t cpu_threads = FLAGS_cpu_threads;
61 
62   bool warmup = FLAGS_warmup;
63 
64 #if defined(ET_USE_THREADPOOL)
65   uint32_t num_performant_cores = cpu_threads == -1
66       ? ::executorch::extension::cpuinfo::get_num_performant_cores()
67       : static_cast<uint32_t>(cpu_threads);
68   ET_LOG(
69       Info, "Resetting threadpool with num threads = %d", num_performant_cores);
70   if (num_performant_cores > 0) {
71     ::executorch::extension::threadpool::get_threadpool()
72         ->_unsafe_reset_threadpool(num_performant_cores);
73   }
74 #endif
75   // create llama runner
76   example::Runner runner(model_path, tokenizer_path, temperature);
77 
78   if (warmup) {
79     runner.warmup(prompt, seq_len);
80   }
81   // generate
82   runner.generate(prompt, seq_len);
83 
84   return 0;
85 }
86