xref: /aosp_15_r20/external/pytorch/torch/csrc/CudaIPCTypes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/MapAllocator.h>
2 #include <c10/cuda/CUDAGuard.h>
3 #include <torch/csrc/CudaIPCTypes.h>
4 #include <atomic>
5 #include <map>
6 #include <mutex>
7 #include <string>
8 
9 namespace torch {
10 
11 namespace {
12 
warnProducerTerminatedBeforeSharedTensorsReleased()13 void warnProducerTerminatedBeforeSharedTensorsReleased() {
14   static bool warned = false;
15   if (!warned) {
16     LOG(WARNING)
17         << "Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]";
18     warned = true;
19   }
20 }
21 
22 struct CudaIPCGlobalEntities {
23   // This class is used as a singleton (see cuda_ipc_global_entities)
24   // This variable is used to track its lifetime to avoid accessing it
25   // after it was destroyed which would lead to segmentation faults
26   // Note that a trvial type is used which doesn't suffer from construction
27   // and destruction order issues
28   static bool alive;
29 
30   std::mutex ref_counters_mutex_;
31   std::atomic<int64_t> sync_events_used_{0};
32   std::map<std::string, std::shared_ptr<CudaIPCRefCountersFile>>
33       ref_counters_files_;
34   std::shared_ptr<CudaIPCRefCountersFile> next_available_ref_counters_file_;
35   CudaIPCSentDataLimbo CudaIPCSentDataLimbo_;
CudaIPCGlobalEntitiestorch::__anoned1c5ce70111::CudaIPCGlobalEntities36   CudaIPCGlobalEntities() {
37     alive = true;
38   }
~CudaIPCGlobalEntitiestorch::__anoned1c5ce70111::CudaIPCGlobalEntities39   ~CudaIPCGlobalEntities() {
40     CudaIPCSentDataLimbo_.collect();
41     safe_clean_current_file();
42     if (next_available_ref_counters_file_) {
43       warnProducerTerminatedBeforeSharedTensorsReleased();
44     }
45     alive = false;
46   }
safe_clean_current_filetorch::__anoned1c5ce70111::CudaIPCGlobalEntities47   void safe_clean_current_file() {
48     std::lock_guard<std::mutex> lock(ref_counters_mutex_);
49     if (next_available_ref_counters_file_ &&
50         next_available_ref_counters_file_->offsets_in_use() == 0) {
51       ref_counters_files_.erase(next_available_ref_counters_file_->handle());
52       next_available_ref_counters_file_.reset();
53     }
54   }
55 };
56 
57 bool CudaIPCGlobalEntities::alive = false;
58 CudaIPCGlobalEntities cuda_ipc_global_entities;
59 
~CudaIPCSentDataLimbo()60 CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() {
61   collect();
62   if (size() > 0) {
63     warnProducerTerminatedBeforeSharedTensorsReleased();
64   }
65 }
66 
collect()67 bool CudaIPCSentDataLimbo::collect() {
68   bool freed_memory = false;
69   std::vector<std::unique_ptr<CudaIPCSentData>> reset_blocks;
70   { // Begin critical section to modify shared blocks
71     std::lock_guard<std::mutex> lock(limbo_mutex_);
72     std::vector<std::unique_ptr<CudaIPCSentData>> kept_blocks;
73     for (auto& sd : shared_blocks_) {
74       if (sd->counter_value() > 0) {
75         kept_blocks.push_back(std::move(sd));
76       } else {
77         freed_memory = true;
78         reset_blocks.push_back(std::move(sd));
79       }
80     }
81     shared_blocks_ = std::move(kept_blocks);
82   }
83   // Need to reset blocks out of the critical section here, otherwise it
84   // deadlocks.
85   for (auto& sd : reset_blocks) {
86     sd.reset();
87   }
88   return freed_memory;
89 }
90 
add(std::unique_ptr<CudaIPCSentData> shared_block)91 void CudaIPCSentDataLimbo::add(std::unique_ptr<CudaIPCSentData> shared_block) {
92   std::lock_guard<std::mutex> lock(limbo_mutex_);
93   static bool warned = false;
94   if (shared_blocks_.size() > CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO &&
95       !warned) {
96     LOG(WARNING)
97         << "Producer process tried to deallocate over "
98         << CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO
99         << " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. "
100         << "We assume it will never going to be the case, but if it is, please file but to https://github.com/pytorch/pytorch";
101     warned = true;
102   }
103   shared_blocks_.push_back(std::move(shared_block));
104 }
105 
size()106 uint64_t CudaIPCSentDataLimbo::size() {
107   std::lock_guard<std::mutex> lock(limbo_mutex_);
108   return shared_blocks_.size();
109 }
110 
CudaIPCSentDataDelete(void * ptr)111 void CudaIPCSentDataDelete(void* ptr) {
112   std::unique_ptr<CudaIPCSentData> sent_data(
113       static_cast<CudaIPCSentData*>(ptr));
114   if (!CudaIPCGlobalEntities::alive) {
115     return;
116   }
117   if (sent_data->counter_value() > 0) {
118     cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data));
119   }
120   cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
121 }
122 
ReturnRefCounter(const std::string & handle,uint64_t offset)123 void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) {
124   if (!CudaIPCGlobalEntities::alive) {
125     return;
126   }
127   std::lock_guard<std::mutex> lock(
128       cuda_ipc_global_entities.ref_counters_mutex_);
129   auto& map = cuda_ipc_global_entities.ref_counters_files_;
130   auto it = map.find(handle);
131   if (it != map.end()) {
132     it->second->return_offset(offset);
133     if (it->second->offsets_in_use() == 0 && !it->second->have_offsets()) {
134       map.erase(handle);
135     }
136   }
137 }
138 
139 } // namespace
140 
CudaIPCSentData(std::string handle,uint64_t offset,uint64_t * counter_ptr,at::Device device)141 CudaIPCSentData::CudaIPCSentData(
142     std::string handle,
143     uint64_t offset,
144     uint64_t* counter_ptr,
145     at::Device device)
146     : handle_(std::move(handle)),
147       offset_(offset),
148       counter_ptr_(counter_ptr),
149       original_ptr_(),
150       device_(device) {
151 #if !defined(USE_ROCM)
152   // CUDA have the unofficial limit on the number of recorded blocking
153   // interprocess events, to prevent using of all events, we are switching to
154   // StreamSync before limit reached.
155   //
156   //  ```python
157   //  import torch
158   //  a = [ torch.cuda.Event(
159   //      enable_timing=False, blocking=True, interprocess=True) for i in
160   //      range(30000) ]
161   //  [i.record() for i in a]
162   //  ```
163   //
164   if (cuda_ipc_global_entities.sync_events_used_.load() <
165       CUDA_IPC_MAXIMUM_EVENTS_TO_USE) {
166     // TODO: More efficient would be to create event inside of main thread (at
167     // the moment of the queue.put). The reason this is more efficient is
168     // because the main thread may have queued extra work on the stream, which
169     // this event will consequently wait for (uselessly).
170     cuda_ipc_global_entities.sync_events_used_++;
171     C10_CUDA_CHECK(cudaEventCreateWithFlags(
172         &event_,
173         cudaEventDisableTiming | cudaEventInterprocess |
174             cudaEventBlockingSync));
175     C10_CUDA_CHECK(cudaEventRecord(
176         event_, c10::cuda::getCurrentCUDAStream(device.index())));
177     event_sync_required_ = true;
178   } else {
179     auto stream = c10::cuda::getCurrentCUDAStream(device.index());
180     at::cuda::stream_synchronize(stream);
181     event_ = nullptr;
182     event_sync_required_ = false;
183   }
184 #else
185   // cuIpcGetEventHandle with HIP is not supported, so we have to sync
186   // stream instead of passing event
187   auto stream = c10::cuda::getCurrentCUDAStream(device.index());
188   at::cuda::stream_synchronize(stream);
189   event_sync_required_ = false;
190 #endif
191 }
192 
~CudaIPCSentData()193 CudaIPCSentData::~CudaIPCSentData() {
194   ReturnRefCounter(handle_, offset_);
195 #if !defined(USE_ROCM)
196   try {
197     if (event_sync_required_) {
198       at::cuda::CUDAGuard device_guard(device_.index());
199       C10_CUDA_CHECK(cudaEventDestroy(event_));
200       if (!CudaIPCGlobalEntities::alive) {
201         return;
202       }
203       cuda_ipc_global_entities.sync_events_used_--;
204     }
205   } catch (...) { /* No throw */
206   }
207 #endif
208 }
209 
counter_value()210 uint64_t CudaIPCSentData::counter_value() {
211   return *counter_ptr_;
212 }
213 
GetNewRefCountedSentData(void * data,at::Device device)214 at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) {
215   {
216     std::lock_guard<std::mutex> lock(
217         cuda_ipc_global_entities.ref_counters_mutex_);
218     if (!cuda_ipc_global_entities.next_available_ref_counters_file_) {
219       std::string ref_counter_handle = at::NewProcessWideShmHandle();
220 
221       int flags =
222           at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
223       at::DataPtr sptr = at::RefcountedMapAllocator::makeDataPtr(
224           ref_counter_handle.c_str(),
225           flags,
226           sizeof(int64_t) * CUDA_IPC_REF_COUNTER_FILE_SIZE,
227           nullptr);
228       auto rc = std::make_shared<CudaIPCRefCountersFile>(
229           ref_counter_handle, CUDA_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr));
230       cuda_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc;
231       cuda_ipc_global_entities.next_available_ref_counters_file_ = rc;
232     }
233   }
234   cuda_ipc_global_entities.next_available_ref_counters_file_->set_counter(1);
235   auto sent_data = new CudaIPCSentData(
236       cuda_ipc_global_entities.next_available_ref_counters_file_->handle(),
237       cuda_ipc_global_entities.next_available_ref_counters_file_->get_offset(),
238       cuda_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(),
239       device);
240 
241   cuda_ipc_global_entities.next_available_ref_counters_file_->rotate_offset();
242   if (!cuda_ipc_global_entities.next_available_ref_counters_file_
243            ->have_offsets()) {
244     cuda_ipc_global_entities.next_available_ref_counters_file_.reset();
245   }
246   return at::DataPtr(data, sent_data, CudaIPCSentDataDelete, device);
247 }
248 
CudaIPCCollect()249 bool CudaIPCCollect() {
250   if (!CudaIPCGlobalEntities::alive) {
251     return true;
252   }
253   bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
254   if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) {
255     cuda_ipc_global_entities.safe_clean_current_file();
256   }
257   return freed_memory;
258 }
259 
260 } // namespace torch
261 
262 namespace c10 {
263 namespace {
264 REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
265 }
266 } // namespace c10
267