xref: /aosp_15_r20/external/executorch/extension/llm/sampler/sampler.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // This is a modified version of https://github.com/karpathy/llama2.c.git
10 // @lint-ignore-every LICENSELINT
11 /**
12  * MIT License
13  *
14  * Copyright (c) 2023 Andrej
15  *
16  * Permission is hereby granted, free of charge, to any person obtaining a copy
17  * of this software and associated documentation files (the "Software"), to deal
18  * in the Software without restriction, including without limitation the rights
19  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20  * copies of the Software, and to permit persons to whom the Software is
21  * furnished to do so, subject to the following conditions:
22  *
23  * The above copyright notice and this permission notice shall be included in
24  * all copies or substantial portions of the Software.
25  *
26  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32  * SOFTWARE.
33  */
34 
35 #include <executorch/extension/llm/sampler/sampler.h>
36 #include <algorithm>
37 
38 namespace executorch {
39 namespace extension {
40 namespace llm {
41 
42 // sampler stuff
43 template <typename T>
sample_argmax(T * probabilities)44 int32_t Sampler::sample_argmax(T* probabilities) {
45   // return the index that has the highest probability
46   int max_i = 0;
47   T max_p = probabilities[0];
48   for (int i = 1; i < vocab_size_; i++) {
49     if (probabilities[i] > max_p) {
50       max_i = i;
51       max_p = probabilities[i];
52     }
53   }
54   return max_i;
55 }
56 
57 template <typename T>
sample_mult(T * probabilities,float coin)58 int32_t Sampler::sample_mult(T* probabilities, float coin) {
59   // sample index from probabilities (they must sum to 1!)
60   // coin is a random number in [0, 1), usually from random_f32()
61   T cdf = 0.0;
62   for (int i = 0; i < vocab_size_; i++) {
63     cdf += probabilities[i];
64     if (coin < cdf) {
65       return i;
66     }
67   }
68   return vocab_size_ - 1; // in case of rounding errors
69 }
70 
71 template <typename T>
sample_topp(T * probabilities,float coin)72 int32_t Sampler::sample_topp(T* probabilities, float coin) {
73   // top-p sampling (or "nucleus sampling") samples from the smallest set of
74   // tokens that exceed probability topp. This way we never sample tokens that
75   // have very low probabilities and are less likely to go "off the rails".
76   // coin is a random number in [0, 1), usually from random_f32()
77   int n = vocab_size_;
78   int n0 = 0;
79   // quicksort indices in descending order of probabilities
80   // values smaller than (1 - topp) / (n - 1) cannot be part of the result
81   // so for efficiency we crop these out as candidates before sorting
82   std::unique_ptr<ProbIndex<T>[]> probindex =
83       std::make_unique<ProbIndex<T>[]>(vocab_size_);
84 
85   const float cutoff = (1.0f - topp_) / (n - 1);
86   for (int i = 0; i < n; i++) {
87     if (probabilities[i] >= cutoff) {
88       probindex[n0].index = i;
89       probindex[n0].prob = probabilities[i];
90       n0++;
91     }
92   }
93 
94   auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
95     return a.prob > b.prob;
96   };
97   std::sort(probindex.get(), probindex.get() + n0, compare);
98 
99   // truncate the list where cumulative probability exceeds topp
100   T cumulative_prob = 0;
101   int last_idx = n0 - 1; // in case of rounding errors consider all elements
102   for (int i = 0; i < n0; i++) {
103     cumulative_prob += probindex[i].prob;
104     if (cumulative_prob > topp_) {
105       last_idx = i;
106       break; // we've exceeded topp by including last_idx
107     }
108   }
109 
110   // sample from the truncated list
111   const T& r = coin * cumulative_prob;
112   T cdf = 0;
113   for (int i = 0; i <= last_idx; i++) {
114     cdf += probindex[i].prob;
115     if (r < cdf) {
116       return probindex[i].index;
117     }
118   }
119   return probindex[last_idx].index; // in case of rounding errors
120 }
121 
Sampler(int vocab_size,float temperature,float topp,unsigned long long rng_seed)122 Sampler::Sampler(
123     int vocab_size,
124     float temperature,
125     float topp,
126     unsigned long long rng_seed)
127     : vocab_size_(vocab_size),
128       inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
129       topp_(topp),
130       rng_state_(rng_seed) {}
131 
132 template <typename T>
softmax(T * x,int size)133 static void softmax(T* x, int size) {
134   // find max value (for numerical stability)
135   T max_val = x[0];
136   for (int i = 1; i < size; i++) {
137     if (x[i] > max_val) {
138       max_val = x[i];
139     }
140   }
141   // exp and sum
142   T sum = 0;
143   for (int i = 0; i < size; i++) {
144     x[i] = expf(x[i] - max_val);
145     sum += x[i];
146   }
147   // normalize
148   for (int i = 0; i < size; i++) {
149     x[i] /= sum;
150   }
151 }
152 
random_u32(unsigned long long * state)153 static unsigned int random_u32(unsigned long long* state) {
154   // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
155   *state ^= *state >> 12;
156   *state ^= *state << 25;
157   *state ^= *state >> 27;
158   return (*state * 0x2545F4914F6CDD1Dull) >> 32;
159 }
160 
random_f32(unsigned long long * state)161 static float random_f32(unsigned long long* state) { // random float32 in [0,1)
162   return (random_u32(state) >> 8) / 16777216.0f;
163 }
164 
165 template <typename T>
sample(T * logits)166 int32_t Sampler::sample(T* logits) {
167   // sample the token given the logits and some hyperparameters
168   int next;
169   if (inv_temperature_ == 0.0f) {
170     // greedy argmax sampling: take the token with the highest probability
171     next = sample_argmax(logits);
172   } else {
173     // apply the temperature to the logits
174     for (int q = 0; q < vocab_size_; q++) {
175       logits[q] *= inv_temperature_;
176     }
177     // apply softmax to the logits to get the probabilities for next token
178     softmax(logits, vocab_size_);
179     // flip a (float) coin (this is our source of entropy for sampling)
180     float coin = random_f32(&rng_state_);
181     // we sample from this distribution to get the next token
182     if (topp_ <= 0 || topp_ >= 1) {
183       // simply sample from the predicted probability distribution
184       next = sample_mult(logits, coin);
185     } else {
186       // top-p (nucleus) sampling, clamping the least likely tokens to zero
187       next = sample_topp(logits, coin);
188     }
189   }
190   return next;
191 }
192 
193 template int32_t Sampler::sample<float>(float* logits);
194 template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
195 template int32_t Sampler::sample<exec_aten::BFloat16>(
196     exec_aten::BFloat16* logits);
197 
198 } // namespace llm
199 } // namespace extension
200 } // namespace executorch
201