1 #pragma once
2 #include <ATen/jit_macros.h>
3
4 #if AT_USE_JITERATOR()
5
6 #include <c10/macros/Export.h>
7 #include <c10/util/SmallVector.h>
8 #include <ATen/core/Tensor.h>
9
10 #include <string>
11 #include <vector>
12
13 namespace at::cuda {
14
15 TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
16 const std::string& code_string,
17 const std::string& kernel_name,
18 const int num_outputs,
19 const c10::SmallVector<at::Tensor>& tensors,
20 const c10::SmallVector<at::Scalar>& extra_args,
21 bool return_by_ref);
22
23 } // namespace at::cuda
24
25 #else
26
27 namespace at::cuda {
28
CompileAndLaunchKernel(const std::string & code_string,const std::string & kernel_name,const int num_outputs,const c10::SmallVector<at::Tensor> & tensors,const c10::SmallVector<at::Scalar> & extra_args,bool return_by_ref)29 TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
30 const std::string& code_string,
31 const std::string& kernel_name,
32 const int num_outputs,
33 const c10::SmallVector<at::Tensor>& tensors,
34 const c10::SmallVector<at::Scalar>& extra_args,
35 bool return_by_ref) {
36 TORCH_CHECK(false, "Jiterator is not supported");
37 }
38 } // namespace at::cuda
39
40 #endif // AT_USE_JITERATOR()
41