xref: /aosp_15_r20/external/executorch/extension/llm/runner/stats.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker // Runner stats for LLM
10*523fa7a6SAndroid Build Coastguard Worker #pragma once
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/runner/util.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/log.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
14*523fa7a6SAndroid Build Coastguard Worker #include <sstream>
15*523fa7a6SAndroid Build Coastguard Worker #include <string>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
18*523fa7a6SAndroid Build Coastguard Worker namespace extension {
19*523fa7a6SAndroid Build Coastguard Worker namespace llm {
20*523fa7a6SAndroid Build Coastguard Worker 
21*523fa7a6SAndroid Build Coastguard Worker struct ET_EXPERIMENTAL Stats {
22*523fa7a6SAndroid Build Coastguard Worker   // Scaling factor for timestamps - in this case, we use ms.
23*523fa7a6SAndroid Build Coastguard Worker   const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
24*523fa7a6SAndroid Build Coastguard Worker   // Time stamps for the different stages of the execution
25*523fa7a6SAndroid Build Coastguard Worker   // model_load_start_ms: Start of model loading.
26*523fa7a6SAndroid Build Coastguard Worker   long model_load_start_ms;
27*523fa7a6SAndroid Build Coastguard Worker   // model_load_end_ms: End of model loading.
28*523fa7a6SAndroid Build Coastguard Worker   long model_load_end_ms;
29*523fa7a6SAndroid Build Coastguard Worker   // inference_start_ms: Immediately after the model is loaded (or we check
30*523fa7a6SAndroid Build Coastguard Worker   // for model load), measure the inference time.
31*523fa7a6SAndroid Build Coastguard Worker   // NOTE: It's actually the tokenizer encode + model execution time.
32*523fa7a6SAndroid Build Coastguard Worker   long inference_start_ms;
33*523fa7a6SAndroid Build Coastguard Worker   // End of the tokenizer encode time.
34*523fa7a6SAndroid Build Coastguard Worker   long token_encode_end_ms;
35*523fa7a6SAndroid Build Coastguard Worker   // Start of the model execution (forward function) time.
36*523fa7a6SAndroid Build Coastguard Worker   long model_execution_start_ms;
37*523fa7a6SAndroid Build Coastguard Worker   // End of the model execution (forward function) time.
38*523fa7a6SAndroid Build Coastguard Worker   long model_execution_end_ms;
39*523fa7a6SAndroid Build Coastguard Worker   // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
40*523fa7a6SAndroid Build Coastguard Worker   // before the inference loop starts
41*523fa7a6SAndroid Build Coastguard Worker   long prompt_eval_end_ms;
42*523fa7a6SAndroid Build Coastguard Worker   // first_token: Timestamp when the first generated token is emitted
43*523fa7a6SAndroid Build Coastguard Worker   long first_token_ms;
44*523fa7a6SAndroid Build Coastguard Worker   // inference_end_ms: End of inference/generation.
45*523fa7a6SAndroid Build Coastguard Worker   long inference_end_ms;
46*523fa7a6SAndroid Build Coastguard Worker   // Keep a running total of the time spent in sampling.
47*523fa7a6SAndroid Build Coastguard Worker   long aggregate_sampling_time_ms;
48*523fa7a6SAndroid Build Coastguard Worker   // Token count from prompt
49*523fa7a6SAndroid Build Coastguard Worker   int64_t num_prompt_tokens;
50*523fa7a6SAndroid Build Coastguard Worker   // Token count from generated (total - prompt)
51*523fa7a6SAndroid Build Coastguard Worker   int64_t num_generated_tokens;
on_sampling_beginStats52*523fa7a6SAndroid Build Coastguard Worker   inline void on_sampling_begin() {
53*523fa7a6SAndroid Build Coastguard Worker     aggregate_sampling_timer_start_timestamp = time_in_ms();
54*523fa7a6SAndroid Build Coastguard Worker   }
on_sampling_endStats55*523fa7a6SAndroid Build Coastguard Worker   inline void on_sampling_end() {
56*523fa7a6SAndroid Build Coastguard Worker     aggregate_sampling_time_ms +=
57*523fa7a6SAndroid Build Coastguard Worker         time_in_ms() - aggregate_sampling_timer_start_timestamp;
58*523fa7a6SAndroid Build Coastguard Worker     aggregate_sampling_timer_start_timestamp = 0;
59*523fa7a6SAndroid Build Coastguard Worker   }
60*523fa7a6SAndroid Build Coastguard Worker 
61*523fa7a6SAndroid Build Coastguard Worker   void reset(bool all_stats = false) {
62*523fa7a6SAndroid Build Coastguard Worker     // Not resetting model_load_start_ms and model_load_end_ms because reset is
63*523fa7a6SAndroid Build Coastguard Worker     // typically called after warmup and before running the actual run.
64*523fa7a6SAndroid Build Coastguard Worker     // However, we don't load the model again during the actual run after
65*523fa7a6SAndroid Build Coastguard Worker     // warmup. So, we don't want to reset these timestamps unless we are
66*523fa7a6SAndroid Build Coastguard Worker     // resetting everything.
67*523fa7a6SAndroid Build Coastguard Worker     if (all_stats) {
68*523fa7a6SAndroid Build Coastguard Worker       model_load_start_ms = 0;
69*523fa7a6SAndroid Build Coastguard Worker       model_load_end_ms = 0;
70*523fa7a6SAndroid Build Coastguard Worker     }
71*523fa7a6SAndroid Build Coastguard Worker     inference_start_ms = 0;
72*523fa7a6SAndroid Build Coastguard Worker     prompt_eval_end_ms = 0;
73*523fa7a6SAndroid Build Coastguard Worker     first_token_ms = 0;
74*523fa7a6SAndroid Build Coastguard Worker     inference_end_ms = 0;
75*523fa7a6SAndroid Build Coastguard Worker     aggregate_sampling_time_ms = 0;
76*523fa7a6SAndroid Build Coastguard Worker     num_prompt_tokens = 0;
77*523fa7a6SAndroid Build Coastguard Worker     num_generated_tokens = 0;
78*523fa7a6SAndroid Build Coastguard Worker     aggregate_sampling_timer_start_timestamp = 0;
79*523fa7a6SAndroid Build Coastguard Worker   }
80*523fa7a6SAndroid Build Coastguard Worker 
81*523fa7a6SAndroid Build Coastguard Worker  private:
82*523fa7a6SAndroid Build Coastguard Worker   long aggregate_sampling_timer_start_timestamp = 0;
83*523fa7a6SAndroid Build Coastguard Worker };
84*523fa7a6SAndroid Build Coastguard Worker 
85*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kTopp = 0.9f;
86*523fa7a6SAndroid Build Coastguard Worker 
stats_to_json_string(const Stats & stats)87*523fa7a6SAndroid Build Coastguard Worker inline std::string stats_to_json_string(const Stats& stats) {
88*523fa7a6SAndroid Build Coastguard Worker   std::stringstream ss;
89*523fa7a6SAndroid Build Coastguard Worker   ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
90*523fa7a6SAndroid Build Coastguard Worker      << "\"generated_tokens\":" << stats.num_generated_tokens << ","
91*523fa7a6SAndroid Build Coastguard Worker      << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
92*523fa7a6SAndroid Build Coastguard Worker      << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
93*523fa7a6SAndroid Build Coastguard Worker      << "\"inference_start_ms\":" << stats.inference_start_ms << ","
94*523fa7a6SAndroid Build Coastguard Worker      << "\"inference_end_ms\":" << stats.inference_end_ms << ","
95*523fa7a6SAndroid Build Coastguard Worker      << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
96*523fa7a6SAndroid Build Coastguard Worker      << "\"first_token_ms\":" << stats.first_token_ms << ","
97*523fa7a6SAndroid Build Coastguard Worker      << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
98*523fa7a6SAndroid Build Coastguard Worker      << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
99*523fa7a6SAndroid Build Coastguard Worker      << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
100*523fa7a6SAndroid Build Coastguard Worker   return ss.str();
101*523fa7a6SAndroid Build Coastguard Worker }
102*523fa7a6SAndroid Build Coastguard Worker 
print_report(const Stats & stats)103*523fa7a6SAndroid Build Coastguard Worker inline void print_report(const Stats& stats) {
104*523fa7a6SAndroid Build Coastguard Worker   printf("PyTorchObserver %s\n", stats_to_json_string(stats).c_str());
105*523fa7a6SAndroid Build Coastguard Worker 
106*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
107*523fa7a6SAndroid Build Coastguard Worker       Info,
108*523fa7a6SAndroid Build Coastguard Worker       "\tPrompt Tokens: %" PRIu64 "    Generated Tokens: %" PRIu64,
109*523fa7a6SAndroid Build Coastguard Worker       stats.num_prompt_tokens,
110*523fa7a6SAndroid Build Coastguard Worker       stats.num_generated_tokens);
111*523fa7a6SAndroid Build Coastguard Worker 
112*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
113*523fa7a6SAndroid Build Coastguard Worker       Info,
114*523fa7a6SAndroid Build Coastguard Worker       "\tModel Load Time:\t\t%f (seconds)",
115*523fa7a6SAndroid Build Coastguard Worker       ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
116*523fa7a6SAndroid Build Coastguard Worker        stats.SCALING_FACTOR_UNITS_PER_SECOND));
117*523fa7a6SAndroid Build Coastguard Worker   double inference_time_ms =
118*523fa7a6SAndroid Build Coastguard Worker       (double)(stats.inference_end_ms - stats.inference_start_ms);
119*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
120*523fa7a6SAndroid Build Coastguard Worker       Info,
121*523fa7a6SAndroid Build Coastguard Worker       "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
122*523fa7a6SAndroid Build Coastguard Worker       inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
123*523fa7a6SAndroid Build Coastguard Worker 
124*523fa7a6SAndroid Build Coastguard Worker       (stats.num_generated_tokens) /
125*523fa7a6SAndroid Build Coastguard Worker           (double)(stats.inference_end_ms - stats.inference_start_ms) *
126*523fa7a6SAndroid Build Coastguard Worker           stats.SCALING_FACTOR_UNITS_PER_SECOND);
127*523fa7a6SAndroid Build Coastguard Worker   double prompt_eval_time =
128*523fa7a6SAndroid Build Coastguard Worker       (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
129*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
130*523fa7a6SAndroid Build Coastguard Worker       Info,
131*523fa7a6SAndroid Build Coastguard Worker       "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
132*523fa7a6SAndroid Build Coastguard Worker       prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
133*523fa7a6SAndroid Build Coastguard Worker       (stats.num_prompt_tokens) / prompt_eval_time *
134*523fa7a6SAndroid Build Coastguard Worker           stats.SCALING_FACTOR_UNITS_PER_SECOND);
135*523fa7a6SAndroid Build Coastguard Worker 
136*523fa7a6SAndroid Build Coastguard Worker   double eval_time =
137*523fa7a6SAndroid Build Coastguard Worker       (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
138*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
139*523fa7a6SAndroid Build Coastguard Worker       Info,
140*523fa7a6SAndroid Build Coastguard Worker       "\t\tGenerated %" PRIu64
141*523fa7a6SAndroid Build Coastguard Worker       " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
142*523fa7a6SAndroid Build Coastguard Worker       stats.num_generated_tokens,
143*523fa7a6SAndroid Build Coastguard Worker       eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
144*523fa7a6SAndroid Build Coastguard Worker       stats.num_generated_tokens / eval_time *
145*523fa7a6SAndroid Build Coastguard Worker           stats.SCALING_FACTOR_UNITS_PER_SECOND);
146*523fa7a6SAndroid Build Coastguard Worker 
147*523fa7a6SAndroid Build Coastguard Worker   // Time to first token is measured from the start of inference, excluding
148*523fa7a6SAndroid Build Coastguard Worker   // model load time.
149*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
150*523fa7a6SAndroid Build Coastguard Worker       Info,
151*523fa7a6SAndroid Build Coastguard Worker       "\tTime to first generated token:\t%f (seconds)",
152*523fa7a6SAndroid Build Coastguard Worker       ((double)(stats.first_token_ms - stats.inference_start_ms) /
153*523fa7a6SAndroid Build Coastguard Worker        stats.SCALING_FACTOR_UNITS_PER_SECOND));
154*523fa7a6SAndroid Build Coastguard Worker 
155*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
156*523fa7a6SAndroid Build Coastguard Worker       Info,
157*523fa7a6SAndroid Build Coastguard Worker       "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
158*523fa7a6SAndroid Build Coastguard Worker       stats.num_prompt_tokens + stats.num_generated_tokens,
159*523fa7a6SAndroid Build Coastguard Worker       (double)stats.aggregate_sampling_time_ms /
160*523fa7a6SAndroid Build Coastguard Worker           stats.SCALING_FACTOR_UNITS_PER_SECOND);
161*523fa7a6SAndroid Build Coastguard Worker }
162*523fa7a6SAndroid Build Coastguard Worker 
163*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
164*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
165*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
166*523fa7a6SAndroid Build Coastguard Worker 
167*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
168*523fa7a6SAndroid Build Coastguard Worker namespace llm {
169*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
170*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
171*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::kTopp;
172*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::print_report;
173*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Stats;
174*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
175*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
176