xref: /aosp_15_r20/external/pytorch/torch/jit/_decompositions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch import Tensor
4
5
6aten = torch.ops.aten
7import inspect
8import warnings
9from typing import Callable, Dict, List, Optional, Set, TypeVar
10from typing_extensions import ParamSpec
11
12from torch.types import Number
13
14
15decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
16function_name_set: Set[str] = set()
17
18_T = TypeVar("_T")
19_P = ParamSpec("_P")
20
21
22def check_decomposition_has_type_annotations(f):
23    inspect_empty = inspect._empty  # type: ignore[attr-defined]
24    sig = inspect.signature(f)
25    for param in sig.parameters.values():
26        assert (
27            param.annotation != inspect_empty
28        ), f"No signature on param {param.name} for function {f.name}"
29
30    assert (
31        sig.return_annotation != inspect_empty
32    ), f"No return annotation for function {f.name}"
33
34
35def signatures_match(decomposition_sig, torch_op_sig):
36    decomp_params = decomposition_sig.parameters
37    op_params = torch_op_sig.parameters
38
39    if len(decomp_params) != len(op_params):
40        return False
41
42    for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
43        # can't check full equality yet because not all fields are correcly deduced
44        # in the torch_op_sig - like default value
45        # can't check 'kind' bc
46        # kwarg-only values with defaults not yet supported in TS
47        inspect_empty = inspect._empty  # type: ignore[attr-defined]
48        for field in ["name", "annotation"]:
49            if field == "name" and decomp_param.name == "self":
50                warnings.warn("PyTorch uses 'input' instead of 'self' on public api")
51
52            if getattr(decomp_param, field) != getattr(op_param, field):
53                return False
54
55        decomp_default = decomp_param.default
56        op_default = op_param.default
57        # default value not always correctly inferred as being present on torch schema,
58        # but if specified on both they should be equal
59        if decomp_default != inspect_empty and op_default != inspect_empty:
60            if decomp_default != op_default:
61                return False
62
63    return decomposition_sig.return_annotation == torch_op_sig.return_annotation
64
65
66def register_decomposition(
67    aten_op: torch._ops.OpOverload,
68    registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
69) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
70    def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
71        nonlocal registry
72        if registry is None:
73            registry = decomposition_table
74
75        assert isinstance(aten_op, torch._ops.OpOverload)
76
77        # Need unique name for jit function serialization
78        assert (
79            f.__name__ not in function_name_set
80        ), f"Duplicated function name {f.__name__}"
81        function_name_set.add(f.__name__)
82
83        scripted_func = torch.jit.script(f)
84        torch._C._jit_pass_inline(scripted_func.graph)
85
86        for _ in range(2):
87            torch._C._jit_pass_peephole(scripted_func.graph)
88            torch._C._jit_pass_constant_propagation(scripted_func.graph)
89
90        registry[str(aten_op._schema)] = scripted_func
91        return f
92
93    return decomposition_decorator
94
95
96# TODO: replace torch.sigmoid -> aten.sigmoid
97
98
99@register_decomposition(aten.var.correction)
100def var_decomposition(
101    input: Tensor,
102    dim: Optional[List[int]] = None,
103    correction: Optional[Number] = None,
104    keepdim: bool = False,
105) -> Tensor:
106    if dim is None:
107        dim_i: List[int] = []
108        dim = dim_i
109
110    if isinstance(dim, (tuple, list)) and len(dim) == 0:
111        n = input.numel()
112    else:
113        n = 1
114        for dim_i in dim:  # type: ignore[assignment]
115            n *= input.shape[dim_i]  # type: ignore[call-overload]
116
117    mean = aten.mean(input, dim, True)
118    sub = input - mean
119    sq = sub * sub
120    sum = aten.sum(sq, dim, keepdim)
121
122    if correction is None:
123        denom = float(n - 1)
124    else:
125        if isinstance(correction, int):
126            denom = float(n - correction)
127        elif isinstance(correction, float):
128            denom = float(n) - correction
129        else:
130            raise RuntimeError("correction must be int or float")
131
132    return sum / max(0, denom)
133
134
135@register_decomposition(aten.var.default)
136def var(input: Tensor, unbiased: bool = True) -> Tensor:
137    return var_decomposition(input, correction=(1 if unbiased else 0))
138