xref: /aosp_15_r20/external/pytorch/aten/src/ATen/MapAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/util/string_view.h>
5 
6 namespace at {
7 
8 enum MappedAllocatorModes {
9   ALLOCATOR_MAPPED_SHARED = 1,
10   ALLOCATOR_MAPPED_SHAREDMEM = 2,
11   ALLOCATOR_MAPPED_EXCLUSIVE = 4,
12   ALLOCATOR_MAPPED_NOCREATE = 8,
13   ALLOCATOR_MAPPED_KEEPFD = 16,
14   ALLOCATOR_MAPPED_FROMFD = 32,
15   ALLOCATOR_MAPPED_UNLINK = 64
16 };
17 
18 // Sentinel value/type to help distinguish the file descriptor constructor from
19 // the non-file descriptor constructor
20 enum WithFd { WITH_FD };
21 
22 TORCH_API std::string NewProcessWideShmHandle();
23 
24 class TORCH_API MapAllocator {
25  public:
26   MapAllocator(c10::string_view filename, int flags, size_t size);
27   MapAllocator(
28       WithFd,
29       c10::string_view filename,
30       int fd,
31       int flags,
32       size_t size);
33   MapAllocator(const MapAllocator&) = delete;
34   MapAllocator& operator=(const MapAllocator&) = delete;
35   MapAllocator(MapAllocator&&) = delete;
36   MapAllocator& operator=(MapAllocator&&) = delete;
37 
filename()38   const char* filename() const {
39     return filename_.c_str();
40   }
fd()41   int fd() const {
42 #ifdef _WIN32
43     TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
44 #else
45     return fd_;
46 #endif
47   }
size()48   ptrdiff_t size() const {
49     return size_;
50   }
51   // Return a pointer to the actual data for this allocator
52   // (in the case of the refcounted allocator, this is offset
53   // from the base pointer.)
data()54   virtual void* data() const {
55     return base_ptr_;
56   }
57 
flags()58   int flags() const {
59     return flags_;
60   }
61 
62   static MapAllocator* fromDataPtr(const at::DataPtr&);
63   static at::DataPtr makeDataPtr(
64       c10::string_view filename,
65       int flags,
66       size_t size,
67       size_t* actual_size_out);
68   static at::DataPtr makeDataPtr(
69       WithFd,
70       const char* filename,
71       int fd,
72       int flags,
73       size_t size,
74       size_t* actual_size_out);
75 
76   // Closes the data.  Helps us avoid destructor shenanigans
77   virtual void close();
78 
79   // This is very dangerous.  You have to redefine this destructor for each
80   // subclass
81   virtual ~MapAllocator();
82 
83  protected:
84   bool closed_ = false;
85   std::string filename_;
86   int flags_ = 0;
87   ptrdiff_t size_; /* mapped size */
88 #ifdef _WIN32
89   void* handle_;
90   void* event_;
91   std::string eventname_;
92 #else
93   int fd_ = -1;
94 #endif
95   void* base_ptr_ = nullptr;
96 };
97 
98 // Base-from-member idiom
99 struct TORCH_API RefcountedMapAllocatorArgCheck {
100   RefcountedMapAllocatorArgCheck(int flags);
101 };
102 
103 class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
104                                          public MapAllocator {
105  public:
106   RefcountedMapAllocator(const char* filename, int flags, size_t size);
107   RefcountedMapAllocator(
108       WithFd,
109       const char* filename,
110       int fd,
111       int flags,
112       size_t size);
113 
114   static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
115   static at::DataPtr makeDataPtr(
116       const char* filename,
117       int flags,
118       size_t size,
119       size_t* actual_size_out);
120   static at::DataPtr makeDataPtr(
121       WithFd,
122       const char* filename,
123       int fd,
124       int flags,
125       size_t size,
126       size_t* actual_size_out);
127 
128   void* data() const override;
129 
130   void incref();
131   int decref();
132   void close() override;
133 
~RefcountedMapAllocator()134   ~RefcountedMapAllocator() override {
135     RefcountedMapAllocator::close();
136   }
137 
138  protected:
139   void checkFlags();
140   void initializeAlloc();
141 };
142 
143 } // namespace at
144