1 #if !defined(C10_MOBILE) && !defined(ANDROID) 2 #pragma once 3 4 #include <ATen/Tensor.h> 5 #include <torch/csrc/inductor/aoti_runtime/interface.h> 6 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h> 7 8 // Forward declare DynamicLibrary 9 namespace at { 10 struct DynamicLibrary; 11 } 12 13 namespace torch::inductor { 14 using TensorConstantMap = std::unordered_map<std::string, at::Tensor*>; 15 16 class TORCH_API AOTIModelContainerRunner { 17 public: 18 AOTIModelContainerRunner() = delete; 19 AOTIModelContainerRunner(const AOTIModelContainerRunner& other) = delete; 20 AOTIModelContainerRunner(AOTIModelContainerRunner&& other) = delete; 21 AOTIModelContainerRunner& operator=(const AOTIModelContainerRunner& other) = 22 delete; 23 AOTIModelContainerRunner& operator=(AOTIModelContainerRunner&& other) = 24 delete; 25 ~AOTIModelContainerRunner(); 26 27 std::vector<at::Tensor> run( 28 std::vector<at::Tensor>& inputs, 29 AOTInductorStreamHandle cuda_stream_handle = nullptr); 30 31 std::unordered_map<std::string, std::string> getConstantNamesToOriginalFQNs() 32 const; 33 std::unordered_map<std::string, int32_t> getConstantNamesToDtypes() const; 34 void update_inactive_constant_buffer(const TensorConstantMap& const_map); 35 void update_constant_buffer( 36 const TensorConstantMap& const_map, 37 bool use_inactive, 38 bool validate_full_updates); 39 void run_const_fold( 40 bool use_inactive, 41 AOTInductorStreamHandle cuda_stream_handle = nullptr); 42 void swap_constant_buffer(); 43 44 std::vector<std::string> get_call_spec(); 45 46 protected: 47 AOTIModelContainerRunner( 48 const std::string& model_so_path, 49 size_t num_models, 50 const std::string& device_str, 51 const std::string& cubin_dir); 52 53 std::unique_ptr<at::DynamicLibrary> model_so_; 54 decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr}; 55 decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr}; 56 decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{ 57 nullptr}; 58 decltype(&AOTInductorModelContainerRun) run_func_{nullptr}; 59 decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{ 60 nullptr}; 61 decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{ 62 nullptr}; 63 decltype(&AOTInductorModelContainerGetConstantOriginalFQN) 64 get_constant_original_fqn_func_{nullptr}; 65 decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{ 66 nullptr}; 67 decltype(&AOTInductorModelContainerUpdateConstantBuffer) 68 update_constant_buffer_func_{nullptr}; 69 decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer) 70 update_inactive_constant_buffer_func_{nullptr}; 71 decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{ 72 nullptr}; 73 decltype(&AOTInductorModelContainerSwapConstantBuffer) 74 swap_constant_buffer_func_{nullptr}; 75 decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr}; 76 77 AOTInductorModelContainerHandle container_handle_ = nullptr; 78 79 AOTIProxyExecutorHandle proxy_executor_handle_; 80 81 private: 82 std::unique_ptr<torch::aot_inductor::ProxyExecutor> proxy_executor_; 83 }; 84 85 using CreateAOTIModelRunnerFunc = std::shared_ptr<AOTIModelContainerRunner> (*)( 86 const std::string& model_so_path, 87 size_t num_models, 88 const std::string& device_str, 89 const std::string& bin_dir); 90 91 // Return a global map "device name" -> "aoti model runner create function" for 92 // all registered in AOTI external backends 93 TORCH_API std::unordered_map<std::string, CreateAOTIModelRunnerFunc>& 94 getAOTIModelRunnerRegistry(); 95 96 // To register a new external backend in AOTI one needs to create an instance of 97 // this struct. It is not thread-safe. Becase it is expected to be called during 98 // the initialization of the program. 99 struct TORCH_API RegisterAOTIModelRunner { RegisterAOTIModelRunnerRegisterAOTIModelRunner100 RegisterAOTIModelRunner( 101 const std::string& name, 102 CreateAOTIModelRunnerFunc create_aoti_model_runner_fn) { 103 getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn; 104 } 105 }; 106 107 } // namespace torch::inductor 108 #endif 109