xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/Utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include "llm_helper/include/llm_types.h"
12 
13 #include <chrono>
14 #include <fstream>
15 #include <regex>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19 #include <vector>
20 
21 namespace example {
22 namespace utils {
23 
24 class Timer {
25  public:
Timer(std::function<void (double)> callback)26   explicit Timer(std::function<void(double)> callback) : mCallback(callback) {}
27 
Start()28   void Start() {
29     mTimeStart = std::chrono::high_resolution_clock::now();
30   }
31 
End()32   void End() {
33     const auto time_end = std::chrono::high_resolution_clock::now();
34     const double elapsed_time_sec =
35         std::chrono::duration_cast<std::chrono::microseconds>(
36             time_end - mTimeStart)
37             .count() /
38         1000000.0;
39     mCallback(elapsed_time_sec);
40   }
41 
42  private:
43   std::chrono::high_resolution_clock::time_point mTimeStart;
44   std::function<void(double)> mCallback;
45 };
46 
47 // Split string via a separator character
split(const std::string & str,const char sep)48 static std::vector<std::string> split(const std::string& str, const char sep) {
49   std::vector<std::string> tokens;
50   std::ostringstream match_pattern;
51   match_pattern << "([^" << sep << "]+)";
52   const std::regex token_pattern(match_pattern.str());
53   std::smatch match;
54   auto cur = str.cbegin();
55   while (std::regex_search(cur, str.cend(), match, token_pattern)) {
56     tokens.push_back(match[0].str());
57     cur = match.suffix().first;
58   }
59   return tokens;
60 }
61 
read_file(const std::string & filepath)62 static std::string read_file(const std::string& filepath) {
63   std::ifstream file(filepath);
64   std::stringstream buffer;
65   buffer << file.rdbuf();
66   return buffer.str();
67 }
68 
69 template <typename LogitsType>
argmax(const void * logits_buffer,const size_t vocab_size)70 static uint64_t argmax(const void* logits_buffer, const size_t vocab_size) {
71   auto logits = reinterpret_cast<const LogitsType*>(logits_buffer);
72   LogitsType max = logits[0];
73   uint64_t index = 0;
74   for (size_t i = 1; i < vocab_size; i++) {
75     if (logits[i] > max) {
76       max = logits[i];
77       index = i;
78     }
79   }
80   return index;
81 }
82 
argmax(const llm_helper::LLMType logitsType,const void * logits_buffer,const size_t vocab_size)83 static uint64_t argmax(
84     const llm_helper::LLMType logitsType,
85     const void* logits_buffer,
86     const size_t vocab_size) {
87   switch (logitsType) {
88     case llm_helper::LLMType::INT16:
89       return argmax<int16_t>(logits_buffer, vocab_size);
90     case llm_helper::LLMType::FP16:
91       return argmax<__fp16>(logits_buffer, vocab_size);
92     case llm_helper::LLMType::FP32:
93       return argmax<float>(logits_buffer, vocab_size);
94     default:
95       ET_LOG(
96           Error,
97           "Unsupported logits type for argmax: %s",
98           getLLMTypeName(logitsType));
99       return 0;
100   }
101 }
102 
103 template <typename T>
to_string(const std::vector<T> vec)104 static std::string to_string(const std::vector<T> vec) {
105   std::ostringstream ss;
106   auto iter = vec.cbegin();
107   ss << "{" << *iter++;
108   while (iter != vec.cend()) {
109     ss << ", " << *iter++;
110   }
111   ss << "}";
112   return ss.str();
113 }
114 
115 } // namespace utils
116 } // namespace example
117