1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker // An interface for LLM runners. Developers can create their own runner that 10*523fa7a6SAndroid Build Coastguard Worker // implements their own load and generation logic to run the model. 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Worker #pragma once 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker #include <functional> 15*523fa7a6SAndroid Build Coastguard Worker #include <string> 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/runner/stats.h> 18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h> 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker namespace executorch { 21*523fa7a6SAndroid Build Coastguard Worker namespace extension { 22*523fa7a6SAndroid Build Coastguard Worker namespace llm { 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL IRunner { 25*523fa7a6SAndroid Build Coastguard Worker public: 26*523fa7a6SAndroid Build Coastguard Worker virtual ~IRunner() = default; 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker // Checks if the model is loaded. 29*523fa7a6SAndroid Build Coastguard Worker virtual bool is_loaded() const = 0; 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker // Load the model and tokenizer. 32*523fa7a6SAndroid Build Coastguard Worker virtual ::executorch::runtime::Error load() = 0; 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker // Generate the output tokens. 35*523fa7a6SAndroid Build Coastguard Worker virtual ::executorch::runtime::Error generate( 36*523fa7a6SAndroid Build Coastguard Worker const std::string& prompt, 37*523fa7a6SAndroid Build Coastguard Worker int32_t seq_len, 38*523fa7a6SAndroid Build Coastguard Worker std::function<void(const std::string&)> token_callback = {}, 39*523fa7a6SAndroid Build Coastguard Worker std::function<void(const ::executorch::extension::llm::Stats&)> 40*523fa7a6SAndroid Build Coastguard Worker stats_callback = {}, 41*523fa7a6SAndroid Build Coastguard Worker bool echo = true, 42*523fa7a6SAndroid Build Coastguard Worker bool warming = false) = 0; 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker // Stop the generation. 45*523fa7a6SAndroid Build Coastguard Worker virtual void stop() = 0; 46*523fa7a6SAndroid Build Coastguard Worker }; 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker } // namespace llm 49*523fa7a6SAndroid Build Coastguard Worker } // namespace extension 50*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch 51