xref: /aosp_15_r20/external/pytorch/test/custom_operator/op.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/script.h>
2 
3 #include <cstddef>
4 #include <vector>
5 #include <string>
6 
7 // clang-format off
8 #  if defined(_WIN32)
9 #    if defined(custom_ops_EXPORTS)
10 #      define CUSTOM_OP_API __declspec(dllexport)
11 #    else
12 #      define CUSTOM_OP_API __declspec(dllimport)
13 #    endif
14 #  else
15 #    define CUSTOM_OP_API
16 #  endif
17 // clang-format on
18 
19 CUSTOM_OP_API torch::List<torch::Tensor> custom_op(
20     torch::Tensor tensor,
21     double scalar,
22     int64_t repeat);
23 
24 CUSTOM_OP_API int64_t custom_op2(std::string s1, std::string s2);
25