1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker // This is a modified version of https://github.com/karpathy/llama2.c.git
10*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore-every LICENSELINT
11*523fa7a6SAndroid Build Coastguard Worker /**
12*523fa7a6SAndroid Build Coastguard Worker * MIT License
13*523fa7a6SAndroid Build Coastguard Worker *
14*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) 2023 Andrej
15*523fa7a6SAndroid Build Coastguard Worker *
16*523fa7a6SAndroid Build Coastguard Worker * Permission is hereby granted, free of charge, to any person obtaining a copy
17*523fa7a6SAndroid Build Coastguard Worker * of this software and associated documentation files (the "Software"), to deal
18*523fa7a6SAndroid Build Coastguard Worker * in the Software without restriction, including without limitation the rights
19*523fa7a6SAndroid Build Coastguard Worker * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20*523fa7a6SAndroid Build Coastguard Worker * copies of the Software, and to permit persons to whom the Software is
21*523fa7a6SAndroid Build Coastguard Worker * furnished to do so, subject to the following conditions:
22*523fa7a6SAndroid Build Coastguard Worker *
23*523fa7a6SAndroid Build Coastguard Worker * The above copyright notice and this permission notice shall be included in
24*523fa7a6SAndroid Build Coastguard Worker * all copies or substantial portions of the Software.
25*523fa7a6SAndroid Build Coastguard Worker *
26*523fa7a6SAndroid Build Coastguard Worker * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27*523fa7a6SAndroid Build Coastguard Worker * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28*523fa7a6SAndroid Build Coastguard Worker * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29*523fa7a6SAndroid Build Coastguard Worker * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30*523fa7a6SAndroid Build Coastguard Worker * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31*523fa7a6SAndroid Build Coastguard Worker * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32*523fa7a6SAndroid Build Coastguard Worker * SOFTWARE.
33*523fa7a6SAndroid Build Coastguard Worker */
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/sampler/sampler.h>
36*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
39*523fa7a6SAndroid Build Coastguard Worker namespace extension {
40*523fa7a6SAndroid Build Coastguard Worker namespace llm {
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker // sampler stuff
43*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_argmax(T * probabilities)44*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_argmax(T* probabilities) {
45*523fa7a6SAndroid Build Coastguard Worker // return the index that has the highest probability
46*523fa7a6SAndroid Build Coastguard Worker int max_i = 0;
47*523fa7a6SAndroid Build Coastguard Worker T max_p = probabilities[0];
48*523fa7a6SAndroid Build Coastguard Worker for (int i = 1; i < vocab_size_; i++) {
49*523fa7a6SAndroid Build Coastguard Worker if (probabilities[i] > max_p) {
50*523fa7a6SAndroid Build Coastguard Worker max_i = i;
51*523fa7a6SAndroid Build Coastguard Worker max_p = probabilities[i];
52*523fa7a6SAndroid Build Coastguard Worker }
53*523fa7a6SAndroid Build Coastguard Worker }
54*523fa7a6SAndroid Build Coastguard Worker return max_i;
55*523fa7a6SAndroid Build Coastguard Worker }
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_mult(T * probabilities,float coin)58*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_mult(T* probabilities, float coin) {
59*523fa7a6SAndroid Build Coastguard Worker // sample index from probabilities (they must sum to 1!)
60*523fa7a6SAndroid Build Coastguard Worker // coin is a random number in [0, 1), usually from random_f32()
61*523fa7a6SAndroid Build Coastguard Worker T cdf = 0.0;
62*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < vocab_size_; i++) {
63*523fa7a6SAndroid Build Coastguard Worker cdf += probabilities[i];
64*523fa7a6SAndroid Build Coastguard Worker if (coin < cdf) {
65*523fa7a6SAndroid Build Coastguard Worker return i;
66*523fa7a6SAndroid Build Coastguard Worker }
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker return vocab_size_ - 1; // in case of rounding errors
69*523fa7a6SAndroid Build Coastguard Worker }
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_topp(T * probabilities,float coin)72*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_topp(T* probabilities, float coin) {
73*523fa7a6SAndroid Build Coastguard Worker // top-p sampling (or "nucleus sampling") samples from the smallest set of
74*523fa7a6SAndroid Build Coastguard Worker // tokens that exceed probability topp. This way we never sample tokens that
75*523fa7a6SAndroid Build Coastguard Worker // have very low probabilities and are less likely to go "off the rails".
76*523fa7a6SAndroid Build Coastguard Worker // coin is a random number in [0, 1), usually from random_f32()
77*523fa7a6SAndroid Build Coastguard Worker int n = vocab_size_;
78*523fa7a6SAndroid Build Coastguard Worker int n0 = 0;
79*523fa7a6SAndroid Build Coastguard Worker // quicksort indices in descending order of probabilities
80*523fa7a6SAndroid Build Coastguard Worker // values smaller than (1 - topp) / (n - 1) cannot be part of the result
81*523fa7a6SAndroid Build Coastguard Worker // so for efficiency we crop these out as candidates before sorting
82*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<ProbIndex<T>[]> probindex =
83*523fa7a6SAndroid Build Coastguard Worker std::make_unique<ProbIndex<T>[]>(vocab_size_);
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker const float cutoff = (1.0f - topp_) / (n - 1);
86*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < n; i++) {
87*523fa7a6SAndroid Build Coastguard Worker if (probabilities[i] >= cutoff) {
88*523fa7a6SAndroid Build Coastguard Worker probindex[n0].index = i;
89*523fa7a6SAndroid Build Coastguard Worker probindex[n0].prob = probabilities[i];
90*523fa7a6SAndroid Build Coastguard Worker n0++;
91*523fa7a6SAndroid Build Coastguard Worker }
92*523fa7a6SAndroid Build Coastguard Worker }
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
95*523fa7a6SAndroid Build Coastguard Worker return a.prob > b.prob;
96*523fa7a6SAndroid Build Coastguard Worker };
97*523fa7a6SAndroid Build Coastguard Worker std::sort(probindex.get(), probindex.get() + n0, compare);
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker // truncate the list where cumulative probability exceeds topp
100*523fa7a6SAndroid Build Coastguard Worker T cumulative_prob = 0;
101*523fa7a6SAndroid Build Coastguard Worker int last_idx = n0 - 1; // in case of rounding errors consider all elements
102*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < n0; i++) {
103*523fa7a6SAndroid Build Coastguard Worker cumulative_prob += probindex[i].prob;
104*523fa7a6SAndroid Build Coastguard Worker if (cumulative_prob > topp_) {
105*523fa7a6SAndroid Build Coastguard Worker last_idx = i;
106*523fa7a6SAndroid Build Coastguard Worker break; // we've exceeded topp by including last_idx
107*523fa7a6SAndroid Build Coastguard Worker }
108*523fa7a6SAndroid Build Coastguard Worker }
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker // sample from the truncated list
111*523fa7a6SAndroid Build Coastguard Worker const T& r = coin * cumulative_prob;
112*523fa7a6SAndroid Build Coastguard Worker T cdf = 0;
113*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i <= last_idx; i++) {
114*523fa7a6SAndroid Build Coastguard Worker cdf += probindex[i].prob;
115*523fa7a6SAndroid Build Coastguard Worker if (r < cdf) {
116*523fa7a6SAndroid Build Coastguard Worker return probindex[i].index;
117*523fa7a6SAndroid Build Coastguard Worker }
118*523fa7a6SAndroid Build Coastguard Worker }
119*523fa7a6SAndroid Build Coastguard Worker return probindex[last_idx].index; // in case of rounding errors
120*523fa7a6SAndroid Build Coastguard Worker }
121*523fa7a6SAndroid Build Coastguard Worker
Sampler(int vocab_size,float temperature,float topp,unsigned long long rng_seed)122*523fa7a6SAndroid Build Coastguard Worker Sampler::Sampler(
123*523fa7a6SAndroid Build Coastguard Worker int vocab_size,
124*523fa7a6SAndroid Build Coastguard Worker float temperature,
125*523fa7a6SAndroid Build Coastguard Worker float topp,
126*523fa7a6SAndroid Build Coastguard Worker unsigned long long rng_seed)
127*523fa7a6SAndroid Build Coastguard Worker : vocab_size_(vocab_size),
128*523fa7a6SAndroid Build Coastguard Worker inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
129*523fa7a6SAndroid Build Coastguard Worker topp_(topp),
130*523fa7a6SAndroid Build Coastguard Worker rng_state_(rng_seed) {}
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker template <typename T>
softmax(T * x,int size)133*523fa7a6SAndroid Build Coastguard Worker static void softmax(T* x, int size) {
134*523fa7a6SAndroid Build Coastguard Worker // find max value (for numerical stability)
135*523fa7a6SAndroid Build Coastguard Worker T max_val = x[0];
136*523fa7a6SAndroid Build Coastguard Worker for (int i = 1; i < size; i++) {
137*523fa7a6SAndroid Build Coastguard Worker if (x[i] > max_val) {
138*523fa7a6SAndroid Build Coastguard Worker max_val = x[i];
139*523fa7a6SAndroid Build Coastguard Worker }
140*523fa7a6SAndroid Build Coastguard Worker }
141*523fa7a6SAndroid Build Coastguard Worker // exp and sum
142*523fa7a6SAndroid Build Coastguard Worker T sum = 0;
143*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < size; i++) {
144*523fa7a6SAndroid Build Coastguard Worker x[i] = expf(x[i] - max_val);
145*523fa7a6SAndroid Build Coastguard Worker sum += x[i];
146*523fa7a6SAndroid Build Coastguard Worker }
147*523fa7a6SAndroid Build Coastguard Worker // normalize
148*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < size; i++) {
149*523fa7a6SAndroid Build Coastguard Worker x[i] /= sum;
150*523fa7a6SAndroid Build Coastguard Worker }
151*523fa7a6SAndroid Build Coastguard Worker }
152*523fa7a6SAndroid Build Coastguard Worker
random_u32(unsigned long long * state)153*523fa7a6SAndroid Build Coastguard Worker static unsigned int random_u32(unsigned long long* state) {
154*523fa7a6SAndroid Build Coastguard Worker // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
155*523fa7a6SAndroid Build Coastguard Worker *state ^= *state >> 12;
156*523fa7a6SAndroid Build Coastguard Worker *state ^= *state << 25;
157*523fa7a6SAndroid Build Coastguard Worker *state ^= *state >> 27;
158*523fa7a6SAndroid Build Coastguard Worker return (*state * 0x2545F4914F6CDD1Dull) >> 32;
159*523fa7a6SAndroid Build Coastguard Worker }
160*523fa7a6SAndroid Build Coastguard Worker
random_f32(unsigned long long * state)161*523fa7a6SAndroid Build Coastguard Worker static float random_f32(unsigned long long* state) { // random float32 in [0,1)
162*523fa7a6SAndroid Build Coastguard Worker return (random_u32(state) >> 8) / 16777216.0f;
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample(T * logits)166*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample(T* logits) {
167*523fa7a6SAndroid Build Coastguard Worker // sample the token given the logits and some hyperparameters
168*523fa7a6SAndroid Build Coastguard Worker int next;
169*523fa7a6SAndroid Build Coastguard Worker if (inv_temperature_ == 0.0f) {
170*523fa7a6SAndroid Build Coastguard Worker // greedy argmax sampling: take the token with the highest probability
171*523fa7a6SAndroid Build Coastguard Worker next = sample_argmax(logits);
172*523fa7a6SAndroid Build Coastguard Worker } else {
173*523fa7a6SAndroid Build Coastguard Worker // apply the temperature to the logits
174*523fa7a6SAndroid Build Coastguard Worker for (int q = 0; q < vocab_size_; q++) {
175*523fa7a6SAndroid Build Coastguard Worker logits[q] *= inv_temperature_;
176*523fa7a6SAndroid Build Coastguard Worker }
177*523fa7a6SAndroid Build Coastguard Worker // apply softmax to the logits to get the probabilities for next token
178*523fa7a6SAndroid Build Coastguard Worker softmax(logits, vocab_size_);
179*523fa7a6SAndroid Build Coastguard Worker // flip a (float) coin (this is our source of entropy for sampling)
180*523fa7a6SAndroid Build Coastguard Worker float coin = random_f32(&rng_state_);
181*523fa7a6SAndroid Build Coastguard Worker // we sample from this distribution to get the next token
182*523fa7a6SAndroid Build Coastguard Worker if (topp_ <= 0 || topp_ >= 1) {
183*523fa7a6SAndroid Build Coastguard Worker // simply sample from the predicted probability distribution
184*523fa7a6SAndroid Build Coastguard Worker next = sample_mult(logits, coin);
185*523fa7a6SAndroid Build Coastguard Worker } else {
186*523fa7a6SAndroid Build Coastguard Worker // top-p (nucleus) sampling, clamping the least likely tokens to zero
187*523fa7a6SAndroid Build Coastguard Worker next = sample_topp(logits, coin);
188*523fa7a6SAndroid Build Coastguard Worker }
189*523fa7a6SAndroid Build Coastguard Worker }
190*523fa7a6SAndroid Build Coastguard Worker return next;
191*523fa7a6SAndroid Build Coastguard Worker }
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<float>(float* logits);
194*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
195*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<exec_aten::BFloat16>(
196*523fa7a6SAndroid Build Coastguard Worker exec_aten::BFloat16* logits);
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
199*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
200*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
201