1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #pragma once
10
11 #include "llm_helper/include/llm_types.h"
12
13 #include <chrono>
14 #include <fstream>
15 #include <regex>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19 #include <vector>
20
21 namespace example {
22 namespace utils {
23
24 class Timer {
25 public:
Timer(std::function<void (double)> callback)26 explicit Timer(std::function<void(double)> callback) : mCallback(callback) {}
27
Start()28 void Start() {
29 mTimeStart = std::chrono::high_resolution_clock::now();
30 }
31
End()32 void End() {
33 const auto time_end = std::chrono::high_resolution_clock::now();
34 const double elapsed_time_sec =
35 std::chrono::duration_cast<std::chrono::microseconds>(
36 time_end - mTimeStart)
37 .count() /
38 1000000.0;
39 mCallback(elapsed_time_sec);
40 }
41
42 private:
43 std::chrono::high_resolution_clock::time_point mTimeStart;
44 std::function<void(double)> mCallback;
45 };
46
47 // Split string via a separator character
split(const std::string & str,const char sep)48 static std::vector<std::string> split(const std::string& str, const char sep) {
49 std::vector<std::string> tokens;
50 std::ostringstream match_pattern;
51 match_pattern << "([^" << sep << "]+)";
52 const std::regex token_pattern(match_pattern.str());
53 std::smatch match;
54 auto cur = str.cbegin();
55 while (std::regex_search(cur, str.cend(), match, token_pattern)) {
56 tokens.push_back(match[0].str());
57 cur = match.suffix().first;
58 }
59 return tokens;
60 }
61
read_file(const std::string & filepath)62 static std::string read_file(const std::string& filepath) {
63 std::ifstream file(filepath);
64 std::stringstream buffer;
65 buffer << file.rdbuf();
66 return buffer.str();
67 }
68
69 template <typename LogitsType>
argmax(const void * logits_buffer,const size_t vocab_size)70 static uint64_t argmax(const void* logits_buffer, const size_t vocab_size) {
71 auto logits = reinterpret_cast<const LogitsType*>(logits_buffer);
72 LogitsType max = logits[0];
73 uint64_t index = 0;
74 for (size_t i = 1; i < vocab_size; i++) {
75 if (logits[i] > max) {
76 max = logits[i];
77 index = i;
78 }
79 }
80 return index;
81 }
82
argmax(const llm_helper::LLMType logitsType,const void * logits_buffer,const size_t vocab_size)83 static uint64_t argmax(
84 const llm_helper::LLMType logitsType,
85 const void* logits_buffer,
86 const size_t vocab_size) {
87 switch (logitsType) {
88 case llm_helper::LLMType::INT16:
89 return argmax<int16_t>(logits_buffer, vocab_size);
90 case llm_helper::LLMType::FP16:
91 return argmax<__fp16>(logits_buffer, vocab_size);
92 case llm_helper::LLMType::FP32:
93 return argmax<float>(logits_buffer, vocab_size);
94 default:
95 ET_LOG(
96 Error,
97 "Unsupported logits type for argmax: %s",
98 getLLMTypeName(logitsType));
99 return 0;
100 }
101 }
102
103 template <typename T>
to_string(const std::vector<T> vec)104 static std::string to_string(const std::vector<T> vec) {
105 std::ostringstream ss;
106 auto iter = vec.cbegin();
107 ss << "{" << *iter++;
108 while (iter != vec.cend()) {
109 ss << ", " << *iter++;
110 }
111 ss << "}";
112 return ss.str();
113 }
114
115 } // namespace utils
116 } // namespace example
117