xref: /aosp_15_r20/external/pytorch/torch/_lazy/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch._C._lazy
4from torch.utils._pytree import tree_flatten, tree_unflatten
5
6from .closure import add_step_closure, run_step_closures
7
8
9def mark_step(device: str = "", wait=False):
10    """Triggers a mark step, which amounts to
11    - collecting a group of 'live' lazy tensors to index into the compilation cache
12      (lowering/compiling their IR graphs if not cached)
13    - kicking off execution of the compiled function
14    - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
15    """
16    # TODO(whc) expand this to include backend hooks and align with XLA backend needs
17    torch._C._lazy._mark_step(device, [], wait=wait)
18
19    run_step_closures()
20
21
22def wait_device_ops(devices=None):
23    """Waits for all the async operations on the given devices to complete.
24    Args:
25      devices (string..., optional): The devices whose async ops need to be waited
26        for. If empty, all the local devices will be waited for.
27    """
28    if devices is None:
29        devices = []
30    torch._C._lazy._wait_device_ops(devices=devices)
31
32
33def sync_multi(tensors, devices):
34    """
35    Sync the list of lazy tensors so there IR get lowered for the activate backend
36    and the compiled computation graph get cached.
37    """
38    torch._C._lazy._sync_multi(tensors, devices)
39
40
41def get_tensor_id(tensor):
42    """Return a unique id of the lazy tensor maintained by LTC"""
43    return torch._C._lazy._get_tensor_id(tensor)
44
45
46def to_cpu(tensors, devices=None):
47    devices = devices or ["lazy"]
48
49    flattened, spec = tree_flatten(tensors)
50    sync_multi(flattened, devices)
51    return tree_unflatten([t.to("cpu") for t in flattened], spec)
52
53
54def save(tensors, *args, **kwargs):
55    torch.save(to_cpu(tensors), *args, **kwargs)
56