1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker // Runner stats for LLM
10*523fa7a6SAndroid Build Coastguard Worker #pragma once
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/runner/util.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/log.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
14*523fa7a6SAndroid Build Coastguard Worker #include <sstream>
15*523fa7a6SAndroid Build Coastguard Worker #include <string>
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
18*523fa7a6SAndroid Build Coastguard Worker namespace extension {
19*523fa7a6SAndroid Build Coastguard Worker namespace llm {
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker struct ET_EXPERIMENTAL Stats {
22*523fa7a6SAndroid Build Coastguard Worker // Scaling factor for timestamps - in this case, we use ms.
23*523fa7a6SAndroid Build Coastguard Worker const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
24*523fa7a6SAndroid Build Coastguard Worker // Time stamps for the different stages of the execution
25*523fa7a6SAndroid Build Coastguard Worker // model_load_start_ms: Start of model loading.
26*523fa7a6SAndroid Build Coastguard Worker long model_load_start_ms;
27*523fa7a6SAndroid Build Coastguard Worker // model_load_end_ms: End of model loading.
28*523fa7a6SAndroid Build Coastguard Worker long model_load_end_ms;
29*523fa7a6SAndroid Build Coastguard Worker // inference_start_ms: Immediately after the model is loaded (or we check
30*523fa7a6SAndroid Build Coastguard Worker // for model load), measure the inference time.
31*523fa7a6SAndroid Build Coastguard Worker // NOTE: It's actually the tokenizer encode + model execution time.
32*523fa7a6SAndroid Build Coastguard Worker long inference_start_ms;
33*523fa7a6SAndroid Build Coastguard Worker // End of the tokenizer encode time.
34*523fa7a6SAndroid Build Coastguard Worker long token_encode_end_ms;
35*523fa7a6SAndroid Build Coastguard Worker // Start of the model execution (forward function) time.
36*523fa7a6SAndroid Build Coastguard Worker long model_execution_start_ms;
37*523fa7a6SAndroid Build Coastguard Worker // End of the model execution (forward function) time.
38*523fa7a6SAndroid Build Coastguard Worker long model_execution_end_ms;
39*523fa7a6SAndroid Build Coastguard Worker // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
40*523fa7a6SAndroid Build Coastguard Worker // before the inference loop starts
41*523fa7a6SAndroid Build Coastguard Worker long prompt_eval_end_ms;
42*523fa7a6SAndroid Build Coastguard Worker // first_token: Timestamp when the first generated token is emitted
43*523fa7a6SAndroid Build Coastguard Worker long first_token_ms;
44*523fa7a6SAndroid Build Coastguard Worker // inference_end_ms: End of inference/generation.
45*523fa7a6SAndroid Build Coastguard Worker long inference_end_ms;
46*523fa7a6SAndroid Build Coastguard Worker // Keep a running total of the time spent in sampling.
47*523fa7a6SAndroid Build Coastguard Worker long aggregate_sampling_time_ms;
48*523fa7a6SAndroid Build Coastguard Worker // Token count from prompt
49*523fa7a6SAndroid Build Coastguard Worker int64_t num_prompt_tokens;
50*523fa7a6SAndroid Build Coastguard Worker // Token count from generated (total - prompt)
51*523fa7a6SAndroid Build Coastguard Worker int64_t num_generated_tokens;
on_sampling_beginStats52*523fa7a6SAndroid Build Coastguard Worker inline void on_sampling_begin() {
53*523fa7a6SAndroid Build Coastguard Worker aggregate_sampling_timer_start_timestamp = time_in_ms();
54*523fa7a6SAndroid Build Coastguard Worker }
on_sampling_endStats55*523fa7a6SAndroid Build Coastguard Worker inline void on_sampling_end() {
56*523fa7a6SAndroid Build Coastguard Worker aggregate_sampling_time_ms +=
57*523fa7a6SAndroid Build Coastguard Worker time_in_ms() - aggregate_sampling_timer_start_timestamp;
58*523fa7a6SAndroid Build Coastguard Worker aggregate_sampling_timer_start_timestamp = 0;
59*523fa7a6SAndroid Build Coastguard Worker }
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker void reset(bool all_stats = false) {
62*523fa7a6SAndroid Build Coastguard Worker // Not resetting model_load_start_ms and model_load_end_ms because reset is
63*523fa7a6SAndroid Build Coastguard Worker // typically called after warmup and before running the actual run.
64*523fa7a6SAndroid Build Coastguard Worker // However, we don't load the model again during the actual run after
65*523fa7a6SAndroid Build Coastguard Worker // warmup. So, we don't want to reset these timestamps unless we are
66*523fa7a6SAndroid Build Coastguard Worker // resetting everything.
67*523fa7a6SAndroid Build Coastguard Worker if (all_stats) {
68*523fa7a6SAndroid Build Coastguard Worker model_load_start_ms = 0;
69*523fa7a6SAndroid Build Coastguard Worker model_load_end_ms = 0;
70*523fa7a6SAndroid Build Coastguard Worker }
71*523fa7a6SAndroid Build Coastguard Worker inference_start_ms = 0;
72*523fa7a6SAndroid Build Coastguard Worker prompt_eval_end_ms = 0;
73*523fa7a6SAndroid Build Coastguard Worker first_token_ms = 0;
74*523fa7a6SAndroid Build Coastguard Worker inference_end_ms = 0;
75*523fa7a6SAndroid Build Coastguard Worker aggregate_sampling_time_ms = 0;
76*523fa7a6SAndroid Build Coastguard Worker num_prompt_tokens = 0;
77*523fa7a6SAndroid Build Coastguard Worker num_generated_tokens = 0;
78*523fa7a6SAndroid Build Coastguard Worker aggregate_sampling_timer_start_timestamp = 0;
79*523fa7a6SAndroid Build Coastguard Worker }
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker private:
82*523fa7a6SAndroid Build Coastguard Worker long aggregate_sampling_timer_start_timestamp = 0;
83*523fa7a6SAndroid Build Coastguard Worker };
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kTopp = 0.9f;
86*523fa7a6SAndroid Build Coastguard Worker
stats_to_json_string(const Stats & stats)87*523fa7a6SAndroid Build Coastguard Worker inline std::string stats_to_json_string(const Stats& stats) {
88*523fa7a6SAndroid Build Coastguard Worker std::stringstream ss;
89*523fa7a6SAndroid Build Coastguard Worker ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
90*523fa7a6SAndroid Build Coastguard Worker << "\"generated_tokens\":" << stats.num_generated_tokens << ","
91*523fa7a6SAndroid Build Coastguard Worker << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
92*523fa7a6SAndroid Build Coastguard Worker << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
93*523fa7a6SAndroid Build Coastguard Worker << "\"inference_start_ms\":" << stats.inference_start_ms << ","
94*523fa7a6SAndroid Build Coastguard Worker << "\"inference_end_ms\":" << stats.inference_end_ms << ","
95*523fa7a6SAndroid Build Coastguard Worker << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
96*523fa7a6SAndroid Build Coastguard Worker << "\"first_token_ms\":" << stats.first_token_ms << ","
97*523fa7a6SAndroid Build Coastguard Worker << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
98*523fa7a6SAndroid Build Coastguard Worker << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
99*523fa7a6SAndroid Build Coastguard Worker << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
100*523fa7a6SAndroid Build Coastguard Worker return ss.str();
101*523fa7a6SAndroid Build Coastguard Worker }
102*523fa7a6SAndroid Build Coastguard Worker
print_report(const Stats & stats)103*523fa7a6SAndroid Build Coastguard Worker inline void print_report(const Stats& stats) {
104*523fa7a6SAndroid Build Coastguard Worker printf("PyTorchObserver %s\n", stats_to_json_string(stats).c_str());
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
107*523fa7a6SAndroid Build Coastguard Worker Info,
108*523fa7a6SAndroid Build Coastguard Worker "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
109*523fa7a6SAndroid Build Coastguard Worker stats.num_prompt_tokens,
110*523fa7a6SAndroid Build Coastguard Worker stats.num_generated_tokens);
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
113*523fa7a6SAndroid Build Coastguard Worker Info,
114*523fa7a6SAndroid Build Coastguard Worker "\tModel Load Time:\t\t%f (seconds)",
115*523fa7a6SAndroid Build Coastguard Worker ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
116*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND));
117*523fa7a6SAndroid Build Coastguard Worker double inference_time_ms =
118*523fa7a6SAndroid Build Coastguard Worker (double)(stats.inference_end_ms - stats.inference_start_ms);
119*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
120*523fa7a6SAndroid Build Coastguard Worker Info,
121*523fa7a6SAndroid Build Coastguard Worker "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
122*523fa7a6SAndroid Build Coastguard Worker inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker (stats.num_generated_tokens) /
125*523fa7a6SAndroid Build Coastguard Worker (double)(stats.inference_end_ms - stats.inference_start_ms) *
126*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND);
127*523fa7a6SAndroid Build Coastguard Worker double prompt_eval_time =
128*523fa7a6SAndroid Build Coastguard Worker (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
129*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
130*523fa7a6SAndroid Build Coastguard Worker Info,
131*523fa7a6SAndroid Build Coastguard Worker "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
132*523fa7a6SAndroid Build Coastguard Worker prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
133*523fa7a6SAndroid Build Coastguard Worker (stats.num_prompt_tokens) / prompt_eval_time *
134*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND);
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker double eval_time =
137*523fa7a6SAndroid Build Coastguard Worker (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
138*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
139*523fa7a6SAndroid Build Coastguard Worker Info,
140*523fa7a6SAndroid Build Coastguard Worker "\t\tGenerated %" PRIu64
141*523fa7a6SAndroid Build Coastguard Worker " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
142*523fa7a6SAndroid Build Coastguard Worker stats.num_generated_tokens,
143*523fa7a6SAndroid Build Coastguard Worker eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
144*523fa7a6SAndroid Build Coastguard Worker stats.num_generated_tokens / eval_time *
145*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND);
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker // Time to first token is measured from the start of inference, excluding
148*523fa7a6SAndroid Build Coastguard Worker // model load time.
149*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
150*523fa7a6SAndroid Build Coastguard Worker Info,
151*523fa7a6SAndroid Build Coastguard Worker "\tTime to first generated token:\t%f (seconds)",
152*523fa7a6SAndroid Build Coastguard Worker ((double)(stats.first_token_ms - stats.inference_start_ms) /
153*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND));
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
156*523fa7a6SAndroid Build Coastguard Worker Info,
157*523fa7a6SAndroid Build Coastguard Worker "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
158*523fa7a6SAndroid Build Coastguard Worker stats.num_prompt_tokens + stats.num_generated_tokens,
159*523fa7a6SAndroid Build Coastguard Worker (double)stats.aggregate_sampling_time_ms /
160*523fa7a6SAndroid Build Coastguard Worker stats.SCALING_FACTOR_UNITS_PER_SECOND);
161*523fa7a6SAndroid Build Coastguard Worker }
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
164*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
165*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
168*523fa7a6SAndroid Build Coastguard Worker namespace llm {
169*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
170*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
171*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::kTopp;
172*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::print_report;
173*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Stats;
174*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
175*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
176