xref: /aosp_15_r20/external/executorch/extension/llm/sampler/sampler.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <cctype>
12*523fa7a6SAndroid Build Coastguard Worker #include <cmath>
13*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
14*523fa7a6SAndroid Build Coastguard Worker #include <cstdlib>
15*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
16*523fa7a6SAndroid Build Coastguard Worker #include <memory>
17*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB
18*523fa7a6SAndroid Build Coastguard Worker #include <torch/torch.h>
19*523fa7a6SAndroid Build Coastguard Worker #endif
20*523fa7a6SAndroid Build Coastguard Worker 
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/compiler.h>
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
25*523fa7a6SAndroid Build Coastguard Worker namespace extension {
26*523fa7a6SAndroid Build Coastguard Worker namespace llm {
27*523fa7a6SAndroid Build Coastguard Worker // A simple llama2 sampler.
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker template <typename T>
30*523fa7a6SAndroid Build Coastguard Worker struct ET_EXPERIMENTAL ProbIndex {
31*523fa7a6SAndroid Build Coastguard Worker   T prob;
32*523fa7a6SAndroid Build Coastguard Worker   int32_t index;
33*523fa7a6SAndroid Build Coastguard Worker }; // struct used when sorting probabilities during top-p sampling
34*523fa7a6SAndroid Build Coastguard Worker 
35*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL Sampler {
36*523fa7a6SAndroid Build Coastguard Worker  public:
37*523fa7a6SAndroid Build Coastguard Worker   Sampler(
38*523fa7a6SAndroid Build Coastguard Worker       int32_t vocab_size,
39*523fa7a6SAndroid Build Coastguard Worker       float temperature,
40*523fa7a6SAndroid Build Coastguard Worker       float topp,
41*523fa7a6SAndroid Build Coastguard Worker       unsigned long long rng_seed);
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
44*523fa7a6SAndroid Build Coastguard Worker   int32_t sample(T* logits);
45*523fa7a6SAndroid Build Coastguard Worker 
46*523fa7a6SAndroid Build Coastguard Worker  private:
47*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
48*523fa7a6SAndroid Build Coastguard Worker   int32_t sample_topp(T* probabilities, float coin);
49*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
50*523fa7a6SAndroid Build Coastguard Worker   int32_t sample_mult(T* probabilities, float coin);
51*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
52*523fa7a6SAndroid Build Coastguard Worker   int32_t sample_argmax(T* probabilities);
53*523fa7a6SAndroid Build Coastguard Worker 
54*523fa7a6SAndroid Build Coastguard Worker  private:
55*523fa7a6SAndroid Build Coastguard Worker   int32_t vocab_size_;
56*523fa7a6SAndroid Build Coastguard Worker   // reciprocal of temperature, or 0 if temperature == 0.
57*523fa7a6SAndroid Build Coastguard Worker   float inv_temperature_;
58*523fa7a6SAndroid Build Coastguard Worker   float topp_;
59*523fa7a6SAndroid Build Coastguard Worker   unsigned long long rng_state_;
60*523fa7a6SAndroid Build Coastguard Worker };
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
63*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
64*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
65*523fa7a6SAndroid Build Coastguard Worker 
66*523fa7a6SAndroid Build Coastguard Worker namespace torch {
67*523fa7a6SAndroid Build Coastguard Worker namespace executor {
68*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
69*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
70*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::ProbIndex;
71*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Sampler;
72*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
73*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
74