xref: /aosp_15_r20/external/pytorch/c10/core/StorageImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/core/Device.h>
5 #include <c10/core/DeviceType.h>
6 #include <c10/core/SymInt.h>
7 #include <c10/core/impl/COW.h>
8 #include <c10/core/impl/COWDeleter.h>
9 #include <c10/core/impl/PyObjectSlot.h>
10 #include <c10/macros/Export.h>
11 #include <c10/util/Exception.h>
12 #include <c10/util/UniqueVoidPtr.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <cstddef>
15 #include <utility>
16 
17 namespace c10 {
18 
19 C10_API void throwNullDataPtrError();
20 C10_API void warnDeprecatedDataPtr();
21 
22 // A storage represents the underlying backing data buffer for a
23 // tensor.  This concept was inherited from the original Torch7
24 // codebase; we'd kind of like to get rid of the concept
25 // (see https://github.com/pytorch/pytorch/issues/14797) but
26 // it's hard work and no one has gotten around to doing it.
27 //
28 // NB: storage is supposed to uniquely own a data pointer; e.g.,
29 // two non-null data pointers alias if and only if they are from
30 // the same storage.  Technically you can violate this invariant
31 // (e.g., you can create a non-owning StorageImpl with at::from_blob)
32 // but a lot of things won't work correctly, including:
33 //
34 // - An ordinary deleter on such a storage is wrong, because normal deleters
35 //   assume unique ownership, but if you have two storages at the same data,
36 //   that implies there is some sort of shared ownership. So your deleter would
37 //   have to actually be internally doing some sort of refcount thing
38 // - Deepcopy in Python side relies on storage equality and not data pointer
39 //   equality; so if there are two separate storages pointing to the same data,
40 //   the data will actually get duplicated in that case (one data ptr before,
41 //   two data ptrs after)
42 // - Version counts won't work correctly, because we do all VC tracking at the
43 //   level of storages (unless you explicitly disconnect the VC with detach);
44 //   mutation because data pointers are the same are totally untracked
45 struct C10_API StorageImpl : public c10::intrusive_ptr_target {
46  public:
47   struct use_byte_size_t {};
48 
StorageImplStorageImpl49   StorageImpl(
50       use_byte_size_t /*use_byte_size*/,
51       SymInt size_bytes,
52       at::DataPtr data_ptr,
53       at::Allocator* allocator,
54       bool resizable)
55       : data_ptr_(std::move(data_ptr)),
56         size_bytes_(std::move(size_bytes)),
57         size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
58         resizable_(resizable),
59         received_cuda_(false),
60         allocator_(allocator) {
61     if (resizable) {
62       TORCH_INTERNAL_ASSERT(
63           allocator_, "For resizable storage, allocator must be provided");
64     }
65     refresh_has_data_ptr_check();
66   }
67 
StorageImplStorageImpl68   StorageImpl(
69       use_byte_size_t /*use_byte_size*/,
70       const SymInt& size_bytes,
71       at::Allocator* allocator,
72       bool resizable)
73       : StorageImpl(
74             use_byte_size_t(),
75             size_bytes,
76             size_bytes.is_heap_allocated()
77                 ? allocator->allocate(0)
78                 : allocator->allocate(size_bytes.as_int_unchecked()),
79             allocator,
80             resizable) {}
81 
82   StorageImpl& operator=(StorageImpl&& other) = delete;
83   StorageImpl& operator=(const StorageImpl&) = delete;
84   StorageImpl() = delete;
85   StorageImpl(StorageImpl&& other) = delete;
86   StorageImpl(const StorageImpl&) = delete;
87   ~StorageImpl() override = default;
88 
resetStorageImpl89   void reset() {
90     data_ptr_.clear();
91     size_bytes_ = 0;
92     size_bytes_is_heap_allocated_ = false;
93   }
94 
95   // Destructor doesn't call release_resources because it's
96   // unnecessary; don't forget to change that if needed!
release_resourcesStorageImpl97   void release_resources() override {
98     data_ptr_.clear();
99   }
100 
nbytesStorageImpl101   size_t nbytes() const {
102     // OK to do this instead of maybe_as_int as nbytes is guaranteed positive
103     TORCH_CHECK(!size_bytes_is_heap_allocated_);
104     return size_bytes_.as_int_unchecked();
105   }
106 
sym_nbytesStorageImpl107   SymInt sym_nbytes() const {
108     return size_bytes_;
109   }
110 
111   // TODO: remove later
set_nbytesStorageImpl112   void set_nbytes(size_t size_bytes) {
113     size_bytes_ = static_cast<int64_t>(size_bytes);
114     size_bytes_is_heap_allocated_ = false;
115   }
116 
set_nbytesStorageImpl117   void set_nbytes(c10::SymInt size_bytes) {
118     size_bytes_ = std::move(size_bytes);
119   }
120 
resizableStorageImpl121   bool resizable() const {
122     return resizable_;
123   }
124 
data_ptrStorageImpl125   const at::DataPtr& data_ptr() const {
126     return data_ptr_;
127   }
128 
mutable_data_ptrStorageImpl129   at::DataPtr& mutable_data_ptr() {
130     if (C10_UNLIKELY(has_data_ptr_check_)) {
131       if (throw_on_mutable_data_ptr_) {
132         throwNullDataPtrError();
133       }
134       if (warn_deprecated_on_mutable_data_ptr_) {
135         warnDeprecatedDataPtr();
136       }
137       maybe_materialize_cow();
138     }
139     return data_ptr_;
140   }
141 
142   // Returns the data_ptr. Bypasses all checks.
_mutable_data_ptr_no_checksStorageImpl143   at::DataPtr& _mutable_data_ptr_no_checks() {
144     return data_ptr_;
145   }
146 
147   // Returns the previous data_ptr
set_data_ptrStorageImpl148   at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
149     // We need to materialize the old COW DataPtr because it is
150     // being returned as mutable.
151     maybe_materialize_cow();
152     return set_data_ptr_no_materialize_cow(std::move(data_ptr));
153   }
154 
set_data_ptr_noswapStorageImpl155   void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
156     data_ptr_ = std::move(data_ptr);
157     refresh_has_data_ptr_check();
158   }
159 
dataStorageImpl160   const void* data() const {
161     return data_ptr_.get();
162   }
163 
mutable_dataStorageImpl164   void* mutable_data() {
165     if (C10_UNLIKELY(has_data_ptr_check_)) {
166       if (throw_on_mutable_data_ptr_) {
167         throwNullDataPtrError();
168       }
169       if (warn_deprecated_on_mutable_data_ptr_) {
170         warnDeprecatedDataPtr();
171       }
172       maybe_materialize_cow();
173     }
174     return data_ptr_.mutable_get();
175   }
176 
device_typeStorageImpl177   at::DeviceType device_type() const {
178     return data_ptr_.device().type();
179   }
180 
allocatorStorageImpl181   at::Allocator* allocator() {
182     return allocator_;
183   }
184 
allocatorStorageImpl185   const at::Allocator* allocator() const {
186     return allocator_;
187   }
188 
189   // You generally shouldn't use this method, but it is occasionally
190   // useful if you want to override how a tensor will be reallocated,
191   // after it was already allocated (and its initial allocator was
192   // set)
set_allocatorStorageImpl193   void set_allocator(at::Allocator* allocator) {
194     allocator_ = allocator;
195   }
196 
deviceStorageImpl197   Device device() const {
198     return data_ptr_.device();
199   }
200 
set_resizableStorageImpl201   void set_resizable(bool resizable) {
202     if (resizable) {
203       // We need an allocator to be resizable
204       AT_ASSERT(allocator_);
205     }
206     resizable_ = resizable;
207   }
208 
209   /**
210    * Can only be called when use_count is 1
211    */
212   void UniqueStorageShareExternalPointer(
213       void* src,
214       size_t size_bytes,
215       DeleterFnPtr d = nullptr) {
216     UniqueStorageShareExternalPointer(
217         at::DataPtr(src, src, d, data_ptr_.device()), size_bytes);
218   }
219 
220   /**
221    * Can only be called when use_count is 1
222    */
UniqueStorageShareExternalPointerStorageImpl223   void UniqueStorageShareExternalPointer(
224       at::DataPtr&& data_ptr,
225       size_t size_bytes) {
226     data_ptr_ = std::move(data_ptr);
227     size_bytes_ = static_cast<int64_t>(size_bytes);
228     size_bytes_is_heap_allocated_ = false;
229     allocator_ = nullptr;
230     resizable_ = false;
231   }
232 
233   // This method can be used only after storage construction and cannot be used
234   // to modify storage status
set_received_cudaStorageImpl235   void set_received_cuda(bool received_cuda) {
236     received_cuda_ = received_cuda;
237   }
238 
received_cudaStorageImpl239   bool received_cuda() {
240     return received_cuda_;
241   }
242 
pyobj_slotStorageImpl243   impl::PyObjectSlot* pyobj_slot() {
244     return &pyobj_slot_;
245   }
246 
pyobj_slotStorageImpl247   const impl::PyObjectSlot* pyobj_slot() const {
248     return &pyobj_slot_;
249   }
250 
set_throw_on_mutable_data_ptrStorageImpl251   void set_throw_on_mutable_data_ptr() {
252     throw_on_mutable_data_ptr_ = true;
253     refresh_has_data_ptr_check();
254   }
255 
set_warn_deprecated_on_mutable_data_ptrStorageImpl256   void set_warn_deprecated_on_mutable_data_ptr() {
257     warn_deprecated_on_mutable_data_ptr_ = true;
258     refresh_has_data_ptr_check();
259   }
260 
261  protected:
262   // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
263   friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
264 
265   // Returns the previous data_ptr. If the old data_ptr was COW,
266   // this avoids materializing it
set_data_ptr_no_materialize_cowStorageImpl267   at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
268     at::DataPtr old_data_ptr(std::move(data_ptr_));
269     data_ptr_ = std::move(data_ptr);
270     refresh_has_data_ptr_check();
271     return old_data_ptr;
272   }
273 
274  private:
refresh_has_data_ptr_checkStorageImpl275   void refresh_has_data_ptr_check() {
276     has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
277         warn_deprecated_on_mutable_data_ptr_;
278   }
279 
is_cowStorageImpl280   inline bool is_cow() const {
281     return c10::impl::cow::is_cow_data_ptr(data_ptr_);
282   }
283 
284   // Triggers a copy if this is a copy-on-write tensor.
maybe_materialize_cowStorageImpl285   void maybe_materialize_cow() {
286     if (is_cow()) {
287       impl::cow::materialize_cow_storage(*this);
288     }
289   }
290 
291   DataPtr data_ptr_;
292   SymInt size_bytes_;
293   bool size_bytes_is_heap_allocated_;
294   bool resizable_;
295   // Identifies that Storage was received from another process and doesn't have
296   // local to process cuda memory allocation
297   bool received_cuda_;
298   // All special checks in data/data_ptr calls are guarded behind this single
299   // boolean. This is for performance: .data/.data_ptr calls are commonly in the
300   // hot-path.
301   bool has_data_ptr_check_ = false;
302   // If we should throw when mutable_data_ptr() or mutable_data() is called.
303   bool throw_on_mutable_data_ptr_ = false;
304   // If we warn when mutable_data_ptr() or mutable_data() is called.
305   bool warn_deprecated_on_mutable_data_ptr_ = false;
306   Allocator* allocator_;
307   impl::PyObjectSlot pyobj_slot_;
308 };
309 
310 // Declare StorageImpl create function pointer types.
311 using StorageImplCreateHelper = intrusive_ptr<StorageImpl> (*)(
312     StorageImpl::use_byte_size_t,
313     SymInt size_bytes,
314     DataPtr data_ptr,
315     Allocator* allocator,
316     bool resizable);
317 
318 C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr);
319 
320 C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t);
321 
322 C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
323     c10::StorageImpl::use_byte_size_t use_byte_size,
324     c10::SymInt size_bytes,
325     c10::DataPtr data_ptr,
326     c10::Allocator* allocator,
327     bool resizable,
328     std::optional<at::Device> device_opt);
329 
330 } // namespace c10
331