xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/symbol.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file contains canonical definitions for our symbol naming conventions,
4across torch.fx.experimental.symbolic_shapes and torch._inductor.  The
5intention is:
6
71. To make it easily greppable where all the sites we use a prefix are
82. Make it possible to easily tell if we can introduce a new prefix without
9   introducing a conflict
10
11You can occasionally test if prefixes have been hardcoded by renaming prefixes
12in this file and seeing what breaks.
13"""
14
15from enum import auto, Enum
16from typing import Sequence, Union
17
18import sympy
19
20
21class SymT(Enum):
22    SIZE = auto()
23    FLOAT = auto()
24    UNBACKED_INT = auto()
25    UNBACKED_FLOAT = auto()
26    # Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
27    # If one of these shows up in an indexing expression, that means an
28    # indirect load is happening.
29    TMP = auto()
30    # Inductor: Placeholder variable that is later replaced with TMP
31    INDIRECT = auto()
32    # Inductor: Some size expressions are replaced with a precomputed size ps0
33    # which is computed host side, and then directly reused in the kernel, so
34    # we don't repeatedly recompute it on device.
35    PRECOMPUTED_SIZE = auto()
36    # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
37    # dim in the loop
38    INDEX = auto()
39    # Inductor: A reduction indexing r0 variable in loops IR which ranges over
40    # reduced dim in the loop
41    RINDEX = auto()
42    # Inductor: In templated kernels torch._inductor.kernel, we have a hook to
43    # store the final output and append epilogue fusions.  To do this, we must
44    # know what the indexes the outputs range over.  NB: These will also
45    # advertise as INDEX, this is... probably OK?
46    TEMPLATE_INDEX = auto()
47    # Inductor: iteration domain for blockIdx.x/blockIdx.y
48    XBLOCK = auto()
49    YBLOCK = auto()
50    # Inductor: this is used solely for dynamic_reshape_indexer
51    VIEW = auto()
52    # Alternate (non-modular) indexing used in halide kernels
53    HALIDE = auto()
54
55
56# Invariant: there must not be a prefix which is a prefix of another string,
57# as this introduces ambiguity
58prefix_str = {
59    SymT.SIZE: "s",  # integer
60    SymT.UNBACKED_INT: "u",  # integer
61    # Prefix z here is chosen to avoid false aliasing in symbol_is_type test
62    # DO NOT add a "z" type.  You also need to avoid conflicts on these
63    # prefixes but this is somewhat easier to manage
64    SymT.FLOAT: "zf",
65    SymT.UNBACKED_FLOAT: "zuf",
66    SymT.TMP: "tmp",
67    SymT.PRECOMPUTED_SIZE: "ps",
68    SymT.INDEX: "i",
69    SymT.RINDEX: "r",
70    SymT.TEMPLATE_INDEX: "idx",
71    SymT.XBLOCK: "x",
72    SymT.YBLOCK: "y",
73    SymT.INDIRECT: "indirect",  # false aliasing?
74    SymT.VIEW: "view",
75    SymT.HALIDE: "h",
76}
77
78
79def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
80    # TODO: maybe put the assumptions here directly
81    return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
82
83
84# This type is a little wider than it should be, because free_symbols says
85# that it contains Basic, rather than Symbol
86def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool:
87    assert isinstance(sym, sympy.Symbol)
88    name_str = sym.name.lower()  # Match capitalized names like XBLOCK, RBLOCK
89    if isinstance(prefix, SymT):
90        return name_str.startswith(prefix_str[prefix])
91    else:
92        return name_str.startswith(tuple(prefix_str[p] for p in prefix))
93
94
95def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool:
96    return any(symbol_is_type(v, prefix) for v in e.free_symbols)
97