1 #include <c10/core/RefcountedDeleter.h>
2
3 #include <mutex>
4
5 namespace c10 {
6
refcounted_deleter(void * ctx_)7 void refcounted_deleter(void* ctx_) {
8 RefcountedDeleterContext& ctx =
9 *reinterpret_cast<RefcountedDeleterContext*>(ctx_);
10 ctx.refcount--;
11 if (ctx.refcount == 0) {
12 ctx.other_ctx = nullptr;
13 delete &ctx;
14 }
15 }
16
17 std::mutex replace_data_ptr_mutex;
18
maybeApplyRefcountedDeleter(const c10::Storage & storage)19 void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
20 std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
21 c10::DataPtr& data_ptr = storage.mutable_data_ptr();
22
23 if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
24 // Data pointer is already shared
25 return;
26 }
27
28 void* data = data_ptr.get();
29 void* other_ctx = data_ptr.get_context();
30 c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
31 c10::Device device = data_ptr.device();
32
33 // Release the context of the original DataPtr so that the data doesn't
34 // get deleted when the original DataPtr is replaced
35 data_ptr.release_context();
36
37 c10::RefcountedDeleterContext* refcount_ctx =
38 new c10::RefcountedDeleterContext(other_ctx, other_deleter);
39
40 c10::DataPtr new_data_ptr(
41 data,
42 reinterpret_cast<void*>(refcount_ctx),
43 &c10::refcounted_deleter,
44 device);
45 storage.set_data_ptr(std::move(new_data_ptr));
46 }
47
newStorageImplFromRefcountedDataPtr(const c10::Storage & storage)48 c10::Storage newStorageImplFromRefcountedDataPtr(const c10::Storage& storage) {
49 c10::maybeApplyRefcountedDeleter(storage);
50
51 c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
52
53 c10::DataPtr& data_ptr = storage.mutable_data_ptr();
54 c10::DataPtr new_data_ptr(
55 data_ptr.get(),
56 data_ptr.get_context(),
57 data_ptr.get_deleter(),
58 data_ptr.device());
59
60 // NOTE: This refcount increment should always happen immediately after
61 // `new_data_ptr` is created. No other lines of code should be added between
62 // them in the future, unless there's a very good reason for it, because if
63 // any errors are raised and `new_data_ptr` is deleted before the refcount is
64 // incremented, the refcount will get decremented and end up being one less
65 // than it should be.
66 reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
67 ->refcount++;
68
69 c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
70 c10::StorageImpl::use_byte_size_t(),
71 storage_impl->nbytes(),
72 std::move(new_data_ptr),
73 storage_impl->allocator(),
74 /*resizable=*/storage_impl->resizable());
75 return new_storage;
76 }
77
78 } // namespace c10
79