xref: /aosp_15_r20/external/executorch/extension/llm/runner/irunner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker // An interface for LLM runners. Developers can create their own runner that
10*523fa7a6SAndroid Build Coastguard Worker // implements their own load and generation logic to run the model.
11*523fa7a6SAndroid Build Coastguard Worker 
12*523fa7a6SAndroid Build Coastguard Worker #pragma once
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <functional>
15*523fa7a6SAndroid Build Coastguard Worker #include <string>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/runner/stats.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
21*523fa7a6SAndroid Build Coastguard Worker namespace extension {
22*523fa7a6SAndroid Build Coastguard Worker namespace llm {
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL IRunner {
25*523fa7a6SAndroid Build Coastguard Worker  public:
26*523fa7a6SAndroid Build Coastguard Worker   virtual ~IRunner() = default;
27*523fa7a6SAndroid Build Coastguard Worker 
28*523fa7a6SAndroid Build Coastguard Worker   // Checks if the model is loaded.
29*523fa7a6SAndroid Build Coastguard Worker   virtual bool is_loaded() const = 0;
30*523fa7a6SAndroid Build Coastguard Worker 
31*523fa7a6SAndroid Build Coastguard Worker   // Load the model and tokenizer.
32*523fa7a6SAndroid Build Coastguard Worker   virtual ::executorch::runtime::Error load() = 0;
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker   // Generate the output tokens.
35*523fa7a6SAndroid Build Coastguard Worker   virtual ::executorch::runtime::Error generate(
36*523fa7a6SAndroid Build Coastguard Worker       const std::string& prompt,
37*523fa7a6SAndroid Build Coastguard Worker       int32_t seq_len,
38*523fa7a6SAndroid Build Coastguard Worker       std::function<void(const std::string&)> token_callback = {},
39*523fa7a6SAndroid Build Coastguard Worker       std::function<void(const ::executorch::extension::llm::Stats&)>
40*523fa7a6SAndroid Build Coastguard Worker           stats_callback = {},
41*523fa7a6SAndroid Build Coastguard Worker       bool echo = true,
42*523fa7a6SAndroid Build Coastguard Worker       bool warming = false) = 0;
43*523fa7a6SAndroid Build Coastguard Worker 
44*523fa7a6SAndroid Build Coastguard Worker   // Stop the generation.
45*523fa7a6SAndroid Build Coastguard Worker   virtual void stop() = 0;
46*523fa7a6SAndroid Build Coastguard Worker };
47*523fa7a6SAndroid Build Coastguard Worker 
48*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
49*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
50*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
51