1 #include <torch/csrc/python_headers.h>
2
3 #include <torch/csrc/Device.h>
4 #include <torch/csrc/Dtype.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/Layout.h>
8 #include <torch/csrc/Storage.h>
9 #include <torch/csrc/autograd/generated/VariableType.h>
10 #include <torch/csrc/utils/cuda_enabled.h>
11 #include <torch/csrc/utils/device_lazy_init.h>
12 #include <torch/csrc/utils/object_ptr.h>
13
14 #include <ATen/ATen.h>
15 #include <ATen/FunctionalStorageImpl.h>
16
17 #include <array>
18 #include <stdexcept>
19
20 namespace torch {
21 namespace {
22 std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
23 dtype_registry = {};
24
25 std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
26 layout_registry = {};
27
28 } // namespace
29
registerDtypeObject(THPDtype * dtype,at::ScalarType scalarType)30 void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
31 dtype_registry[static_cast<int>(scalarType)] = dtype;
32 }
33
registerLayoutObject(THPLayout * thp_layout,at::Layout layout)34 void registerLayoutObject(THPLayout* thp_layout, at::Layout layout) {
35 layout_registry[static_cast<int>(layout)] = thp_layout;
36 }
37
getTHPDtype(at::ScalarType scalarType)38 THPDtype* getTHPDtype(at::ScalarType scalarType) {
39 auto dtype = dtype_registry[static_cast<int>(scalarType)];
40 if (!dtype) {
41 throw std::invalid_argument("unsupported scalarType");
42 }
43 return dtype;
44 }
45
getTHPLayout(at::Layout layout)46 THPLayout* getTHPLayout(at::Layout layout) {
47 auto thp_layout = layout_registry[static_cast<int>(layout)];
48 if (!thp_layout) {
49 throw std::invalid_argument("unsupported at::Layout");
50 }
51 return thp_layout;
52 }
53
createPyObject(const at::Storage & storage)54 PyObject* createPyObject(const at::Storage& storage) {
55 // Note [Invalid Python Storages]
56 // When a user creates a python tensor wrapper subclass, the subclass
57 // is a tensor object that has a nullptr storage.
58 // We still allow users to call `my_subclass.untyped_storage()`, and get back
59 // a valid storage object (this can be useful for detecting aliasing
60 // information about storages from python). However, any accesses to the
61 // data_ptr is not allowed, through methods like
62 // x.untyped_storage().data_ptr()
63 PyObject* obj = THPStorage_Wrap(storage);
64 if (!obj)
65 throw python_error();
66 return obj;
67 }
68
loadTypedStorageTypeObject()69 PyTypeObject* loadTypedStorageTypeObject() {
70 PyObject* storage_module = PyImport_ImportModule("torch.storage");
71 TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
72
73 PyObject* typed_storage_obj =
74 PyObject_GetAttrString(storage_module, "TypedStorage");
75 TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
76 return reinterpret_cast<PyTypeObject*>(
77 PyObject_GetAttrString(storage_module, "TypedStorage"));
78 }
79
getTypedStorageTypeObject()80 PyTypeObject* getTypedStorageTypeObject() {
81 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
82 static PyTypeObject* typed_storage_type_obj = loadTypedStorageTypeObject();
83 return typed_storage_type_obj;
84 }
85
isStorage(PyObject * obj)86 bool isStorage(PyObject* obj) {
87 if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
88 return true;
89 }
90 return THPStorage_Check(obj);
91 }
92
createStorageGetType(PyObject * obj)93 std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
94 PyObject* obj) {
95 at::ScalarType scalar_type = at::ScalarType::Undefined;
96 bool is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
97 PyObject* untyped_storage_obj = nullptr;
98
99 if (is_typed_storage) {
100 // NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
101 // `_untyped_storage`, so we must decrement them. The refcounts will still
102 // stay nonzero since the `TypedStorage` maintains a reference.
103 PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
104 TORCH_INTERNAL_ASSERT(dtype_obj);
105 TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
106 scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
107 Py_DECREF(dtype_obj);
108
109 untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
110 TORCH_INTERNAL_ASSERT(untyped_storage_obj);
111 Py_DECREF(untyped_storage_obj);
112
113 } else {
114 scalar_type = at::kByte;
115 untyped_storage_obj = obj;
116 }
117
118 TORCH_CHECK(
119 THPStorage_Check(untyped_storage_obj),
120 "not a storage '",
121 Py_TYPE(obj)->tp_name,
122 "'");
123
124 auto storage = THPStorage_Unpack(untyped_storage_obj);
125 return std::make_tuple(storage, scalar_type, is_typed_storage);
126 }
127
createStorage(PyObject * obj)128 at::Storage createStorage(PyObject* obj) {
129 return std::get<0>(createStorageGetType(obj));
130 }
131
132 } // namespace torch
133