xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <Python.h>
2 
3 #include <torch/csrc/utils/tensor_types.h>
4 
5 #include <ATen/Context.h>
6 #include <ATen/Formatting.h>
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/autograd/generated/VariableType.h>
9 #include <torch/csrc/tensor/python_tensor.h>
10 
11 #include <c10/util/CallOnce.h>
12 
13 #include <algorithm>
14 #include <sstream>
15 #include <unordered_map>
16 
17 using namespace at;
18 
19 namespace torch::utils {
20 
parse_privateuseone_backend(bool is_sparse=false)21 static const char* parse_privateuseone_backend(bool is_sparse = false) {
22   static std::string backend_name = "torch." + get_privateuse1_backend();
23   static std::string sparse_backend_name = backend_name + ".sparse";
24   return is_sparse == false ? backend_name.c_str()
25                             : sparse_backend_name.c_str();
26 }
27 
backend_to_string(const at::Backend & backend)28 const char* backend_to_string(const at::Backend& backend) {
29   switch (backend) {
30     case at::Backend::CPU:
31       return "torch";
32     case at::Backend::CUDA:
33       return "torch.cuda";
34     case at::Backend::XPU:
35       return "torch.xpu";
36     case at::Backend::IPU:
37       return "torch.ipu";
38     case at::Backend::SparseCPU:
39       return "torch.sparse";
40     case at::Backend::SparseCUDA:
41       return "torch.cuda.sparse";
42     case at::Backend::SparseXPU:
43       return "torch.xpu.sparse";
44     case at::Backend::QuantizedCPU:
45       return "torch.quantized";
46     case at::Backend::HPU:
47       return "torch.hpu";
48     case at::Backend::MPS:
49       return "torch.mps";
50     case at::Backend::MTIA:
51       return "torch.mtia";
52     case at::Backend::PrivateUse1:
53       return parse_privateuseone_backend();
54     case at::Backend::SparsePrivateUse1:
55       return parse_privateuseone_backend(true);
56     case at::Backend::Lazy:
57       return "torch.lazy";
58     case at::Backend::XLA:
59       return "torch.xla";
60     case at::Backend::Meta:
61       return "torch.meta";
62     default:
63       AT_ERROR("Unimplemented backend ", backend);
64   }
65 }
66 
options_to_string(const at::TensorOptions & options)67 std::string options_to_string(const at::TensorOptions& options) {
68   std::ostringstream ss;
69   ss << backend_to_string(options.backend()) << "."
70      << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor";
71   return ss.str();
72 }
73 
type_to_string(const at::DeprecatedTypeProperties & type)74 std::string type_to_string(const at::DeprecatedTypeProperties& type) {
75   std::ostringstream ss;
76   ss << backend_to_string(type.backend()) << "." << toString(type.scalarType())
77      << "Tensor";
78   return ss.str();
79 }
80 
options_from_string(const std::string & str)81 at::TensorOptions options_from_string(const std::string& str) {
82   static std::string cuda_prefix("torch.cuda.");
83   static std::string xpu_prefix("torch.xpu.");
84   static std::string privateUser_prefix(
85       std::string(parse_privateuseone_backend()) + ".");
86   static c10::once_flag cpu_once;
87   static c10::once_flag cuda_once;
88   static c10::once_flag xpu_once;
89   static c10::once_flag privateUser1_once;
90   static std::unordered_map<std::string, at::DeprecatedTypeProperties*> cpu_map;
91   static std::unordered_map<std::string, at::DeprecatedTypeProperties*> xpu_map;
92   static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
93       cuda_map;
94   static std::unordered_map<std::string, at::DeprecatedTypeProperties*>
95       privateUser1_map;
96 
97   const std::unordered_map<std::string, at::DeprecatedTypeProperties*>* map =
98       nullptr;
99 
100   if (str == "torch.Tensor") {
101     auto backend =
102         dispatchKeyToBackend(torch::tensors::get_default_dispatch_key());
103     auto scalar_type = torch::tensors::get_default_scalar_type();
104     return getDeprecatedTypeProperties(backend, scalar_type).options();
105   }
106 
107   if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin())
108           .first == cuda_prefix.end()) {
109     // torch.cuda. is prefix of str
110     c10::call_once(cuda_once, []() {
111       for (auto type : autograd::VariableType::allCUDATypes()) {
112         cuda_map.emplace(type_to_string(*type), type);
113       }
114     });
115     map = &cuda_map;
116   } else if (
117       std::mismatch(xpu_prefix.begin(), xpu_prefix.end(), str.begin()).first ==
118       xpu_prefix.end()) {
119     // torch.xpu. is prefix of str
120     c10::call_once(xpu_once, []() {
121       for (auto type : autograd::VariableType::allXPUTypes()) {
122         xpu_map.emplace(type_to_string(*type), type);
123       }
124     });
125     map = &xpu_map;
126   } else if (
127       std::mismatch(
128           privateUser_prefix.begin(), privateUser_prefix.end(), str.begin())
129           .first == privateUser_prefix.end()) {
130     // torch.foo. foo is privateUser1 name
131     c10::call_once(privateUser1_once, []() {
132       for (auto type : autograd::VariableType::allPrivateUser1Types()) {
133         privateUser1_map.emplace(type_to_string(*type), type);
134       }
135     });
136     map = &privateUser1_map;
137   } else {
138     c10::call_once(cpu_once, []() {
139       for (auto type : autograd::VariableType::allCPUTypes()) {
140         cpu_map.emplace(type_to_string(*type), type);
141       }
142     });
143     map = &cpu_map;
144   }
145 
146   auto it = map->find(str);
147   TORCH_CHECK_VALUE(it != map->end(), "invalid type: '", str, "'");
148   return it->second->options();
149 }
150 
all_declared_types()151 std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
152   std::vector<std::pair<Backend, ScalarType>> ret;
153 
154   // NOTE: Do not add more types here. This list controls the creation
155   // of legacy tensor types e.g. torch.cuda.FloatTensor which are
156   // maintained for backwards-compatibility only.
157   auto backends = {
158       Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA};
159   auto scalar_types = {
160       ScalarType::Byte,
161       ScalarType::Char,
162       ScalarType::Double,
163       ScalarType::Float,
164       ScalarType::Int,
165       ScalarType::Long,
166       ScalarType::Short,
167       ScalarType::Half,
168       ScalarType::Bool,
169       ScalarType::BFloat16};
170 
171   for (auto& backend : backends) {
172     for (auto& scalar_type : scalar_types) {
173       // there is no sparse bool type.
174       if (scalar_type == ScalarType::Bool &&
175           (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
176         continue;
177       }
178       ret.emplace_back(backend, scalar_type);
179     }
180   }
181 
182   return ret;
183 }
184 
185 } // namespace torch::utils
186