xref: /aosp_15_r20/external/pytorch/torch/fx/annotate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.proxy import Proxy
3from ._compatibility import compatibility
4
5@compatibility(is_backward_compatible=False)
6def annotate(val, type):
7    """
8    Annotates a Proxy object with a given type.
9
10    This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object
11    Args:
12        val (object): An object to be annotated if its type is torch.fx.Proxy.
13        type (object): A type to be assigned to a given proxy object as val.
14    Returns:
15        The given val.
16    Raises:
17        RuntimeError: If a val already has a type in its node.
18    """
19    if isinstance(val, Proxy):
20        if val.node.type:
21            raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
22                               f" Existing type is {val.node.type} "
23                               f"and new type is {type}. "
24                               f"This could happen if you tried to annotate a function parameter "
25                               f"value (in which case you should use the type slot "
26                               f"on the function signature) or you called "
27                               f"annotate on the same value twice")
28        else:
29            val.node.type = type
30        return val
31    else:
32        return val
33