1# mypy: allow-untyped-defs 2from torch.fx.proxy import Proxy 3from ._compatibility import compatibility 4 5@compatibility(is_backward_compatible=False) 6def annotate(val, type): 7 """ 8 Annotates a Proxy object with a given type. 9 10 This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object 11 Args: 12 val (object): An object to be annotated if its type is torch.fx.Proxy. 13 type (object): A type to be assigned to a given proxy object as val. 14 Returns: 15 The given val. 16 Raises: 17 RuntimeError: If a val already has a type in its node. 18 """ 19 if isinstance(val, Proxy): 20 if val.node.type: 21 raise RuntimeError(f"Tried to annotate a value that already had a type on it!" 22 f" Existing type is {val.node.type} " 23 f"and new type is {type}. " 24 f"This could happen if you tried to annotate a function parameter " 25 f"value (in which case you should use the type slot " 26 f"on the function signature) or you called " 27 f"annotate on the same value twice") 28 else: 29 val.node.type = type 30 return val 31 else: 32 return val 33