1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gflags/gflags.h>
10
11 #include <executorch/examples/models/llama/runner/runner.h>
12
13 #if defined(ET_USE_THREADPOOL)
14 #include <executorch/extension/threadpool/cpuinfo_utils.h>
15 #include <executorch/extension/threadpool/threadpool.h>
16 #endif
17
18 DEFINE_string(
19 model_path,
20 "llama2.pte",
21 "Model serialized in flatbuffer format.");
22
23 DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
24
25 DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
26
27 DEFINE_double(
28 temperature,
29 0.8f,
30 "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
31
32 DEFINE_int32(
33 seq_len,
34 128,
35 "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
36
37 DEFINE_int32(
38 cpu_threads,
39 -1,
40 "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
41
42 DEFINE_bool(warmup, false, "Whether to run a warmup run.");
43
main(int32_t argc,char ** argv)44 int32_t main(int32_t argc, char** argv) {
45 gflags::ParseCommandLineFlags(&argc, &argv, true);
46
47 // Create a loader to get the data of the program file. There are other
48 // DataLoaders that use mmap() or point32_t to data that's already in memory,
49 // and users can create their own DataLoaders to load from arbitrary sources.
50 const char* model_path = FLAGS_model_path.c_str();
51
52 const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
53
54 const char* prompt = FLAGS_prompt.c_str();
55
56 double temperature = FLAGS_temperature;
57
58 int32_t seq_len = FLAGS_seq_len;
59
60 int32_t cpu_threads = FLAGS_cpu_threads;
61
62 bool warmup = FLAGS_warmup;
63
64 #if defined(ET_USE_THREADPOOL)
65 uint32_t num_performant_cores = cpu_threads == -1
66 ? ::executorch::extension::cpuinfo::get_num_performant_cores()
67 : static_cast<uint32_t>(cpu_threads);
68 ET_LOG(
69 Info, "Resetting threadpool with num threads = %d", num_performant_cores);
70 if (num_performant_cores > 0) {
71 ::executorch::extension::threadpool::get_threadpool()
72 ->_unsafe_reset_threadpool(num_performant_cores);
73 }
74 #endif
75 // create llama runner
76 example::Runner runner(model_path, tokenizer_path, temperature);
77
78 if (warmup) {
79 runner.warmup(prompt, seq_len);
80 }
81 // generate
82 runner.generate(prompt, seq_len);
83
84 return 0;
85 }
86