xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
3 
4 namespace torch::inductor {
5 
AOTIModelContainerRunnerCuda(const std::string & model_so_path,size_t num_models,const std::string & device_str,const std::string & cubin_dir)6 AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda(
7     const std::string& model_so_path,
8     size_t num_models,
9     const std::string& device_str,
10     const std::string& cubin_dir)
11     : AOTIModelContainerRunner(
12           model_so_path,
13           num_models,
14           device_str,
15           cubin_dir) {}
16 
17 AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default;
18 
run(std::vector<at::Tensor> & inputs)19 std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
20     std::vector<at::Tensor>& inputs) {
21   at::cuda::CUDAStream cuda_stream = c10::cuda::getCurrentCUDAStream();
22   return AOTIModelContainerRunner::run(
23       inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
24 }
25 
run_with_cuda_stream(std::vector<at::Tensor> & inputs,at::cuda::CUDAStream cuda_stream)26 std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run_with_cuda_stream(
27     std::vector<at::Tensor>& inputs,
28     at::cuda::CUDAStream cuda_stream) {
29   return AOTIModelContainerRunner::run(
30       inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
31 }
32 
33 } // namespace torch::inductor
34 #endif
35