1# mypy: allow-untyped-defs 2import torch 3from torch import Tensor 4 5 6aten = torch.ops.aten 7import inspect 8import warnings 9from typing import Callable, Dict, List, Optional, Set, TypeVar 10from typing_extensions import ParamSpec 11 12from torch.types import Number 13 14 15decomposition_table: Dict[str, torch.jit.ScriptFunction] = {} 16function_name_set: Set[str] = set() 17 18_T = TypeVar("_T") 19_P = ParamSpec("_P") 20 21 22def check_decomposition_has_type_annotations(f): 23 inspect_empty = inspect._empty # type: ignore[attr-defined] 24 sig = inspect.signature(f) 25 for param in sig.parameters.values(): 26 assert ( 27 param.annotation != inspect_empty 28 ), f"No signature on param {param.name} for function {f.name}" 29 30 assert ( 31 sig.return_annotation != inspect_empty 32 ), f"No return annotation for function {f.name}" 33 34 35def signatures_match(decomposition_sig, torch_op_sig): 36 decomp_params = decomposition_sig.parameters 37 op_params = torch_op_sig.parameters 38 39 if len(decomp_params) != len(op_params): 40 return False 41 42 for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): 43 # can't check full equality yet because not all fields are correcly deduced 44 # in the torch_op_sig - like default value 45 # can't check 'kind' bc 46 # kwarg-only values with defaults not yet supported in TS 47 inspect_empty = inspect._empty # type: ignore[attr-defined] 48 for field in ["name", "annotation"]: 49 if field == "name" and decomp_param.name == "self": 50 warnings.warn("PyTorch uses 'input' instead of 'self' on public api") 51 52 if getattr(decomp_param, field) != getattr(op_param, field): 53 return False 54 55 decomp_default = decomp_param.default 56 op_default = op_param.default 57 # default value not always correctly inferred as being present on torch schema, 58 # but if specified on both they should be equal 59 if decomp_default != inspect_empty and op_default != inspect_empty: 60 if decomp_default != op_default: 61 return False 62 63 return decomposition_sig.return_annotation == torch_op_sig.return_annotation 64 65 66def register_decomposition( 67 aten_op: torch._ops.OpOverload, 68 registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None, 69) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: 70 def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]: 71 nonlocal registry 72 if registry is None: 73 registry = decomposition_table 74 75 assert isinstance(aten_op, torch._ops.OpOverload) 76 77 # Need unique name for jit function serialization 78 assert ( 79 f.__name__ not in function_name_set 80 ), f"Duplicated function name {f.__name__}" 81 function_name_set.add(f.__name__) 82 83 scripted_func = torch.jit.script(f) 84 torch._C._jit_pass_inline(scripted_func.graph) 85 86 for _ in range(2): 87 torch._C._jit_pass_peephole(scripted_func.graph) 88 torch._C._jit_pass_constant_propagation(scripted_func.graph) 89 90 registry[str(aten_op._schema)] = scripted_func 91 return f 92 93 return decomposition_decorator 94 95 96# TODO: replace torch.sigmoid -> aten.sigmoid 97 98 99@register_decomposition(aten.var.correction) 100def var_decomposition( 101 input: Tensor, 102 dim: Optional[List[int]] = None, 103 correction: Optional[Number] = None, 104 keepdim: bool = False, 105) -> Tensor: 106 if dim is None: 107 dim_i: List[int] = [] 108 dim = dim_i 109 110 if isinstance(dim, (tuple, list)) and len(dim) == 0: 111 n = input.numel() 112 else: 113 n = 1 114 for dim_i in dim: # type: ignore[assignment] 115 n *= input.shape[dim_i] # type: ignore[call-overload] 116 117 mean = aten.mean(input, dim, True) 118 sub = input - mean 119 sq = sub * sub 120 sum = aten.sum(sq, dim, keepdim) 121 122 if correction is None: 123 denom = float(n - 1) 124 else: 125 if isinstance(correction, int): 126 denom = float(n - correction) 127 elif isinstance(correction, float): 128 denom = float(n) - correction 129 else: 130 raise RuntimeError("correction must be int or float") 131 132 return sum / max(0, denom) 133 134 135@register_decomposition(aten.var.default) 136def var(input: Tensor, unbiased: bool = True) -> Tensor: 137 return var_decomposition(input, correction=(1 if unbiased else 0)) 138