1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <cctype> 12 #include <cmath> 13 #include <cstdio> 14 #include <cstdlib> 15 #include <cstring> 16 #include <memory> 17 #ifdef USE_ATEN_LIB 18 #include <torch/torch.h> 19 #endif 20 21 #include <executorch/runtime/core/exec_aten/exec_aten.h> 22 #include <executorch/runtime/platform/compiler.h> 23 24 namespace executorch { 25 namespace extension { 26 namespace llm { 27 // A simple llama2 sampler. 28 29 template <typename T> 30 struct ET_EXPERIMENTAL ProbIndex { 31 T prob; 32 int32_t index; 33 }; // struct used when sorting probabilities during top-p sampling 34 35 class ET_EXPERIMENTAL Sampler { 36 public: 37 Sampler( 38 int32_t vocab_size, 39 float temperature, 40 float topp, 41 unsigned long long rng_seed); 42 43 template <typename T> 44 int32_t sample(T* logits); 45 46 private: 47 template <typename T> 48 int32_t sample_topp(T* probabilities, float coin); 49 template <typename T> 50 int32_t sample_mult(T* probabilities, float coin); 51 template <typename T> 52 int32_t sample_argmax(T* probabilities); 53 54 private: 55 int32_t vocab_size_; 56 // reciprocal of temperature, or 0 if temperature == 0. 57 float inv_temperature_; 58 float topp_; 59 unsigned long long rng_state_; 60 }; 61 62 } // namespace llm 63 } // namespace extension 64 } // namespace executorch 65 66 namespace torch { 67 namespace executor { 68 // TODO(T197294990): Remove these deprecated aliases once all users have moved 69 // to the new `::executorch` namespaces. 70 using ::executorch::extension::llm::ProbIndex; 71 using ::executorch::extension::llm::Sampler; 72 } // namespace executor 73 } // namespace torch 74