1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #pragma once 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker #include <cctype> 12*523fa7a6SAndroid Build Coastguard Worker #include <cmath> 13*523fa7a6SAndroid Build Coastguard Worker #include <cstdio> 14*523fa7a6SAndroid Build Coastguard Worker #include <cstdlib> 15*523fa7a6SAndroid Build Coastguard Worker #include <cstring> 16*523fa7a6SAndroid Build Coastguard Worker #include <memory> 17*523fa7a6SAndroid Build Coastguard Worker #ifdef USE_ATEN_LIB 18*523fa7a6SAndroid Build Coastguard Worker #include <torch/torch.h> 19*523fa7a6SAndroid Build Coastguard Worker #endif 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h> 22*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/compiler.h> 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker namespace executorch { 25*523fa7a6SAndroid Build Coastguard Worker namespace extension { 26*523fa7a6SAndroid Build Coastguard Worker namespace llm { 27*523fa7a6SAndroid Build Coastguard Worker // A simple llama2 sampler. 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker template <typename T> 30*523fa7a6SAndroid Build Coastguard Worker struct ET_EXPERIMENTAL ProbIndex { 31*523fa7a6SAndroid Build Coastguard Worker T prob; 32*523fa7a6SAndroid Build Coastguard Worker int32_t index; 33*523fa7a6SAndroid Build Coastguard Worker }; // struct used when sorting probabilities during top-p sampling 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL Sampler { 36*523fa7a6SAndroid Build Coastguard Worker public: 37*523fa7a6SAndroid Build Coastguard Worker Sampler( 38*523fa7a6SAndroid Build Coastguard Worker int32_t vocab_size, 39*523fa7a6SAndroid Build Coastguard Worker float temperature, 40*523fa7a6SAndroid Build Coastguard Worker float topp, 41*523fa7a6SAndroid Build Coastguard Worker unsigned long long rng_seed); 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker template <typename T> 44*523fa7a6SAndroid Build Coastguard Worker int32_t sample(T* logits); 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker private: 47*523fa7a6SAndroid Build Coastguard Worker template <typename T> 48*523fa7a6SAndroid Build Coastguard Worker int32_t sample_topp(T* probabilities, float coin); 49*523fa7a6SAndroid Build Coastguard Worker template <typename T> 50*523fa7a6SAndroid Build Coastguard Worker int32_t sample_mult(T* probabilities, float coin); 51*523fa7a6SAndroid Build Coastguard Worker template <typename T> 52*523fa7a6SAndroid Build Coastguard Worker int32_t sample_argmax(T* probabilities); 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker private: 55*523fa7a6SAndroid Build Coastguard Worker int32_t vocab_size_; 56*523fa7a6SAndroid Build Coastguard Worker // reciprocal of temperature, or 0 if temperature == 0. 57*523fa7a6SAndroid Build Coastguard Worker float inv_temperature_; 58*523fa7a6SAndroid Build Coastguard Worker float topp_; 59*523fa7a6SAndroid Build Coastguard Worker unsigned long long rng_state_; 60*523fa7a6SAndroid Build Coastguard Worker }; 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker } // namespace llm 63*523fa7a6SAndroid Build Coastguard Worker } // namespace extension 64*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker namespace torch { 67*523fa7a6SAndroid Build Coastguard Worker namespace executor { 68*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved 69*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces. 70*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::ProbIndex; 71*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Sampler; 72*523fa7a6SAndroid Build Coastguard Worker } // namespace executor 73*523fa7a6SAndroid Build Coastguard Worker } // namespace torch 74