xref: /aosp_15_r20/external/pytorch/torch/_decomp/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3from collections import defaultdict
4from functools import wraps
5from itertools import chain
6from typing import Callable, Dict, List, Sequence, TypeVar, Union
7from typing_extensions import ParamSpec
8
9import torch
10import torch.library
11from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
12from torch._prims_common import CustomOutParamAnnotation
13from torch.utils import _pytree as pytree
14
15
16__all__ = [
17    "decomposition_table",
18    "pre_autograd_decomposition_table",
19    "meta_table",
20    "register_decomposition",
21    "get_decompositions",
22    "core_aten_decompositions",
23]
24
25_T = TypeVar("_T")
26_P = ParamSpec("_P")
27
28# TODO: relax key type here; torch registrations should be possible to; but
29# right now this type is accurate
30global_decomposition_table: Dict[
31    str, Dict[torch._ops.OperatorBase, Callable]
32] = defaultdict(dict)
33
34decomposition_table = global_decomposition_table["post_autograd"]
35pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
36meta_table = global_decomposition_table["meta"]
37
38
39def _add_op_to_registry(registry, op, fn):
40    """
41    This is an internal API for adding an op to the decomposition table.
42
43    If op is OpOverload, it will be added to the registry directly.
44    If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
45    """
46    overloads: List[Union[torch._ops.OperatorBase]] = []
47    if isinstance(op, HigherOrderOperator):
48        # There's no concept of overloads for HigherOrderOperator
49        registry[op] = fn
50        return
51    elif isinstance(op, OpOverload):
52        overloads.append(op)
53    else:
54        assert isinstance(op, OpOverloadPacket)
55        for ol in op.overloads():
56            overloads.append(getattr(op, ol))
57
58    for op_overload in overloads:
59        if op_overload in registry:
60            raise RuntimeError(f"duplicate registrations for {op_overload}")
61        # TorchScript dumps a bunch of extra nonsense overloads
62        # which don't have corresponding dispatcher entries, we need
63        # to filter those out, e.g aten.add.float_int
64        if torch._C._dispatch_has_kernel(op_overload.name()):
65            registry[op_overload] = fn
66
67
68def _convert_out_params(f):
69    out_annotation = f.__annotations__.get("out")
70
71    # If there are no out params, do not wrap the function.
72    if not out_annotation:
73        return f
74
75    # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
76    if getattr(out_annotation, "__origin__", None) is tuple:
77        sig = inspect.signature(f)
78        out_names = sig.return_annotation._fields
79        # If out is a tuple, we need to register a function that unpacks all the out
80        # elements as this is what native_functions.yaml expects
81
82        @wraps(f)
83        def _fn(*args, **kwargs):
84            out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
85            # Either all of the out kwargs are set or none of them
86            is_none = out_kwargs[0] is None
87            assert all((o is None) == is_none for o in out_kwargs)
88            return f(*args, **kwargs, out=None if is_none else out_kwargs)
89
90        out_params = [
91            inspect.Parameter(
92                o,
93                kind=inspect.Parameter.KEYWORD_ONLY,
94                default=None,
95                annotation=t,
96            )
97            for o, t in zip(out_names, out_annotation.__args__)
98        ]
99        # Drop the out parameter and concatenate the new kwargs in the signature
100        params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
101        _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
102            parameters=params, return_annotation=sig.return_annotation  # type: ignore[arg-type]
103        )
104        # Drop the out parameter and concatenate the new kwargs in the annotations
105        _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
106        for o in out_params:
107            _fn.__annotations__[o.name] = o.annotation
108
109        # Propagate that this function is wrapped by `out_wrapper`
110        _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper  # type: ignore[attr-defined]
111
112        return _fn
113
114    # Alternatively, there may be a single tensor out parameter with a name
115    # other than "out". This will need special treatment and is indicated by an
116    # annotation, which we will remove here so it is not exposed after wrapping.
117    custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
118    if custom_out_param_name:
119
120        @wraps(f)
121        def _fn(*args, **kwargs):
122            out_kwarg = kwargs.pop(custom_out_param_name, None)
123            return f(*args, **kwargs, out=out_kwarg)
124
125        out_param = inspect.Parameter(
126            custom_out_param_name,
127            kind=inspect.Parameter.KEYWORD_ONLY,
128            default=None,
129            annotation=out_annotation,
130        )
131
132        # Drop the out parameter and concatenate the new kwarg in the signature
133        sig = inspect.signature(f)
134        params = chain(
135            (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
136        )
137        _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
138            parameters=params, return_annotation=sig.return_annotation  # type: ignore[arg-type]
139        )
140
141        # Drop the out parameter and concatenate the new kwargs in the annotations
142        _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
143        _fn.__annotations__[out_param.name] = out_param.annotation
144
145        return _fn
146
147    return f
148
149
150def register_decomposition(
151    aten_op, registry=None, *, type="post_autograd", unsafe=False
152) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
153    """
154    A decorator to register a function as a decomposition to the Python
155    decomposition table.  Use it like this::
156
157        @register_decomposition(torch.ops.aten.clamp_min)
158        def clamp_min(x):
159            return torch.clamp(self, min=min)
160
161    If you are writing a new decomposition, consider contributing it
162    directly to PyTorch in torch._decomp.decompositions.
163
164    This API is experimental; we are almost certainly going to extend
165    the API when we make decompositions eligible for use in transforms (e.g.,
166    autograd) and not just backend tracing, where we then need to know if a
167    decomposition can be used to simulate a transform.
168
169    By default, we also will register it to the Meta key of dispatcher,
170    and replace the c++ Meta implementation if there is already one.
171
172    unsafe kwarg is for reuse of this function for registering non-function
173    things
174    """
175
176    assert type in {"post_autograd", "pre_autograd", "meta"}
177
178    def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
179        orig_fn = fn
180        if not unsafe:
181            fn = _convert_out_params(fn)
182
183        nonlocal registry
184        if registry is None:
185            registry = global_decomposition_table[type]
186
187        def register(op):
188            _add_op_to_registry(registry, op, fn)
189
190        # To handle allowing multiple aten_ops at once
191        pytree.tree_map_(register, aten_op)
192        return orig_fn
193
194    return decomposition_decorator
195
196
197def get_decompositions(
198    aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
199    type: str = "post_autograd",
200) -> Dict[torch._ops.OperatorBase, Callable]:
201    """
202    Retrieve a dictionary of decompositions corresponding to the list of
203    operator overloads and overload packets passed as input.  Overload
204    packets will include all decomposed overloads in the packet.  If there is
205    no decomposition for a requested operator, it is silently ignored.
206
207    This API is experimental; we are almost certainly going to give an alternate,
208    more recommended formulation, where a user provides the set of operators
209    they know how to implement, and we provide decompositions for everything
210    not in this set.
211    """
212    assert type in {"post_autograd", "pre_autograd", "meta"}
213
214    registry = global_decomposition_table[type]
215    packets_to_overloads = defaultdict(list)
216    for opo in registry:
217        if isinstance(opo, (OpOverload, OpOverloadPacket)):
218            packets_to_overloads[opo.overloadpacket].append(opo)
219    decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
220    for op in aten_ops:
221        if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
222            for op_overload in packets_to_overloads[op]:
223                decompositions[op_overload] = registry[op_overload]
224        elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
225            decompositions[op] = registry[op]
226    return decompositions
227
228
229def remove_decompositions(
230    decompositions: Dict[torch._ops.OperatorBase, Callable],
231    aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
232) -> None:
233    """
234    Given a dictionary of decompositions obtained from get_decompositions(), removes
235    operators associated with a list of operator overloads and overload packets passed
236    as input. If the decomposition dictionary does not contain a decomposition that is
237    specified to be removed, it is silently ignored.
238    """
239    for op in aten_ops:
240        if isinstance(op, OpOverloadPacket):
241            for overload_name in op.overloads():
242                opo = getattr(op, overload_name)
243                decompositions.pop(opo, None)
244        elif isinstance(op, OpOverload):
245            decompositions.pop(op, None)
246
247
248# populate the table
249import torch._decomp.decompositions
250import torch._refs
251
252
253# See NOTE [Core ATen Ops]
254#
255# list was copied from torch/_inductor/decomposition.py
256# excluding decompositions that results in prim ops
257# Resulting opset of decomposition is core aten ops
258def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
259    aten = torch.ops.aten
260    return get_decompositions(
261        [
262            aten.addcdiv,
263            aten.addcdiv_,
264            aten.addcmul,
265            aten.addcmul_,
266            aten.addr,
267            aten.affine_grid_generator,
268            aten.alias_copy,
269            aten.all,
270            aten.aminmax,
271            aten.arange.default,
272            aten.arange.start,
273            aten.avg_pool2d_backward,
274            aten.baddbmm,
275            aten.binary_cross_entropy,
276            aten.binary_cross_entropy_backward,
277            aten.binary_cross_entropy_with_logits,
278            aten.block_diag,
279            aten.celu,
280            aten.celu_,
281            aten.channel_shuffle,
282            aten.clamp_max,
283            aten.clamp_min,
284            aten.col2im,
285            aten.count_nonzero,
286            aten.linalg_cross,
287            aten.cudnn_batch_norm,
288            aten.cudnn_batch_norm_backward,
289            aten.miopen_batch_norm_backward,
290            aten.deg2rad,
291            aten.deg2rad_,
292            aten.detach,
293            aten.diag_embed,
294            aten.diagonal_backward,
295            aten.dot,
296            aten.vdot,
297            aten.elu,
298            aten.elu_,
299            aten.elu_backward,
300            aten._embedding_bag,
301            aten.embedding_dense_backward,
302            aten.empty_like,
303            aten._euclidean_dist.default,
304            aten.expand_as,
305            aten.expand_copy,
306            aten.eye,
307            aten.fill,
308            aten.fill_,
309            aten.floor_divide,
310            aten.frac,
311            aten.frac_,
312            aten._fused_moving_avg_obs_fq_helper,
313            aten.gelu_,
314            aten.gelu_backward,
315            aten.glu,
316            aten.glu_backward,
317            aten.hardshrink,
318            aten.hardsigmoid,
319            aten.hardsigmoid_,
320            aten.hardsigmoid_backward,
321            aten.hardswish,
322            aten.hardswish_,
323            aten.hardswish_backward,
324            aten.hardtanh_,
325            aten.hardtanh_backward,
326            aten.heaviside,
327            aten.heaviside_,
328            aten.huber_loss,
329            aten.huber_loss_backward,
330            aten.im2col,
331            aten.index_add,
332            aten.index_add_,
333            aten.index_copy,
334            aten.index_copy_,
335            aten.index_fill,
336            aten.index_fill_,
337            aten.isin,
338            aten.isneginf,
339            aten.isposinf,
340            aten.l1_loss,
341            aten._lazy_clone,
342            aten._test_parallel_materialize,
343            aten.leaky_relu_,
344            aten.leaky_relu_backward,
345            aten.lerp,
346            aten.lerp_,
347            aten.linspace,
348            aten.logaddexp,
349            aten.logaddexp2,
350            aten.logit,
351            aten.logit_,
352            aten.logit_backward,
353            aten.log_sigmoid_backward,
354            aten.log_sigmoid_forward,
355            aten._log_softmax_backward_data,
356            aten.logspace,
357            aten.logsumexp.default,
358            aten.masked_fill,
359            aten.masked_fill_,
360            aten.mish,
361            aten.mish_,
362            aten.mse_loss,
363            aten.mse_loss_backward,
364            aten.multi_margin_loss,
365            aten.multilabel_margin_loss_forward,
366            aten.mv,
367            aten.mvlgamma,
368            aten.mvlgamma_,
369            aten.nansum,
370            aten.nan_to_num,
371            aten.nan_to_num_,
372            aten.narrow,
373            aten.native_batch_norm_backward,
374            aten.native_dropout_backward,
375            aten.native_group_norm_backward,
376            aten.native_layer_norm_backward,
377            aten.new_empty,
378            aten.new_full,
379            aten.new_ones,
380            aten.new_zeros,
381            aten.nll_loss2d_forward,
382            aten.nll_loss2d_backward,
383            aten.nll_loss_backward,
384            aten.nll_loss_forward,
385            aten.norm,
386            aten.ones,
387            aten.ones_like,
388            aten.pixel_shuffle,
389            aten.pixel_unshuffle,
390            aten._prelu_kernel,
391            aten._prelu_kernel_backward,
392            aten._reshape_alias,
393            aten.rad2deg,
394            aten.rad2deg_,
395            aten.reflection_pad1d,
396            aten.reflection_pad1d_backward,
397            aten.reflection_pad2d,
398            aten.reflection_pad2d_backward,
399            aten.reflection_pad3d,
400            aten.reflection_pad3d_backward,
401            aten.replication_pad1d,
402            aten.replication_pad2d,
403            aten.replication_pad3d,
404            aten.renorm,
405            aten.renorm_,
406            aten.replication_pad2d,
407            aten.resize_as,
408            aten.roll,
409            aten.rot90,
410            aten.rrelu_with_noise,
411            aten.rrelu_with_noise_,
412            aten.rsub,
413            aten._safe_softmax,
414            aten._scaled_dot_product_flash_attention_for_cpu.default,
415            aten.select_backward,
416            aten.select_scatter,
417            aten.sgn,
418            aten.sgn_,
419            aten.sigmoid_backward,
420            aten.silu,
421            aten.silu_,
422            aten.silu_backward,
423            aten.sinc,
424            aten.sinc_,
425            aten.slice_backward,
426            aten.smooth_l1_loss,
427            aten.smooth_l1_loss_backward,
428            aten.soft_margin_loss,
429            aten.soft_margin_loss_backward,
430            aten._softmax_backward_data,
431            aten.softplus,
432            aten.softplus_backward,
433            aten.softshrink,
434            aten.special_entr,
435            aten.special_log_ndtr,
436            aten.special_xlog1py,
437            aten.split.Tensor,
438            aten.split_with_sizes_copy,
439            aten.squeeze.default,
440            aten.squeeze.dim,
441            aten.std,
442            aten.std_mean,
443            aten.stack,
444            aten.sum.default,
445            aten.sum.out,
446            aten.t,
447            aten.t_copy,
448            aten.take,
449            aten.tanh_backward,
450            aten.threshold,
451            aten.threshold_,
452            aten.threshold_backward,
453            aten.trace,
454            aten.transpose.int,
455            aten.tril,
456            aten.tril_,
457            aten.triu,
458            aten.triu_,
459            aten.unbind,
460            aten.unfold_backward,
461            aten.unfold_copy,
462            aten._unsafe_index,
463            aten._unsafe_index_put,
464            aten._unsafe_masked_index,
465            aten._unsafe_masked_index_put_accumulate,
466            aten.unsafe_split.Tensor,
467            aten.unsafe_split_with_sizes,
468            aten.unsqueeze_copy,
469            aten._unsafe_view,
470            aten.upsample_linear1d,
471            aten.upsample_bilinear2d,
472            aten.upsample_trilinear3d,
473            aten.upsample_nearest2d_backward,
474            aten.view_as_complex,
475            aten.xlogy,
476            aten.xlogy_,
477            aten.zero,
478            aten.zero_,
479            aten.zeros,
480            aten.zeros_like,
481            aten._chunk_cat,
482            aten._weight_norm_interface,
483        ]
484    )
485