1 #include <torch/csrc/Generator.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/CPUGeneratorImpl.h>
5 #include <structmember.h>
6
7 #include <ATen/core/GeneratorForPrivateuseone.h>
8 #include <ATen/detail/XPUHooksInterface.h>
9 #include <torch/csrc/Device.h>
10 #include <torch/csrc/Exceptions.h>
11 #include <torch/csrc/THP.h>
12 #include <torch/csrc/autograd/generated/VariableType.h>
13 #include <torch/csrc/autograd/generated/variable_factories.h>
14 #include <torch/csrc/autograd/python_variable.h>
15 #include <torch/csrc/utils/python_arg_parser.h>
16 #include <torch/csrc/utils/tensor_types.h>
17
18 #include <utility>
19
20 #ifdef USE_CUDA
21 #include <ATen/cuda/CUDAGeneratorImpl.h>
22 #endif
23
24 #ifdef USE_MPS
25 #include <ATen/mps/MPSGeneratorImpl.h>
26 #endif
27
28 using namespace at;
29 using namespace torch;
30
31 PyObject* THPGeneratorClass = nullptr;
32
THPGenerator_initDefaultGenerator(at::Generator cdata)33 PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) {
34 auto type = (PyTypeObject*)THPGeneratorClass;
35 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
36 if (!self)
37 throw python_error();
38 auto self_ = reinterpret_cast<THPGenerator*>(self.get());
39 self_->cdata = std::move(cdata);
40 return self.release();
41 }
42
THPGenerator_dealloc(PyObject * _self)43 static void THPGenerator_dealloc(PyObject* _self) {
44 auto self = reinterpret_cast<THPGenerator*>(_self);
45 if (self->cdata.defined()) {
46 self->cdata.set_pyobj(nullptr);
47 self->cdata.~Generator();
48 }
49 Py_TYPE(_self)->tp_free(_self);
50 }
51
THPGenerator_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)52 static PyObject* THPGenerator_pynew(
53 PyTypeObject* type,
54 PyObject* args,
55 PyObject* kwargs) {
56 HANDLE_TH_ERRORS
57 static torch::PythonArgParser parser({"Generator(Device device=None)"});
58 torch::ParsedArgs<1> parsed_args;
59 auto r = parser.parse(args, kwargs, parsed_args);
60 auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
61
62 THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
63 if (device.type() == at::kCPU) {
64 self->cdata = make_generator<CPUGeneratorImpl>();
65 }
66 #ifdef USE_CUDA
67 else if (device.type() == at::kCUDA) {
68 self->cdata = make_generator<CUDAGeneratorImpl>(device.index());
69 }
70 #elif USE_MPS
71 else if (device.type() == at::kMPS) {
72 self->cdata = make_generator<MPSGeneratorImpl>();
73 }
74 #endif
75 else if (device.type() == at::kXPU) {
76 self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
77 } else if (device.type() == at::kIPU) {
78 self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
79 } else if (device.type() == at::kPrivateUse1) {
80 self->cdata = at::GetGeneratorForPrivateuse1(device.index());
81 } else {
82 AT_ERROR(
83 "Device type ",
84 c10::DeviceTypeName(device.type()),
85 " is not supported for torch.Generator() api.");
86 }
87 return (PyObject*)self.release();
88 END_HANDLE_TH_ERRORS
89 }
90
THPGenerator_getState(PyObject * _self,PyObject * noargs)91 static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) {
92 using namespace torch::autograd;
93 HANDLE_TH_ERRORS
94 auto& gen = ((THPGenerator*)_self)->cdata;
95
96 // See Note [Acquire lock when using random generators]
97 std::scoped_lock<std::mutex> lock(gen.mutex());
98 auto state_tensor = gen.get_state();
99
100 return THPVariable_Wrap(std::move(state_tensor));
101 END_HANDLE_TH_ERRORS
102 }
103
THPGenerator_setState(PyObject * _self,PyObject * _new_state)104 static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
105 using namespace torch::autograd;
106
107 HANDLE_TH_ERRORS
108 if (!THPVariable_Check(_new_state)) {
109 throw torch::TypeError(
110 "expected a torch.ByteTensor, but got %s",
111 Py_TYPE(_new_state)->tp_name);
112 }
113 auto self = (THPGenerator*)_self;
114 auto& gen = self->cdata;
115 const auto& new_state_tensor = THPVariable_Unpack(_new_state);
116
117 // See Note [Acquire lock when using random generators]
118 std::scoped_lock<std::mutex> lock(gen.mutex());
119 gen.set_state(new_state_tensor);
120
121 Py_INCREF(self);
122 return (PyObject*)self;
123 END_HANDLE_TH_ERRORS
124 }
125
unpack_uint64(PyObject * pyobj)126 uint64_t unpack_uint64(PyObject* pyobj) {
127 uint64_t unsigned_obj = 0;
128 try {
129 // First try to interpret as unsigned long
130 unsigned_obj = THPUtils_unpackUInt64(pyobj);
131 } catch (...) {
132 if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
133 // If an overflow happened, then the pyobj could be negative,
134 // so try to interpret it as signed long
135 PyErr_Clear();
136 int64_t obj = THPUtils_unpackLong(pyobj);
137 unsigned_obj = *(reinterpret_cast<uint64_t*>(&obj));
138 } else {
139 // If any other type of exception happened, rethrow it
140 throw;
141 }
142 }
143 return unsigned_obj;
144 }
145
THPGenerator_graphSafeGetState(PyObject * _self,PyObject * noargs)146 static PyObject* THPGenerator_graphSafeGetState(
147 PyObject* _self,
148 PyObject* noargs) {
149 HANDLE_TH_ERRORS
150 auto& gen = ((THPGenerator*)_self)->cdata;
151
152 // See Note [Acquire lock when using random generators]
153 std::scoped_lock<std::mutex> lock(gen.mutex());
154
155 return THPGenerator_Wrap(gen.graphsafe_get_state());
156 END_HANDLE_TH_ERRORS
157 }
158
THPGenerator_graphSafeSetState(PyObject * _self,PyObject * _state)159 static PyObject* THPGenerator_graphSafeSetState(
160 PyObject* _self,
161 PyObject* _state) {
162 HANDLE_TH_ERRORS
163 auto self = (THPGenerator*)_self;
164 auto& gen = self->cdata;
165
166 // See Note [Acquire lock when using random generators]
167 std::scoped_lock<std::mutex> lock(gen.mutex());
168 gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
169
170 Py_INCREF(self);
171 return (PyObject*)self;
172 END_HANDLE_TH_ERRORS
173 }
174
THPGenerator_cloneState(PyObject * _self,PyObject * noargs)175 static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
176 HANDLE_TH_ERRORS
177 auto& gen = ((THPGenerator*)_self)->cdata;
178
179 // See Note [Acquire lock when using random generators]
180 std::scoped_lock<std::mutex> lock(gen.mutex());
181 auto new_generator = gen.clone();
182
183 return THPGenerator_Wrap(new_generator);
184 END_HANDLE_TH_ERRORS
185 }
186
THPGenerator_manualSeed(PyObject * _self,PyObject * seed)187 static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
188 HANDLE_TH_ERRORS
189 auto self = (THPGenerator*)_self;
190 auto generator = self->cdata;
191 TORCH_CHECK(
192 THPUtils_checkLong(seed),
193 "manual_seed expected a long, "
194 "but got ",
195 THPUtils_typename(seed));
196 uint64_t unsigned_seed = unpack_uint64(seed);
197 // See Note [Acquire lock when using random generators]
198 std::scoped_lock<std::mutex> lock(generator.mutex());
199 generator.set_current_seed(unsigned_seed);
200 Py_INCREF(self);
201 return (PyObject*)self;
202 END_HANDLE_TH_ERRORS
203 }
204
THPGenerator_setOffset(PyObject * _self,PyObject * offset)205 static PyObject* THPGenerator_setOffset(PyObject* _self, PyObject* offset) {
206 HANDLE_TH_ERRORS
207 auto self = (THPGenerator*)_self;
208 auto generator = self->cdata;
209 TORCH_CHECK(
210 THPUtils_checkLong(offset),
211 "manual_offset expected a long, "
212 "but got ",
213 THPUtils_typename(offset));
214 uint64_t unsigned_offset = unpack_uint64(offset);
215 // See Note [Acquire lock when using random generators]
216 std::scoped_lock<std::mutex> lock(generator.mutex());
217 generator.set_offset(unsigned_offset);
218 Py_INCREF(self);
219 return (PyObject*)self;
220 END_HANDLE_TH_ERRORS
221 }
222
THPGenerator_seed(PyObject * _self,PyObject * noargs)223 static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) {
224 HANDLE_TH_ERRORS
225 // See Note [Acquire lock when using random generators]
226 auto self = (THPGenerator*)_self;
227 std::scoped_lock<std::mutex> lock(self->cdata.mutex());
228 uint64_t seed_val = self->cdata.seed();
229 return THPUtils_packUInt64(seed_val);
230 END_HANDLE_TH_ERRORS
231 }
232
THPGenerator_initialSeed(PyObject * _self,PyObject * noargs)233 static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) {
234 HANDLE_TH_ERRORS
235 auto self = (THPGenerator*)_self;
236 return THPUtils_packUInt64(self->cdata.current_seed());
237 END_HANDLE_TH_ERRORS
238 }
239
THPGenerator_getOffset(PyObject * _self,PyObject * noargs)240 static PyObject* THPGenerator_getOffset(PyObject* _self, PyObject* noargs) {
241 HANDLE_TH_ERRORS
242 auto self = (THPGenerator*)_self;
243 return THPUtils_packUInt64(self->cdata.get_offset());
244 END_HANDLE_TH_ERRORS
245 }
246
THPGenerator_get_device(THPGenerator * self,void * unused)247 static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) {
248 HANDLE_TH_ERRORS
249 return THPDevice_New(self->cdata.device());
250 END_HANDLE_TH_ERRORS
251 }
252
THPGenerator_reduce(PyObject * _self,PyObject * noargs)253 PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) {
254 HANDLE_TH_ERRORS
255 auto self = (THPGenerator*)_self;
256 auto& gen = self->cdata;
257
258 auto ret = THPObjectPtr{PyTuple_New(3)};
259 if (!ret)
260 throw python_error();
261
262 py::object torch_module = py::module::import("torch");
263 py::object torch_generator = torch_module.attr("Generator");
264 PyTuple_SET_ITEM(ret.get(), 0, torch_generator.release().ptr());
265
266 auto args = THPObjectPtr{PyTuple_New(1)};
267 if (!args)
268 throw python_error();
269
270 PyTuple_SET_ITEM(args.get(), 0, THPGenerator_get_device(self, nullptr));
271 PyTuple_SET_ITEM(ret.get(), 1, args.release());
272
273 auto state = THPObjectPtr{PyTuple_New(3)};
274 if (!state)
275 throw python_error();
276
277 c10::DeviceType device_type = gen.device().type();
278 PyTuple_SET_ITEM(state.get(), 0, THPGenerator_initialSeed(_self, nullptr));
279 PyTuple_SET_ITEM(
280 state.get(),
281 1,
282 device_type != at::kCPU ? THPGenerator_getOffset(_self, nullptr)
283 : Py_None);
284 PyTuple_SET_ITEM(state.get(), 2, THPGenerator_getState(_self, nullptr));
285 PyTuple_SET_ITEM(ret.get(), 2, state.release());
286
287 return ret.release();
288 END_HANDLE_TH_ERRORS
289 }
290
THPGenerator_pickleSetState(PyObject * _self,PyObject * state)291 static PyObject* THPGenerator_pickleSetState(PyObject* _self, PyObject* state) {
292 HANDLE_TH_ERRORS
293 THPGenerator_manualSeed(_self, PyTuple_GET_ITEM(state, 0));
294 auto& offset = PyTuple_GET_ITEM(state, 1);
295 if (offset != Py_None) {
296 THPGenerator_setOffset(_self, offset);
297 }
298 THPGenerator_setState(_self, PyTuple_GET_ITEM(state, 2));
299 Py_RETURN_NONE;
300 END_HANDLE_TH_ERRORS
301 }
302
303 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
304 static struct PyGetSetDef THPGenerator_properties[] = {
305 {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr},
306 {nullptr}};
307
308 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
309 static PyMethodDef THPGenerator_methods[] = {
310 {"__reduce__", THPGenerator_reduce, METH_NOARGS, nullptr},
311 {"__setstate__", THPGenerator_pickleSetState, METH_O, nullptr},
312 {"get_state", THPGenerator_getState, METH_NOARGS, nullptr},
313 {"set_state", THPGenerator_setState, METH_O, nullptr},
314 {"clone_state", THPGenerator_cloneState, METH_NOARGS, nullptr},
315 {"graphsafe_get_state",
316 THPGenerator_graphSafeGetState,
317 METH_NOARGS,
318 nullptr},
319 {"graphsafe_set_state", THPGenerator_graphSafeSetState, METH_O, nullptr},
320 {"set_offset", THPGenerator_setOffset, METH_O, nullptr},
321 {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr},
322 {"seed", THPGenerator_seed, METH_NOARGS, nullptr},
323 {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr},
324 {"get_offset", THPGenerator_getOffset, METH_NOARGS, nullptr},
325 {nullptr}};
326
327 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
328 static struct PyMemberDef THPGenerator_members[] = {
329 {"_cdata", T_ULONGLONG, offsetof(THPGenerator, cdata), READONLY, nullptr},
330 {nullptr}};
331
332 PyTypeObject THPGeneratorType = {
333 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.Generator", /* tp_name */
334 sizeof(THPGenerator), /* tp_basicsize */
335 0, /* tp_itemsize */
336 THPGenerator_dealloc, /* tp_dealloc */
337 0, /* tp_vectorcall_offset */
338 nullptr, /* tp_getattr */
339 nullptr, /* tp_setattr */
340 nullptr, /* tp_reserved */
341 nullptr, /* tp_repr */
342 nullptr, /* tp_as_number */
343 nullptr, /* tp_as_sequence */
344 nullptr, /* tp_as_mapping */
345 nullptr, /* tp_hash */
346 nullptr, /* tp_call */
347 nullptr, /* tp_str */
348 nullptr, /* tp_getattro */
349 nullptr, /* tp_setattro */
350 nullptr, /* tp_as_buffer */
351 // NOLINTNEXTLINE(misc-redundant-expression)
352 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
353 nullptr, /* tp_doc */
354 nullptr, /* tp_traverse */
355 nullptr, /* tp_clear */
356 nullptr, /* tp_richcompare */
357 0, /* tp_weaklistoffset */
358 nullptr, /* tp_iter */
359 nullptr, /* tp_iternext */
360 THPGenerator_methods, /* tp_methods */
361 THPGenerator_members, /* tp_members */
362 THPGenerator_properties, /* tp_getset */
363 nullptr, /* tp_base */
364 nullptr, /* tp_dict */
365 nullptr, /* tp_descr_get */
366 nullptr, /* tp_descr_set */
367 0, /* tp_dictoffset */
368 nullptr, /* tp_init */
369 nullptr, /* tp_alloc */
370 THPGenerator_pynew, /* tp_new */
371 };
372
THPGenerator_init(PyObject * module)373 bool THPGenerator_init(PyObject* module) {
374 THPGeneratorClass = (PyObject*)&THPGeneratorType;
375 if (PyType_Ready(&THPGeneratorType) < 0)
376 return false;
377 Py_INCREF(&THPGeneratorType);
378 PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType);
379 return true;
380 }
381
set_pyobj(const Generator & self,PyObject * pyobj)382 void set_pyobj(const Generator& self, PyObject* pyobj) {
383 TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined generator");
384 self.set_pyobj(pyobj);
385 }
386
pyobj(const Generator & self)387 PyObject* pyobj(const Generator& self) {
388 TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined generator");
389 return self.pyobj();
390 }
391
THPGenerator_Wrap(Generator gen)392 PyObject* THPGenerator_Wrap(Generator gen) {
393 if (!gen.defined()) {
394 Py_RETURN_NONE;
395 }
396
397 if (auto obj = pyobj(gen)) {
398 Py_INCREF(obj);
399 return obj;
400 }
401
402 return THPGenerator_NewWithVar(
403 (PyTypeObject*)THPGeneratorClass, std::move(gen));
404 }
405
THPGenerator_Unwrap(PyObject * state)406 at::Generator THPGenerator_Unwrap(PyObject* state) {
407 if (!Py_IS_TYPE(state, &THPGeneratorType)) {
408 throw torch::TypeError(
409 "expected a Generator, but got %s", Py_TYPE(state)->tp_name);
410 }
411 return reinterpret_cast<THPGenerator*>(state)->cdata;
412 }
413
414 // Creates a new Python object for a Generator. The Generator must not already
415 // have a PyObject* associated with it.
THPGenerator_NewWithVar(PyTypeObject * type,Generator gen)416 PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
417 PyObject* obj = type->tp_alloc(type, 0);
418 if (obj) {
419 auto g = (THPGenerator*)obj;
420 new (&g->cdata) Generator(std::move(gen));
421 set_pyobj(g->cdata, obj);
422 }
423 return obj;
424 }
425