xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/memory_planner.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/runtime/static/impl.h>
4 
5 namespace torch::jit {
6 
7 // A StorageGroup represents a collection of tensors that share backing storage.
8 class StorageGroup {
9  public:
10   // Every storage group must contain at least one tensor.
StorageGroup(at::Tensor * tensor)11   explicit StorageGroup(at::Tensor* tensor) : group_{tensor} {}
12 
addTensor(at::Tensor * tensor)13   void addTensor(at::Tensor* tensor) {
14     group_.push_back(tensor);
15   }
16 
group()17   const std::vector<at::Tensor*>& group() const {
18     return group_;
19   }
20 
maxTensorSize()21   size_t maxTensorSize() const {
22     return max_tensor_size_;
23   }
24 
setMaxTensorSize(size_t new_size)25   void setMaxTensorSize(size_t new_size) {
26     max_tensor_size_ = new_size;
27   }
28 
numManagedTensors()29   size_t numManagedTensors() const {
30     return group_.size();
31   }
32 
33  private:
34   // The size attribute represents the amount of memory that will be
35   // allocated for all tensors in this storage group. Initially it
36   // is zero, eventually it gets updated by the MemoryPlanner.
37   size_t max_tensor_size_ = 0;
38   std::vector<at::Tensor*> group_{};
39 };
40 
41 // A contiguous buffer of `StorageImpl`s
42 class ManagedStorages {
43  public:
44   ManagedStorages();
45 
46   ~ManagedStorages();
47 
48   void allocate(size_t capacity);
49 
50   void deallocate();
51 
is_allocated()52   bool is_allocated() const {
53     return storages_ != nullptr;
54   }
55 
56   // Append a new StorageImpl to the buffer. The new StorageImpl is given the
57   // same size and allocator as `storageImpl` argument
58   void append(at::StorageImpl& storageImpl);
59 
60   at::StorageImpl& operator[](size_t idx) {
61     TORCH_INTERNAL_ASSERT(storages_ != nullptr);
62     return storages_[idx];
63   }
64 
65   const at::StorageImpl& operator[](size_t idx) const {
66     TORCH_INTERNAL_ASSERT(storages_ != nullptr);
67     return storages_[idx];
68   }
69 
size()70   size_t size() const {
71     return size_;
72   }
73 
empty()74   bool empty() const {
75     return size_ == 0;
76   }
77 
capacity()78   size_t capacity() const {
79     return capacity_;
80   }
81 
82  private:
83   // We will use placement-new to add new storages to this buffer
84   at::StorageImpl* storages_;
85 
86   // Current number of storages that have been placed into the storage buffer
87   size_t size_;
88 
89   // Total allocated capacity of the storage buffer
90   size_t capacity_;
91 };
92 
93 TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors(
94     graph_node_list nodes,
95     const ManagedTensorRanges& ranges,
96     const c10::FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor);
97 
98 // There are three types of ops in a processed graph in Static Runtime:
99 //   1. op with _out variant
100 //   2. view-producing op
101 //   3. tensor-producing op (could be replaced with type 1 by adding the _out
102 //      variant to Static Runtime)
103 // In Static Runtime, type 2 ops are replaced with their corresponding copy
104 // versions when enable_out_variant is enabled and become type 1 ops.The memory
105 // planner only manages tensors that are outputs of type 1 ops. For type 3, the
106 // output tensors are allocated inside the operator and can't be directly
107 // managed by memory planner.
108 //
109 // Memory planner tries to minimize the number of memory allocations by
110 // tracking the output tensors of ops with _out variants with unique DataPtr
111 // (part of StorageImpl). It tries to do this in several steps:
112 //   1. record the max memory usage for each Tensor with unique DataPtr at the
113 //      end of each iteration
114 //   2. in the next iteration, allocate the buffer for the max total usage and
115 //      compute the offset of each allocation with regard to the single memory
116 //      buffer, optionally reusing memory. In the first iteration, we rely on
117 //      the default allocator for memory allocation.
118 //   3. free the buffer at the end of each iteration
119 // Steps 1 and 3 are handled by `deallocate()`, and step 2 by `allocate()`.
120 // Only models with simple output types are supported, i.e. None, Tensor or
121 // List/Tuple/Dict of Tensors. Complex output types such as List of Lists are
122 // not supported.
123 //
124 // Additional Optimizations:
125 //
126 // [Borrowed IValue Outputs]
127 // A few native ops (notably, `static_runtime::dict_unpack` and
128 // `static_runtime::VarTupleUnpack`) simply unpack IValues to a bunch of
129 // outputs without modification. For example, `dict_unpack` does the following:
130 // for each key in inputs:
131 //     output[i] = dict_input[key]
132 // To avoid refcount bumps, the outputs of these ops are non-owning references.
133 // This requires special logic in the memory planner - when adding an op that
134 // borrows outputs, be sure that the memory planner is updated accordingly!
135 //
136 // [Managed Output Tensors]
137 // The memory planner is able to manage output tensors if the appropriate
138 // `StaticModuleOptions` are set. However, the memory planner handles output
139 // tensors separately from regular intermediate tensors:
140 // 1. They don't participate in memory reuse.
141 // 2. The memory planner cannot reclaim their backing storage until they have
142 //    been explicitly freed by the client.
143 
144 class MemoryPlanner {
145  public:
146   MemoryPlanner(
147       BlockRunner* block_runner,
148       const BlockInfo& block_info,
149       bool enable_out_variant,
150       bool manage_output_tensors);
151 
152   // disable copying and moving
153   MemoryPlanner(const MemoryPlanner&) = delete;
154   MemoryPlanner& operator=(const MemoryPlanner&) = delete;
155   MemoryPlanner(MemoryPlanner&&) = delete;
156   MemoryPlanner& operator=(MemoryPlanner&&) = delete;
157   virtual ~MemoryPlanner() = default;
158 
159   void allocate();
160   void deallocate();
161   void deallocateOutputTensors();
162 
total_num_managed_tensors()163   size_t total_num_managed_tensors() const {
164     return num_managed_tensors_;
165   }
166 
total_reused_tensors()167   size_t total_reused_tensors() const {
168     return reused_tensors_;
169   }
170 
total_num_managed_output_tensors()171   size_t total_num_managed_output_tensors() const {
172     return managed_output_tensors_.size();
173   }
174 
total_num_unmanaged()175   C10_NODISCARD size_t total_num_unmanaged() const {
176     return num_unmanaged_non_scalars() + num_unmanaged_scalars();
177   }
178 
num_unmanaged_non_scalars()179   C10_NODISCARD size_t num_unmanaged_non_scalars() const {
180     return unmanaged_ivalues_.size() + unmanaged_borrowed_ivalues_.size();
181   }
182 
num_unmanaged_scalars()183   C10_NODISCARD size_t num_unmanaged_scalars() const {
184     return num_unmanaged_scalar_ivalues_;
185   }
186 
total_managed()187   size_t total_managed() const {
188     return managed_bytes_;
189   }
190 
numOutputBufferBytes()191   size_t numOutputBufferBytes() const {
192     return output_buffer_bytes_;
193   }
194 
195   // Check if `ivalue` is contained as a managed tensor. Only used in DCHECK().
isManagedOutputTensor(const IValue & ivalue)196   bool isManagedOutputTensor(const IValue& ivalue) const {
197     if (!output_buffer_ || // output buffer got already deallocated.
198         output_buffer_bytes_ == 0 || // memory planning is not yet initialized.
199         !ivalue.isTensor() // a non-tensor is never managed
200     ) {
201       return false;
202     }
203     const auto& tensor = ivalue.toTensor();
204     if (!tensor.has_storage() || !tensor.storage().data_ptr()) {
205       return false;
206     }
207     // TODO: Improve this once D31357486 is landed.
208     uint8_t* tensor_ptr =
209         static_cast<uint8_t*>(tensor.storage().data_ptr().get());
210     uint8_t* buffer_start = static_cast<uint8_t*>(output_buffer_.get());
211     uint8_t* buffer_end = buffer_start + output_buffer_bytes_;
212     return buffer_start <= tensor_ptr && tensor_ptr < buffer_end;
213   }
214 
isManagedStorageImpl(const at::StorageImpl * impl)215   bool isManagedStorageImpl(const at::StorageImpl* impl) const {
216     if (storages_.empty()) {
217       return false;
218     }
219     // Comparing pointers that aren't within the same array is
220     // UB. We're doing fancy memory allocation stuff, so we cast to an
221     // integer type and carry on.
222     const auto impl_p = reinterpret_cast<uintptr_t>(impl);
223     const auto start = reinterpret_cast<uintptr_t>(&storages_[0]);
224     const auto end =
225         reinterpret_cast<uintptr_t>(&storages_[0] + storages_.size());
226     return impl_p >= start && impl_p < end;
227   }
228 
overlapWithInternalBuffer(void * data_ptr)229   bool overlapWithInternalBuffer(void* data_ptr) {
230     return buffer_start_ <= data_ptr && data_ptr < buffer_end_;
231   }
232 
233  protected:
234   uint8_t* allocateBuffer(size_t num_bytes);
235 
236   size_t managed_bytes_{0};
237   size_t reused_tensors_{0};
238 
239   // We allocate StorageImpls ourselves so that 1) we don't have to do
240   // an extra two loads per Tensor (which will likely miss in the CPU
241   // data cache) first reading the Storage (i.e., StorageImpl pointer)
242   // from the TensorImpl object and then second dereferencing it and
243   // 2) our memory access pattern during allocate() has high locality.
244   // We don't have any guarantee that the model doesn't change the
245   // Storage for managed tensors out from under us during execution,
246   // so we have to check the StorageImpls each time we deallocate.
247   ManagedStorages storages_;
248 
249   // Contains the size (in bytes) of the data to be allocated for each storage
250   std::vector<size_t> storages_nbytes_;
251 
252  private:
253   // ivalues created in one run but not managed by MemoryPlanner
254   std::vector<IValue*> unmanaged_ivalues_;
255 
256   // Special class of unmanaged values: some native ops create IValues
257   // in a "borrowed" state that can and must be cleaned up without a
258   // reference count decrement.
259   std::vector<IValue*> unmanaged_borrowed_ivalues_;
260 
261   // Even more special class of unmanaged values: if select_tensor
262   // outputs are outputs of the graph, then they need to be restored
263   // to an ordinary "strong reference" state.
264   std::vector<IValue*> borrowed_ivalues_needing_incref_;
265 
266   std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors_{};
267   at::DataPtr buffer_; // allocated each time we call Run()
268   uint8_t* buffer_start_{nullptr};
269   uint8_t* buffer_end_{nullptr};
270   size_t num_managed_tensors_{0};
271   size_t num_unmanaged_scalar_ivalues_{0};
272 
273   at::DataPtr output_buffer_;
274   size_t output_buffer_bytes_{0};
275 
276   virtual void allocateManagedTensors() = 0;
277   virtual void deallocateManagedTensors() = 0;
278 
279   void allocateOutputTensors();
280 };
281 
282 class StandardMemoryPlanner : public MemoryPlanner {
283  public:
284   StandardMemoryPlanner(
285       BlockRunner* block_runner,
286       const BlockInfo& block_info,
287       bool enable_out_variant,
288       bool manage_output_tensors,
289       bool optimize_memory);
290 
291  protected:
292   void allocateManagedTensors() override;
293   void deallocateManagedTensors() override;
294 
295   std::vector<StorageGroup> managed_tensors_{};
296 };
297 
298 } // namespace torch::jit
299