xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/UnsafeFromTH.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at {
5 
unsafeTensorFromTH(void * th_pointer,bool retain)6 inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
7   auto tensor_impl = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(static_cast<TensorImpl*>(th_pointer));
8   if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
9     c10::raw::intrusive_ptr::incref(tensor_impl.get());
10   }
11   return Tensor(std::move(tensor_impl));
12 }
13 
unsafeStorageFromTH(void * th_pointer,bool retain)14 inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
15   if (retain && th_pointer) {
16     c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
17   }
18   return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
19 }
20 
21 }
22