xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/lazy_graph_executor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <torch/csrc/lazy/backend/lowering_context.h>
5 #include <torch/csrc/lazy/core/cache.h>
6 #include <torch/csrc/lazy/core/ir_util.h>
7 #include <torch/csrc/lazy/core/multi_wait.h>
8 #include <torch/csrc/lazy/core/tensor.h>
9 #include <torch/csrc/lazy/core/util.h>
10 
11 namespace torch {
12 namespace lazy {
13 
14 class TORCH_API LazyGraphExecutor {
15  public:
16   struct DeviceDataInfo : public BackendData::Info {
DeviceDataInfoDeviceDataInfo17     DeviceDataInfo(int64_t tensor_id, bool read_only)
18         : tensor_id(tensor_id), read_only(read_only) {}
19 
20     int64_t tensor_id = 0;
21     bool read_only = false;
22   };
23 
24   // Register a lazy graph executor instance that can be retrieved using Get()
25   static void Register(LazyGraphExecutor*);
26   static LazyGraphExecutor* Get();
27 
28   virtual ~LazyGraphExecutor() = default;
29 
30   // Override these methods to perform custom tensor registration and
31   // unregistration Note: It is vital that the parent implementations are also
32   // called in order for the tensors to show up in the live tensor list
33   virtual void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
34   virtual void UnregisterTensor(LazyTensor::Data* data);
35 
36   // Seed for random generator.
37   // Override to supply your own DeviceContextArena.
38   virtual Value GetRngSeed(const BackendDevice& device);
39   virtual uint64_t GetRunningSeed(const BackendDevice& device);
40   virtual void SetRngSeed(const BackendDevice& device, uint64_t seed);
41 
42   void DeviceBarrier(const BackendDevice& device);
43 
44   BackendDataPtr GetDeviceData(
45       const at::Tensor& tensor,
46       const BackendDevice& device);
47 
48   BackendDataPtr GetDeviceData(
49       const at::Scalar& value,
50       at::ScalarType scalar_type,
51       const BackendDevice& device);
52 
53   // Retrieves the set of lazy tensors which are currently live in the system,
54   // for the given device. If device is nullptr, the live tensors for all
55   // devices will be returned. Returned tensors are sorted by device as primary
56   // key, and by unique ID as secondary key.
57   std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device);
58 
59   // Makes sure that any outstanding IR operation accumulated over live tensors,
60   // gets turned into device data. If wait is true, the sync operation will be
61   // run synchronously. The devices argument, if not empty, tells the devices
62   // which should be partecipating into the replicated computation.
63   virtual void SyncLiveTensorsGraph(
64       const BackendDevice* device,
65       c10::ArrayRef<std::string> devices,
66       bool wait);
67 
68   // Applies all the pending IR operations queued over the input tensors. All
69   // the tensors must be on the same device. If wait is true, the sync operation
70   // will be run synchronously. The devices argument, if not empty, tells the
71   // devices which should be partecipating into the replicated computation.
72   void SyncTensorsGraph(
73       std::vector<LazyTensorPtr>* tensors,
74       c10::ArrayRef<std::string> devices,
75       bool wait,
76       bool sync_ltc_data);
77 
78   // Marks an execution step, which allows the tensor framework to understand
79   // the computation boundaries.
80   // Override to supply your own DeviceContextArena.
81   virtual void MarkStep(const BackendDevice& device);
82 
83   // Waits for all the outstanding operations on all the supplied devices.
84   // If devices is empty, the wait will happen for all local devices.
85   void WaitDeviceOps(c10::ArrayRef<BackendDevice> devices);
86 
87   // Retrieves the PyTorch CPU tensors behind the lazy tensors IR operations.
88   // All the tensors must be on the same device.
89   std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors);
90 
91   size_t IncTrimCounter() const;
92 
93   // Dumps the backend specific text of the computation accumulated in the graph
94   // which is attached the tensors.
95   std::string DumpBackendComputation(const std::vector<LazyTensorPtr>& tensors);
96 
97   Value GetDeviceDataIrValue(
98       const at::Scalar& value,
99       c10::ScalarType type,
100       const BackendDevice& device);
101   Value GetIrValueForScalar(
102       const at::Scalar& value,
103       c10::ScalarType type,
104       const BackendDevice& device);
105   Value GetIrValueForScalar(
106       const at::Scalar& value,
107       const BackendDevice& device);
108 
109   // TODO: even though this API is currently used **only** in codegen to
110   // generate real scalar IR values vs scalar tensors, we would like to
111   // use it in other cases where `GetIrValueForXXXScalar` is used, as well
112   // In order to do that, we need to untangle the cases where we don't need
113   // `expand` and where we don't expect a scalar tensor
114   Value GetIrValueForScalarFromCodegen(
115       const at::Scalar& value,
116       const BackendDevice& device);
117   Value GetIrValueForExpandedScalar(
118       const at::Scalar& value,
119       const Shape& shape,
120       const BackendDevice& device);
121 
122   struct CachedComputation {
CachedComputationCachedComputation123     explicit CachedComputation(ComputationPtr computation)
124         : computation(std::move(computation)) {}
125 
126     ComputationPtr computation;
127   };
128 
129   using ComputationCache = Cache<hash_t, CachedComputation, HashReducer>;
130 
131   ComputationCache* GetComputationCache();
132 
133   hash_t GetGraphHash(const std::vector<LazyTensorPtr>& tensors);
134 
135  protected:
136   // TODO(alanwaketan): Revisit if all of them need to be accessible to
137   // derived classes.
138 
139   struct SyncTensorsConfig {
140     // Whether we want to force data on the target tensors (hence trimming
141     // the IR graph above them).
142     bool force_ltc_data = true;
143     // Whether when setting the data, the other properties of the tensor
144     // state should be reset.
145     bool sync_ltc_data = true;
146   };
147 
148   struct SyncTensorCollection {
SyncTensorCollectionSyncTensorCollection149     SyncTensorCollection() : hash(0) {}
150 
151     SyncTensorsConfig config;
152     std::vector<size_t> indices;
153     hash_t hash;
154     std::vector<ExceptionCleanup> unlocker;
155     BackendDevice device;
156   };
157 
158   struct PostOrderData {
159     std::vector<const Node*> post_order;
160     Util::EmissionMap emission_map;
161     std::vector<BackendDataPtr> parameters_data;
162     std::vector<size_t> parameter_sequence;
163   };
164 
165   // Locking:
166   // We perform two kinds of operations of tensors, synchronous and
167   // asynchronous. The ApplyPendingGraph() are synchronous, as we need the
168   // device data result immediately. Before the synchronous operations can
169   // start, they need to wait that the pending asynchronous operations have
170   // completed. Synchronous operations do not hold device locks, since they are
171   // strictly sequential, dictated by the PyTorch execution order. The
172   // SyncTensorsGraph() is asynchronous, and returns immediately after having
173   // scheduled the asynchronous operation. While executing, the asynchronous
174   // operations will hold locks on all the participating devices (in most common
175   // cases there will be only one device).
176   // Since asynchronous operations capture device locks, only one asynchronous
177   // operation can execute at the same time, on a given device. Tensor
178   // operations which send data to device do not need to hold any device locks
179   // while doing so. Only operations which _use_ device data (computations, and
180   // transfer from server) need to wait for asynchronous operations to complete
181   // (barrier).
182 
183   class DeviceLocker {
184    public:
DeviceLocker(BackendDevice device)185     explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
186 
device()187     const BackendDevice& device() const {
188       return device_;
189     }
190 
191     void Lock();
192     void Unlock(std::exception_ptr exptr);
193     void Barrier();
194 
195    private:
196     void CheckResetException();
197 
198     BackendDevice device_;
199     std::mutex mutex_;
200     std::condition_variable cv_;
201     bool locked_ = false;
202     std::exception_ptr exptr_;
203   };
204 
205   class DeviceLockerArena {
206    public:
207     static DeviceLockerArena* Get();
208 
209     std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device);
210 
211     void DeviceBarrier(const BackendDevice& device);
212 
213     // Use a set to impose an order on the device locking sequence (ABBA
214     // prevention).
215     std::vector<ExceptionCleanup> LockDevices(
216         const std::set<BackendDevice>& devices);
217 
218    private:
219     ExceptionCleanup LockDevice(const BackendDevice& device);
220 
221     std::mutex mutex_;
222     std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
223   };
224 
225   class DataCacheArena {
226    public:
227     static DataCacheArena* Get();
228 
229     BackendDataPtr GetDeviceData(
230         const at::Tensor& tensor,
231         const BackendDevice& device);
232 
233     BackendDataPtr GetDeviceData(
234         const at::Scalar& value,
235         at::ScalarType scalar_type,
236         const BackendDevice& device);
237 
238    private:
239     struct TensorHasher {
240       size_t operator()(const at::Tensor& tensor) const;
241     };
242     struct TensorComparer {
243       bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
244           const;
245     };
246 
247     explicit DataCacheArena(size_t max_cache_size);
248 
249     using DataCache =
250         Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
251 
252     DataCache* GetDataCache(const BackendDevice& device);
253 
254     size_t max_cache_size_ = 0;
255     std::mutex mutex_;
256     std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
257   };
258 
259   // The DeviceContextArena holds per device live information and statistics,
260   // among which the lazy tensors which are currently alive in the system. This
261   // is used to create computation "barriers" in order to flush pending
262   // operations and ensure the same computations are created during the training
263   // loops.
264   // TODO(alanwaketan): Add a registry such that we don't need to make all
265   // related methods virtual.
266   class DeviceContextArena {
267    protected:
268     struct DeviceContext {
269       std::mutex lock;
270       std::map<int64_t, std::weak_ptr<LazyTensor::Data>> tensors_data;
271       uint64_t seed = 101;
272       uint64_t running_seed = 101;
273       Value seed_ir_value;
274     };
275 
276    public:
277     static DeviceContextArena* Get();
278     virtual ~DeviceContextArena() = default;
279 
280     void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
281     void UnregisterTensor(LazyTensor::Data* data);
282 
283     std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device);
284 
285     // Overriding it allow derived class to use their own IRs for Value.
286     virtual Value GetRngSeed(const BackendDevice& device);
287     uint64_t GetRunningSeed(const BackendDevice& device);
288     void SetRngSeed(const BackendDevice& device, uint64_t seed);
289 
290     void MarkStep(const BackendDevice& device);
291 
292     std::vector<BackendDevice> GetActiveDevices();
293 
294    protected:
295     DeviceContext* GetDeviceContext(const BackendDevice& device);
296 
297     void ForAllDeviceContexts(
298         const std::function<void(DeviceContext*)>& fn,
299         const BackendDevice* device);
300 
301     // Overriding it allow derived class to use their own conversions.
302     virtual Value IrValueFromScalar(
303         const at::Scalar& value,
304         at::ScalarType scalar_type,
305         const BackendDevice& device);
306 
307    private:
308     std::vector<DeviceContext*> GetAllDeviceContexts();
309 
310     std::mutex lock_;
311     std::map<BackendDevice, DeviceContext*> device_contexts_;
312   };
313 
314   struct Async {
315     Async(
316         SyncTensorCollection* coll,
317         std::vector<BackendDataPtr> parameters_data,
318         std::vector<BackendDataPtr> tensors_data,
319         ComputationCache::TypePtr cached_computation);
320     virtual ~Async() = default;
321 
322     void Wait();
323 
324     MultiWait mwait;
325     std::vector<size_t> indices;
326     std::vector<ExceptionCleanup> unlocker;
327     std::vector<BackendDataPtr> parameters_data;
328     BackendDevice device;
329     ComputationCache::TypePtr cached_computation;
330     std::vector<BackendDataPtr> tensors_data;
331   };
332 
333   void ResetTrimCounter() const;
334 
335   // Waits for this SyncTensorCollection's device barrier and acquire the lock.
336   virtual void TensorCollectionBarrier(SyncTensorCollection* coll);
337 
338   // One can override to insert your own profiler.
339   virtual PostOrderData RunPostOrder(
340       const std::vector<Value>& ir_values,
341       SyncTensorCollection* coll);
342 
343  private:
344   struct CompilationResult {
345     BackendDevice device;
346     size_t emitted_nodes = 0;
347     ComputationPtr computation;
348     std::vector<BackendDataPtr> parameters_data;
349   };
350 
351   virtual bool ShouldSyncTensor(const LazyTensorPtr& tensor) const;
352 
353   SyncTensorCollection CollectSyncTensors(
354       const std::vector<LazyTensorPtr>& tensors,
355       const SyncTensorsConfig& config);
356 
357   std::vector<Value> CollectRoots(
358       const std::vector<LazyTensorPtr>& tensors,
359       c10::ArrayRef<size_t> indices);
360 
361   std::vector<BackendDataPtr> SetTensorData(
362       std::vector<LazyTensorPtr>* tensors,
363       const SyncTensorsConfig& config,
364       c10::ArrayRef<size_t> indices,
365       const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec);
366 
367   void ExtractIRAndPrepareTensorData(
368       std::vector<LazyTensorPtr>* tensors,
369       const SyncTensorsConfig& config,
370       c10::ArrayRef<size_t> indices,
371       std::vector<Value>& ir_values,
372       std::vector<BackendDataPtr>& tensor_data_vec);
373 
374   std::shared_ptr<Async> TryRunCachedSync(
375       std::vector<LazyTensorPtr>* tensors,
376       SyncTensorCollection* coll,
377       PostOrderData* po_data,
378       const std::vector<BackendDataPtr>& tensor_data_vec);
379 
380   CompilationResult Compile(
381       const std::vector<LazyTensorPtr>& tensors,
382       c10::ArrayRef<std::string> devices,
383       const SyncTensorCollection& coll,
384       PostOrderData* po_data,
385       const std::vector<Value>& ir_values);
386 
387   ComputationCache::TypePtr LookupCachedCompile(const hash_t& hash);
388 
389   std::shared_ptr<Async> SyncTensorsGraphInternal(
390       std::vector<LazyTensorPtr>* tensors,
391       c10::ArrayRef<std::string> devices,
392       const SyncTensorsConfig& config);
393 
394   // Schedules the execution of a sync tensors operation in background. The
395   // asynchronous operation will hold the device locks by capturing the ones
396   // present within the coll structure.
397   std::shared_ptr<Async> ScheduleSyncTensorsGraph(
398       SyncTensorCollection* coll,
399       std::vector<BackendDataPtr> parameters_data,
400       std::vector<BackendDataPtr> tensors_data,
401       ComputationCache::TypePtr cached_computation);
402 
403   std::shared_ptr<Async> ScheduleSyncTensorsGraph(
404       std::vector<LazyTensorPtr>* tensors,
405       SyncTensorCollection* coll,
406       std::vector<BackendDataPtr> parameters_data,
407       ComputationCache::TypePtr cached_computation,
408       const std::vector<BackendDataPtr>& tensor_data_vec);
409 
410   std::vector<at::Tensor> GetTensorsFused(std::vector<LazyTensorPtr>* tensors);
411 
412   std::vector<at::Tensor> FetchTensors(
413       std::vector<LazyTensorPtr>* tensors,
414       c10::ArrayRef<BackendDataPtr> tensors_data,
415       const std::vector<size_t>* indices);
416 
417   // Gathers the device data for all the input tensors, after an
418   // asynchronous operation.
419   std::vector<BackendDataPtr> GatherTensorsData(
420       const std::vector<LazyTensorPtr>& tensors,
421       c10::ArrayRef<size_t> indices,
422       c10::ArrayRef<BackendDataPtr> tensors_data);
423 };
424 
425 } // namespace lazy
426 } // namespace torch
427