1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
3
4 namespace torch::inductor {
5
AOTIModelContainerRunnerCuda(const std::string & model_so_path,size_t num_models,const std::string & device_str,const std::string & cubin_dir)6 AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda(
7 const std::string& model_so_path,
8 size_t num_models,
9 const std::string& device_str,
10 const std::string& cubin_dir)
11 : AOTIModelContainerRunner(
12 model_so_path,
13 num_models,
14 device_str,
15 cubin_dir) {}
16
17 AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default;
18
run(std::vector<at::Tensor> & inputs)19 std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
20 std::vector<at::Tensor>& inputs) {
21 at::cuda::CUDAStream cuda_stream = c10::cuda::getCurrentCUDAStream();
22 return AOTIModelContainerRunner::run(
23 inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
24 }
25
run_with_cuda_stream(std::vector<at::Tensor> & inputs,at::cuda::CUDAStream cuda_stream)26 std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run_with_cuda_stream(
27 std::vector<at::Tensor>& inputs,
28 at::cuda::CUDAStream cuda_stream) {
29 return AOTIModelContainerRunner::run(
30 inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
31 }
32
33 } // namespace torch::inductor
34 #endif
35