1 #pragma once 2 3 #include <torch/csrc/jit/runtime/static/impl.h> 4 5 namespace torch::jit { 6 7 // A StorageGroup represents a collection of tensors that share backing storage. 8 class StorageGroup { 9 public: 10 // Every storage group must contain at least one tensor. StorageGroup(at::Tensor * tensor)11 explicit StorageGroup(at::Tensor* tensor) : group_{tensor} {} 12 addTensor(at::Tensor * tensor)13 void addTensor(at::Tensor* tensor) { 14 group_.push_back(tensor); 15 } 16 group()17 const std::vector<at::Tensor*>& group() const { 18 return group_; 19 } 20 maxTensorSize()21 size_t maxTensorSize() const { 22 return max_tensor_size_; 23 } 24 setMaxTensorSize(size_t new_size)25 void setMaxTensorSize(size_t new_size) { 26 max_tensor_size_ = new_size; 27 } 28 numManagedTensors()29 size_t numManagedTensors() const { 30 return group_.size(); 31 } 32 33 private: 34 // The size attribute represents the amount of memory that will be 35 // allocated for all tensors in this storage group. Initially it 36 // is zero, eventually it gets updated by the MemoryPlanner. 37 size_t max_tensor_size_ = 0; 38 std::vector<at::Tensor*> group_{}; 39 }; 40 41 // A contiguous buffer of `StorageImpl`s 42 class ManagedStorages { 43 public: 44 ManagedStorages(); 45 46 ~ManagedStorages(); 47 48 void allocate(size_t capacity); 49 50 void deallocate(); 51 is_allocated()52 bool is_allocated() const { 53 return storages_ != nullptr; 54 } 55 56 // Append a new StorageImpl to the buffer. The new StorageImpl is given the 57 // same size and allocator as `storageImpl` argument 58 void append(at::StorageImpl& storageImpl); 59 60 at::StorageImpl& operator[](size_t idx) { 61 TORCH_INTERNAL_ASSERT(storages_ != nullptr); 62 return storages_[idx]; 63 } 64 65 const at::StorageImpl& operator[](size_t idx) const { 66 TORCH_INTERNAL_ASSERT(storages_ != nullptr); 67 return storages_[idx]; 68 } 69 size()70 size_t size() const { 71 return size_; 72 } 73 empty()74 bool empty() const { 75 return size_ == 0; 76 } 77 capacity()78 size_t capacity() const { 79 return capacity_; 80 } 81 82 private: 83 // We will use placement-new to add new storages to this buffer 84 at::StorageImpl* storages_; 85 86 // Current number of storages that have been placed into the storage buffer 87 size_t size_; 88 89 // Total allocated capacity of the storage buffer 90 size_t capacity_; 91 }; 92 93 TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors( 94 graph_node_list nodes, 95 const ManagedTensorRanges& ranges, 96 const c10::FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor); 97 98 // There are three types of ops in a processed graph in Static Runtime: 99 // 1. op with _out variant 100 // 2. view-producing op 101 // 3. tensor-producing op (could be replaced with type 1 by adding the _out 102 // variant to Static Runtime) 103 // In Static Runtime, type 2 ops are replaced with their corresponding copy 104 // versions when enable_out_variant is enabled and become type 1 ops.The memory 105 // planner only manages tensors that are outputs of type 1 ops. For type 3, the 106 // output tensors are allocated inside the operator and can't be directly 107 // managed by memory planner. 108 // 109 // Memory planner tries to minimize the number of memory allocations by 110 // tracking the output tensors of ops with _out variants with unique DataPtr 111 // (part of StorageImpl). It tries to do this in several steps: 112 // 1. record the max memory usage for each Tensor with unique DataPtr at the 113 // end of each iteration 114 // 2. in the next iteration, allocate the buffer for the max total usage and 115 // compute the offset of each allocation with regard to the single memory 116 // buffer, optionally reusing memory. In the first iteration, we rely on 117 // the default allocator for memory allocation. 118 // 3. free the buffer at the end of each iteration 119 // Steps 1 and 3 are handled by `deallocate()`, and step 2 by `allocate()`. 120 // Only models with simple output types are supported, i.e. None, Tensor or 121 // List/Tuple/Dict of Tensors. Complex output types such as List of Lists are 122 // not supported. 123 // 124 // Additional Optimizations: 125 // 126 // [Borrowed IValue Outputs] 127 // A few native ops (notably, `static_runtime::dict_unpack` and 128 // `static_runtime::VarTupleUnpack`) simply unpack IValues to a bunch of 129 // outputs without modification. For example, `dict_unpack` does the following: 130 // for each key in inputs: 131 // output[i] = dict_input[key] 132 // To avoid refcount bumps, the outputs of these ops are non-owning references. 133 // This requires special logic in the memory planner - when adding an op that 134 // borrows outputs, be sure that the memory planner is updated accordingly! 135 // 136 // [Managed Output Tensors] 137 // The memory planner is able to manage output tensors if the appropriate 138 // `StaticModuleOptions` are set. However, the memory planner handles output 139 // tensors separately from regular intermediate tensors: 140 // 1. They don't participate in memory reuse. 141 // 2. The memory planner cannot reclaim their backing storage until they have 142 // been explicitly freed by the client. 143 144 class MemoryPlanner { 145 public: 146 MemoryPlanner( 147 BlockRunner* block_runner, 148 const BlockInfo& block_info, 149 bool enable_out_variant, 150 bool manage_output_tensors); 151 152 // disable copying and moving 153 MemoryPlanner(const MemoryPlanner&) = delete; 154 MemoryPlanner& operator=(const MemoryPlanner&) = delete; 155 MemoryPlanner(MemoryPlanner&&) = delete; 156 MemoryPlanner& operator=(MemoryPlanner&&) = delete; 157 virtual ~MemoryPlanner() = default; 158 159 void allocate(); 160 void deallocate(); 161 void deallocateOutputTensors(); 162 total_num_managed_tensors()163 size_t total_num_managed_tensors() const { 164 return num_managed_tensors_; 165 } 166 total_reused_tensors()167 size_t total_reused_tensors() const { 168 return reused_tensors_; 169 } 170 total_num_managed_output_tensors()171 size_t total_num_managed_output_tensors() const { 172 return managed_output_tensors_.size(); 173 } 174 total_num_unmanaged()175 C10_NODISCARD size_t total_num_unmanaged() const { 176 return num_unmanaged_non_scalars() + num_unmanaged_scalars(); 177 } 178 num_unmanaged_non_scalars()179 C10_NODISCARD size_t num_unmanaged_non_scalars() const { 180 return unmanaged_ivalues_.size() + unmanaged_borrowed_ivalues_.size(); 181 } 182 num_unmanaged_scalars()183 C10_NODISCARD size_t num_unmanaged_scalars() const { 184 return num_unmanaged_scalar_ivalues_; 185 } 186 total_managed()187 size_t total_managed() const { 188 return managed_bytes_; 189 } 190 numOutputBufferBytes()191 size_t numOutputBufferBytes() const { 192 return output_buffer_bytes_; 193 } 194 195 // Check if `ivalue` is contained as a managed tensor. Only used in DCHECK(). isManagedOutputTensor(const IValue & ivalue)196 bool isManagedOutputTensor(const IValue& ivalue) const { 197 if (!output_buffer_ || // output buffer got already deallocated. 198 output_buffer_bytes_ == 0 || // memory planning is not yet initialized. 199 !ivalue.isTensor() // a non-tensor is never managed 200 ) { 201 return false; 202 } 203 const auto& tensor = ivalue.toTensor(); 204 if (!tensor.has_storage() || !tensor.storage().data_ptr()) { 205 return false; 206 } 207 // TODO: Improve this once D31357486 is landed. 208 uint8_t* tensor_ptr = 209 static_cast<uint8_t*>(tensor.storage().data_ptr().get()); 210 uint8_t* buffer_start = static_cast<uint8_t*>(output_buffer_.get()); 211 uint8_t* buffer_end = buffer_start + output_buffer_bytes_; 212 return buffer_start <= tensor_ptr && tensor_ptr < buffer_end; 213 } 214 isManagedStorageImpl(const at::StorageImpl * impl)215 bool isManagedStorageImpl(const at::StorageImpl* impl) const { 216 if (storages_.empty()) { 217 return false; 218 } 219 // Comparing pointers that aren't within the same array is 220 // UB. We're doing fancy memory allocation stuff, so we cast to an 221 // integer type and carry on. 222 const auto impl_p = reinterpret_cast<uintptr_t>(impl); 223 const auto start = reinterpret_cast<uintptr_t>(&storages_[0]); 224 const auto end = 225 reinterpret_cast<uintptr_t>(&storages_[0] + storages_.size()); 226 return impl_p >= start && impl_p < end; 227 } 228 overlapWithInternalBuffer(void * data_ptr)229 bool overlapWithInternalBuffer(void* data_ptr) { 230 return buffer_start_ <= data_ptr && data_ptr < buffer_end_; 231 } 232 233 protected: 234 uint8_t* allocateBuffer(size_t num_bytes); 235 236 size_t managed_bytes_{0}; 237 size_t reused_tensors_{0}; 238 239 // We allocate StorageImpls ourselves so that 1) we don't have to do 240 // an extra two loads per Tensor (which will likely miss in the CPU 241 // data cache) first reading the Storage (i.e., StorageImpl pointer) 242 // from the TensorImpl object and then second dereferencing it and 243 // 2) our memory access pattern during allocate() has high locality. 244 // We don't have any guarantee that the model doesn't change the 245 // Storage for managed tensors out from under us during execution, 246 // so we have to check the StorageImpls each time we deallocate. 247 ManagedStorages storages_; 248 249 // Contains the size (in bytes) of the data to be allocated for each storage 250 std::vector<size_t> storages_nbytes_; 251 252 private: 253 // ivalues created in one run but not managed by MemoryPlanner 254 std::vector<IValue*> unmanaged_ivalues_; 255 256 // Special class of unmanaged values: some native ops create IValues 257 // in a "borrowed" state that can and must be cleaned up without a 258 // reference count decrement. 259 std::vector<IValue*> unmanaged_borrowed_ivalues_; 260 261 // Even more special class of unmanaged values: if select_tensor 262 // outputs are outputs of the graph, then they need to be restored 263 // to an ordinary "strong reference" state. 264 std::vector<IValue*> borrowed_ivalues_needing_incref_; 265 266 std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors_{}; 267 at::DataPtr buffer_; // allocated each time we call Run() 268 uint8_t* buffer_start_{nullptr}; 269 uint8_t* buffer_end_{nullptr}; 270 size_t num_managed_tensors_{0}; 271 size_t num_unmanaged_scalar_ivalues_{0}; 272 273 at::DataPtr output_buffer_; 274 size_t output_buffer_bytes_{0}; 275 276 virtual void allocateManagedTensors() = 0; 277 virtual void deallocateManagedTensors() = 0; 278 279 void allocateOutputTensors(); 280 }; 281 282 class StandardMemoryPlanner : public MemoryPlanner { 283 public: 284 StandardMemoryPlanner( 285 BlockRunner* block_runner, 286 const BlockInfo& block_info, 287 bool enable_out_variant, 288 bool manage_output_tensors, 289 bool optimize_memory); 290 291 protected: 292 void allocateManagedTensors() override; 293 void deallocateManagedTensors() override; 294 295 std::vector<StorageGroup> managed_tensors_{}; 296 }; 297 298 } // namespace torch::jit 299