xref: /aosp_15_r20/external/pytorch/torch/_export/pass_infra/node_metadata.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Dict, Set
2
3
4NodeMetadataValue = Any
5
6
7PROTECTED_KEYS: Set[str] = {
8    "val",
9    "stack_trace",
10    "nn_module_stack",
11    "debug_handle",
12    "tensor_meta",
13}
14
15
16class NodeMetadata:
17    def __init__(self, data: Dict[str, Any]) -> None:
18        self.data: Dict[str, Any] = data.copy()
19
20    def __getitem__(self, key: str) -> NodeMetadataValue:
21        return self.data[key]
22
23    def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
24        if key in PROTECTED_KEYS:
25            raise RuntimeError(f"Could not override node key: {key}")
26        self.data[key] = value
27
28    def __contains__(self, key: str) -> bool:
29        return key in self.data
30
31    def copy(self) -> "NodeMetadata":
32        return NodeMetadata(self.data.copy())
33