xref: /aosp_15_r20/external/pytorch/c10/core/Allocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Allocator.h>
2 
3 #include <c10/util/ThreadLocalDebugInfo.h>
4 
5 namespace c10 {
6 
clone(const void * data,std::size_t n)7 DataPtr Allocator::clone(const void* data, std::size_t n) {
8   DataPtr new_data = allocate(n);
9   copy_data(new_data.mutable_get(), data, n);
10   return new_data;
11 }
12 
default_copy_data(void * dest,const void * src,std::size_t count) const13 void Allocator::default_copy_data(
14     void* dest,
15     const void* src,
16     std::size_t count) const {
17   std::memcpy(dest, src, count);
18 }
19 
is_simple_data_ptr(const DataPtr & data_ptr) const20 bool Allocator::is_simple_data_ptr(const DataPtr& data_ptr) const {
21   return data_ptr.get() == data_ptr.get_context();
22 }
23 
deleteInefficientStdFunctionContext(void * ptr)24 static void deleteInefficientStdFunctionContext(void* ptr) {
25   delete static_cast<InefficientStdFunctionContext*>(ptr);
26 }
27 
makeDataPtr(void * ptr,std::function<void (void *)> deleter,Device device)28 at::DataPtr InefficientStdFunctionContext::makeDataPtr(
29     void* ptr,
30     std::function<void(void*)> deleter,
31     Device device) {
32   return {
33       ptr,
34       new InefficientStdFunctionContext(ptr, std::move(deleter)),
35       &deleteInefficientStdFunctionContext,
36       device};
37 }
38 
39 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
40 C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES];
41 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
42 C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0};
43 
SetAllocator(at::DeviceType t,at::Allocator * alloc,uint8_t priority)44 void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) {
45   if (priority >= allocator_priority[static_cast<int>(t)]) {
46     allocator_array[static_cast<int>(t)] = alloc;
47     allocator_priority[static_cast<int>(t)] = priority;
48   }
49 }
50 
GetAllocator(const at::DeviceType & t)51 at::Allocator* GetAllocator(const at::DeviceType& t) {
52   auto* alloc = allocator_array[static_cast<int>(t)];
53   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alloc, "Allocator for ", t, " is not set.");
54   return alloc;
55 }
56 
memoryProfilingEnabled()57 bool memoryProfilingEnabled() {
58   auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
59       ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
60   return reporter_ptr && reporter_ptr->memoryProfilingEnabled();
61 }
62 
reportMemoryUsageToProfiler(void * ptr,int64_t alloc_size,size_t total_allocated,size_t total_reserved,Device device)63 void reportMemoryUsageToProfiler(
64     void* ptr,
65     int64_t alloc_size,
66     size_t total_allocated,
67     size_t total_reserved,
68     Device device) {
69   auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
70       ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
71   if (reporter_ptr) {
72     reporter_ptr->reportMemoryUsage(
73         ptr, alloc_size, total_allocated, total_reserved, device);
74   }
75 }
76 
reportOutOfMemoryToProfiler(int64_t alloc_size,size_t total_allocated,size_t total_reserved,Device device)77 void reportOutOfMemoryToProfiler(
78     int64_t alloc_size,
79     size_t total_allocated,
80     size_t total_reserved,
81     Device device) {
82   auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
83       ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
84   if (reporter_ptr) {
85     reporter_ptr->reportOutOfMemory(
86         alloc_size, total_allocated, total_reserved, device);
87   }
88 }
89 
90 MemoryReportingInfoBase::MemoryReportingInfoBase() = default;
91 
reportOutOfMemory(int64_t,size_t,size_t,Device)92 void MemoryReportingInfoBase::reportOutOfMemory(
93     int64_t /*alloc_size*/,
94     size_t /*total_allocated*/,
95     size_t /*total_reserved*/,
96     Device /*device*/) {}
97 
98 } // namespace c10
99