1 #include <torch/extension.h>
2
3 // test include_dirs in setuptools.setup with relative path
4 #include <tmp.h>
5 #include <ATen/OpMathType.h>
6
sigmoid_add(torch::Tensor x,torch::Tensor y)7 torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
8 return x.sigmoid() + y.sigmoid();
9 }
10
11 struct MatrixMultiplier {
MatrixMultiplierMatrixMultiplier12 MatrixMultiplier(int A, int B) {
13 tensor_ =
14 torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
15 }
forwardMatrixMultiplier16 torch::Tensor forward(torch::Tensor weights) {
17 return tensor_.mm(weights);
18 }
getMatrixMultiplier19 torch::Tensor get() const {
20 return tensor_;
21 }
22
23 private:
24 torch::Tensor tensor_;
25 };
26
function_taking_optional(std::optional<torch::Tensor> tensor)27 bool function_taking_optional(std::optional<torch::Tensor> tensor) {
28 return tensor.has_value();
29 }
30
random_tensor()31 torch::Tensor random_tensor() {
32 return torch::randn({1});
33 }
34
get_math_type(at::ScalarType other)35 at::ScalarType get_math_type(at::ScalarType other) {
36 return at::toOpMathType(other);
37 }
38
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)39 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
40 m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
41 m.def(
42 "function_taking_optional",
43 &function_taking_optional,
44 "function_taking_optional");
45 py::class_<MatrixMultiplier>(m, "MatrixMultiplier")
46 .def(py::init<int, int>())
47 .def("forward", &MatrixMultiplier::forward)
48 .def("get", &MatrixMultiplier::get);
49
50 m.def("get_complex", []() { return c10::complex<double>(1.0, 2.0); });
51 m.def("get_device", []() { return at::device_of(random_tensor()).value(); });
52 m.def("get_generator", []() { return at::detail::getDefaultCPUGenerator(); });
53 m.def("get_intarrayref", []() { return at::IntArrayRef({1, 2, 3}); });
54 m.def("get_memory_format", []() { return c10::get_contiguous_memory_format(); });
55 m.def("get_storage", []() { return random_tensor().storage(); });
56 m.def("get_symfloat", []() { return c10::SymFloat(1.0); });
57 m.def("get_symint", []() { return c10::SymInt(1); });
58 m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
59 m.def("get_tensor", []() { return random_tensor(); });
60 m.def("get_math_type", &get_math_type);
61 }
62