xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/utils/python_arg_parsing.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <torch/csrc/python_headers.h>
5 
6 #include <torch/csrc/utils/python_arg_parser.h>
7 
8 namespace torch::autograd::utils {
9 
10 // The parameter allow_copy is to accept copy for Tensor.to (and by proxy
11 // PackedSequences.to) but not nn.Module.to.
12 inline std::tuple<
13     std::optional<at::Device>,
14     std::optional<at::ScalarType>,
15     bool,
16     bool,
17     std::optional<at::MemoryFormat>>
parse_to_conversion(PythonArgs & r,bool allow_copy)18 parse_to_conversion(PythonArgs& r, bool allow_copy) {
19   if (r.idx == 0) {
20     if (!allow_copy && !r.isNone(3))
21       throw std::runtime_error(".to() does not accept copy argument");
22     return std::make_tuple(
23         r.deviceOptional(0),
24         r.scalartypeOptional(1),
25         r.toBool(2),
26         r.toBool(3),
27         r.memoryformatOptional(4));
28   } else if (r.idx == 1) {
29     if (!allow_copy && !r.isNone(2))
30       throw std::runtime_error(".to() does not accept copy argument");
31     return std::make_tuple(
32         std::nullopt,
33         r.scalartype(0),
34         r.toBool(1),
35         r.toBool(2),
36         r.memoryformatOptional(3));
37   } else {
38     auto tensor = r.tensor(0);
39     if (!allow_copy && !r.isNone(2))
40       throw std::runtime_error(".to() does not accept copy argument");
41     return std::make_tuple(
42         tensor.device(),
43         tensor.scalar_type(),
44         r.toBool(1),
45         r.toBool(2),
46         r.memoryformatOptional(3));
47   }
48 }
49 } // namespace torch::autograd::utils
50