1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <torch/csrc/python_headers.h>
5
6 #include <torch/csrc/utils/python_arg_parser.h>
7
8 namespace torch::autograd::utils {
9
10 // The parameter allow_copy is to accept copy for Tensor.to (and by proxy
11 // PackedSequences.to) but not nn.Module.to.
12 inline std::tuple<
13 std::optional<at::Device>,
14 std::optional<at::ScalarType>,
15 bool,
16 bool,
17 std::optional<at::MemoryFormat>>
parse_to_conversion(PythonArgs & r,bool allow_copy)18 parse_to_conversion(PythonArgs& r, bool allow_copy) {
19 if (r.idx == 0) {
20 if (!allow_copy && !r.isNone(3))
21 throw std::runtime_error(".to() does not accept copy argument");
22 return std::make_tuple(
23 r.deviceOptional(0),
24 r.scalartypeOptional(1),
25 r.toBool(2),
26 r.toBool(3),
27 r.memoryformatOptional(4));
28 } else if (r.idx == 1) {
29 if (!allow_copy && !r.isNone(2))
30 throw std::runtime_error(".to() does not accept copy argument");
31 return std::make_tuple(
32 std::nullopt,
33 r.scalartype(0),
34 r.toBool(1),
35 r.toBool(2),
36 r.memoryformatOptional(3));
37 } else {
38 auto tensor = r.tensor(0);
39 if (!allow_copy && !r.isNone(2))
40 throw std::runtime_error(".to() does not accept copy argument");
41 return std::make_tuple(
42 tensor.device(),
43 tensor.scalar_type(),
44 r.toBool(1),
45 r.toBool(2),
46 r.memoryformatOptional(3));
47 }
48 }
49 } // namespace torch::autograd::utils
50