1from typing import Any, Dict, Set 2 3 4NodeMetadataValue = Any 5 6 7PROTECTED_KEYS: Set[str] = { 8 "val", 9 "stack_trace", 10 "nn_module_stack", 11 "debug_handle", 12 "tensor_meta", 13} 14 15 16class NodeMetadata: 17 def __init__(self, data: Dict[str, Any]) -> None: 18 self.data: Dict[str, Any] = data.copy() 19 20 def __getitem__(self, key: str) -> NodeMetadataValue: 21 return self.data[key] 22 23 def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: 24 if key in PROTECTED_KEYS: 25 raise RuntimeError(f"Could not override node key: {key}") 26 self.data[key] = value 27 28 def __contains__(self, key: str) -> bool: 29 return key in self.data 30 31 def copy(self) -> "NodeMetadata": 32 return NodeMetadata(self.data.copy()) 33