1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // This is a modified version of https://github.com/karpathy/llama2.c.git
10 // @lint-ignore-every LICENSELINT
11 /**
12 * MIT License
13 *
14 * Copyright (c) 2023 Andrej
15 *
16 * Permission is hereby granted, free of charge, to any person obtaining a copy
17 * of this software and associated documentation files (the "Software"), to deal
18 * in the Software without restriction, including without limitation the rights
19 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 * copies of the Software, and to permit persons to whom the Software is
21 * furnished to do so, subject to the following conditions:
22 *
23 * The above copyright notice and this permission notice shall be included in
24 * all copies or substantial portions of the Software.
25 *
26 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 * SOFTWARE.
33 */
34
35 #include <executorch/extension/llm/sampler/sampler.h>
36 #include <algorithm>
37
38 namespace executorch {
39 namespace extension {
40 namespace llm {
41
42 // sampler stuff
43 template <typename T>
sample_argmax(T * probabilities)44 int32_t Sampler::sample_argmax(T* probabilities) {
45 // return the index that has the highest probability
46 int max_i = 0;
47 T max_p = probabilities[0];
48 for (int i = 1; i < vocab_size_; i++) {
49 if (probabilities[i] > max_p) {
50 max_i = i;
51 max_p = probabilities[i];
52 }
53 }
54 return max_i;
55 }
56
57 template <typename T>
sample_mult(T * probabilities,float coin)58 int32_t Sampler::sample_mult(T* probabilities, float coin) {
59 // sample index from probabilities (they must sum to 1!)
60 // coin is a random number in [0, 1), usually from random_f32()
61 T cdf = 0.0;
62 for (int i = 0; i < vocab_size_; i++) {
63 cdf += probabilities[i];
64 if (coin < cdf) {
65 return i;
66 }
67 }
68 return vocab_size_ - 1; // in case of rounding errors
69 }
70
71 template <typename T>
sample_topp(T * probabilities,float coin)72 int32_t Sampler::sample_topp(T* probabilities, float coin) {
73 // top-p sampling (or "nucleus sampling") samples from the smallest set of
74 // tokens that exceed probability topp. This way we never sample tokens that
75 // have very low probabilities and are less likely to go "off the rails".
76 // coin is a random number in [0, 1), usually from random_f32()
77 int n = vocab_size_;
78 int n0 = 0;
79 // quicksort indices in descending order of probabilities
80 // values smaller than (1 - topp) / (n - 1) cannot be part of the result
81 // so for efficiency we crop these out as candidates before sorting
82 std::unique_ptr<ProbIndex<T>[]> probindex =
83 std::make_unique<ProbIndex<T>[]>(vocab_size_);
84
85 const float cutoff = (1.0f - topp_) / (n - 1);
86 for (int i = 0; i < n; i++) {
87 if (probabilities[i] >= cutoff) {
88 probindex[n0].index = i;
89 probindex[n0].prob = probabilities[i];
90 n0++;
91 }
92 }
93
94 auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
95 return a.prob > b.prob;
96 };
97 std::sort(probindex.get(), probindex.get() + n0, compare);
98
99 // truncate the list where cumulative probability exceeds topp
100 T cumulative_prob = 0;
101 int last_idx = n0 - 1; // in case of rounding errors consider all elements
102 for (int i = 0; i < n0; i++) {
103 cumulative_prob += probindex[i].prob;
104 if (cumulative_prob > topp_) {
105 last_idx = i;
106 break; // we've exceeded topp by including last_idx
107 }
108 }
109
110 // sample from the truncated list
111 const T& r = coin * cumulative_prob;
112 T cdf = 0;
113 for (int i = 0; i <= last_idx; i++) {
114 cdf += probindex[i].prob;
115 if (r < cdf) {
116 return probindex[i].index;
117 }
118 }
119 return probindex[last_idx].index; // in case of rounding errors
120 }
121
Sampler(int vocab_size,float temperature,float topp,unsigned long long rng_seed)122 Sampler::Sampler(
123 int vocab_size,
124 float temperature,
125 float topp,
126 unsigned long long rng_seed)
127 : vocab_size_(vocab_size),
128 inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
129 topp_(topp),
130 rng_state_(rng_seed) {}
131
132 template <typename T>
softmax(T * x,int size)133 static void softmax(T* x, int size) {
134 // find max value (for numerical stability)
135 T max_val = x[0];
136 for (int i = 1; i < size; i++) {
137 if (x[i] > max_val) {
138 max_val = x[i];
139 }
140 }
141 // exp and sum
142 T sum = 0;
143 for (int i = 0; i < size; i++) {
144 x[i] = expf(x[i] - max_val);
145 sum += x[i];
146 }
147 // normalize
148 for (int i = 0; i < size; i++) {
149 x[i] /= sum;
150 }
151 }
152
random_u32(unsigned long long * state)153 static unsigned int random_u32(unsigned long long* state) {
154 // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
155 *state ^= *state >> 12;
156 *state ^= *state << 25;
157 *state ^= *state >> 27;
158 return (*state * 0x2545F4914F6CDD1Dull) >> 32;
159 }
160
random_f32(unsigned long long * state)161 static float random_f32(unsigned long long* state) { // random float32 in [0,1)
162 return (random_u32(state) >> 8) / 16777216.0f;
163 }
164
165 template <typename T>
sample(T * logits)166 int32_t Sampler::sample(T* logits) {
167 // sample the token given the logits and some hyperparameters
168 int next;
169 if (inv_temperature_ == 0.0f) {
170 // greedy argmax sampling: take the token with the highest probability
171 next = sample_argmax(logits);
172 } else {
173 // apply the temperature to the logits
174 for (int q = 0; q < vocab_size_; q++) {
175 logits[q] *= inv_temperature_;
176 }
177 // apply softmax to the logits to get the probabilities for next token
178 softmax(logits, vocab_size_);
179 // flip a (float) coin (this is our source of entropy for sampling)
180 float coin = random_f32(&rng_state_);
181 // we sample from this distribution to get the next token
182 if (topp_ <= 0 || topp_ >= 1) {
183 // simply sample from the predicted probability distribution
184 next = sample_mult(logits, coin);
185 } else {
186 // top-p (nucleus) sampling, clamping the least likely tokens to zero
187 next = sample_topp(logits, coin);
188 }
189 }
190 return next;
191 }
192
193 template int32_t Sampler::sample<float>(float* logits);
194 template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
195 template int32_t Sampler::sample<exec_aten::BFloat16>(
196 exec_aten::BFloat16* logits);
197
198 } // namespace llm
199 } // namespace extension
200 } // namespace executorch
201