xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #pragma once
3 
4 #include <c10/cuda/CUDAStream.h>
5 #include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
6 
7 namespace torch::inductor {
8 
9 // NOTICE: Following APIs are subject to change due to active development
10 // We provide NO BC guarantee for these APIs
11 class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
12  public:
13   // @param device_str: cuda device string, e.g. "cuda", "cuda:0"
14   AOTIModelContainerRunnerCuda(
15       const std::string& model_so_path,
16       size_t num_models = 1,
17       const std::string& device_str = "cuda",
18       const std::string& cubin_dir = "");
19 
20   ~AOTIModelContainerRunnerCuda();
21 
22   std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs);
23 
24   std::vector<at::Tensor> run_with_cuda_stream(
25       std::vector<at::Tensor>& inputs,
26       at::cuda::CUDAStream cuda_stream);
27 };
28 
29 } // namespace torch::inductor
30 #endif
31