1# mypy: allow-untyped-defs 2# pyre-strict 3from typing import Union 4 5import torch 6 7 8class ProxyValue: 9 # pyre-ignore 10 def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): 11 # pyre-ignore 12 self.data = data 13 self.proxy_or_node = proxy 14 15 @property 16 def node(self) -> torch.fx.Node: 17 if isinstance(self.proxy_or_node, torch.fx.Node): 18 return self.proxy_or_node 19 assert isinstance(self.proxy_or_node, torch.fx.Proxy) 20 return self.proxy_or_node.node 21 22 @property 23 def proxy(self) -> torch.fx.Proxy: 24 if not isinstance(self.proxy_or_node, torch.fx.Proxy): 25 raise RuntimeError( 26 f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" 27 ) 28 return self.proxy_or_node 29 30 def to_tensor(self) -> torch.Tensor: 31 assert isinstance(self.data, torch.Tensor) 32 return self.data 33 34 def is_tensor(self) -> bool: 35 return isinstance(self.data, torch.Tensor) 36 37 # pyre-ignore 38 def __iter__(self): 39 yield from self.data 40 41 def __bool__(self) -> bool: 42 return bool(self.data) 43