xref: /aosp_15_r20/external/executorch/extension/llm/sampler/sampler.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cctype>
12 #include <cmath>
13 #include <cstdio>
14 #include <cstdlib>
15 #include <cstring>
16 #include <memory>
17 #ifdef USE_ATEN_LIB
18 #include <torch/torch.h>
19 #endif
20 
21 #include <executorch/runtime/core/exec_aten/exec_aten.h>
22 #include <executorch/runtime/platform/compiler.h>
23 
24 namespace executorch {
25 namespace extension {
26 namespace llm {
27 // A simple llama2 sampler.
28 
29 template <typename T>
30 struct ET_EXPERIMENTAL ProbIndex {
31   T prob;
32   int32_t index;
33 }; // struct used when sorting probabilities during top-p sampling
34 
35 class ET_EXPERIMENTAL Sampler {
36  public:
37   Sampler(
38       int32_t vocab_size,
39       float temperature,
40       float topp,
41       unsigned long long rng_seed);
42 
43   template <typename T>
44   int32_t sample(T* logits);
45 
46  private:
47   template <typename T>
48   int32_t sample_topp(T* probabilities, float coin);
49   template <typename T>
50   int32_t sample_mult(T* probabilities, float coin);
51   template <typename T>
52   int32_t sample_argmax(T* probabilities);
53 
54  private:
55   int32_t vocab_size_;
56   // reciprocal of temperature, or 0 if temperature == 0.
57   float inv_temperature_;
58   float topp_;
59   unsigned long long rng_state_;
60 };
61 
62 } // namespace llm
63 } // namespace extension
64 } // namespace executorch
65 
66 namespace torch {
67 namespace executor {
68 // TODO(T197294990): Remove these deprecated aliases once all users have moved
69 // to the new `::executorch` namespaces.
70 using ::executorch::extension::llm::ProbIndex;
71 using ::executorch::extension::llm::Sampler;
72 } // namespace executor
73 } // namespace torch
74