1# mypy: allow-untyped-defs 2 3import torch._C._lazy 4from torch.utils._pytree import tree_flatten, tree_unflatten 5 6from .closure import add_step_closure, run_step_closures 7 8 9def mark_step(device: str = "", wait=False): 10 """Triggers a mark step, which amounts to 11 - collecting a group of 'live' lazy tensors to index into the compilation cache 12 (lowering/compiling their IR graphs if not cached) 13 - kicking off execution of the compiled function 14 - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) 15 """ 16 # TODO(whc) expand this to include backend hooks and align with XLA backend needs 17 torch._C._lazy._mark_step(device, [], wait=wait) 18 19 run_step_closures() 20 21 22def wait_device_ops(devices=None): 23 """Waits for all the async operations on the given devices to complete. 24 Args: 25 devices (string..., optional): The devices whose async ops need to be waited 26 for. If empty, all the local devices will be waited for. 27 """ 28 if devices is None: 29 devices = [] 30 torch._C._lazy._wait_device_ops(devices=devices) 31 32 33def sync_multi(tensors, devices): 34 """ 35 Sync the list of lazy tensors so there IR get lowered for the activate backend 36 and the compiled computation graph get cached. 37 """ 38 torch._C._lazy._sync_multi(tensors, devices) 39 40 41def get_tensor_id(tensor): 42 """Return a unique id of the lazy tensor maintained by LTC""" 43 return torch._C._lazy._get_tensor_id(tensor) 44 45 46def to_cpu(tensors, devices=None): 47 devices = devices or ["lazy"] 48 49 flattened, spec = tree_flatten(tensors) 50 sync_multi(flattened, devices) 51 return tree_unflatten([t.to("cpu") for t in flattened], spec) 52 53 54def save(tensors, *args, **kwargs): 55 torch.save(to_cpu(tensors), *args, **kwargs) 56