xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runner/model_container_runner.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #pragma once
3 
4 #include <ATen/Tensor.h>
5 #include <torch/csrc/inductor/aoti_runtime/interface.h>
6 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
7 
8 // Forward declare DynamicLibrary
9 namespace at {
10 struct DynamicLibrary;
11 }
12 
13 namespace torch::inductor {
14 using TensorConstantMap = std::unordered_map<std::string, at::Tensor*>;
15 
16 class TORCH_API AOTIModelContainerRunner {
17  public:
18   AOTIModelContainerRunner() = delete;
19   AOTIModelContainerRunner(const AOTIModelContainerRunner& other) = delete;
20   AOTIModelContainerRunner(AOTIModelContainerRunner&& other) = delete;
21   AOTIModelContainerRunner& operator=(const AOTIModelContainerRunner& other) =
22       delete;
23   AOTIModelContainerRunner& operator=(AOTIModelContainerRunner&& other) =
24       delete;
25   ~AOTIModelContainerRunner();
26 
27   std::vector<at::Tensor> run(
28       std::vector<at::Tensor>& inputs,
29       AOTInductorStreamHandle cuda_stream_handle = nullptr);
30 
31   std::unordered_map<std::string, std::string> getConstantNamesToOriginalFQNs()
32       const;
33   std::unordered_map<std::string, int32_t> getConstantNamesToDtypes() const;
34   void update_inactive_constant_buffer(const TensorConstantMap& const_map);
35   void update_constant_buffer(
36       const TensorConstantMap& const_map,
37       bool use_inactive,
38       bool validate_full_updates);
39   void run_const_fold(
40       bool use_inactive,
41       AOTInductorStreamHandle cuda_stream_handle = nullptr);
42   void swap_constant_buffer();
43 
44   std::vector<std::string> get_call_spec();
45 
46  protected:
47   AOTIModelContainerRunner(
48       const std::string& model_so_path,
49       size_t num_models,
50       const std::string& device_str,
51       const std::string& cubin_dir);
52 
53   std::unique_ptr<at::DynamicLibrary> model_so_;
54   decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr};
55   decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr};
56   decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{
57       nullptr};
58   decltype(&AOTInductorModelContainerRun) run_func_{nullptr};
59   decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{
60       nullptr};
61   decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{
62       nullptr};
63   decltype(&AOTInductorModelContainerGetConstantOriginalFQN)
64       get_constant_original_fqn_func_{nullptr};
65   decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{
66       nullptr};
67   decltype(&AOTInductorModelContainerUpdateConstantBuffer)
68       update_constant_buffer_func_{nullptr};
69   decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer)
70       update_inactive_constant_buffer_func_{nullptr};
71   decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{
72       nullptr};
73   decltype(&AOTInductorModelContainerSwapConstantBuffer)
74       swap_constant_buffer_func_{nullptr};
75   decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr};
76 
77   AOTInductorModelContainerHandle container_handle_ = nullptr;
78 
79   AOTIProxyExecutorHandle proxy_executor_handle_;
80 
81  private:
82   std::unique_ptr<torch::aot_inductor::ProxyExecutor> proxy_executor_;
83 };
84 
85 using CreateAOTIModelRunnerFunc = std::shared_ptr<AOTIModelContainerRunner> (*)(
86     const std::string& model_so_path,
87     size_t num_models,
88     const std::string& device_str,
89     const std::string& bin_dir);
90 
91 // Return a global map "device name" -> "aoti model runner create function" for
92 // all registered in AOTI external backends
93 TORCH_API std::unordered_map<std::string, CreateAOTIModelRunnerFunc>&
94 getAOTIModelRunnerRegistry();
95 
96 // To register a new external backend in AOTI one needs to create an instance of
97 // this struct. It is not thread-safe. Becase it is expected to be called during
98 // the initialization of the program.
99 struct TORCH_API RegisterAOTIModelRunner {
RegisterAOTIModelRunnerRegisterAOTIModelRunner100   RegisterAOTIModelRunner(
101       const std::string& name,
102       CreateAOTIModelRunnerFunc create_aoti_model_runner_fn) {
103     getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn;
104   }
105 };
106 
107 } // namespace torch::inductor
108 #endif
109