/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include #ifdef USE_ATEN_LIB #include #endif #include #include namespace executorch { namespace extension { namespace llm { // A simple llama2 sampler. template struct ET_EXPERIMENTAL ProbIndex { T prob; int32_t index; }; // struct used when sorting probabilities during top-p sampling class ET_EXPERIMENTAL Sampler { public: Sampler( int32_t vocab_size, float temperature, float topp, unsigned long long rng_seed); template int32_t sample(T* logits); private: template int32_t sample_topp(T* probabilities, float coin); template int32_t sample_mult(T* probabilities, float coin); template int32_t sample_argmax(T* probabilities); private: int32_t vocab_size_; // reciprocal of temperature, or 0 if temperature == 0. float inv_temperature_; float topp_; unsigned long long rng_state_; }; } // namespace llm } // namespace extension } // namespace executorch namespace torch { namespace executor { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::extension::llm::ProbIndex; using ::executorch::extension::llm::Sampler; } // namespace executor } // namespace torch