1 #pragma once 2 3 #include <c10/util/ArrayRef.h> 4 #include <torch/csrc/lazy/backend/lowering_context.h> 5 #include <torch/csrc/lazy/core/cache.h> 6 #include <torch/csrc/lazy/core/ir_util.h> 7 #include <torch/csrc/lazy/core/multi_wait.h> 8 #include <torch/csrc/lazy/core/tensor.h> 9 #include <torch/csrc/lazy/core/util.h> 10 11 namespace torch { 12 namespace lazy { 13 14 class TORCH_API LazyGraphExecutor { 15 public: 16 struct DeviceDataInfo : public BackendData::Info { DeviceDataInfoDeviceDataInfo17 DeviceDataInfo(int64_t tensor_id, bool read_only) 18 : tensor_id(tensor_id), read_only(read_only) {} 19 20 int64_t tensor_id = 0; 21 bool read_only = false; 22 }; 23 24 // Register a lazy graph executor instance that can be retrieved using Get() 25 static void Register(LazyGraphExecutor*); 26 static LazyGraphExecutor* Get(); 27 28 virtual ~LazyGraphExecutor() = default; 29 30 // Override these methods to perform custom tensor registration and 31 // unregistration Note: It is vital that the parent implementations are also 32 // called in order for the tensors to show up in the live tensor list 33 virtual void RegisterTensor(std::shared_ptr<LazyTensor::Data> data); 34 virtual void UnregisterTensor(LazyTensor::Data* data); 35 36 // Seed for random generator. 37 // Override to supply your own DeviceContextArena. 38 virtual Value GetRngSeed(const BackendDevice& device); 39 virtual uint64_t GetRunningSeed(const BackendDevice& device); 40 virtual void SetRngSeed(const BackendDevice& device, uint64_t seed); 41 42 void DeviceBarrier(const BackendDevice& device); 43 44 BackendDataPtr GetDeviceData( 45 const at::Tensor& tensor, 46 const BackendDevice& device); 47 48 BackendDataPtr GetDeviceData( 49 const at::Scalar& value, 50 at::ScalarType scalar_type, 51 const BackendDevice& device); 52 53 // Retrieves the set of lazy tensors which are currently live in the system, 54 // for the given device. If device is nullptr, the live tensors for all 55 // devices will be returned. Returned tensors are sorted by device as primary 56 // key, and by unique ID as secondary key. 57 std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device); 58 59 // Makes sure that any outstanding IR operation accumulated over live tensors, 60 // gets turned into device data. If wait is true, the sync operation will be 61 // run synchronously. The devices argument, if not empty, tells the devices 62 // which should be partecipating into the replicated computation. 63 virtual void SyncLiveTensorsGraph( 64 const BackendDevice* device, 65 c10::ArrayRef<std::string> devices, 66 bool wait); 67 68 // Applies all the pending IR operations queued over the input tensors. All 69 // the tensors must be on the same device. If wait is true, the sync operation 70 // will be run synchronously. The devices argument, if not empty, tells the 71 // devices which should be partecipating into the replicated computation. 72 void SyncTensorsGraph( 73 std::vector<LazyTensorPtr>* tensors, 74 c10::ArrayRef<std::string> devices, 75 bool wait, 76 bool sync_ltc_data); 77 78 // Marks an execution step, which allows the tensor framework to understand 79 // the computation boundaries. 80 // Override to supply your own DeviceContextArena. 81 virtual void MarkStep(const BackendDevice& device); 82 83 // Waits for all the outstanding operations on all the supplied devices. 84 // If devices is empty, the wait will happen for all local devices. 85 void WaitDeviceOps(c10::ArrayRef<BackendDevice> devices); 86 87 // Retrieves the PyTorch CPU tensors behind the lazy tensors IR operations. 88 // All the tensors must be on the same device. 89 std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors); 90 91 size_t IncTrimCounter() const; 92 93 // Dumps the backend specific text of the computation accumulated in the graph 94 // which is attached the tensors. 95 std::string DumpBackendComputation(const std::vector<LazyTensorPtr>& tensors); 96 97 Value GetDeviceDataIrValue( 98 const at::Scalar& value, 99 c10::ScalarType type, 100 const BackendDevice& device); 101 Value GetIrValueForScalar( 102 const at::Scalar& value, 103 c10::ScalarType type, 104 const BackendDevice& device); 105 Value GetIrValueForScalar( 106 const at::Scalar& value, 107 const BackendDevice& device); 108 109 // TODO: even though this API is currently used **only** in codegen to 110 // generate real scalar IR values vs scalar tensors, we would like to 111 // use it in other cases where `GetIrValueForXXXScalar` is used, as well 112 // In order to do that, we need to untangle the cases where we don't need 113 // `expand` and where we don't expect a scalar tensor 114 Value GetIrValueForScalarFromCodegen( 115 const at::Scalar& value, 116 const BackendDevice& device); 117 Value GetIrValueForExpandedScalar( 118 const at::Scalar& value, 119 const Shape& shape, 120 const BackendDevice& device); 121 122 struct CachedComputation { CachedComputationCachedComputation123 explicit CachedComputation(ComputationPtr computation) 124 : computation(std::move(computation)) {} 125 126 ComputationPtr computation; 127 }; 128 129 using ComputationCache = Cache<hash_t, CachedComputation, HashReducer>; 130 131 ComputationCache* GetComputationCache(); 132 133 hash_t GetGraphHash(const std::vector<LazyTensorPtr>& tensors); 134 135 protected: 136 // TODO(alanwaketan): Revisit if all of them need to be accessible to 137 // derived classes. 138 139 struct SyncTensorsConfig { 140 // Whether we want to force data on the target tensors (hence trimming 141 // the IR graph above them). 142 bool force_ltc_data = true; 143 // Whether when setting the data, the other properties of the tensor 144 // state should be reset. 145 bool sync_ltc_data = true; 146 }; 147 148 struct SyncTensorCollection { SyncTensorCollectionSyncTensorCollection149 SyncTensorCollection() : hash(0) {} 150 151 SyncTensorsConfig config; 152 std::vector<size_t> indices; 153 hash_t hash; 154 std::vector<ExceptionCleanup> unlocker; 155 BackendDevice device; 156 }; 157 158 struct PostOrderData { 159 std::vector<const Node*> post_order; 160 Util::EmissionMap emission_map; 161 std::vector<BackendDataPtr> parameters_data; 162 std::vector<size_t> parameter_sequence; 163 }; 164 165 // Locking: 166 // We perform two kinds of operations of tensors, synchronous and 167 // asynchronous. The ApplyPendingGraph() are synchronous, as we need the 168 // device data result immediately. Before the synchronous operations can 169 // start, they need to wait that the pending asynchronous operations have 170 // completed. Synchronous operations do not hold device locks, since they are 171 // strictly sequential, dictated by the PyTorch execution order. The 172 // SyncTensorsGraph() is asynchronous, and returns immediately after having 173 // scheduled the asynchronous operation. While executing, the asynchronous 174 // operations will hold locks on all the participating devices (in most common 175 // cases there will be only one device). 176 // Since asynchronous operations capture device locks, only one asynchronous 177 // operation can execute at the same time, on a given device. Tensor 178 // operations which send data to device do not need to hold any device locks 179 // while doing so. Only operations which _use_ device data (computations, and 180 // transfer from server) need to wait for asynchronous operations to complete 181 // (barrier). 182 183 class DeviceLocker { 184 public: DeviceLocker(BackendDevice device)185 explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {} 186 device()187 const BackendDevice& device() const { 188 return device_; 189 } 190 191 void Lock(); 192 void Unlock(std::exception_ptr exptr); 193 void Barrier(); 194 195 private: 196 void CheckResetException(); 197 198 BackendDevice device_; 199 std::mutex mutex_; 200 std::condition_variable cv_; 201 bool locked_ = false; 202 std::exception_ptr exptr_; 203 }; 204 205 class DeviceLockerArena { 206 public: 207 static DeviceLockerArena* Get(); 208 209 std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device); 210 211 void DeviceBarrier(const BackendDevice& device); 212 213 // Use a set to impose an order on the device locking sequence (ABBA 214 // prevention). 215 std::vector<ExceptionCleanup> LockDevices( 216 const std::set<BackendDevice>& devices); 217 218 private: 219 ExceptionCleanup LockDevice(const BackendDevice& device); 220 221 std::mutex mutex_; 222 std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_; 223 }; 224 225 class DataCacheArena { 226 public: 227 static DataCacheArena* Get(); 228 229 BackendDataPtr GetDeviceData( 230 const at::Tensor& tensor, 231 const BackendDevice& device); 232 233 BackendDataPtr GetDeviceData( 234 const at::Scalar& value, 235 at::ScalarType scalar_type, 236 const BackendDevice& device); 237 238 private: 239 struct TensorHasher { 240 size_t operator()(const at::Tensor& tensor) const; 241 }; 242 struct TensorComparer { 243 bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2) 244 const; 245 }; 246 247 explicit DataCacheArena(size_t max_cache_size); 248 249 using DataCache = 250 Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>; 251 252 DataCache* GetDataCache(const BackendDevice& device); 253 254 size_t max_cache_size_ = 0; 255 std::mutex mutex_; 256 std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_; 257 }; 258 259 // The DeviceContextArena holds per device live information and statistics, 260 // among which the lazy tensors which are currently alive in the system. This 261 // is used to create computation "barriers" in order to flush pending 262 // operations and ensure the same computations are created during the training 263 // loops. 264 // TODO(alanwaketan): Add a registry such that we don't need to make all 265 // related methods virtual. 266 class DeviceContextArena { 267 protected: 268 struct DeviceContext { 269 std::mutex lock; 270 std::map<int64_t, std::weak_ptr<LazyTensor::Data>> tensors_data; 271 uint64_t seed = 101; 272 uint64_t running_seed = 101; 273 Value seed_ir_value; 274 }; 275 276 public: 277 static DeviceContextArena* Get(); 278 virtual ~DeviceContextArena() = default; 279 280 void RegisterTensor(std::shared_ptr<LazyTensor::Data> data); 281 void UnregisterTensor(LazyTensor::Data* data); 282 283 std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device); 284 285 // Overriding it allow derived class to use their own IRs for Value. 286 virtual Value GetRngSeed(const BackendDevice& device); 287 uint64_t GetRunningSeed(const BackendDevice& device); 288 void SetRngSeed(const BackendDevice& device, uint64_t seed); 289 290 void MarkStep(const BackendDevice& device); 291 292 std::vector<BackendDevice> GetActiveDevices(); 293 294 protected: 295 DeviceContext* GetDeviceContext(const BackendDevice& device); 296 297 void ForAllDeviceContexts( 298 const std::function<void(DeviceContext*)>& fn, 299 const BackendDevice* device); 300 301 // Overriding it allow derived class to use their own conversions. 302 virtual Value IrValueFromScalar( 303 const at::Scalar& value, 304 at::ScalarType scalar_type, 305 const BackendDevice& device); 306 307 private: 308 std::vector<DeviceContext*> GetAllDeviceContexts(); 309 310 std::mutex lock_; 311 std::map<BackendDevice, DeviceContext*> device_contexts_; 312 }; 313 314 struct Async { 315 Async( 316 SyncTensorCollection* coll, 317 std::vector<BackendDataPtr> parameters_data, 318 std::vector<BackendDataPtr> tensors_data, 319 ComputationCache::TypePtr cached_computation); 320 virtual ~Async() = default; 321 322 void Wait(); 323 324 MultiWait mwait; 325 std::vector<size_t> indices; 326 std::vector<ExceptionCleanup> unlocker; 327 std::vector<BackendDataPtr> parameters_data; 328 BackendDevice device; 329 ComputationCache::TypePtr cached_computation; 330 std::vector<BackendDataPtr> tensors_data; 331 }; 332 333 void ResetTrimCounter() const; 334 335 // Waits for this SyncTensorCollection's device barrier and acquire the lock. 336 virtual void TensorCollectionBarrier(SyncTensorCollection* coll); 337 338 // One can override to insert your own profiler. 339 virtual PostOrderData RunPostOrder( 340 const std::vector<Value>& ir_values, 341 SyncTensorCollection* coll); 342 343 private: 344 struct CompilationResult { 345 BackendDevice device; 346 size_t emitted_nodes = 0; 347 ComputationPtr computation; 348 std::vector<BackendDataPtr> parameters_data; 349 }; 350 351 virtual bool ShouldSyncTensor(const LazyTensorPtr& tensor) const; 352 353 SyncTensorCollection CollectSyncTensors( 354 const std::vector<LazyTensorPtr>& tensors, 355 const SyncTensorsConfig& config); 356 357 std::vector<Value> CollectRoots( 358 const std::vector<LazyTensorPtr>& tensors, 359 c10::ArrayRef<size_t> indices); 360 361 std::vector<BackendDataPtr> SetTensorData( 362 std::vector<LazyTensorPtr>* tensors, 363 const SyncTensorsConfig& config, 364 c10::ArrayRef<size_t> indices, 365 const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec); 366 367 void ExtractIRAndPrepareTensorData( 368 std::vector<LazyTensorPtr>* tensors, 369 const SyncTensorsConfig& config, 370 c10::ArrayRef<size_t> indices, 371 std::vector<Value>& ir_values, 372 std::vector<BackendDataPtr>& tensor_data_vec); 373 374 std::shared_ptr<Async> TryRunCachedSync( 375 std::vector<LazyTensorPtr>* tensors, 376 SyncTensorCollection* coll, 377 PostOrderData* po_data, 378 const std::vector<BackendDataPtr>& tensor_data_vec); 379 380 CompilationResult Compile( 381 const std::vector<LazyTensorPtr>& tensors, 382 c10::ArrayRef<std::string> devices, 383 const SyncTensorCollection& coll, 384 PostOrderData* po_data, 385 const std::vector<Value>& ir_values); 386 387 ComputationCache::TypePtr LookupCachedCompile(const hash_t& hash); 388 389 std::shared_ptr<Async> SyncTensorsGraphInternal( 390 std::vector<LazyTensorPtr>* tensors, 391 c10::ArrayRef<std::string> devices, 392 const SyncTensorsConfig& config); 393 394 // Schedules the execution of a sync tensors operation in background. The 395 // asynchronous operation will hold the device locks by capturing the ones 396 // present within the coll structure. 397 std::shared_ptr<Async> ScheduleSyncTensorsGraph( 398 SyncTensorCollection* coll, 399 std::vector<BackendDataPtr> parameters_data, 400 std::vector<BackendDataPtr> tensors_data, 401 ComputationCache::TypePtr cached_computation); 402 403 std::shared_ptr<Async> ScheduleSyncTensorsGraph( 404 std::vector<LazyTensorPtr>* tensors, 405 SyncTensorCollection* coll, 406 std::vector<BackendDataPtr> parameters_data, 407 ComputationCache::TypePtr cached_computation, 408 const std::vector<BackendDataPtr>& tensor_data_vec); 409 410 std::vector<at::Tensor> GetTensorsFused(std::vector<LazyTensorPtr>* tensors); 411 412 std::vector<at::Tensor> FetchTensors( 413 std::vector<LazyTensorPtr>* tensors, 414 c10::ArrayRef<BackendDataPtr> tensors_data, 415 const std::vector<size_t>* indices); 416 417 // Gathers the device data for all the input tensors, after an 418 // asynchronous operation. 419 std::vector<BackendDataPtr> GatherTensorsData( 420 const std::vector<LazyTensorPtr>& tensors, 421 c10::ArrayRef<size_t> indices, 422 c10::ArrayRef<BackendDataPtr> tensors_data); 423 }; 424 425 } // namespace lazy 426 } // namespace torch 427