xref: /aosp_15_r20/external/pytorch/torch/_export/pass_infra/proxy_value.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# pyre-strict
3from typing import Union
4
5import torch
6
7
8class ProxyValue:
9    # pyre-ignore
10    def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
11        # pyre-ignore
12        self.data = data
13        self.proxy_or_node = proxy
14
15    @property
16    def node(self) -> torch.fx.Node:
17        if isinstance(self.proxy_or_node, torch.fx.Node):
18            return self.proxy_or_node
19        assert isinstance(self.proxy_or_node, torch.fx.Proxy)
20        return self.proxy_or_node.node
21
22    @property
23    def proxy(self) -> torch.fx.Proxy:
24        if not isinstance(self.proxy_or_node, torch.fx.Proxy):
25            raise RuntimeError(
26                f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
27            )
28        return self.proxy_or_node
29
30    def to_tensor(self) -> torch.Tensor:
31        assert isinstance(self.data, torch.Tensor)
32        return self.data
33
34    def is_tensor(self) -> bool:
35        return isinstance(self.data, torch.Tensor)
36
37    # pyre-ignore
38    def __iter__(self):
39        yield from self.data
40
41    def __bool__(self) -> bool:
42        return bool(self.data)
43