xref: /aosp_15_r20/external/pytorch/torch/_lazy/computation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch._C._lazy
3import torch._C._lazy_ts_backend
4
5
6def get_tensors_ts_device_data_node(tensors):
7    """Return tensor ids and eager tensors for DeviceData nodes in the
8    IR for the passed in lazy tensors.
9
10    TODO: This API is currently ts backend specific. We are working on
11    generalizing it to all backends including XLA.
12    """
13    return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
14
15
16def get_graph_hash(tensors):
17    """Return the graph hash for the passed in lazy tensors"""
18    return torch._C._lazy._get_graph_hash(tensors)
19
20
21def run_cached_graph(hash_str, graph_inputs):
22    """Running the cached computation graph with the given inputs
23
24    TODO: This API is currently ts backend specific. We are working on
25    generalizing it to all backends including XLA.
26    """
27    return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)
28