xref: /aosp_15_r20/external/pytorch/torch/csrc/Device.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Device.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/python_strings.h>
9 
10 #include <ATen/Device.h>
11 #include <c10/util/Exception.h>
12 
13 #include <structmember.h>
14 #include <limits>
15 #include <sstream>
16 
17 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
18 PyObject* THPUpperModuleOfDevice = nullptr;
19 
THPDevice_New(const at::Device & device)20 PyObject* THPDevice_New(const at::Device& device) {
21   auto type = (PyTypeObject*)&THPDeviceType;
22   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
23   if (!self)
24     throw python_error();
25   auto self_ = reinterpret_cast<THPDevice*>(self.get());
26   self_->device = device;
27   return self.release();
28 }
29 
THPDevice_repr(THPDevice * self)30 PyObject* THPDevice_repr(THPDevice* self) {
31   std::ostringstream oss;
32   oss << "device(type=\'" << self->device.type() << "\'";
33   if (self->device.has_index()) {
34     // `self->device.index()` returns uint8_t which is treated as ascii while
35     // printing, hence casting it to uint16_t.
36     // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
37     oss << ", index=" << static_cast<uint16_t>(self->device.index());
38   }
39   oss << ")";
40   return THPUtils_packString(oss.str().c_str());
41 }
42 
THPDevice_str(THPDevice * self)43 PyObject* THPDevice_str(THPDevice* self) {
44   std::ostringstream oss;
45   oss << self->device;
46   return THPUtils_packString(oss.str().c_str());
47 }
48 
THPDevice_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)49 PyObject* THPDevice_pynew(
50     PyTypeObject* type,
51     PyObject* args,
52     PyObject* kwargs) {
53   HANDLE_TH_ERRORS
54   static torch::PythonArgParser parser(
55       {"device(Device device)",
56        "device(c10::string_view type, int64_t? index=-1)"});
57   torch::ParsedArgs<2> parsed_args;
58   auto r = parser.parse(args, kwargs, parsed_args);
59   if (r.has_torch_function()) {
60     return handle_torch_function(
61         r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
62   }
63   if (r.idx == 0) {
64     auto device = r.device(0);
65     return THPDevice_New(device);
66   } else if (r.idx == 1) {
67     auto as_device = r.device(0); // this works, because device can take strings
68     if (as_device.has_index()) {
69       auto device_type = r.string(0);
70       throw std::runtime_error(
71           "type (string) must not include an index because index "
72           "was passed explicitly: " +
73           device_type);
74     }
75     int64_t device_index = -1;
76     if (!r.isNone(1)) {
77       device_index = r.toInt64(1);
78       // -1 is allowed in ATen/C++, to mean the default device, but not in
79       // Python.
80       TORCH_CHECK(device_index >= 0, "Device index must not be negative");
81     }
82     at::Device device(
83         as_device.type(), static_cast<c10::DeviceIndex>(device_index));
84     return THPDevice_New(device);
85   }
86   Py_RETURN_NONE;
87   END_HANDLE_TH_ERRORS
88 }
89 
THPDevice_type(THPDevice * self,PyObject * noargs)90 PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
91   HANDLE_TH_ERRORS
92   std::ostringstream oss;
93   oss << self->device.type();
94   return THPUtils_packString(oss.str().c_str());
95   Py_RETURN_NONE;
96   END_HANDLE_TH_ERRORS
97 }
98 
THPDevice_index(THPDevice * self,PyObject * noargs)99 PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
100   HANDLE_TH_ERRORS
101   if (self->device.has_index()) {
102     return THPUtils_packInt64(self->device.index());
103   } else {
104     Py_RETURN_NONE;
105   }
106   END_HANDLE_TH_ERRORS
107 }
108 
THPDevice_hash(THPDevice * self)109 static Py_ssize_t THPDevice_hash(THPDevice* self) {
110   HANDLE_TH_ERRORS
111   return static_cast<Py_ssize_t>(
112       std::hash<at::Device>{}(self->device) %
113       std::numeric_limits<Py_ssize_t>::max());
114   END_HANDLE_TH_ERRORS_RET(-1)
115 }
116 
THPDevice_rc(PyObject * a,PyObject * b,int op)117 PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
118   HANDLE_TH_ERRORS
119   if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
120     // Py_RETURN_NOTIMPLEMENTED not in python 2.
121     Py_INCREF(Py_NotImplemented);
122     return Py_NotImplemented;
123   }
124   THPDevice* da = reinterpret_cast<THPDevice*>(a);
125   THPDevice* db = reinterpret_cast<THPDevice*>(b);
126 
127   switch (op) {
128     case Py_EQ:
129       if (da->device == db->device) {
130         Py_RETURN_TRUE;
131       } else {
132         Py_RETURN_FALSE;
133       }
134     case Py_NE:
135       if (da->device == db->device) {
136         Py_RETURN_FALSE;
137       } else {
138         Py_RETURN_TRUE;
139       }
140     case Py_LT:
141     case Py_LE:
142     case Py_GT:
143     case Py_GE:
144       throw torch::TypeError("comparison not implemented");
145     default:
146       throw torch::TypeError("unexpected comparison op");
147   }
148   END_HANDLE_TH_ERRORS
149 }
150 
THPDevice_reduce(PyObject * _self,PyObject * noargs)151 PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
152   HANDLE_TH_ERRORS
153   auto self = (THPDevice*)_self;
154   auto ret = THPObjectPtr{PyTuple_New(2)};
155   if (!ret)
156     throw python_error();
157 
158   py::object torch_module = py::module::import("torch");
159   py::object torch_device = torch_module.attr("device");
160   PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
161 
162   THPObjectPtr args;
163   std::ostringstream oss;
164   oss << self->device.type();
165   if (self->device.has_index()) {
166     args = THPObjectPtr{Py_BuildValue(
167         "(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
168   } else {
169     args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
170   }
171   if (!args)
172     throw python_error();
173   PyTuple_SET_ITEM(ret.get(), 1, args.release());
174 
175   return ret.release();
176   END_HANDLE_TH_ERRORS
177 }
178 
THPDevice_enter(PyObject * self,PyObject * noargs)179 PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
180   HANDLE_TH_ERRORS
181   py::object mode = py::module::import("torch.utils._device")
182                         .attr("DeviceContext")(py::handle(self));
183   at::impl::PythonTorchFunctionTLS::push_onto_stack(
184       std::make_shared<c10::SafePyObject>(
185           mode.release().ptr(), getPyInterpreter()));
186   // So that with torch.device('cuda') as dev: works
187   Py_INCREF(self);
188   return self;
189   END_HANDLE_TH_ERRORS
190 }
191 
THPDevice_exit(PyObject * self,PyObject * unused)192 PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
193   HANDLE_TH_ERRORS
194   at::impl::PythonTorchFunctionTLS::pop_stack();
195   Py_RETURN_NONE;
196   END_HANDLE_TH_ERRORS
197 }
198 
THPDevice_call(PyObject * self,PyObject * args,PyObject * kwargs)199 PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
200   HANDLE_TH_ERRORS
201   py::object deco =
202       py::module::import("torch.utils._device").attr("device_decorator");
203   return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
204       .release()
205       .ptr();
206   END_HANDLE_TH_ERRORS
207 }
208 
209 typedef PyObject* (*getter)(PyObject*, void*);
210 
211 // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
212 
213 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
214 static struct PyGetSetDef THPDevice_properties[] = {
215     {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
216     {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
217     {nullptr}};
218 
219 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
220 static PyMethodDef THPDevice_methods[] = {
221     {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
222     {"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
223     {"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
224     {nullptr} /* Sentinel */
225 };
226 
227 PyTypeObject THPDeviceType = {
228     PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
229     sizeof(THPDevice), /* tp_basicsize */
230     0, /* tp_itemsize */
231     nullptr, /* tp_dealloc */
232     0, /* tp_vectorcall_offset */
233     nullptr, /* tp_getattr */
234     nullptr, /* tp_setattr */
235     nullptr, /* tp_reserved */
236     (reprfunc)THPDevice_repr, /* tp_repr */
237     nullptr, /* tp_as_number */
238     nullptr, /* tp_as_sequence */
239     nullptr, /* tp_as_mapping */
240     (hashfunc)THPDevice_hash, /* tp_hash  */
241     // TODO: We're not sure if this is a good idea or not, because making
242     // torch.device callable means that it will start returning true
243     // for callable() queries, and that is unexpected.  We can always add
244     // this later, so for now, don't actually implement this
245     // THPDevice_call, /* tp_call */
246     nullptr, /* tp_call */
247     (reprfunc)THPDevice_str, /* tp_str */
248     nullptr, /* tp_getattro */
249     nullptr, /* tp_setattro */
250     nullptr, /* tp_as_buffer */
251     Py_TPFLAGS_DEFAULT, /* tp_flags */
252     nullptr, /* tp_doc */
253     nullptr, /* tp_traverse */
254     nullptr, /* tp_clear */
255     (richcmpfunc)THPDevice_rc, /* tp_richcompare */
256     0, /* tp_weaklistoffset */
257     nullptr, /* tp_iter */
258     nullptr, /* tp_iternext */
259     THPDevice_methods, /* tp_methods */
260     nullptr, /* tp_members */
261     THPDevice_properties, /* tp_getset */
262     nullptr, /* tp_base */
263     nullptr, /* tp_dict */
264     nullptr, /* tp_descr_get */
265     nullptr, /* tp_descr_set */
266     0, /* tp_dictoffset */
267     nullptr, /* tp_init */
268     nullptr, /* tp_alloc */
269     THPDevice_pynew, /* tp_new */
270 };
271 
THPDevice_init(PyObject * module)272 void THPDevice_init(PyObject* module) {
273   if (PyType_Ready(&THPDeviceType) < 0) {
274     throw python_error();
275   }
276   Py_INCREF(&THPDeviceType);
277   THPUpperModuleOfDevice = module;
278   if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
279     throw python_error();
280   }
281 }
282