xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Generator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Generator.h>
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/Exception.h>
4 
5 namespace at {
6 
set_state(const at::Tensor & new_state)7 void Generator::set_state(const at::Tensor& new_state) {
8   TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed");
9   this->impl_->set_state(*new_state.unsafeGetTensorImpl());
10 }
11 
get_state() const12 at::Tensor Generator::get_state() const {
13   return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
14 }
15 
graphsafe_set_state(const Generator & new_state)16 void Generator::graphsafe_set_state(const Generator& new_state) {
17   this->impl_->graphsafe_set_state(new_state.getIntrusivePtr());
18 }
19 
graphsafe_get_state() const20 Generator Generator::graphsafe_get_state() const {
21   return Generator(this->impl_->graphsafe_get_state());
22 }
23 
24 } // namespace at
25