1 #include <torch/csrc/Device.h>
2
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9
10 #include <ATen/Device.h>
11 #include <c10/util/Exception.h>
12
13 #include <structmember.h>
14 #include <limits>
15 #include <sstream>
16
17 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
18 PyObject* THPUpperModuleOfDevice = nullptr;
19
THPDevice_New(const at::Device & device)20 PyObject* THPDevice_New(const at::Device& device) {
21 auto type = (PyTypeObject*)&THPDeviceType;
22 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
23 if (!self)
24 throw python_error();
25 auto self_ = reinterpret_cast<THPDevice*>(self.get());
26 self_->device = device;
27 return self.release();
28 }
29
THPDevice_repr(THPDevice * self)30 PyObject* THPDevice_repr(THPDevice* self) {
31 std::ostringstream oss;
32 oss << "device(type=\'" << self->device.type() << "\'";
33 if (self->device.has_index()) {
34 // `self->device.index()` returns uint8_t which is treated as ascii while
35 // printing, hence casting it to uint16_t.
36 // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
37 oss << ", index=" << static_cast<uint16_t>(self->device.index());
38 }
39 oss << ")";
40 return THPUtils_packString(oss.str().c_str());
41 }
42
THPDevice_str(THPDevice * self)43 PyObject* THPDevice_str(THPDevice* self) {
44 std::ostringstream oss;
45 oss << self->device;
46 return THPUtils_packString(oss.str().c_str());
47 }
48
THPDevice_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)49 PyObject* THPDevice_pynew(
50 PyTypeObject* type,
51 PyObject* args,
52 PyObject* kwargs) {
53 HANDLE_TH_ERRORS
54 static torch::PythonArgParser parser(
55 {"device(Device device)",
56 "device(c10::string_view type, int64_t? index=-1)"});
57 torch::ParsedArgs<2> parsed_args;
58 auto r = parser.parse(args, kwargs, parsed_args);
59 if (r.has_torch_function()) {
60 return handle_torch_function(
61 r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
62 }
63 if (r.idx == 0) {
64 auto device = r.device(0);
65 return THPDevice_New(device);
66 } else if (r.idx == 1) {
67 auto as_device = r.device(0); // this works, because device can take strings
68 if (as_device.has_index()) {
69 auto device_type = r.string(0);
70 throw std::runtime_error(
71 "type (string) must not include an index because index "
72 "was passed explicitly: " +
73 device_type);
74 }
75 int64_t device_index = -1;
76 if (!r.isNone(1)) {
77 device_index = r.toInt64(1);
78 // -1 is allowed in ATen/C++, to mean the default device, but not in
79 // Python.
80 TORCH_CHECK(device_index >= 0, "Device index must not be negative");
81 }
82 at::Device device(
83 as_device.type(), static_cast<c10::DeviceIndex>(device_index));
84 return THPDevice_New(device);
85 }
86 Py_RETURN_NONE;
87 END_HANDLE_TH_ERRORS
88 }
89
THPDevice_type(THPDevice * self,PyObject * noargs)90 PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
91 HANDLE_TH_ERRORS
92 std::ostringstream oss;
93 oss << self->device.type();
94 return THPUtils_packString(oss.str().c_str());
95 Py_RETURN_NONE;
96 END_HANDLE_TH_ERRORS
97 }
98
THPDevice_index(THPDevice * self,PyObject * noargs)99 PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
100 HANDLE_TH_ERRORS
101 if (self->device.has_index()) {
102 return THPUtils_packInt64(self->device.index());
103 } else {
104 Py_RETURN_NONE;
105 }
106 END_HANDLE_TH_ERRORS
107 }
108
THPDevice_hash(THPDevice * self)109 static Py_ssize_t THPDevice_hash(THPDevice* self) {
110 HANDLE_TH_ERRORS
111 return static_cast<Py_ssize_t>(
112 std::hash<at::Device>{}(self->device) %
113 std::numeric_limits<Py_ssize_t>::max());
114 END_HANDLE_TH_ERRORS_RET(-1)
115 }
116
THPDevice_rc(PyObject * a,PyObject * b,int op)117 PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
118 HANDLE_TH_ERRORS
119 if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
120 // Py_RETURN_NOTIMPLEMENTED not in python 2.
121 Py_INCREF(Py_NotImplemented);
122 return Py_NotImplemented;
123 }
124 THPDevice* da = reinterpret_cast<THPDevice*>(a);
125 THPDevice* db = reinterpret_cast<THPDevice*>(b);
126
127 switch (op) {
128 case Py_EQ:
129 if (da->device == db->device) {
130 Py_RETURN_TRUE;
131 } else {
132 Py_RETURN_FALSE;
133 }
134 case Py_NE:
135 if (da->device == db->device) {
136 Py_RETURN_FALSE;
137 } else {
138 Py_RETURN_TRUE;
139 }
140 case Py_LT:
141 case Py_LE:
142 case Py_GT:
143 case Py_GE:
144 throw torch::TypeError("comparison not implemented");
145 default:
146 throw torch::TypeError("unexpected comparison op");
147 }
148 END_HANDLE_TH_ERRORS
149 }
150
THPDevice_reduce(PyObject * _self,PyObject * noargs)151 PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
152 HANDLE_TH_ERRORS
153 auto self = (THPDevice*)_self;
154 auto ret = THPObjectPtr{PyTuple_New(2)};
155 if (!ret)
156 throw python_error();
157
158 py::object torch_module = py::module::import("torch");
159 py::object torch_device = torch_module.attr("device");
160 PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
161
162 THPObjectPtr args;
163 std::ostringstream oss;
164 oss << self->device.type();
165 if (self->device.has_index()) {
166 args = THPObjectPtr{Py_BuildValue(
167 "(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
168 } else {
169 args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
170 }
171 if (!args)
172 throw python_error();
173 PyTuple_SET_ITEM(ret.get(), 1, args.release());
174
175 return ret.release();
176 END_HANDLE_TH_ERRORS
177 }
178
THPDevice_enter(PyObject * self,PyObject * noargs)179 PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
180 HANDLE_TH_ERRORS
181 py::object mode = py::module::import("torch.utils._device")
182 .attr("DeviceContext")(py::handle(self));
183 at::impl::PythonTorchFunctionTLS::push_onto_stack(
184 std::make_shared<c10::SafePyObject>(
185 mode.release().ptr(), getPyInterpreter()));
186 // So that with torch.device('cuda') as dev: works
187 Py_INCREF(self);
188 return self;
189 END_HANDLE_TH_ERRORS
190 }
191
THPDevice_exit(PyObject * self,PyObject * unused)192 PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
193 HANDLE_TH_ERRORS
194 at::impl::PythonTorchFunctionTLS::pop_stack();
195 Py_RETURN_NONE;
196 END_HANDLE_TH_ERRORS
197 }
198
THPDevice_call(PyObject * self,PyObject * args,PyObject * kwargs)199 PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
200 HANDLE_TH_ERRORS
201 py::object deco =
202 py::module::import("torch.utils._device").attr("device_decorator");
203 return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
204 .release()
205 .ptr();
206 END_HANDLE_TH_ERRORS
207 }
208
209 typedef PyObject* (*getter)(PyObject*, void*);
210
211 // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
212
213 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
214 static struct PyGetSetDef THPDevice_properties[] = {
215 {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
216 {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
217 {nullptr}};
218
219 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
220 static PyMethodDef THPDevice_methods[] = {
221 {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
222 {"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
223 {"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
224 {nullptr} /* Sentinel */
225 };
226
227 PyTypeObject THPDeviceType = {
228 PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
229 sizeof(THPDevice), /* tp_basicsize */
230 0, /* tp_itemsize */
231 nullptr, /* tp_dealloc */
232 0, /* tp_vectorcall_offset */
233 nullptr, /* tp_getattr */
234 nullptr, /* tp_setattr */
235 nullptr, /* tp_reserved */
236 (reprfunc)THPDevice_repr, /* tp_repr */
237 nullptr, /* tp_as_number */
238 nullptr, /* tp_as_sequence */
239 nullptr, /* tp_as_mapping */
240 (hashfunc)THPDevice_hash, /* tp_hash */
241 // TODO: We're not sure if this is a good idea or not, because making
242 // torch.device callable means that it will start returning true
243 // for callable() queries, and that is unexpected. We can always add
244 // this later, so for now, don't actually implement this
245 // THPDevice_call, /* tp_call */
246 nullptr, /* tp_call */
247 (reprfunc)THPDevice_str, /* tp_str */
248 nullptr, /* tp_getattro */
249 nullptr, /* tp_setattro */
250 nullptr, /* tp_as_buffer */
251 Py_TPFLAGS_DEFAULT, /* tp_flags */
252 nullptr, /* tp_doc */
253 nullptr, /* tp_traverse */
254 nullptr, /* tp_clear */
255 (richcmpfunc)THPDevice_rc, /* tp_richcompare */
256 0, /* tp_weaklistoffset */
257 nullptr, /* tp_iter */
258 nullptr, /* tp_iternext */
259 THPDevice_methods, /* tp_methods */
260 nullptr, /* tp_members */
261 THPDevice_properties, /* tp_getset */
262 nullptr, /* tp_base */
263 nullptr, /* tp_dict */
264 nullptr, /* tp_descr_get */
265 nullptr, /* tp_descr_set */
266 0, /* tp_dictoffset */
267 nullptr, /* tp_init */
268 nullptr, /* tp_alloc */
269 THPDevice_pynew, /* tp_new */
270 };
271
THPDevice_init(PyObject * module)272 void THPDevice_init(PyObject* module) {
273 if (PyType_Ready(&THPDeviceType) < 0) {
274 throw python_error();
275 }
276 Py_INCREF(&THPDeviceType);
277 THPUpperModuleOfDevice = module;
278 if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
279 throw python_error();
280 }
281 }
282