xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/lazy_graph_executor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
2 
3 #include <ATen/ScalarOps.h>
4 #include <c10/util/Logging.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/lazy/core/config.h>
8 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
9 #include <torch/csrc/lazy/core/ir_dump_util.h>
10 #include <torch/csrc/lazy/core/ir_util.h>
11 #include <torch/csrc/lazy/core/tensor_util.h>
12 #include <torch/csrc/lazy/core/unique.h>
13 
14 #include <torch/csrc/lazy/core/debug_util.h>
15 #include <torch/csrc/lazy/core/ir_builder.h>
16 #include <torch/csrc/lazy/core/metrics.h>
17 #include <torch/csrc/lazy/core/ops/arithmetic_ir_ops.h>
18 #include <torch/csrc/lazy/core/thread_pool.h>
19 
20 #include <ATen/ScalarOps.h>
21 
22 namespace torch {
23 namespace lazy {
24 namespace {
25 
26 struct TlsData {
Resettorch::lazy::__anonff7fad9f0111::TlsData27   void Reset() {
28     trim_counter = 0;
29   }
30 
31   size_t trim_counter = 0;
32 };
33 
34 thread_local TlsData g_tls_data;
35 
TensorCompare(const at::Tensor & t1,const at::Tensor & t2)36 bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2) {
37   if (t1.scalar_type() != t2.scalar_type() || t1.sizes() != t2.sizes()) {
38     return false;
39   }
40   // PyTorch currently has an issue comparing tensors which have NaN values in
41   // it. The compare is not deterministic. So we do memory compare here until
42   // the PyTorch equal() API is fixed.
43   at::Tensor contiguous_t1 = t1.contiguous();
44   at::Tensor contiguous_t2 = t2.contiguous();
45   return std::memcmp(
46              contiguous_t1.data_ptr(),
47              contiguous_t2.data_ptr(),
48              contiguous_t1.numel() * contiguous_t1.itemsize()) == 0;
49 }
50 
51 // Return true if any tensor in the list has an underlying IR (leaf or
52 // operation).
TensorsHaveIR(const std::vector<LazyTensorPtr> & tensors)53 bool TensorsHaveIR(const std::vector<LazyTensorPtr>& tensors) {
54   for (const auto& tensor : tensors) {
55     if (tensor->CurrentDataHandle() || tensor->CurrentIrValue()) {
56       return true;
57     }
58   }
59   return false;
60 }
61 
62 std::atomic<LazyGraphExecutor*> lazy_graph_executor_registry;
63 } // namespace
64 
Get()65 auto LazyGraphExecutor::DeviceContextArena::Get()
66     -> LazyGraphExecutor::DeviceContextArena* {
67   static DeviceContextArena* arena = new DeviceContextArena();
68   return arena;
69 }
70 
RegisterTensor(std::shared_ptr<LazyTensor::Data> data)71 void LazyGraphExecutor::DeviceContextArena::RegisterTensor(
72     std::shared_ptr<LazyTensor::Data> data) {
73   DeviceContext* devctx = GetDeviceContext(data->device);
74   std::lock_guard<std::mutex> lock(devctx->lock);
75   devctx->tensors_data.emplace(data->unique_id, data);
76 }
77 
UnregisterTensor(LazyTensor::Data * data)78 void LazyGraphExecutor::DeviceContextArena::UnregisterTensor(
79     LazyTensor::Data* data) {
80   DeviceContext* devctx = GetDeviceContext(data->device);
81   std::lock_guard<std::mutex> lock(devctx->lock);
82   devctx->tensors_data.erase(data->unique_id);
83 }
84 
85 std::vector<LazyTensorPtr> LazyGraphExecutor::DeviceContextArena::
GetLiveTensors(const BackendDevice * device)86     GetLiveTensors(const BackendDevice* device) {
87   std::vector<LazyTensorPtr> tensors;
88   auto fn = [&](DeviceContext* devctx) {
89     std::lock_guard<std::mutex> lock(devctx->lock);
90     for (auto& uid_wptr : devctx->tensors_data) {
91       std::shared_ptr<LazyTensor::Data> data = uid_wptr.second.lock();
92       if (data != nullptr) {
93         tensors.push_back(LazyTensor::Create(std::move(data)));
94       }
95     }
96   };
97   ForAllDeviceContexts(fn, device);
98   return tensors;
99 }
100 
GetRngSeed(const BackendDevice & device)101 Value LazyGraphExecutor::DeviceContextArena::GetRngSeed(
102     const BackendDevice& device) {
103   static const at::ScalarType kSeedType = at::ScalarType::Long;
104   static const uint64_t kSeedMul = 214013;
105   static const uint64_t kSeedAdd = 2531011;
106   DeviceContext* devctx = GetDeviceContext(device);
107   std::lock_guard<std::mutex> lock(devctx->lock);
108   if (!devctx->seed_ir_value) {
109     devctx->seed_ir_value =
110         IrValueFromScalar(MakeIntScalar(devctx->seed), kSeedType, device);
111   }
112   // Keep the running seed as scalar as well, so we can return it directly
113   // without executing graphs.
114   devctx->running_seed = kSeedAdd + kSeedMul * devctx->running_seed;
115   // Compose new seeds from the root seed, to avoid creating too many
116   // computation parameters which might overflow the device capacity.
117   Value k = MakeScalar(MakeIntScalar(kSeedMul), kSeedType);
118   Value b = MakeScalar(MakeIntScalar(kSeedAdd), kSeedType);
119   devctx->seed_ir_value = b + k * devctx->seed_ir_value;
120   return devctx->seed_ir_value;
121 }
122 
GetRunningSeed(const BackendDevice & device)123 uint64_t LazyGraphExecutor::DeviceContextArena::GetRunningSeed(
124     const BackendDevice& device) {
125   DeviceContext* devctx = GetDeviceContext(device);
126   std::lock_guard<std::mutex> lock(devctx->lock);
127   return devctx->running_seed;
128 }
129 
SetRngSeed(const BackendDevice & device,uint64_t seed)130 void LazyGraphExecutor::DeviceContextArena::SetRngSeed(
131     const BackendDevice& device,
132     uint64_t seed) {
133   DeviceContext* devctx = GetDeviceContext(device);
134   std::lock_guard<std::mutex> lock(devctx->lock);
135   devctx->seed = seed;
136   devctx->running_seed = devctx->seed;
137   devctx->seed_ir_value = Value();
138 }
139 
MarkStep(const BackendDevice & device)140 void LazyGraphExecutor::DeviceContextArena::MarkStep(
141     const BackendDevice& device) {
142   DeviceContext* devctx = GetDeviceContext(device);
143   std::lock_guard<std::mutex> lock(devctx->lock);
144   devctx->seed = 1012031 + devctx->seed * 7012063;
145   devctx->running_seed = devctx->seed;
146   devctx->seed_ir_value = Value();
147 }
148 
149 std::vector<BackendDevice> LazyGraphExecutor::DeviceContextArena::
GetActiveDevices()150     GetActiveDevices() {
151   std::vector<BackendDevice> active_devices;
152   std::lock_guard<std::mutex> lock(lock_);
153   active_devices.reserve(device_contexts_.size());
154   for (auto& device_contexts : device_contexts_) {
155     active_devices.push_back(device_contexts.first);
156   }
157   return active_devices;
158 }
159 
GetAllDeviceContexts()160 auto LazyGraphExecutor::DeviceContextArena::GetAllDeviceContexts()
161     -> std::vector<DeviceContext*> {
162   std::vector<DeviceContext*> all_device_contexts;
163   std::lock_guard<std::mutex> lock(lock_);
164   all_device_contexts.reserve(device_contexts_.size());
165   for (auto& device_contexts : device_contexts_) {
166     all_device_contexts.push_back(device_contexts.second);
167   }
168   return all_device_contexts;
169 }
170 
ForAllDeviceContexts(const std::function<void (DeviceContext *)> & fn,const BackendDevice * device)171 void LazyGraphExecutor::DeviceContextArena::ForAllDeviceContexts(
172     const std::function<void(DeviceContext*)>& fn,
173     const BackendDevice* device) {
174   if (device == nullptr) {
175     for (auto devctx : GetAllDeviceContexts()) {
176       fn(devctx);
177     }
178   } else {
179     fn(GetDeviceContext(*device));
180   }
181 }
182 
GetDeviceContext(const BackendDevice & device)183 auto LazyGraphExecutor::DeviceContextArena::GetDeviceContext(
184     const BackendDevice& device) -> DeviceContext* {
185   std::lock_guard<std::mutex> lock(lock_);
186   auto it = device_contexts_.find(device);
187   if (it == device_contexts_.end()) {
188     it = device_contexts_.emplace(device, new DeviceContext()).first;
189   }
190   return it->second;
191 }
192 
IrValueFromScalar(const at::Scalar & value,at::ScalarType scalar_type,const BackendDevice & device)193 Value LazyGraphExecutor::DeviceContextArena::IrValueFromScalar(
194     const at::Scalar& value,
195     at::ScalarType scalar_type,
196     const BackendDevice& device) {
197   at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type));
198   BackendDataPtr device_data = TensorToDataHandle(tensor, device);
199   return MakeDeviceData(std::move(device_data));
200 }
201 
Lock()202 void LazyGraphExecutor::DeviceLocker::Lock() {
203   std::unique_lock<std::mutex> lock(mutex_);
204   cv_.wait(lock, [this] { return !locked_; });
205   CheckResetException();
206   locked_ = true;
207 }
208 
Unlock(std::exception_ptr exptr)209 void LazyGraphExecutor::DeviceLocker::Unlock(std::exception_ptr exptr) {
210   std::lock_guard<std::mutex> lock(mutex_);
211   locked_ = false;
212   exptr_ = std::move(exptr);
213   cv_.notify_all();
214 }
215 
Barrier()216 void LazyGraphExecutor::DeviceLocker::Barrier() {
217   std::unique_lock<std::mutex> lock(mutex_);
218   cv_.wait(lock, [this] { return !locked_; });
219   cv_.notify_all();
220   CheckResetException();
221 }
222 
CheckResetException()223 void LazyGraphExecutor::DeviceLocker::CheckResetException() {
224   std::exception_ptr exptr = std::move(exptr_);
225   exptr_ = nullptr;
226   if (exptr != nullptr) {
227     std::rethrow_exception(exptr);
228   }
229 }
230 
Get()231 auto LazyGraphExecutor::DeviceLockerArena::Get() -> DeviceLockerArena* {
232   static DeviceLockerArena* arena = new DeviceLockerArena();
233   return arena;
234 }
235 
GetLocker(const BackendDevice & device)236 auto LazyGraphExecutor::DeviceLockerArena::GetLocker(
237     const BackendDevice& device) -> std::shared_ptr<DeviceLocker> {
238   std::lock_guard<std::mutex> lock(mutex_);
239   auto it = lockers_.find(device);
240   if (it == lockers_.end()) {
241     it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device)).first;
242   }
243   return it->second;
244 }
245 
DeviceBarrier(const BackendDevice & device)246 void LazyGraphExecutor::DeviceLockerArena::DeviceBarrier(
247     const BackendDevice& device) {
248   auto locker = DeviceLockerArena::Get()->GetLocker(device);
249   locker->Barrier();
250 }
251 
LockDevices(const std::set<BackendDevice> & devices)252 std::vector<ExceptionCleanup> LazyGraphExecutor::DeviceLockerArena::LockDevices(
253     const std::set<BackendDevice>& devices) {
254   std::vector<ExceptionCleanup> unlocker;
255   unlocker.reserve(devices.size());
256   for (auto& device : devices) {
257     unlocker.emplace_back(LockDevice(device));
258   }
259   return unlocker;
260 }
261 
LockDevice(const BackendDevice & device)262 ExceptionCleanup LazyGraphExecutor::DeviceLockerArena::LockDevice(
263     const BackendDevice& device) {
264   VLOG(4) << "Waiting on device barrier for device " << device << " ...";
265   std::shared_ptr<DeviceLocker> locker;
266   {
267     TORCH_LAZY_TIMED("DeviceLockWait");
268     locker = DeviceLockerArena::Get()->GetLocker(device);
269     locker->Lock();
270   }
271   VLOG(4) << "Waiting on device barrier for device " << device << " done!";
272   return torch::lazy::ExceptionCleanup(
273       [locker = std::move(locker)](
274           torch::lazy::ExceptionCleanup::StatusType status) {
275         locker->Unlock(std::move(status));
276       });
277 }
278 
Get()279 auto LazyGraphExecutor::DataCacheArena::Get() -> DataCacheArena* {
280   static DataCacheArena* arena =
281       new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
282   return arena;
283 }
284 
DataCacheArena(size_t max_cache_size)285 LazyGraphExecutor::DataCacheArena::DataCacheArena(size_t max_cache_size)
286     : max_cache_size_(max_cache_size) {}
287 
GetDeviceData(const at::Tensor & tensor,const BackendDevice & device)288 BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
289     const at::Tensor& tensor,
290     const BackendDevice& device) {
291   DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
292   ;
293   BackendDataPtr device_data = cache->Get(tensor);
294   if (device_data == nullptr) {
295     at::Tensor tensor_copy = CopyTensor(tensor);
296     device_data = TensorToDataHandle(tensor_copy, device);
297     cache->Add(std::move(tensor_copy), device_data);
298     TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
299   }
300   return device_data;
301 }
302 
GetDeviceData(const at::Scalar & value,at::ScalarType scalar_type,const BackendDevice & device)303 BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
304     const at::Scalar& value,
305     at::ScalarType scalar_type,
306     const BackendDevice& device) {
307   // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
308   at::Tensor t = at::scalar_tensor(
309       value,
310       at::TensorOptions(
311           scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
312                                                   : scalar_type));
313   if (scalar_type == at::ScalarType::BFloat16) {
314     t = t.to(scalar_type);
315   }
316   return GetDeviceData(t, device);
317 }
318 
operator ()(const at::Tensor & tensor) const319 size_t LazyGraphExecutor::DataCacheArena::TensorHasher::operator()(
320     const at::Tensor& tensor) const {
321   return HashReduce(
322       HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
323 }
324 
operator ()(const at::Tensor & tensor1,const at::Tensor & tensor2) const325 bool LazyGraphExecutor::DataCacheArena::TensorComparer::operator()(
326     const at::Tensor& tensor1,
327     const at::Tensor& tensor2) const {
328   return TensorCompare(tensor1, tensor2);
329 }
330 
GetDataCache(const BackendDevice & device)331 auto LazyGraphExecutor::DataCacheArena::GetDataCache(
332     const BackendDevice& device) -> DataCache* {
333   std::lock_guard<std::mutex> lock(mutex_);
334   if (FLAGS_torch_lazy_enable_device_data_cache) {
335     auto it = device_caches_.find(device);
336     if (it == device_caches_.end()) {
337       it = device_caches_
338                .emplace(device, std::make_unique<DataCache>(max_cache_size_))
339                .first;
340     }
341     return it->second.get();
342   } else {
343     // If cache is disabled then always return a zero size cache
344     static DataCache s_empty_cache(0);
345     return &s_empty_cache;
346   }
347 }
348 
Register(LazyGraphExecutor * executor)349 void LazyGraphExecutor::Register(LazyGraphExecutor* executor) {
350   lazy_graph_executor_registry.store(executor);
351 }
Get()352 LazyGraphExecutor* LazyGraphExecutor::Get() {
353   auto* executor = lazy_graph_executor_registry.load();
354   TORCH_CHECK(executor, "Lazy graph executor not registered.");
355   return executor;
356 }
357 
RegisterTensor(std::shared_ptr<LazyTensor::Data> data)358 void LazyGraphExecutor::RegisterTensor(std::shared_ptr<LazyTensor::Data> data) {
359   DeviceContextArena::Get()->RegisterTensor(data);
360   TORCH_LAZY_COUNTER("CreateLtcTensor", 1);
361 }
362 
UnregisterTensor(LazyTensor::Data * data)363 void LazyGraphExecutor::UnregisterTensor(LazyTensor::Data* data) {
364   DeviceContextArena::Get()->UnregisterTensor(data);
365   TORCH_LAZY_COUNTER("DestroyLtcTensor", 1);
366 }
367 
GetRngSeed(const BackendDevice & device)368 Value LazyGraphExecutor::GetRngSeed(const BackendDevice& device) {
369   return DeviceContextArena::Get()->GetRngSeed(device);
370 }
371 
GetRunningSeed(const BackendDevice & device)372 uint64_t LazyGraphExecutor::GetRunningSeed(const BackendDevice& device) {
373   return DeviceContextArena::Get()->GetRunningSeed(device);
374 }
375 
SetRngSeed(const BackendDevice & device,uint64_t seed)376 void LazyGraphExecutor::SetRngSeed(const BackendDevice& device, uint64_t seed) {
377   DeviceContextArena::Get()->SetRngSeed(device, seed);
378 }
379 
DeviceBarrier(const BackendDevice & device)380 void LazyGraphExecutor::DeviceBarrier(const BackendDevice& device) {
381   DeviceLockerArena::Get()->DeviceBarrier(device);
382 }
383 
GetDeviceData(const at::Tensor & tensor,const BackendDevice & device)384 BackendDataPtr LazyGraphExecutor::GetDeviceData(
385     const at::Tensor& tensor,
386     const BackendDevice& device) {
387   return DataCacheArena::Get()->GetDeviceData(tensor, device);
388 }
389 
GetDeviceData(const at::Scalar & value,at::ScalarType scalar_type,const BackendDevice & device)390 BackendDataPtr LazyGraphExecutor::GetDeviceData(
391     const at::Scalar& value,
392     at::ScalarType scalar_type,
393     const BackendDevice& device) {
394   return DataCacheArena::Get()->GetDeviceData(value, scalar_type, device);
395 }
396 
GetLiveTensors(const BackendDevice * device)397 std::vector<LazyTensorPtr> LazyGraphExecutor::GetLiveTensors(
398     const BackendDevice* device) {
399   return DeviceContextArena::Get()->GetLiveTensors(device);
400 }
401 
SyncLiveTensorsGraph(const BackendDevice * device,c10::ArrayRef<std::string> devices,bool wait)402 void LazyGraphExecutor::SyncLiveTensorsGraph(
403     const BackendDevice* device,
404     c10::ArrayRef<std::string> devices,
405     bool wait) {
406   auto tensors = GetLiveTensors(device);
407   VLOG(4) << tensors.size() << " live tensors: devices=("
408           << c10::Join(", ", devices) << ")";
409   SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
410 }
411 
SyncTensorsGraph(std::vector<LazyTensorPtr> * tensors,c10::ArrayRef<std::string> devices,bool wait,bool sync_ltc_data)412 void LazyGraphExecutor::SyncTensorsGraph(
413     std::vector<LazyTensorPtr>* tensors,
414     c10::ArrayRef<std::string> devices,
415     bool wait,
416     bool sync_ltc_data) {
417   VLOG(4) << "Trying to sync the value of " << tensors->size() << " tensor(s)";
418   SyncTensorsConfig config;
419   config.sync_ltc_data = sync_ltc_data;
420 
421   auto async = SyncTensorsGraphInternal(tensors, devices, config);
422   if (FLAGS_torch_lazy_use_thread_pool && wait && async != nullptr) {
423     async->mwait.Wait();
424   }
425 }
426 
MarkStep(const BackendDevice & device)427 void LazyGraphExecutor::MarkStep(const BackendDevice& device) {
428   TORCH_LAZY_COUNTER("MarkStep", 1);
429   DeviceContextArena::Get()->MarkStep(device);
430   ScopePusher::ResetScopes();
431   ResetTrimCounter();
432   // Move TrieCache's current pointer back to its root
433   TrieCache::Get()->ResetCurrent();
434 }
435 
WaitDeviceOps(c10::ArrayRef<BackendDevice> devices)436 void LazyGraphExecutor::WaitDeviceOps(c10::ArrayRef<BackendDevice> devices) {
437   std::set<BackendDevice> wait_devices;
438   if (!devices.empty()) {
439     for (auto& device : devices) {
440       wait_devices.insert(device);
441     }
442   } else {
443     for (auto& device_str : DeviceContextArena::Get()->GetActiveDevices()) {
444       // TODO: Remove the last use of Device(const std::string& device_spec).
445       wait_devices.insert(BackendDevice(device_str));
446     }
447   }
448   // The LockDevices() API returns a vector of
449   // ExceptionCleanup object, which is going to be freed
450   // immediately, turning this operation into a lock barrier.
451   // NOLINTNEXTLINE
452   DeviceLockerArena::Get()->LockDevices(wait_devices);
453 }
454 
GetTensors(std::vector<LazyTensorPtr> * tensors)455 std::vector<at::Tensor> LazyGraphExecutor::GetTensors(
456     std::vector<LazyTensorPtr>* tensors) {
457   VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)";
458   return GetTensorsFused(tensors);
459 }
460 
ResetTrimCounter() const461 void LazyGraphExecutor::ResetTrimCounter() const {
462   g_tls_data.Reset();
463 }
464 
IncTrimCounter() const465 size_t LazyGraphExecutor::IncTrimCounter() const {
466   return ++g_tls_data.trim_counter;
467 }
468 
DumpBackendComputation(const std::vector<LazyTensorPtr> & tensors)469 std::string LazyGraphExecutor::DumpBackendComputation(
470     const std::vector<LazyTensorPtr>& tensors) {
471   std::vector<Value> ir_values;
472   for (auto& tensor : tensors) {
473     Value ir_value = tensor->CurrentIrValue();
474     if (ir_value) {
475       ir_values.push_back(std::move(ir_value));
476     }
477   }
478   return !ir_values.empty() ? DumpUtil::ToBackend(ir_values, BackendDevice())
479                             : std::string();
480 }
481 
GetDeviceDataIrValue(const at::Scalar & value,c10::ScalarType type,const BackendDevice & device)482 Value LazyGraphExecutor::GetDeviceDataIrValue(
483     const at::Scalar& value,
484     c10::ScalarType type,
485     const BackendDevice& device) {
486   BackendDataPtr data = GetDeviceData(value, type, device);
487   data->SetInfo(std::make_shared<DeviceDataInfo>(
488       /*tensor_id=*/-1, /*read_only=*/true));
489   return MakeDeviceData(std::move(data));
490 }
491 
GetIrValueForScalarFromCodegen(const at::Scalar & value,const BackendDevice & device)492 Value LazyGraphExecutor::GetIrValueForScalarFromCodegen(
493     const at::Scalar& value,
494     const BackendDevice& device) {
495   if (IsSpecialScalar(value)) {
496     return MakeScalar(value, value.type());
497   }
498   auto data = GetDeviceData(value, value.type(), device);
499   data->SetInfo(
500       std::make_shared<DeviceDataInfo>(/*tensor_id=*/-1, /*read_only=*/true));
501   return MakeDeviceData(std::move(data));
502 }
503 
GetIrValueForScalar(const at::Scalar & value,c10::ScalarType type,const BackendDevice & device)504 Value LazyGraphExecutor::GetIrValueForScalar(
505     const at::Scalar& value,
506     c10::ScalarType type,
507     const BackendDevice& device) {
508   if (IsSpecialScalar(value)) {
509     return MakeScalar(value, type);
510   }
511   return GetDeviceDataIrValue(value, type, device);
512 }
513 
GetIrValueForScalar(const at::Scalar & value,const BackendDevice & device)514 Value LazyGraphExecutor::GetIrValueForScalar(
515     const at::Scalar& value,
516     const BackendDevice& device) {
517   return GetIrValueForScalar(value, value.type(), device);
518 }
519 
GetIrValueForExpandedScalar(const at::Scalar & value,const Shape & shape,const BackendDevice & device)520 Value LazyGraphExecutor::GetIrValueForExpandedScalar(
521     const at::Scalar& value,
522     const Shape& shape,
523     const BackendDevice& device) {
524   c10::ArrayRef<int64_t> dimensions = shape.sizes();
525   auto type = shape.scalar_type();
526   Value ir_value = GetIrValueForScalar(value, type, device);
527   if (!dimensions.empty()) {
528     ir_value = MakeExpand(
529         ir_value,
530         dimensions.vec(),
531         /*is_scalar_expand=*/true);
532   }
533   return ir_value;
534 }
535 
Async(SyncTensorCollection * coll,std::vector<BackendDataPtr> parameters_data,std::vector<BackendDataPtr> tensors_data,ComputationCache::TypePtr cached_computation)536 LazyGraphExecutor::Async::Async(
537     SyncTensorCollection* coll,
538     std::vector<BackendDataPtr> parameters_data,
539     std::vector<BackendDataPtr> tensors_data,
540     ComputationCache::TypePtr cached_computation)
541     : mwait(1),
542       indices(std::move(coll->indices)),
543       unlocker(std::move(coll->unlocker)),
544       parameters_data(std::move(parameters_data)),
545       device(coll->device),
546       cached_computation(std::move(cached_computation)),
547       tensors_data(std::move(tensors_data)) {}
548 
Wait()549 void LazyGraphExecutor::Async::Wait() {
550   mwait.Wait();
551   // Accessing other Async members is safe only after MultiWait::Wait()
552   // completes.
553   ExceptionCleanup::StatusType status;
554   for (auto& cleanup : unlocker) {
555     const ExceptionCleanup::StatusType& cleanup_status = cleanup.GetStatus();
556     if (cleanup_status != nullptr) {
557       if (status == nullptr) {
558         status = cleanup_status;
559       }
560       // If we observe the status here, no need to let it propagate to the next
561       // device lock operation.
562       cleanup.SetStatus(nullptr);
563     }
564   }
565   if (status != nullptr) {
566     std::rethrow_exception(status);
567   }
568 }
569 
ShouldSyncTensor(const LazyTensorPtr & tensor) const570 bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr& tensor) const {
571   return tensor->GetIrValue()->op() != ltc_not_supported;
572 }
573 
CollectSyncTensors(const std::vector<LazyTensorPtr> & tensors,const SyncTensorsConfig & config)574 LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors(
575     const std::vector<LazyTensorPtr>& tensors,
576     const SyncTensorsConfig& config) {
577   Unique<BackendDevice> unique_device;
578   for (const auto& tensor : tensors) {
579     unique_device.set(tensor->GetDevice());
580   }
581   SyncTensorCollection coll;
582   if (!unique_device) {
583     return coll;
584   }
585   if (!config.force_ltc_data && !TensorsHaveIR(tensors)) {
586     return coll;
587   }
588 
589   std::vector<at::Tensor> at_tensors;
590   std::vector<BackendDevice> devices;
591   std::vector<size_t> at_tensor_index;
592   std::unordered_set<int64_t> tensor_ids;
593   // The force_ltc_data controls aliasing compilation, so effectively the same
594   // graph with on/off force_ltc_data should not match, hash wise.
595   coll.hash = MHash(config.force_ltc_data);
596   coll.config = config;
597   coll.device = *unique_device;
598   coll.indices.reserve(tensors.size());
599 
600   for (const auto i : c10::irange(tensors.size())) {
601     if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
602         tensors[i]->CurrentDataHandle() == nullptr) {
603       Value ir_value = tensors[i]->CurrentIrValue();
604       if (ir_value) {
605         if (ShouldSyncTensor(tensors[i])) {
606           TORCH_LAZY_COUNTER("SyncedTensorsWithIR", 1);
607           // Add only tensors which need to be synced.
608           coll.hash = HashCombine(coll.hash, ir_value.hash());
609           coll.indices.push_back(i);
610         }
611       } else if (config.force_ltc_data) {
612         // The tensor only has at::Tensor data. We need to queue it for a
613         // device upload.
614         std::optional<at::Tensor> tensor_data = tensors[i]->CurrentTensorData();
615         TORCH_CHECK(tensor_data);
616         at_tensors.push_back(*tensor_data);
617         devices.push_back(tensors[i]->GetDevice());
618         at_tensor_index.push_back(i);
619       }
620     }
621   }
622   if (!at_tensors.empty()) {
623     TORCH_LAZY_COUNTER("SyncTensorsToData", at_tensors.size());
624     std::vector<BackendDataPtr> handles =
625         CreateTensorsData(at_tensors, devices);
626     for (const auto i : c10::irange(handles.size())) {
627       // If we are here, it means that the IR Value for the tensor is not
628       // present. Also, we uploaded the at::Tensor data to the device, but such
629       // data is still valid so we leave it live on the lazy tensor (so that a
630       // following ToTensor() does not need to fetch it from device).
631       tensors[at_tensor_index[i]]->data()->handle = std::move(handles[i]);
632     }
633   }
634   VLOG(4) << "Tensors graph hash " << HashToString(coll.hash) << " on device "
635           << coll.device;
636   return coll;
637 }
638 
CollectRoots(const std::vector<LazyTensorPtr> & tensors,c10::ArrayRef<size_t> indices)639 std::vector<Value> LazyGraphExecutor::CollectRoots(
640     const std::vector<LazyTensorPtr>& tensors,
641     c10::ArrayRef<size_t> indices) {
642   std::vector<Value> roots;
643   roots.reserve(indices.size());
644   for (auto index : indices) {
645     roots.push_back(tensors.at(index)->CurrentIrValue());
646   }
647   return roots;
648 }
649 
ExtractIRAndPrepareTensorData(std::vector<LazyTensorPtr> * tensors,const SyncTensorsConfig & config,c10::ArrayRef<size_t> indices,std::vector<Value> & ir_values,std::vector<BackendDataPtr> & tensor_data_vec)650 void LazyGraphExecutor::ExtractIRAndPrepareTensorData(
651     std::vector<LazyTensorPtr>* tensors,
652     const SyncTensorsConfig& config,
653     c10::ArrayRef<size_t> indices,
654     std::vector<Value>& ir_values,
655     std::vector<BackendDataPtr>& tensor_data_vec) {
656   ir_values.reserve(indices.size());
657   tensor_data_vec.reserve(indices.size());
658   for (auto index : indices) {
659     LazyTensorPtr& tensor = (*tensors)[index];
660     Value ir_value = tensor->CurrentIrValue();
661     ir_values.push_back(ir_value);
662     const BackendDevice& tensor_device = tensor->GetDevice();
663     BackendDataPtr handle = getBackend()->CreateDataPlaceholder(
664         tensor_device, std::move(tensor->shape()));
665     tensor_data_vec.push_back(handle);
666     if (tensor->CurrentDataHandle() == nullptr && config.sync_ltc_data) {
667       tensor->AssignIrValue(Value());
668     }
669   }
670 }
671 
SetTensorData(std::vector<LazyTensorPtr> * tensors,const SyncTensorsConfig & config,c10::ArrayRef<size_t> indices,const std::vector<BackendDataPtr> & tensor_data_vec)672 std::vector<torch::lazy::BackendDataPtr> LazyGraphExecutor::SetTensorData(
673     std::vector<LazyTensorPtr>* tensors,
674     const SyncTensorsConfig& config,
675     c10::ArrayRef<size_t> indices,
676     const std::vector<BackendDataPtr>& tensor_data_vec) {
677   std::vector<BackendDataPtr> tensors_data;
678   tensors_data.reserve(indices.size());
679   for (const auto i : c10::irange(indices.size())) {
680     auto index = indices[i];
681     LazyTensorPtr& tensor = (*tensors)[index];
682     // If the config.force_ltc_data flag is true, the purpose of this tensor
683     // sync operation is to truncate the IR graph and materialize device data in
684     // place of IR graph, on selected tensors. But since operation will complete
685     // asynchronously, if a tensor does not already have device data, we need to
686     // install a placeholder. Since at this point we hold a lock on the device
687     // where the tensors reside (locks held within the coll structure, and moved
688     // into the async variable), any other operation trying to access the
689     // tensor's device data will have to wait until the asynchronous operation
690     // completes.
691     BackendDataPtr handle = tensor->CurrentDataHandle();
692     if (handle == nullptr && config.force_ltc_data) {
693       handle = tensor_data_vec[i];
694       // Note: We are not using SetHandleData method here since that method
695       // resets the ir_value. We have already done the resetting as part
696       // of ExtractIRAndPrepareTensorData to overlap with previous execution.
697       tensor->data()->handle = handle;
698       tensor->data()->tensor_data = std::nullopt;
699     }
700     tensors_data.emplace_back(std::move(handle));
701   }
702   return tensors_data;
703 }
704 
RunPostOrder(const std::vector<Value> & ir_values,SyncTensorCollection * coll)705 LazyGraphExecutor::PostOrderData LazyGraphExecutor::RunPostOrder(
706     const std::vector<Value>& ir_values,
707     SyncTensorCollection* coll) {
708   std::vector<const Node*> roots;
709   roots.reserve(ir_values.size());
710   for (const auto& ir_value : ir_values) {
711     roots.push_back(ir_value.node.get());
712   }
713   PostOrderData po_data;
714   po_data.post_order = Util::ComputePostOrder(roots, &po_data.emission_map);
715   std::unordered_map<BackendData::Handle, size_t> data_handles;
716   for (auto node : po_data.post_order) {
717     const auto backend_data = getBackend()->GetComputationDataFromNode(node);
718     if (backend_data) {
719       /* Acceptable race condition: HasValue may return false. This is OK
720        * since the conditional barrier is a performance optimization. */
721       if (!backend_data->HasValue()) {
722         TensorCollectionBarrier(coll);
723       }
724       BackendData::Handle handle = backend_data->GetHandle();
725       auto it = data_handles.find(handle);
726       if (it != data_handles.end()) {
727         po_data.parameter_sequence.push_back(it->second);
728       } else {
729         po_data.parameter_sequence.push_back(po_data.parameters_data.size());
730         data_handles[handle] = po_data.parameters_data.size();
731         po_data.parameters_data.push_back(backend_data);
732       }
733     }
734   }
735   return po_data;
736 }
737 
TryRunCachedSync(std::vector<LazyTensorPtr> * tensors,SyncTensorCollection * coll,PostOrderData * po_data,const std::vector<BackendDataPtr> & tensor_data_vec)738 std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::TryRunCachedSync(
739     std::vector<LazyTensorPtr>* tensors,
740     SyncTensorCollection* coll,
741     PostOrderData* po_data,
742     const std::vector<BackendDataPtr>& tensor_data_vec) {
743   ComputationCache::TypePtr cached_computation =
744       LookupCachedCompile(coll->hash);
745   if (cached_computation == nullptr) {
746     return nullptr;
747   }
748   if (GRAPH_DUMP_ENABLED) {
749     auto* comp = cached_computation->computation.get();
750     LOG(ERROR) << "Run a cached graph: " << comp->to_string() << std::endl;
751   }
752   TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", po_data->post_order.size());
753   VLOG(5) << "TensorsGraphSize=" << po_data->post_order.size();
754 
755   return ScheduleSyncTensorsGraph(
756       tensors,
757       coll,
758       std::move(po_data->parameters_data),
759       std::move(cached_computation),
760       tensor_data_vec);
761 }
762 
Compile(const std::vector<LazyTensorPtr> & tensors,c10::ArrayRef<std::string> devices,const SyncTensorCollection & coll,PostOrderData * po_data,const std::vector<Value> & ir_values)763 LazyGraphExecutor::CompilationResult LazyGraphExecutor::Compile(
764     const std::vector<LazyTensorPtr>& tensors,
765     c10::ArrayRef<std::string> devices,
766     const SyncTensorCollection& coll,
767     PostOrderData* po_data,
768     const std::vector<Value>& ir_values) {
769   auto lowering_ctx = LoweringContext::Create(
770       "SyncTensorsGraph",
771       coll.device,
772       po_data->post_order,
773       std::move(po_data->emission_map));
774   for (const auto& ir_value : ir_values) {
775     lowering_ctx->AddResult(ir_value);
776   }
777 
778   ComputationPtr computation = lowering_ctx->Build();
779   // If force_ltc_data is true it means that we did a proper sync and are
780   // inside a mark step. If GetTensors was called, force_ltc_data will
781   // be false meaning we are prematurely evaluating some value.
782   computation->in_mark_step = coll.config.force_ltc_data;
783 
784   VLOG(3) << "Compiling IR graph hash " << HashToString(coll.hash)
785           << " on device " << coll.device << " ...";
786   std::vector<ComputationPtr> computations =
787       getBackend()->Compile({computation});
788   VLOG(3) << "Compiling IR graph hash " << HashToString(coll.hash)
789           << " on device " << coll.device << " done!";
790   if (computation) {
791     // TODO(whc) should computation be allowed null here? (because it is in one
792     // case)
793     TORCH_CHECK(
794         computation->parameters_size() ==
795         static_cast<int>(po_data->parameters_data.size()));
796   }
797 
798   return {
799       /*device=*/coll.device,
800       /*emitted_nodes=*/lowering_ctx->GetEmittedNodeCount(),
801       /*computation=*/std::move(computations.front()),
802       /*parameters_data=*/std::move(po_data->parameters_data)};
803 }
804 
GetComputationCache()805 LazyGraphExecutor::ComputationCache* LazyGraphExecutor::GetComputationCache() {
806   static ComputationCache* cache =
807       new ComputationCache(FLAGS_torch_lazy_compilation_cache_size);
808   return cache;
809 }
810 
811 LazyGraphExecutor::ComputationCache::TypePtr LazyGraphExecutor::
LookupCachedCompile(const hash_t & hash)812     LookupCachedCompile(const hash_t& hash) {
813   ComputationCache::TypePtr cached_computation =
814       GetComputationCache()->Get(hash);
815   if (cached_computation == nullptr) {
816     TORCH_LAZY_COUNTER("UncachedCompile", 1);
817     return nullptr;
818   }
819   TORCH_LAZY_COUNTER("CachedCompile", 1);
820   return cached_computation;
821 }
822 
823 #if defined(_MSC_VER)
824 #include <BaseTsd.h>
825 typedef SSIZE_T ssize_t;
826 #endif
827 
828 std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
SyncTensorsGraphInternal(std::vector<LazyTensorPtr> * tensors,c10::ArrayRef<std::string> devices,const SyncTensorsConfig & config)829     SyncTensorsGraphInternal(
830         std::vector<LazyTensorPtr>* tensors,
831         c10::ArrayRef<std::string> devices,
832         const SyncTensorsConfig& config) {
833   SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
834   if (coll.indices.empty()) {
835     /* Enure previous execution is complete before exiting this
836      * function */
837     TensorCollectionBarrier(&coll);
838     return nullptr;
839   }
840   DebugUtil::SaveTensorsGraphInfo(
841       "ScheduleSyncTensorsGraph", *tensors, &coll.indices);
842   std::vector<Value> ir_values;
843   std::vector<BackendDataPtr> tensor_data_vec;
844   ExtractIRAndPrepareTensorData(
845       tensors, coll.config, coll.indices, ir_values, tensor_data_vec);
846   PostOrderData po_data = RunPostOrder(ir_values, &coll);
847   coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence));
848   VLOG(4) << "Parameter sequence graph hash " << HashToString(coll.hash);
849   std::shared_ptr<Async> async =
850       TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec);
851   if (async != nullptr) {
852     return async;
853   }
854 
855   CompilationResult compile_result =
856       Compile(*tensors, devices, coll, &po_data, ir_values);
857   if (GRAPH_DUMP_ENABLED) {
858     auto* comp = compile_result.computation.get();
859     LOG(ERROR) << "Add a cached computation with hash " << coll.hash
860                << std::endl;
861     LOG(ERROR) << "Add a graph to cache: " << comp->to_string() << std::endl;
862   }
863 
864   TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
865   VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
866 
867   auto cached_computation = std::make_shared<CachedComputation>(
868       std::move(compile_result.computation));
869   GetComputationCache()->Add(coll.hash, cached_computation);
870 
871   return ScheduleSyncTensorsGraph(
872       tensors,
873       &coll,
874       std::move(compile_result.parameters_data),
875       std::move(cached_computation),
876       tensor_data_vec);
877 }
878 
879 std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
ScheduleSyncTensorsGraph(SyncTensorCollection * coll,std::vector<BackendDataPtr> parameters_data,std::vector<BackendDataPtr> tensors_data,ComputationCache::TypePtr cached_computation)880     ScheduleSyncTensorsGraph(
881         SyncTensorCollection* coll,
882         std::vector<BackendDataPtr> parameters_data,
883         std::vector<BackendDataPtr> tensors_data,
884         ComputationCache::TypePtr cached_computation) {
885   TensorCollectionBarrier(coll);
886   std::shared_ptr<Async> async = std::make_shared<Async>(
887       coll,
888       std::move(parameters_data),
889       std::move(tensors_data),
890       std::move(cached_computation));
891 
892   auto syncfn = [async, hash = coll->hash]() {
893     try {
894       VLOG(3) << "Executing IR graph hash " << HashToString(hash)
895               << " on device " << async->device << " ...";
896       auto results = getBackend()->ExecuteComputation(
897           async->cached_computation->computation,
898           async->parameters_data,
899           async->device);
900       VLOG(3) << "Executing IR graph hash " << HashToString(hash)
901               << " on device " << async->device << " done!";
902 
903       TORCH_CHECK(
904           async->tensors_data.size() == results.size(),
905           "Expected number of outputs does not match TorchScript Stack size: ",
906           async->tensors_data.size(),
907           " != ",
908           results.size());
909 
910       for (const auto i : c10::irange(results.size())) {
911         if (async->tensors_data[i] != nullptr) {
912           async->tensors_data[i]->Assign(*results[i]);
913         } else {
914           async->tensors_data[i] = std::move(results[i]);
915         }
916       }
917     } catch (...) {
918       // There are two paths of discovery of an exception happening on an
919       // asynchronous task. One happens if the creator of the asynchronous task
920       // explicitly waits for completion, in which case the exception will be
921       // thrown from the Wait() API. Re-throwing the exception below makes sure
922       // this will be captured by the completer function created below, and
923       // surfaced by the Wait() API. But we also need to surface the exception
924       // even in case the caller does not wait, and that is accomplished by
925       // setting the unlockers status. In that case the exception will be
926       // surfaced when the user tries to acquire the device locks the next time.
927       for (auto& unlocker : async->unlocker) {
928         unlocker.SetStatus(std::current_exception());
929       }
930       throw;
931     }
932   };
933 
934   if (FLAGS_torch_lazy_use_thread_pool) {
935     ScheduleIoClosure(async->mwait.Completer(std::move(syncfn)));
936   } else {
937     syncfn();
938   }
939   return async;
940 }
941 
942 std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
ScheduleSyncTensorsGraph(std::vector<LazyTensorPtr> * tensors,SyncTensorCollection * coll,std::vector<BackendDataPtr> parameters_data,ComputationCache::TypePtr cached_computation,const std::vector<BackendDataPtr> & tensor_data_vec)943     ScheduleSyncTensorsGraph(
944         std::vector<LazyTensorPtr>* tensors,
945         SyncTensorCollection* coll,
946         std::vector<BackendDataPtr> parameters_data,
947         ComputationCache::TypePtr cached_computation,
948         const std::vector<BackendDataPtr>& tensor_data_vec) {
949   auto tensors_data =
950       SetTensorData(tensors, coll->config, coll->indices, tensor_data_vec);
951   return ScheduleSyncTensorsGraph(
952       coll,
953       std::move(parameters_data),
954       std::move(tensors_data),
955       std::move(cached_computation));
956 }
957 
GetTensorsFused(std::vector<LazyTensorPtr> * tensors)958 std::vector<at::Tensor> LazyGraphExecutor::GetTensorsFused(
959     std::vector<LazyTensorPtr>* tensors) {
960   SyncTensorsConfig config;
961   config.force_ltc_data = false;
962   auto async = SyncTensorsGraphInternal(tensors, {}, config);
963   if (FLAGS_torch_lazy_use_thread_pool && async != nullptr) {
964     async->mwait.Wait();
965   }
966   std::vector<BackendDataPtr> tensors_data = GatherTensorsData(
967       *tensors,
968       async != nullptr ? async->indices : c10::ArrayRef<size_t>(),
969       async != nullptr ? async->tensors_data : c10::ArrayRef<BackendDataPtr>());
970   return FetchTensors(
971       tensors, tensors_data, async != nullptr ? &async->indices : nullptr);
972 }
973 
974 // This gets tensors from the backend
975 // for TS backend, we'd ideally just cut through these layers and
976 // not need to copy the tensor, just move it
977 
978 // for XLA backend, a copy is going to have to happen,
979 
980 // could we replace the 'Data' object with an at::Tensor, which is 'undefined'
981 // unless a backend attaches a buffer to it?  That way we can have a
982 // 'PopulateTensor' method on backend, which can either attach an existing
983 // tensor buffer to the wrapper, or copy data?
FetchTensors(std::vector<LazyTensorPtr> * tensors,c10::ArrayRef<BackendDataPtr> tensors_data,const std::vector<size_t> * indices)984 std::vector<at::Tensor> LazyGraphExecutor::FetchTensors(
985     std::vector<LazyTensorPtr>* tensors,
986     c10::ArrayRef<BackendDataPtr> tensors_data,
987     const std::vector<size_t>* indices) {
988   std::vector<at::Tensor> results;
989   size_t literals_index = 0;
990   size_t sync_index = 0;
991   results.reserve(tensors->size());
992   for (const auto i : c10::irange(tensors->size())) {
993     if (indices != nullptr && sync_index < indices->size() &&
994         i == (*indices)[sync_index]) {
995       results.push_back(getBackend()->MakeTensorFromComputationData(
996           tensors_data[literals_index], (*tensors)[i]->dtype()));
997       ++literals_index;
998       ++sync_index;
999     } else {
1000       std::optional<at::Tensor> tensor_data =
1001           (*tensors)[i]->CurrentTensorData();
1002       if (tensor_data) {
1003         results.push_back(*tensor_data);
1004       } else {
1005         TORCH_CHECK(literals_index < tensors_data.size());
1006         results.push_back(getBackend()->MakeTensorFromComputationData(
1007             tensors_data[literals_index], (*tensors)[i]->dtype()));
1008         ++literals_index;
1009       }
1010     }
1011   }
1012   return results;
1013 }
1014 
GatherTensorsData(const std::vector<LazyTensorPtr> & tensors,c10::ArrayRef<size_t> indices,c10::ArrayRef<BackendDataPtr> tensors_data)1015 std::vector<BackendDataPtr> LazyGraphExecutor::GatherTensorsData(
1016     const std::vector<LazyTensorPtr>& tensors,
1017     c10::ArrayRef<size_t> indices,
1018     c10::ArrayRef<BackendDataPtr> tensors_data) {
1019   std::vector<BackendDataPtr> result_tensors_data;
1020   std::unordered_map<int64_t, size_t> uid_index_map;
1021   size_t indices_index = 0;
1022   for (const auto i : c10::irange(tensors.size())) {
1023     int64_t tensor_id = tensors[i]->GetUniqueId();
1024     auto it = uid_index_map.find(tensor_id);
1025     if (it != uid_index_map.end()) {
1026       // Current tensor is a duplicate of a previously processed tensor that had
1027       // an IR Node to sync. Get the data from the tensor_data_map.
1028       result_tensors_data.push_back(result_tensors_data[it->second]);
1029     } else if (indices_index < indices.size() && i == indices[indices_index]) {
1030       // If we are at the current index (it means that the tensor at index
1031       // 'i' had an IR node to sync), use the data held within the Async
1032       // object.
1033       uid_index_map.emplace(tensor_id, result_tensors_data.size());
1034       result_tensors_data.push_back(tensors_data[indices_index]);
1035       ++indices_index;
1036     } else if (!tensors[i]->CurrentTensorData()) {
1037       BackendDataPtr handle = tensors[i]->CurrentDataHandle();
1038       TORCH_CHECK(handle != nullptr);
1039       result_tensors_data.push_back(std::move(handle));
1040     }
1041   }
1042   return result_tensors_data;
1043 }
1044 
TensorCollectionBarrier(SyncTensorCollection * coll)1045 void LazyGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) {
1046   if (coll) {
1047     static const std::string invalid_device(
1048         "Unknown0"); /* Temp solution to idetify unassigned devices */
1049     if (coll->device.toString() == invalid_device || !coll->unlocker.empty()) {
1050       return;
1051     }
1052     VLOG(4) << "Waiting on device barrier for device " << coll->device
1053             << " ...";
1054     {
1055       TORCH_LAZY_TIMED("DeviceLockWait");
1056       coll->unlocker = DeviceLockerArena::Get()->LockDevices({coll->device});
1057     }
1058     VLOG(4) << "Waiting on device barrier for device " << coll->device
1059             << " done!";
1060   }
1061 }
1062 
GetGraphHash(const std::vector<LazyTensorPtr> & tensors)1063 hash_t LazyGraphExecutor::GetGraphHash(
1064     const std::vector<LazyTensorPtr>& tensors) {
1065   SyncTensorsConfig config;
1066   config.sync_ltc_data = false;
1067 
1068   auto coll = CollectSyncTensors(tensors, config);
1069   std::vector<Value> ir_values;
1070   for (auto index : coll.indices) {
1071     Value ir_value = tensors[index]->CurrentIrValue();
1072     ir_values.push_back(ir_value);
1073   }
1074   auto po_data = RunPostOrder(ir_values, &coll);
1075   coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence));
1076   return coll.hash;
1077 }
1078 
1079 } // namespace lazy
1080 } // namespace torch
1081