xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/dedupe_symint_uses.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from typing import Union
4
5import torch
6from torch import SymBool, SymFloat, SymInt
7from torch.types import py_sym_types
8
9
10@dataclass
11class _SymExprHash:
12    """
13    Hash for a py_sym_types that will use the underlying sympy expression
14    """
15
16    sym_obj: Union[SymInt, SymFloat, SymBool]
17
18    def __hash__(self) -> int:
19        return hash((type(self.sym_obj), self.sym_obj.node.expr))
20
21    def __eq__(self, value) -> bool:
22        if not isinstance(value, _SymExprHash):
23            return False
24        return self.sym_obj.node.expr == value.sym_obj.node.expr
25
26
27class _SymHashingDict:
28    """
29    Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse
30    existing sym proxies.
31
32    SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail,
33    fallback to symnodes.
34    """
35
36    def __init__(self):
37        self.sym_hash_dict = {}
38
39    def __setitem__(self, key, value):
40        self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value)
41
42    def __getitem__(self, key):
43        return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)]
44
45    def __contains__(self, key):
46        return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict
47
48    def get(self, key, default=None):
49        return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default)
50
51    def _wrap_to_sym_expr_hash(self, key):
52        return _SymExprHash(key) if isinstance(key, py_sym_types) else key
53
54
55def dedupe_symints(graph: torch.fx.Graph):
56    """
57    Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs.
58
59    We only dedupe from graph inputs to avoid adding a potential dependency in the forward
60    from the backward.
61
62    """
63
64    sym_dict = _SymHashingDict()
65    resolvable_from_input_symints = set()
66
67    for node in graph.nodes:
68        val = node.meta.get("val", None)
69        if val is None or not isinstance(val, py_sym_types):
70            continue
71
72        if node.op == "placeholder":
73            resolvable_from_input_symints.add(node)
74            sym_dict[val] = node
75        elif existing_node := sym_dict.get(val):
76            node.replace_all_uses_with(existing_node)
77            graph.erase_node(node)
78        elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
79            sym_dict[val] = node
80            resolvable_from_input_symints.add(node)
81