xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/jiterator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/jit_macros.h>
3 
4 #if AT_USE_JITERATOR()
5 
6 #include <c10/macros/Export.h>
7 #include <c10/util/SmallVector.h>
8 #include <ATen/core/Tensor.h>
9 
10 #include <string>
11 #include <vector>
12 
13 namespace at::cuda {
14 
15 TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
16   const std::string& code_string,
17   const std::string& kernel_name,
18   const int num_outputs,
19   const c10::SmallVector<at::Tensor>& tensors,
20   const c10::SmallVector<at::Scalar>& extra_args,
21   bool return_by_ref);
22 
23 } // namespace at::cuda
24 
25 #else
26 
27 namespace at::cuda {
28 
CompileAndLaunchKernel(const std::string & code_string,const std::string & kernel_name,const int num_outputs,const c10::SmallVector<at::Tensor> & tensors,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)29 TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
30   const std::string& code_string,
31   const std::string& kernel_name,
32   const int num_outputs,
33   const c10::SmallVector<at::Tensor>& tensors,
34   const c10::SmallVector<at::Scalar>& extra_args,
35   bool return_by_ref) {
36     TORCH_CHECK(false, "Jiterator is not supported");
37   }
38 } // namespace at::cuda
39 
40 #endif // AT_USE_JITERATOR()
41