1 #if !defined(C10_MOBILE) && !defined(ANDROID) 2 #pragma once 3 4 #include <c10/cuda/CUDAStream.h> 5 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h> 6 7 namespace torch::inductor { 8 9 // NOTICE: Following APIs are subject to change due to active development 10 // We provide NO BC guarantee for these APIs 11 class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner { 12 public: 13 // @param device_str: cuda device string, e.g. "cuda", "cuda:0" 14 AOTIModelContainerRunnerCuda( 15 const std::string& model_so_path, 16 size_t num_models = 1, 17 const std::string& device_str = "cuda", 18 const std::string& cubin_dir = ""); 19 20 ~AOTIModelContainerRunnerCuda(); 21 22 std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs); 23 24 std::vector<at::Tensor> run_with_cuda_stream( 25 std::vector<at::Tensor>& inputs, 26 at::cuda::CUDAStream cuda_stream); 27 }; 28 29 } // namespace torch::inductor 30 #endif 31