xref: /aosp_15_r20/external/pytorch/torch/csrc/Generator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Generator.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/CPUGeneratorImpl.h>
5 #include <structmember.h>
6 
7 #include <ATen/core/GeneratorForPrivateuseone.h>
8 #include <ATen/detail/XPUHooksInterface.h>
9 #include <torch/csrc/Device.h>
10 #include <torch/csrc/Exceptions.h>
11 #include <torch/csrc/THP.h>
12 #include <torch/csrc/autograd/generated/VariableType.h>
13 #include <torch/csrc/autograd/generated/variable_factories.h>
14 #include <torch/csrc/autograd/python_variable.h>
15 #include <torch/csrc/utils/python_arg_parser.h>
16 #include <torch/csrc/utils/tensor_types.h>
17 
18 #include <utility>
19 
20 #ifdef USE_CUDA
21 #include <ATen/cuda/CUDAGeneratorImpl.h>
22 #endif
23 
24 #ifdef USE_MPS
25 #include <ATen/mps/MPSGeneratorImpl.h>
26 #endif
27 
28 using namespace at;
29 using namespace torch;
30 
31 PyObject* THPGeneratorClass = nullptr;
32 
THPGenerator_initDefaultGenerator(at::Generator cdata)33 PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) {
34   auto type = (PyTypeObject*)THPGeneratorClass;
35   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
36   if (!self)
37     throw python_error();
38   auto self_ = reinterpret_cast<THPGenerator*>(self.get());
39   self_->cdata = std::move(cdata);
40   return self.release();
41 }
42 
THPGenerator_dealloc(PyObject * _self)43 static void THPGenerator_dealloc(PyObject* _self) {
44   auto self = reinterpret_cast<THPGenerator*>(_self);
45   if (self->cdata.defined()) {
46     self->cdata.set_pyobj(nullptr);
47     self->cdata.~Generator();
48   }
49   Py_TYPE(_self)->tp_free(_self);
50 }
51 
THPGenerator_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)52 static PyObject* THPGenerator_pynew(
53     PyTypeObject* type,
54     PyObject* args,
55     PyObject* kwargs) {
56   HANDLE_TH_ERRORS
57   static torch::PythonArgParser parser({"Generator(Device device=None)"});
58   torch::ParsedArgs<1> parsed_args;
59   auto r = parser.parse(args, kwargs, parsed_args);
60   auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
61 
62   THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
63   if (device.type() == at::kCPU) {
64     self->cdata = make_generator<CPUGeneratorImpl>();
65   }
66 #ifdef USE_CUDA
67   else if (device.type() == at::kCUDA) {
68     self->cdata = make_generator<CUDAGeneratorImpl>(device.index());
69   }
70 #elif USE_MPS
71   else if (device.type() == at::kMPS) {
72     self->cdata = make_generator<MPSGeneratorImpl>();
73   }
74 #endif
75   else if (device.type() == at::kXPU) {
76     self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
77   } else if (device.type() == at::kIPU) {
78     self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
79   } else if (device.type() == at::kPrivateUse1) {
80     self->cdata = at::GetGeneratorForPrivateuse1(device.index());
81   } else {
82     AT_ERROR(
83         "Device type ",
84         c10::DeviceTypeName(device.type()),
85         " is not supported for torch.Generator() api.");
86   }
87   return (PyObject*)self.release();
88   END_HANDLE_TH_ERRORS
89 }
90 
THPGenerator_getState(PyObject * _self,PyObject * noargs)91 static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) {
92   using namespace torch::autograd;
93   HANDLE_TH_ERRORS
94   auto& gen = ((THPGenerator*)_self)->cdata;
95 
96   // See Note [Acquire lock when using random generators]
97   std::scoped_lock<std::mutex> lock(gen.mutex());
98   auto state_tensor = gen.get_state();
99 
100   return THPVariable_Wrap(std::move(state_tensor));
101   END_HANDLE_TH_ERRORS
102 }
103 
THPGenerator_setState(PyObject * _self,PyObject * _new_state)104 static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
105   using namespace torch::autograd;
106 
107   HANDLE_TH_ERRORS
108   if (!THPVariable_Check(_new_state)) {
109     throw torch::TypeError(
110         "expected a torch.ByteTensor, but got %s",
111         Py_TYPE(_new_state)->tp_name);
112   }
113   auto self = (THPGenerator*)_self;
114   auto& gen = self->cdata;
115   const auto& new_state_tensor = THPVariable_Unpack(_new_state);
116 
117   // See Note [Acquire lock when using random generators]
118   std::scoped_lock<std::mutex> lock(gen.mutex());
119   gen.set_state(new_state_tensor);
120 
121   Py_INCREF(self);
122   return (PyObject*)self;
123   END_HANDLE_TH_ERRORS
124 }
125 
unpack_uint64(PyObject * pyobj)126 uint64_t unpack_uint64(PyObject* pyobj) {
127   uint64_t unsigned_obj = 0;
128   try {
129     // First try to interpret as unsigned long
130     unsigned_obj = THPUtils_unpackUInt64(pyobj);
131   } catch (...) {
132     if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
133       // If an overflow happened, then the pyobj could be negative,
134       // so try to interpret it as signed long
135       PyErr_Clear();
136       int64_t obj = THPUtils_unpackLong(pyobj);
137       unsigned_obj = *(reinterpret_cast<uint64_t*>(&obj));
138     } else {
139       // If any other type of exception happened, rethrow it
140       throw;
141     }
142   }
143   return unsigned_obj;
144 }
145 
THPGenerator_graphSafeGetState(PyObject * _self,PyObject * noargs)146 static PyObject* THPGenerator_graphSafeGetState(
147     PyObject* _self,
148     PyObject* noargs) {
149   HANDLE_TH_ERRORS
150   auto& gen = ((THPGenerator*)_self)->cdata;
151 
152   // See Note [Acquire lock when using random generators]
153   std::scoped_lock<std::mutex> lock(gen.mutex());
154 
155   return THPGenerator_Wrap(gen.graphsafe_get_state());
156   END_HANDLE_TH_ERRORS
157 }
158 
THPGenerator_graphSafeSetState(PyObject * _self,PyObject * _state)159 static PyObject* THPGenerator_graphSafeSetState(
160     PyObject* _self,
161     PyObject* _state) {
162   HANDLE_TH_ERRORS
163   auto self = (THPGenerator*)_self;
164   auto& gen = self->cdata;
165 
166   // See Note [Acquire lock when using random generators]
167   std::scoped_lock<std::mutex> lock(gen.mutex());
168   gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
169 
170   Py_INCREF(self);
171   return (PyObject*)self;
172   END_HANDLE_TH_ERRORS
173 }
174 
THPGenerator_cloneState(PyObject * _self,PyObject * noargs)175 static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
176   HANDLE_TH_ERRORS
177   auto& gen = ((THPGenerator*)_self)->cdata;
178 
179   // See Note [Acquire lock when using random generators]
180   std::scoped_lock<std::mutex> lock(gen.mutex());
181   auto new_generator = gen.clone();
182 
183   return THPGenerator_Wrap(new_generator);
184   END_HANDLE_TH_ERRORS
185 }
186 
THPGenerator_manualSeed(PyObject * _self,PyObject * seed)187 static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
188   HANDLE_TH_ERRORS
189   auto self = (THPGenerator*)_self;
190   auto generator = self->cdata;
191   TORCH_CHECK(
192       THPUtils_checkLong(seed),
193       "manual_seed expected a long, "
194       "but got ",
195       THPUtils_typename(seed));
196   uint64_t unsigned_seed = unpack_uint64(seed);
197   // See Note [Acquire lock when using random generators]
198   std::scoped_lock<std::mutex> lock(generator.mutex());
199   generator.set_current_seed(unsigned_seed);
200   Py_INCREF(self);
201   return (PyObject*)self;
202   END_HANDLE_TH_ERRORS
203 }
204 
THPGenerator_setOffset(PyObject * _self,PyObject * offset)205 static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
206   HANDLE_TH_ERRORS
207   auto self = (THPGenerator*)_self;
208   auto generator = self->cdata;
209   TORCH_CHECK(
210       THPUtils_checkLong(offset),
211       "manual_offset expected a long, "
212       "but got ",
213       THPUtils_typename(offset));
214   uint64_t unsigned_offset = unpack_uint64(offset);
215   // See Note [Acquire lock when using random generators]
216   std::scoped_lock<std::mutex> lock(generator.mutex());
217   generator.set_offset(unsigned_offset);
218   Py_INCREF(self);
219   return (PyObject*)self;
220   END_HANDLE_TH_ERRORS
221 }
222 
THPGenerator_seed(PyObject * _self,PyObject * noargs)223 static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
224   HANDLE_TH_ERRORS
225   // See Note [Acquire lock when using random generators]
226   auto self = (THPGenerator*)_self;
227   std::scoped_lock<std::mutex> lock(self->cdata.mutex());
228   uint64_t seed_val = self->cdata.seed();
229   return THPUtils_packUInt64(seed_val);
230   END_HANDLE_TH_ERRORS
231 }
232 
THPGenerator_initialSeed(PyObject * _self,PyObject * noargs)233 static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) {
234   HANDLE_TH_ERRORS
235   auto self = (THPGenerator*)_self;
236   return THPUtils_packUInt64(self->cdata.current_seed());
237   END_HANDLE_TH_ERRORS
238 }
239 
THPGenerator_getOffset(PyObject * _self,PyObject * noargs)240 static PyObject* THPGenerator_getOffset(PyObject* _self, PyObject* noargs) {
241   HANDLE_TH_ERRORS
242   auto self = (THPGenerator*)_self;
243   return THPUtils_packUInt64(self->cdata.get_offset());
244   END_HANDLE_TH_ERRORS
245 }
246 
THPGenerator_get_device(THPGenerator * self,void * unused)247 static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) {
248   HANDLE_TH_ERRORS
249   return THPDevice_New(self->cdata.device());
250   END_HANDLE_TH_ERRORS
251 }
252 
THPGenerator_reduce(PyObject * _self,PyObject * noargs)253 PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) {
254   HANDLE_TH_ERRORS
255   auto self = (THPGenerator*)_self;
256   auto& gen = self->cdata;
257 
258   auto ret = THPObjectPtr{PyTuple_New(3)};
259   if (!ret)
260     throw python_error();
261 
262   py::object torch_module = py::module::import("torch");
263   py::object torch_generator = torch_module.attr("Generator");
264   PyTuple_SET_ITEM(ret.get(), 0, torch_generator.release().ptr());
265 
266   auto args = THPObjectPtr{PyTuple_New(1)};
267   if (!args)
268     throw python_error();
269 
270   PyTuple_SET_ITEM(args.get(), 0, THPGenerator_get_device(self, nullptr));
271   PyTuple_SET_ITEM(ret.get(), 1, args.release());
272 
273   auto state = THPObjectPtr{PyTuple_New(3)};
274   if (!state)
275     throw python_error();
276 
277   c10::DeviceType device_type = gen.device().type();
278   PyTuple_SET_ITEM(state.get(), 0, THPGenerator_initialSeed(_self, nullptr));
279   PyTuple_SET_ITEM(
280       state.get(),
281       1,
282       device_type != at::kCPU ? THPGenerator_getOffset(_self, nullptr)
283                               : Py_None);
284   PyTuple_SET_ITEM(state.get(), 2, THPGenerator_getState(_self, nullptr));
285   PyTuple_SET_ITEM(ret.get(), 2, state.release());
286 
287   return ret.release();
288   END_HANDLE_TH_ERRORS
289 }
290 
THPGenerator_pickleSetState(PyObject * _self,PyObject * state)291 static PyObject* THPGenerator_pickleSetState(PyObject* _self, PyObject* state) {
292   HANDLE_TH_ERRORS
293   THPGenerator_manualSeed(_self, PyTuple_GET_ITEM(state, 0));
294   auto& offset = PyTuple_GET_ITEM(state, 1);
295   if (offset != Py_None) {
296     THPGenerator_setOffset(_self, offset);
297   }
298   THPGenerator_setState(_self, PyTuple_GET_ITEM(state, 2));
299   Py_RETURN_NONE;
300   END_HANDLE_TH_ERRORS
301 }
302 
303 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
304 static struct PyGetSetDef THPGenerator_properties[] = {
305     {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
306     {nullptr}};
307 
308 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
309 static PyMethodDef THPGenerator_methods[] = {
310     {"__reduce__", THPGenerator_reduce, METH_NOARGS, nullptr},
311     {"__setstate__", THPGenerator_pickleSetState, METH_O, nullptr},
312     {"get_state", THPGenerator_getState, METH_NOARGS, nullptr},
313     {"set_state", THPGenerator_setState, METH_O, nullptr},
314     {"clone_state", THPGenerator_cloneState, METH_NOARGS, nullptr},
315     {"graphsafe_get_state",
316      THPGenerator_graphSafeGetState,
317      METH_NOARGS,
318      nullptr},
319     {"graphsafe_set_state", THPGenerator_graphSafeSetState, METH_O, nullptr},
320     {"set_offset", THPGenerator_setOffset, METH_O, nullptr},
321     {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr},
322     {"seed", THPGenerator_seed, METH_NOARGS, nullptr},
323     {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr},
324     {"get_offset", THPGenerator_getOffset, METH_NOARGS, nullptr},
325     {nullptr}};
326 
327 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
328 static struct PyMemberDef THPGenerator_members[] = {
329     {"_cdata", T_ULONGLONG, offsetof(THPGenerator, cdata), READONLY, nullptr},
330     {nullptr}};
331 
332 PyTypeObject THPGeneratorType = {
333     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.Generator", /* tp_name */
334     sizeof(THPGenerator), /* tp_basicsize */
335     0, /* tp_itemsize */
336     THPGenerator_dealloc, /* tp_dealloc */
337     0, /* tp_vectorcall_offset */
338     nullptr, /* tp_getattr */
339     nullptr, /* tp_setattr */
340     nullptr, /* tp_reserved */
341     nullptr, /* tp_repr */
342     nullptr, /* tp_as_number */
343     nullptr, /* tp_as_sequence */
344     nullptr, /* tp_as_mapping */
345     nullptr, /* tp_hash  */
346     nullptr, /* tp_call */
347     nullptr, /* tp_str */
348     nullptr, /* tp_getattro */
349     nullptr, /* tp_setattro */
350     nullptr, /* tp_as_buffer */
351     // NOLINTNEXTLINE(misc-redundant-expression)
352     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
353     nullptr, /* tp_doc */
354     nullptr, /* tp_traverse */
355     nullptr, /* tp_clear */
356     nullptr, /* tp_richcompare */
357     0, /* tp_weaklistoffset */
358     nullptr, /* tp_iter */
359     nullptr, /* tp_iternext */
360     THPGenerator_methods, /* tp_methods */
361     THPGenerator_members, /* tp_members */
362     THPGenerator_properties, /* tp_getset */
363     nullptr, /* tp_base */
364     nullptr, /* tp_dict */
365     nullptr, /* tp_descr_get */
366     nullptr, /* tp_descr_set */
367     0, /* tp_dictoffset */
368     nullptr, /* tp_init */
369     nullptr, /* tp_alloc */
370     THPGenerator_pynew, /* tp_new */
371 };
372 
THPGenerator_init(PyObject * module)373 bool THPGenerator_init(PyObject* module) {
374   THPGeneratorClass = (PyObject*)&THPGeneratorType;
375   if (PyType_Ready(&THPGeneratorType) < 0)
376     return false;
377   Py_INCREF(&THPGeneratorType);
378   PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType);
379   return true;
380 }
381 
set_pyobj(const Generator & self,PyObject * pyobj)382 void set_pyobj(const Generator& self, PyObject* pyobj) {
383   TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined generator");
384   self.set_pyobj(pyobj);
385 }
386 
pyobj(const Generator & self)387 PyObject* pyobj(const Generator& self) {
388   TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined generator");
389   return self.pyobj();
390 }
391 
THPGenerator_Wrap(Generator gen)392 PyObject* THPGenerator_Wrap(Generator gen) {
393   if (!gen.defined()) {
394     Py_RETURN_NONE;
395   }
396 
397   if (auto obj = pyobj(gen)) {
398     Py_INCREF(obj);
399     return obj;
400   }
401 
402   return THPGenerator_NewWithVar(
403       (PyTypeObject*)THPGeneratorClass, std::move(gen));
404 }
405 
THPGenerator_Unwrap(PyObject * state)406 at::Generator THPGenerator_Unwrap(PyObject* state) {
407   if (!Py_IS_TYPE(state, &THPGeneratorType)) {
408     throw torch::TypeError(
409         "expected a Generator, but got %s", Py_TYPE(state)->tp_name);
410   }
411   return reinterpret_cast<THPGenerator*>(state)->cdata;
412 }
413 
414 // Creates a new Python object for a Generator. The Generator must not already
415 // have a PyObject* associated with it.
THPGenerator_NewWithVar(PyTypeObject * type,Generator gen)416 PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
417   PyObject* obj = type->tp_alloc(type, 0);
418   if (obj) {
419     auto g = (THPGenerator*)obj;
420     new (&g->cdata) Generator(std::move(gen));
421     set_pyobj(g->cdata, obj);
422   }
423   return obj;
424 }
425