1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from typing import Union 4 5import torch 6from torch import SymBool, SymFloat, SymInt 7from torch.types import py_sym_types 8 9 10@dataclass 11class _SymExprHash: 12 """ 13 Hash for a py_sym_types that will use the underlying sympy expression 14 """ 15 16 sym_obj: Union[SymInt, SymFloat, SymBool] 17 18 def __hash__(self) -> int: 19 return hash((type(self.sym_obj), self.sym_obj.node.expr)) 20 21 def __eq__(self, value) -> bool: 22 if not isinstance(value, _SymExprHash): 23 return False 24 return self.sym_obj.node.expr == value.sym_obj.node.expr 25 26 27class _SymHashingDict: 28 """ 29 Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse 30 existing sym proxies. 31 32 SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, 33 fallback to symnodes. 34 """ 35 36 def __init__(self): 37 self.sym_hash_dict = {} 38 39 def __setitem__(self, key, value): 40 self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) 41 42 def __getitem__(self, key): 43 return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] 44 45 def __contains__(self, key): 46 return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict 47 48 def get(self, key, default=None): 49 return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) 50 51 def _wrap_to_sym_expr_hash(self, key): 52 return _SymExprHash(key) if isinstance(key, py_sym_types) else key 53 54 55def dedupe_symints(graph: torch.fx.Graph): 56 """ 57 Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. 58 59 We only dedupe from graph inputs to avoid adding a potential dependency in the forward 60 from the backward. 61 62 """ 63 64 sym_dict = _SymHashingDict() 65 resolvable_from_input_symints = set() 66 67 for node in graph.nodes: 68 val = node.meta.get("val", None) 69 if val is None or not isinstance(val, py_sym_types): 70 continue 71 72 if node.op == "placeholder": 73 resolvable_from_input_symints.add(node) 74 sym_dict[val] = node 75 elif existing_node := sym_dict.get(val): 76 node.replace_all_uses_with(existing_node) 77 graph.erase_node(node) 78 elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): 79 sym_dict[val] = node 80 resolvable_from_input_symints.add(node) 81