1 #include <torch/csrc/tensor/python_tensor.h>
2
3 #include <pybind11/pybind11.h>
4 #include <structmember.h>
5 #include <torch/csrc/utils/pybind.h>
6
7 #include <torch/csrc/Dtype.h>
8 #include <torch/csrc/DynamicTypes.h>
9 #include <torch/csrc/Exceptions.h>
10 #include <torch/csrc/Layout.h>
11 #include <torch/csrc/autograd/generated/VariableType.h>
12 #include <torch/csrc/autograd/python_variable.h>
13 #include <torch/csrc/autograd/utils/wrap_outputs.h>
14 #include <torch/csrc/autograd/variable.h>
15 #include <torch/csrc/utils/cuda_enabled.h>
16 #include <torch/csrc/utils/device_lazy_init.h>
17 #include <torch/csrc/utils/python_strings.h>
18 #include <torch/csrc/utils/tensor_new.h>
19 #include <torch/csrc/utils/tensor_types.h>
20
21 #include <ATen/ATen.h>
22
23 #include <sstream>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 namespace torch::tensors {
29
30 using namespace at;
31 using namespace torch::autograd;
32
33 struct PyTensorType {
34 PyTypeObject py_type;
35 THPDtype* dtype;
36 THPLayout* layout;
37 bool is_cuda;
38 bool is_xpu;
39 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
40 char name[64];
41 int backend;
42 int scalar_type;
43
get_backendtorch::tensors::PyTensorType44 Backend get_backend() const {
45 return static_cast<Backend>(backend);
46 }
47
get_dispatch_keytorch::tensors::PyTensorType48 DispatchKey get_dispatch_key() const {
49 return backendToDispatchKey(static_cast<Backend>(backend));
50 }
51
get_scalar_typetorch::tensors::PyTensorType52 ScalarType get_scalar_type() const {
53 return static_cast<ScalarType>(scalar_type);
54 }
55 };
56
57 static_assert(
58 std::is_standard_layout_v<PyTensorType>,
59 "PyTensorType must be standard layout");
60
61 static Backend default_backend = Backend::CPU;
62
63 static void py_bind_tensor_types(
64 const std::vector<PyTensorType*>& tensor_types);
65
Tensor_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)66 static PyObject* Tensor_new(
67 PyTypeObject* type,
68 PyObject* args,
69 PyObject* kwargs) {
70 HANDLE_TH_ERRORS
71 auto& tensor_type = *((PyTensorType*)type);
72 TORCH_CHECK_TYPE(
73 !tensor_type.is_cuda || torch::utils::cuda_enabled(),
74 "type ",
75 tensor_type.name,
76 " not available. Torch not compiled with CUDA enabled.")
77 if (tensor_type.is_cuda) {
78 TORCH_WARN_ONCE(
79 "The torch.cuda.*DtypeTensor constructors are no longer recommended. "
80 "It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors.")
81 }
82 return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(
83 tensor_type.get_dispatch_key(),
84 tensor_type.get_scalar_type(),
85 args,
86 kwargs));
87 END_HANDLE_TH_ERRORS
88 }
89
90 // TODO: Deprecate this instancecheck entirely. It's here to make
91 // instanceof(t, torch.FloatTensor) work, but we are not going to keep
92 // adding torch.QuantizedIntTensor classes for every new tensor type
93 // we add...
Tensor_instancecheck(PyObject * _self,PyObject * arg)94 static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) {
95 HANDLE_TH_ERRORS
96 auto self = (PyTensorType*)_self;
97 if (THPVariable_Check(arg)) {
98 const auto& var = THPVariable_Unpack(arg);
99 // NB: This is a little unfortunate, in that if I do an isinstance check
100 // against torch.cuda.FloatTensor, this will immediately initialize CUDA.
101 // I originally thought that it would not be possible for aten_type_ to
102 // be nullptr if you had a tensor of some type, in which case you can
103 // skip initializing aten_type(), but TestAutograd.test_type_conversions
104 // seems to violate this property (for whatever reason.)
105 //
106 // TODO: Stop using legacyExtractDispatchKey here (probably need to build
107 // in instanceof checking to Tensor class itself)
108 if (legacyExtractDispatchKey(var.key_set()) == self->get_dispatch_key() &&
109 var.scalar_type() == static_cast<ScalarType>(self->scalar_type)) {
110 Py_RETURN_TRUE;
111 }
112 }
113 Py_RETURN_FALSE;
114 END_HANDLE_TH_ERRORS
115 }
116
Tensor_dtype(PyTensorType * self,void * unused)117 static PyObject* Tensor_dtype(PyTensorType* self, void* unused) {
118 return torch::autograd::utils::wrap(self->dtype);
119 }
120
Tensor_layout(PyTensorType * self,void * unused)121 static PyObject* Tensor_layout(PyTensorType* self, void* unused) {
122 return torch::autograd::utils::wrap(self->layout);
123 }
124
Tensor_is_cuda(PyTensorType * self,void * unused)125 static PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) {
126 if (self->is_cuda) {
127 Py_RETURN_TRUE;
128 } else {
129 Py_RETURN_FALSE;
130 }
131 }
132
Tensor_is_xpu(PyTensorType * self,void * unused)133 static PyObject* Tensor_is_xpu(PyTensorType* self, void* unused) {
134 if (self->is_xpu) {
135 Py_RETURN_TRUE;
136 } else {
137 Py_RETURN_FALSE;
138 }
139 }
140
Tensor_is_sparse(PyTensorType * self,void * unused)141 static PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) {
142 if (self->layout->layout == at::Layout::Strided) {
143 Py_RETURN_FALSE;
144 } else {
145 Py_RETURN_TRUE;
146 }
147 }
148
Tensor_is_sparse_csr(PyTensorType * self,void * unused)149 static PyObject* Tensor_is_sparse_csr(PyTensorType* self, void* unused) {
150 if (self->layout->layout == at::Layout::SparseCsr) {
151 Py_RETURN_TRUE;
152 } else {
153 Py_RETURN_FALSE;
154 }
155 }
156
157 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
158 static struct PyMethodDef metaclass_methods[] = {
159 {"__instancecheck__", Tensor_instancecheck, METH_O, nullptr},
160 {nullptr}};
161
162 typedef PyObject* (*getter)(PyObject*, void*);
163
164 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
165 static struct PyGetSetDef metaclass_properties[] = {
166 {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr},
167 {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr},
168 {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr},
169 {"is_xpu", (getter)Tensor_is_xpu, nullptr, nullptr, nullptr},
170 {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr},
171 {"is_sparse_csr", (getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr},
172 {nullptr}};
173
174 static PyTypeObject metaclass = {
175 PyVarObject_HEAD_INIT(nullptr, 0) "torch.tensortype", /* tp_name */
176 sizeof(PyTypeObject) /* tp_basicsize */
177 };
178
py_initialize_metaclass(PyTypeObject & metaclass)179 static void py_initialize_metaclass(PyTypeObject& metaclass) {
180 // NOLINTNEXTLINE(misc-redundant-expression)
181 metaclass.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
182 metaclass.tp_methods = metaclass_methods;
183 metaclass.tp_getset = metaclass_properties;
184 metaclass.tp_base = &PyType_Type;
185 if (PyType_Ready(&metaclass) < 0) {
186 throw python_error();
187 }
188 }
189
190 static PyTypeObject tensor_type_prototype = {
191 PyVarObject_HEAD_INIT(&metaclass, 0) nullptr, /* tp_name */
192 sizeof(PyTensorType) /* tp_basicsize */
193 };
194
py_initialize_tensor_type(PyTypeObject & type,const char * name,PyObject * tp_dict)195 static void py_initialize_tensor_type(
196 PyTypeObject& type,
197 const char* name,
198 PyObject* tp_dict) {
199 // NOTE: we don't use the typical static declaration of PyTypeObject because
200 // we need to initialize as many types as there are VariableType instances.
201 // We copy the basic object fields from a prototype definition and initialize
202 // the remaining fields below.
203 memcpy(&type, &tensor_type_prototype, sizeof(PyTypeObject));
204 // Subclassing from torch.<ScalarType>Tensor isn't supported.
205 // (Py_TPFLAGS_BASETYPE omitted). Subclassing torch.Tensor still allowed.
206 type.tp_flags = Py_TPFLAGS_DEFAULT;
207 type.tp_name = name;
208 type.tp_new = Tensor_new;
209 if (PyType_Ready(&type) < 0) {
210 throw python_error();
211 }
212 if (PyDict_Merge(type.tp_dict, tp_dict, 0) < 0) {
213 throw python_error();
214 }
215 }
216
get_name(Backend backend,ScalarType scalarType)217 static std::string get_name(Backend backend, ScalarType scalarType) {
218 std::ostringstream ss;
219 ss << torch::utils::backend_to_string(backend) << "." << toString(scalarType)
220 << "Tensor";
221 return ss.str();
222 }
223
get_storage_obj(Backend backend,ScalarType dtype)224 static THPObjectPtr get_storage_obj(Backend backend, ScalarType dtype) {
225 auto module_name = torch::utils::backend_to_string(backend);
226 auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name));
227 if (!module_obj)
228 throw python_error();
229
230 auto storage_name = std::string(toString(dtype)) + "Storage";
231 THPObjectPtr storage(
232 PyObject_GetAttrString(module_obj.get(), storage_name.c_str()));
233 TORCH_CHECK_TYPE(
234 storage.get(), "couldn't find storage object ", storage_name);
235 return storage;
236 }
237
set_type(PyTensorType & type_obj,Backend backend,ScalarType scalarType)238 static void set_type(
239 PyTensorType& type_obj,
240 Backend backend,
241 ScalarType scalarType) {
242 // This field is lazily initialized from backend and scalar_type
243 type_obj.backend = static_cast<int>(backend);
244 type_obj.scalar_type = static_cast<int>(scalarType);
245 type_obj.layout =
246 (THPLayout*)Py_NewRef(torch::getTHPLayout(layout_from_backend(backend)));
247 type_obj.dtype = (THPDtype*)Py_NewRef(torch::getTHPDtype(scalarType));
248 type_obj.is_cuda =
249 (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
250 type_obj.is_xpu =
251 (backend == at::Backend::XPU || backend == at::Backend::SparseXPU);
252 }
253
set_name(PyTensorType & type_obj,const std::string & name)254 static void set_name(PyTensorType& type_obj, const std::string& name) {
255 size_t n = sizeof(type_obj.name);
256 strncpy(type_obj.name, name.c_str(), n);
257 type_obj.name[n - 1] = '\0';
258 }
259
get_tensor_dict()260 static THPObjectPtr get_tensor_dict() {
261 auto torch = THPObjectPtr(PyImport_ImportModule("torch"));
262 if (!torch)
263 throw python_error();
264
265 auto tensor_class = THPObjectPtr(PyObject_GetAttrString(torch, "Tensor"));
266 if (!tensor_class)
267 throw python_error();
268
269 auto tensor_type = (PyTypeObject*)tensor_class.get();
270 TORCH_CHECK(tensor_type->tp_base, "missing base type for Tensor");
271
272 auto res = THPObjectPtr(PyDict_New());
273 if (!res)
274 throw python_error();
275
276 if (PyDict_Merge(res.get(), tensor_type->tp_dict, 0) < 0) {
277 throw python_error();
278 }
279 if (PyDict_Merge(res.get(), tensor_type->tp_base->tp_dict, 0) < 0) {
280 throw python_error();
281 }
282
283 return res;
284 }
285
286 // A note about the lifetime of the various PyTensorType: normally
287 // PyTypeObject instances are statically allocated, but we want to create them
288 // dynamically at init time, because their exact number depends on
289 // torch::utils::all_declared_types(). The memory for each PyTensorType is
290 // allocated by initialize_aten_types() and never freed: technically it's a
291 // leak, but it's not a problem since we want them to be alive for the whole
292 // time of the process anyway.
293 //
294 // An alternative is to use a std::vector<PyTensorType> instead, and let
295 // std::vector to manage the lifetime of its items. This is problematic
296 // though, because it means that the memory of PyTensorType is deallocated at
297 // some point during the exit: if by chance we have another global destructor
298 // and/or atexit() function which tries to access the PyTensorTypes, we risk
299 // an use-after-free error. This happens for example if we embed CPython and
300 // call Py_Finalize inside an atexit() function which was registered before
301 // importing torch.
302 static std::vector<PyTensorType*> tensor_types;
303
set_default_storage_type(Backend backend,ScalarType dtype)304 static void set_default_storage_type(Backend backend, ScalarType dtype) {
305 THPObjectPtr storage = get_storage_obj(backend, dtype);
306
307 auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
308 if (!torch_module)
309 throw python_error();
310
311 if (PyObject_SetAttrString(torch_module.get(), "Storage", storage) != 0) {
312 throw python_error();
313 }
314 }
315
set_default_tensor_type(std::optional<Backend> backend,std::optional<ScalarType> dtype)316 static void set_default_tensor_type(
317 std::optional<Backend> backend,
318 std::optional<ScalarType> dtype) {
319 if (backend.has_value()) {
320 TORCH_CHECK_TYPE(
321 *backend != Backend::Undefined, "default type cannot be undefined");
322 TORCH_CHECK_TYPE(
323 !isSparse(*backend),
324 "only dense types are supported as the default type");
325 }
326 if (dtype.has_value()) {
327 TORCH_CHECK_TYPE(
328 at::isFloatingType(*dtype),
329 "only floating-point types are supported as the default type");
330 }
331
332 // Try setting default storage in python first as it's the only operation that
333 // can fail
334 set_default_storage_type(
335 backend.value_or(default_backend),
336 dtype.value_or(at::get_default_dtype_as_scalartype()));
337
338 if (dtype.has_value()) {
339 at::set_default_dtype(scalarTypeToTypeMeta(*dtype));
340 }
341 if (backend.has_value()) {
342 default_backend = *backend;
343 }
344 }
345
initialize_aten_types(std::vector<PyTensorType * > & tensor_types)346 static void initialize_aten_types(std::vector<PyTensorType*>& tensor_types) {
347 // includes CUDA types even when PyTorch is not built with CUDA
348 auto declared_types = torch::utils::all_declared_types();
349 tensor_types.resize(declared_types.size());
350
351 for (size_t i = 0, end = declared_types.size(); i != end; i++) {
352 tensor_types[i] = new PyTensorType();
353 auto& tensor_type = *tensor_types[i];
354 Backend backend = declared_types[i].first;
355 ScalarType scalar_type = declared_types[i].second;
356 set_type(tensor_type, backend, scalar_type);
357 set_name(tensor_type, get_name(backend, scalar_type));
358 }
359
360 set_default_tensor_type(Backend::CPU, ScalarType::Float);
361 }
362
initialize_python_bindings()363 void initialize_python_bindings() {
364 // Initialize the at::Type* pointers, name, and properties of the PyTensorType
365 // vector. After this call, the vector must not be resized.
366 initialize_aten_types(tensor_types);
367
368 // Initialize the Python metaclass for the torch.FloatTensor, etc. types.
369 // The metaclass handles __instancecheck__ checks and binds the dtype property
370 // on the type objects.
371 py_initialize_metaclass(metaclass);
372
373 // Get the tp_dict of the Variable class. We copy function definitions
374 // onto each Tensor type object so that they can be accessed via e.g.
375 // `torch.FloatTensor.add`.
376 auto tensor_dict = get_tensor_dict();
377
378 // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor,
379 // etc.
380 for (auto& tensor_type : tensor_types) {
381 py_initialize_tensor_type(
382 tensor_type->py_type, tensor_type->name, tensor_dict.get());
383 }
384
385 // Add the type objects to their corresponding modules. e.g. torch.FloatTensor
386 // is added to the `torch` module as `FloatTensor`. Also add all the type
387 // objects to the set torch._tensor_classes.
388 py_bind_tensor_types(tensor_types);
389 }
390
py_bind_tensor_types(const std::vector<PyTensorType * > & tensor_types)391 static void py_bind_tensor_types(
392 const std::vector<PyTensorType*>& tensor_types) {
393 auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
394 if (!torch_module)
395 throw python_error();
396
397 auto tensor_classes = THPObjectPtr(
398 PyObject_GetAttrString(torch_module.get(), "_tensor_classes"));
399 if (!tensor_classes)
400 throw python_error();
401
402 for (auto& tensor_type : tensor_types) {
403 auto name = std::string(tensor_type->name);
404 auto idx = name.rfind('.');
405 auto type_name = name.substr(idx + 1);
406 auto module_name = name.substr(0, idx);
407
408 auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
409 if (!module_obj)
410 throw python_error();
411
412 PyObject* type_obj = (PyObject*)tensor_type;
413 Py_INCREF(type_obj);
414 if (PyModule_AddObject(module_obj.get(), type_name.c_str(), type_obj) < 0) {
415 throw python_error();
416 }
417 if (PySet_Add(tensor_classes.get(), type_obj) < 0) {
418 throw python_error();
419 }
420 }
421 }
422
PyTensorType_Check(PyObject * obj)423 static bool PyTensorType_Check(PyObject* obj) {
424 auto it = std::find_if(
425 tensor_types.begin(), tensor_types.end(), [obj](PyTensorType* x) {
426 return (PyObject*)x == obj;
427 });
428 return it != tensor_types.end();
429 }
430
py_set_default_tensor_type(PyObject * obj)431 void py_set_default_tensor_type(PyObject* obj) {
432 TORCH_WARN_ONCE(
433 "torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, "
434 "please use torch.set_default_dtype() and torch.set_default_device() as alternatives.")
435 TORCH_CHECK_TYPE(
436 PyTensorType_Check(obj),
437 "invalid type object: only floating-point types are supported as the default type");
438 PyTensorType* type = (PyTensorType*)obj;
439 TORCH_CHECK_TYPE(
440 !type->is_cuda || torch::utils::cuda_enabled(),
441 "type ",
442 type->name,
443 " not available. Torch not compiled with CUDA enabled.")
444 set_default_tensor_type(type->get_backend(), type->get_scalar_type());
445 }
446
py_set_default_dtype(PyObject * obj)447 void py_set_default_dtype(PyObject* obj) {
448 TORCH_CHECK_TYPE(
449 THPDtype_Check(obj),
450 "invalid dtype object: only floating-point types are supported as the default type");
451 auto scalar_type = ((THPDtype*)obj)->scalar_type;
452 set_default_tensor_type(/*backend=*/std::nullopt, scalar_type);
453 }
454
get_default_dispatch_key()455 c10::DispatchKey get_default_dispatch_key() {
456 return backendToDispatchKey(default_backend);
457 }
458
get_default_device()459 at::Device get_default_device() {
460 return at::Device(c10::backendToDeviceType(default_backend));
461 }
462
get_default_scalar_type()463 ScalarType get_default_scalar_type() {
464 return get_default_dtype_as_scalartype();
465 }
466
467 } // namespace torch::tensors
468