1 #include <Python.h>
2
3 #include <torch/csrc/utils/tensor_types.h>
4
5 #include <ATen/Context.h>
6 #include <ATen/Formatting.h>
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/autograd/generated/VariableType.h>
9 #include <torch/csrc/tensor/python_tensor.h>
10
11 #include <c10/util/CallOnce.h>
12
13 #include <algorithm>
14 #include <sstream>
15 #include <unordered_map>
16
17 using namespace at;
18
19 namespace torch::utils {
20
parse_privateuseone_backend(bool is_sparse=false)21 static const char* parse_privateuseone_backend(bool is_sparse = false) {
22 static std::string backend_name = "torch." + get_privateuse1_backend();
23 static std::string sparse_backend_name = backend_name + ".sparse";
24 return is_sparse == false ? backend_name.c_str()
25 : sparse_backend_name.c_str();
26 }
27
backend_to_string(const at::Backend & backend)28 const char* backend_to_string(const at::Backend& backend) {
29 switch (backend) {
30 case at::Backend::CPU:
31 return "torch";
32 case at::Backend::CUDA:
33 return "torch.cuda";
34 case at::Backend::XPU:
35 return "torch.xpu";
36 case at::Backend::IPU:
37 return "torch.ipu";
38 case at::Backend::SparseCPU:
39 return "torch.sparse";
40 case at::Backend::SparseCUDA:
41 return "torch.cuda.sparse";
42 case at::Backend::SparseXPU:
43 return "torch.xpu.sparse";
44 case at::Backend::QuantizedCPU:
45 return "torch.quantized";
46 case at::Backend::HPU:
47 return "torch.hpu";
48 case at::Backend::MPS:
49 return "torch.mps";
50 case at::Backend::MTIA:
51 return "torch.mtia";
52 case at::Backend::PrivateUse1:
53 return parse_privateuseone_backend();
54 case at::Backend::SparsePrivateUse1:
55 return parse_privateuseone_backend(true);
56 case at::Backend::Lazy:
57 return "torch.lazy";
58 case at::Backend::XLA:
59 return "torch.xla";
60 case at::Backend::Meta:
61 return "torch.meta";
62 default:
63 AT_ERROR("Unimplemented backend ", backend);
64 }
65 }
66
options_to_string(const at::TensorOptions & options)67 std::string options_to_string(const at::TensorOptions& options) {
68 std::ostringstream ss;
69 ss << backend_to_string(options.backend()) << "."
70 << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor";
71 return ss.str();
72 }
73
type_to_string(const at::DeprecatedTypeProperties & type)74 std::string type_to_string(const at::DeprecatedTypeProperties& type) {
75 std::ostringstream ss;
76 ss << backend_to_string(type.backend()) << "." << toString(type.scalarType())
77 << "Tensor";
78 return ss.str();
79 }
80
options_from_string(const std::string & str)81 at::TensorOptions options_from_string(const std::string& str) {
82 static std::string cuda_prefix("torch.cuda.");
83 static std::string xpu_prefix("torch.xpu.");
84 static std::string privateUser_prefix(
85 std::string(parse_privateuseone_backend()) + ".");
86 static c10::once_flag cpu_once;
87 static c10::once_flag cuda_once;
88 static c10::once_flag xpu_once;
89 static c10::once_flag privateUser1_once;
90 static std::unordered_map<std::string, at::DeprecatedTypeProperties*> cpu_map;
91 static std::unordered_map<std::string, at::DeprecatedTypeProperties*> xpu_map;
92 static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
93 cuda_map;
94 static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
95 privateUser1_map;
96
97 const std::unordered_map<std::string, at::DeprecatedTypeProperties*>* map =
98 nullptr;
99
100 if (str == "torch.Tensor") {
101 auto backend =
102 dispatchKeyToBackend(torch::tensors::get_default_dispatch_key());
103 auto scalar_type = torch::tensors::get_default_scalar_type();
104 return getDeprecatedTypeProperties(backend, scalar_type).options();
105 }
106
107 if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin())
108 .first == cuda_prefix.end()) {
109 // torch.cuda. is prefix of str
110 c10::call_once(cuda_once, []() {
111 for (auto type : autograd::VariableType::allCUDATypes()) {
112 cuda_map.emplace(type_to_string(*type), type);
113 }
114 });
115 map = &cuda_map;
116 } else if (
117 std::mismatch(xpu_prefix.begin(), xpu_prefix.end(), str.begin()).first ==
118 xpu_prefix.end()) {
119 // torch.xpu. is prefix of str
120 c10::call_once(xpu_once, []() {
121 for (auto type : autograd::VariableType::allXPUTypes()) {
122 xpu_map.emplace(type_to_string(*type), type);
123 }
124 });
125 map = &xpu_map;
126 } else if (
127 std::mismatch(
128 privateUser_prefix.begin(), privateUser_prefix.end(), str.begin())
129 .first == privateUser_prefix.end()) {
130 // torch.foo. foo is privateUser1 name
131 c10::call_once(privateUser1_once, []() {
132 for (auto type : autograd::VariableType::allPrivateUser1Types()) {
133 privateUser1_map.emplace(type_to_string(*type), type);
134 }
135 });
136 map = &privateUser1_map;
137 } else {
138 c10::call_once(cpu_once, []() {
139 for (auto type : autograd::VariableType::allCPUTypes()) {
140 cpu_map.emplace(type_to_string(*type), type);
141 }
142 });
143 map = &cpu_map;
144 }
145
146 auto it = map->find(str);
147 TORCH_CHECK_VALUE(it != map->end(), "invalid type: '", str, "'");
148 return it->second->options();
149 }
150
all_declared_types()151 std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
152 std::vector<std::pair<Backend, ScalarType>> ret;
153
154 // NOTE: Do not add more types here. This list controls the creation
155 // of legacy tensor types e.g. torch.cuda.FloatTensor which are
156 // maintained for backwards-compatibility only.
157 auto backends = {
158 Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA};
159 auto scalar_types = {
160 ScalarType::Byte,
161 ScalarType::Char,
162 ScalarType::Double,
163 ScalarType::Float,
164 ScalarType::Int,
165 ScalarType::Long,
166 ScalarType::Short,
167 ScalarType::Half,
168 ScalarType::Bool,
169 ScalarType::BFloat16};
170
171 for (auto& backend : backends) {
172 for (auto& scalar_type : scalar_types) {
173 // there is no sparse bool type.
174 if (scalar_type == ScalarType::Bool &&
175 (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
176 continue;
177 }
178 ret.emplace_back(backend, scalar_type);
179 }
180 }
181
182 return ret;
183 }
184
185 } // namespace torch::utils
186