xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/memory_snapshot.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <ATen/record_function.h>
3 #include <c10/cuda/CUDACachingAllocator.h>
4 #include <torch/csrc/cuda/memory_snapshot.h>
5 #include <torch/csrc/jit/runtime/interpreter.h>
6 #include <torch/csrc/jit/serialization/pickler.h>
7 #include <torch/csrc/profiler/combined_traceback.h>
8 
9 namespace torch::cuda {
10 
11 using c10::Dict;
12 using c10::IValue;
13 using torch::jit::Pickler;
14 
15 using c10::cuda::CUDACachingAllocator::SegmentInfo;
16 
17 namespace {
write_pickle(const IValue & v)18 std::string write_pickle(const IValue& v) {
19   std::vector<char> result;
20   {
21     auto writer = [&](const char* data, size_t size) {
22       result.insert(result.end(), data, data + size);
23     };
24     Pickler pickler(writer, nullptr, nullptr, nullptr, nullptr, false);
25     pickler.protocol();
26     pickler.pushIValue(v);
27     pickler.stop();
28   }
29   return std::string(result.begin(), result.end());
30 }
new_dict()31 Dict<IValue, IValue> new_dict() {
32   return Dict<IValue, IValue>(c10::AnyType::get(), c10::AnyType::get());
33 }
new_list()34 c10::List<IValue> new_list() {
35   return List<IValue>(c10::AnyType::get());
36 }
37 
ivalue_symbolize(std::vector<CapturedTraceback * > & to_symbolize)38 std::vector<IValue> ivalue_symbolize(
39     std::vector<CapturedTraceback*>& to_symbolize) {
40   // we dedup repeated to_symbolize objects to prevent
41   // creating a bunch of duplicated frame objects
42   std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
43   std::vector<CapturedTraceback*> unique_frames;
44   for (const auto& sc : to_symbolize) {
45     auto it = cached_frames.find(sc);
46     if (it == cached_frames.end()) {
47       cached_frames.insert({sc, unique_frames.size()});
48       unique_frames.push_back(sc);
49     }
50   }
51   auto s = symbolize(unique_frames);
52 
53   IValue line_s = "line";
54   IValue name_s = "name";
55   IValue filename_s = "filename";
56   std::vector<IValue> all_frames;
57   for (const auto& f : s.all_frames) {
58     auto d = new_dict();
59     d.insert(name_s, f.funcname);
60     d.insert(filename_s, f.filename);
61     d.insert(line_s, int64_t(f.lineno));
62     all_frames.emplace_back(std::move(d));
63   }
64 
65   std::vector<IValue> py_unique_frames;
66   for (const auto& t : s.tracebacks) {
67     auto l = new_list();
68     for (const auto& e : t) {
69       l.push_back(all_frames.at(e));
70     }
71     py_unique_frames.emplace_back(std::move(l));
72   }
73 
74   std::vector<IValue> result;
75   result.reserve(to_symbolize.size());
76   for (const auto& sc : to_symbolize) {
77     result.push_back(py_unique_frames.at(cached_frames.at(sc)));
78   }
79   return result;
80 }
81 
gather()82 std::shared_ptr<c10::GatheredContext> gather() {
83   return CapturedTraceback::gather(true, true, false);
84 }
85 
gather_with_cpp()86 std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
87   return CapturedTraceback::gather(true, true, true);
88 }
89 
getFromContext(const std::shared_ptr<c10::GatheredContext> & x)90 CapturedTraceback* getFromContext(
91     const std::shared_ptr<c10::GatheredContext>& x) {
92   if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
93     return sc;
94   }
95   TORCH_CHECK(
96       false,
97       "attempting to gather stack context from the wrong StackContext type.");
98 }
99 
_initRecordAnnotations()100 void _initRecordAnnotations() {
101   static c10::once_flag ra_init;
102   c10::call_once(ra_init, [&] {
103     // Save user annotations to CCA memory snapshot tool
104     at::addThreadLocalCallback(
105         at::RecordFunctionCallback(
106             [](const at::RecordFunction& fn)
107                 -> std::unique_ptr<at::ObserverContext> {
108               c10::cuda::CUDACachingAllocator::recordAnnotation(
109                   {{"name", fn.name()}, {"stage", "START"}});
110               return nullptr;
111             },
112             [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
113               c10::cuda::CUDACachingAllocator::recordAnnotation(
114                   {{"name", fn.name()}, {"stage", "END"}});
115             })
116             .scopes({at::RecordScope::USER_SCOPE}));
117   });
118 }
119 
120 } // namespace
121 
_record_memory_history(bool enabled,bool record_context,int64_t trace_alloc_max_entries,bool trace_alloc_record_context,bool record_cpp_context)122 void _record_memory_history(
123     bool enabled,
124     bool record_context,
125     int64_t trace_alloc_max_entries,
126     bool trace_alloc_record_context,
127     bool record_cpp_context) {
128   c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
129   if (enabled && record_cpp_context &&
130       (trace_alloc_record_context || record_context)) {
131     recorder = gather_with_cpp;
132     // warm up C++ stack unwinding
133     unwind::unwind();
134   }
135   auto when = c10::cuda::CUDACachingAllocator::RecordContext::NEVER;
136   if (trace_alloc_record_context) {
137     when = c10::cuda::CUDACachingAllocator::RecordContext::ALLOC;
138   } else if (record_context) {
139     when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
140   }
141   at::globalContext().lazyInitCUDA();
142   _initRecordAnnotations();
143   c10::cuda::CUDACachingAllocator::recordHistory(
144       enabled, recorder, trace_alloc_max_entries, when);
145 }
146 
checkOptionIn(const std::string & option,std::initializer_list<std::string> valid,const char * error)147 static void checkOptionIn(
148     const std::string& option,
149     std::initializer_list<std::string> valid,
150     const char* error) {
151   TORCH_CHECK(
152       valid.end() != std::find(valid.begin(), valid.end(), option), error);
153 }
154 
_record_memory_history(std::optional<std::string> enabled,std::optional<std::string> context,const std::string & stacks,size_t max_entries)155 void _record_memory_history(
156     std::optional<std::string> enabled,
157     std::optional<std::string> context,
158     const std::string& stacks,
159     size_t max_entries) {
160   if (enabled) {
161     checkOptionIn(
162         *enabled,
163         {"state", "all"},
164         "expected state to be 'state', 'all', or None");
165   }
166   if (context) {
167     checkOptionIn(
168         *context,
169         {"state", "alloc", "all"},
170         "expected context to be 'state', 'alloc', 'all', or None");
171   }
172   checkOptionIn(
173       stacks, {"python", "all"}, "expected stacks to be 'python', or 'all'");
174 
175   c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
176   if (enabled && context && stacks == "all") {
177     recorder = gather_with_cpp;
178     // warm up C++ stack unwinding
179     unwind::unwind();
180   }
181   max_entries = (enabled && *enabled == "all") ? max_entries : 1;
182   auto when = c10::cuda::CUDACachingAllocator::RecordContext::NEVER;
183   if (context) {
184     if (context == "all") {
185       when = c10::cuda::CUDACachingAllocator::RecordContext::ALL;
186     } else if (context == "alloc") {
187       when = c10::cuda::CUDACachingAllocator::RecordContext::ALLOC;
188     } else if (context == "state") {
189       when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
190     }
191   }
192   at::globalContext().lazyInitCUDA();
193   _initRecordAnnotations();
194   c10::cuda::CUDACachingAllocator::recordHistory(
195       enabled.has_value(), recorder, max_entries, when);
196 }
197 
_memory_snapshot_pickled()198 std::string _memory_snapshot_pickled() {
199   IValue device_s = "device";
200   IValue address_s = "address";
201   IValue total_size_s = "total_size";
202   IValue allocated_size_s = "allocated_size";
203   IValue active_size_s = "active_size";
204   IValue requested_size_s = "requested_size";
205   IValue stream_s = "stream";
206   IValue segment_type_s = "segment_type";
207   IValue segment_pool_id = "segment_pool_id";
208   IValue large_s = "large";
209   IValue small_s = "small";
210   IValue size_s = "size";
211   IValue state_s = "state";
212   IValue active_allocated_s = "active_allocated";
213   IValue active_pending_free_s = "active_pending_free";
214   IValue inactive_s = "inactive";
215   IValue addr_s = "addr";
216   IValue filename_s = "filename";
217   IValue name_s = "name";
218   IValue line_s = "line";
219   IValue frames_s = "frames";
220   IValue blocks_s = "blocks";
221   IValue is_expandable_s = "is_expandable";
222   IValue time_us_s = "time_us";
223 
224   auto empty_frames = new_list();
225 
226   std::vector<CapturedTraceback*> frame_tracebacks;
227   std::vector<Dict<IValue, IValue>> frame_dict;
228 
229   auto add_frame_key = [&](const c10::Dict<IValue, IValue>& d,
230                            const std::shared_ptr<c10::GatheredContext>& ctx) {
231     if (ctx) {
232       frame_tracebacks.push_back(getFromContext(ctx));
233       frame_dict.push_back(d);
234     } else {
235       d.insert(frames_s, empty_frames);
236     }
237   };
238 
239   const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
240     auto segmentDict = new_dict();
241     segmentDict.insert(device_s, segmentInfo.device);
242     segmentDict.insert(address_s, static_cast<int64_t>(segmentInfo.address));
243     segmentDict.insert(
244         total_size_s, static_cast<int64_t>(segmentInfo.total_size));
245     segmentDict.insert(
246         allocated_size_s, static_cast<int64_t>(segmentInfo.allocated_size));
247     segmentDict.insert(
248         active_size_s, static_cast<int64_t>(segmentInfo.active_size));
249     segmentDict.insert(
250         requested_size_s, static_cast<int64_t>(segmentInfo.requested_size));
251     segmentDict.insert(stream_s, int64_t(segmentInfo.stream));
252     segmentDict.insert(
253         segment_type_s, (segmentInfo.is_large ? large_s : small_s));
254     segmentDict.insert(
255         segment_pool_id,
256         std::tuple<int64_t, int64_t>(segmentInfo.owner_private_pool_id));
257     segmentDict.insert(is_expandable_s, segmentInfo.is_expandable);
258 
259     add_frame_key(segmentDict, segmentInfo.context_when_allocated);
260 
261     auto address = segmentInfo.address;
262     auto blocks = new_list();
263     for (const auto& blockInfo : segmentInfo.blocks) {
264       auto blockDict = new_dict();
265       blockDict.insert(address_s, static_cast<int64_t>(address));
266       blockDict.insert(size_s, static_cast<int64_t>(blockInfo.size));
267       blockDict.insert(
268           requested_size_s, static_cast<int64_t>(blockInfo.requested_size));
269       blockDict.insert(
270           state_s,
271           (blockInfo.allocated
272                ? active_allocated_s
273                : (blockInfo.active ? active_pending_free_s : inactive_s)));
274       add_frame_key(blockDict, blockInfo.context_when_allocated);
275       address += blockInfo.size;
276       blocks.push_back(blockDict);
277     }
278     segmentDict.insert(blocks_s, blocks);
279 
280     return segmentDict;
281   };
282 
283   auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
284 
285   auto segments = new_list();
286   for (const auto& segmentInfo : snapshot.segments) {
287     segments.push_back(segmentInfoToDict(segmentInfo));
288   }
289 
290   auto traces = new_list();
291   IValue action_s = "action";
292   IValue alloc_s = "alloc";
293   IValue free_requested_s = "free_requested";
294   IValue free_completed_s = "free_completed";
295   IValue segment_alloc_s = "segment_alloc";
296   IValue segment_free_s = "segment_free";
297   IValue segment_map_s = "segment_map";
298   IValue segment_unmap_s = "segment_unmap";
299   IValue snapshot_s = "snapshot";
300   IValue oom_s = "oom";
301   IValue device_free_s = "device_free";
302 
303   using namespace c10::cuda::CUDACachingAllocator;
304 
305   auto action_to_str = [&](TraceEntry::Action action) {
306     switch (action) {
307       case TraceEntry::ALLOC:
308         return alloc_s;
309       case TraceEntry::FREE_REQUESTED:
310         return free_requested_s;
311       case TraceEntry::FREE_COMPLETED:
312         return free_completed_s;
313       case TraceEntry::SEGMENT_ALLOC:
314         return segment_alloc_s;
315       case TraceEntry::SEGMENT_FREE:
316         return segment_free_s;
317       case TraceEntry::OOM:
318         return oom_s;
319       case TraceEntry::SNAPSHOT:
320         return snapshot_s;
321       case TraceEntry::SEGMENT_UNMAP:
322         return segment_unmap_s;
323       case TraceEntry::SEGMENT_MAP:
324         return segment_map_s;
325     }
326     throw std::runtime_error("unreachable");
327   };
328 
329   for (const auto& traceInfo : snapshot.device_traces) {
330     auto trace = new_list();
331     for (const auto& te : traceInfo) {
332       auto trace_entry = new_dict();
333       trace_entry.insert(action_s, action_to_str(te.action_));
334       trace_entry.insert(
335           TraceEntry::OOM == te.action_ ? device_free_s : addr_s,
336           static_cast<int64_t>(te.addr_));
337       trace_entry.insert(size_s, (int64_t)te.size_);
338       trace_entry.insert(stream_s, int64_t(te.stream_));
339       if (te.context_) {
340         auto sc = getFromContext(te.context_);
341         frame_tracebacks.push_back(sc);
342         frame_dict.push_back(trace_entry);
343       }
344       trace_entry.insert(time_us_s, te.time_.t_);
345       trace.push_back(trace_entry);
346     }
347     traces.push_back(trace);
348   }
349 
350   auto external_annotations = new_list();
351   for (const auto& ae : snapshot.external_annotations) {
352     auto annotation_entry = new_dict();
353     for (const auto& md : ae.metadata_) {
354       annotation_entry.insert((IValue)md.first, md.second);
355     }
356     annotation_entry.insert(device_s, ae.device_);
357     annotation_entry.insert(time_us_s, ae.time_.t_);
358     external_annotations.push_back(annotation_entry);
359   }
360 
361   auto allocator_settings = new_dict();
362   IValue last_allocator_settings_s = "PYTORCH_CUDA_ALLOC_CONF";
363   IValue max_split_size_s = "max_split_size";
364   IValue garbage_collection_threshold_s = "garbage_collection_threshold";
365   IValue expandable_segments_s = "expandable_segments";
366   IValue pinned_num_register_threads_s = "pinned_num_register_threads";
367   IValue release_lock_on_malloc_s = "release_lock_on_cudamalloc";
368   IValue pinned_use_host_register_s = "pinned_use_cuda_host_register";
369   IValue roundup_power2_divisions_s = "roundup_power2_divisions";
370 
371   allocator_settings.insert(
372       last_allocator_settings_s,
373       snapshot.config_metadata.last_allocator_settings);
374   allocator_settings.insert(
375       max_split_size_s, int64_t(snapshot.config_metadata.max_split_size));
376   allocator_settings.insert(
377       garbage_collection_threshold_s,
378       snapshot.config_metadata.garbage_collection_threshold);
379   allocator_settings.insert(
380       expandable_segments_s, snapshot.config_metadata.expandable_segments);
381   allocator_settings.insert(
382       pinned_num_register_threads_s,
383       int64_t(snapshot.config_metadata.pinned_num_register_threads));
384   allocator_settings.insert(
385       release_lock_on_malloc_s,
386       snapshot.config_metadata.release_lock_on_malloc);
387   allocator_settings.insert(
388       pinned_use_host_register_s,
389       snapshot.config_metadata.pinned_use_host_register);
390   unsigned int roundup_key = 1;
391   auto roundup_settings = new_dict();
392   for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {
393     IValue roundup_key_s = std::to_string(roundup_key);
394     roundup_settings.insert(roundup_key_s, int64_t(v));
395     roundup_key *= 2;
396   }
397   allocator_settings.insert(roundup_power2_divisions_s, roundup_settings);
398 
399   auto result = new_dict();
400   result.insert("segments", segments);
401   result.insert("device_traces", traces);
402   result.insert("allocator_settings", allocator_settings);
403   result.insert("external_annotations", external_annotations);
404 
405   auto frames = ivalue_symbolize(frame_tracebacks);
406   for (auto i : c10::irange(frames.size())) {
407     frame_dict.at(i).insert(frames_s, frames.at(i));
408   }
409 
410   return write_pickle(result);
411 }
412 } // namespace torch::cuda
413