xref: /aosp_15_r20/external/pytorch/torch/_library/custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import inspect
4import logging
5import weakref
6from contextlib import contextmanager
7from typing import (
8    Any,
9    Callable,
10    Dict,
11    Iterable,
12    Iterator,
13    List,
14    Optional,
15    Sequence,
16    Set,
17    Tuple,
18    Union,
19)
20
21import torch
22from torch import _C, _ops, Tensor
23from torch.utils._exposed_in import exposed_in
24
25from . import autograd, utils
26
27
28device_types_t = Optional[Union[str, Sequence[str]]]
29log = logging.getLogger(__name__)
30
31
32@exposed_in("torch.library")
33def custom_op(
34    name: str,
35    fn: Optional[Callable] = None,
36    /,
37    *,
38    mutates_args: Union[str, Iterable[str]],
39    device_types: device_types_t = None,
40    schema: Optional[str] = None,
41) -> Callable:
42    """Wraps a function into custom operator.
43
44    Reasons why you may want to create a custom op include:
45    - Wrapping a third-party library or custom kernel to work with PyTorch
46    subsystems like Autograd.
47    - Preventing torch.compile/export/FX tracing from peeking inside your function.
48
49    This API is used as a decorator around a function (please see examples).
50    The provided function must have type hints; these are needed to interface
51    with PyTorch's various subsystems.
52
53    Args:
54        name (str): A name for the custom op that looks like "{namespace}::{name}",
55            e.g. "mylib::my_linear". The name is used as the op's stable identifier
56            in PyTorch subsystems (e.g. torch.export, FX graphs).
57            To avoid name collisions, please use your project name as the namespace;
58            e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
59        mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
60            This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
61            it pessimistically assumes that all inputs to the operator are being mutated.
62        device_types (None | str | Sequence[str]): The device type(s) the function
63            is valid for. If no device type is provided, then the function
64            is used as the default implementation for all device types.
65            Examples: "cpu", "cuda".
66            When registering a device-specific implementation for an operator that accepts no Tensors,
67            we require the operator to have a "device: torch.device argument".
68        schema (None | str): A schema string for the operator. If None
69            (recommended) we'll infer a schema for the operator from its type
70            annotations. We recommend letting us infer a schema unless you
71            have a specific reason not to.
72            Example: "(Tensor x, int y) -> (Tensor, Tensor)".
73
74    .. note::
75        We recommend not passing in a ``schema`` arg and instead letting us infer
76        it from the type annotations. It is error-prone to write your own schema.
77        You may wish to provide your own schema if our interpretation of
78        the type annotation is not what you want.
79        For more info on how to write a schema string, see
80        `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
81
82    Examples::
83        >>> import torch
84        >>> from torch import Tensor
85        >>> from torch.library import custom_op
86        >>> import numpy as np
87        >>>
88        >>> @custom_op("mylib::numpy_sin", mutates_args=())
89        >>> def numpy_sin(x: Tensor) -> Tensor:
90        >>>     x_np = x.cpu().numpy()
91        >>>     y_np = np.sin(x_np)
92        >>>     return torch.from_numpy(y_np).to(device=x.device)
93        >>>
94        >>> x = torch.randn(3)
95        >>> y = numpy_sin(x)
96        >>> assert torch.allclose(y, x.sin())
97        >>>
98        >>> # Example of a custom op that only works for one device type.
99        >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
100        >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
101        >>>     x_np = x.numpy()
102        >>>     y_np = np.sin(x_np)
103        >>>     return torch.from_numpy(y_np)
104        >>>
105        >>> x = torch.randn(3)
106        >>> y = numpy_sin_cpu(x)
107        >>> assert torch.allclose(y, x.sin())
108        >>>
109        >>> # Example of a custom op that mutates an input
110        >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
111        >>> def numpy_sin_inplace(x: Tensor) -> None:
112        >>>     x_np = x.numpy()
113        >>>     np.sin(x_np, out=x_np)
114        >>>
115        >>> x = torch.randn(3)
116        >>> expected = x.sin()
117        >>> numpy_sin_inplace(x)
118        >>> assert torch.allclose(x, expected)
119        >>>
120        >>> # Example of a factory function
121        >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
122        >>> def bar(device: torch.device) -> Tensor:
123        >>>     return torch.ones(3)
124        >>>
125        >>> bar("cpu")
126
127    """
128
129    def inner(fn):
130        import torch
131
132        if schema is None:
133            schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
134        else:
135            schema_str = schema
136
137        namespace, opname = name.split("::")
138        result = CustomOpDef(namespace, opname, schema_str, fn)
139        if schema is not None:
140            # Check that schema's alias annotations match those of `mutates_args`.
141            expected = set()
142            for arg in result._opoverload._schema.arguments:
143                if arg.alias_info is not None and arg.alias_info.is_write:
144                    expected.add(arg.name)
145            if expected != set(mutates_args):
146                raise ValueError(
147                    f"Attempted to create a custom op with `mutates_args={mutates_args}` "
148                    f"and `schema={schema}. The schema suggests that the op mutates {expected}"
149                    f"which is different from what was provided to us in `mutates_args`. "
150                    f"Please make these consistent."
151                )
152        result.register_kernel(device_types)(fn)
153        return result
154
155    if fn is None:
156        return inner
157    return inner(fn)
158
159
160class CustomOpDef:
161    """CustomOpDef is a wrapper around a function that turns it into a custom op.
162
163    It has various methods for registering additional behavior for this
164    custom op.
165
166    You should not instantiate CustomOpDef directly; instead, use the
167    :func:`torch.library.custom_op` API.
168    """
169
170    def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
171        # Fields used to interface with the PyTorch dispatcher
172        self._namespace = namespace
173        self._name = name
174        self._schema = schema
175
176        self._init_fn = fn
177
178        self._backend_fns: Dict[Union[str, None], Callable] = {}
179        self._abstract_fn: Optional[Callable] = None
180        self._setup_context_fn: Optional[Callable] = None
181        self._backward_fn: Optional[Callable] = None
182        self._torch_dispatch_fns: Dict[type, Callable] = {}
183        self._vmap_fn: Optional[Callable] = None
184
185        self._lib = get_library_allowing_overwrite(self._namespace, self._name)
186        self._register_to_dispatcher()
187        self._disabled_kernel: Set = set()
188        OPDEFS[self._qualname] = self
189
190    @property
191    def _qualname(self) -> str:
192        return f"{self._namespace}::{self._name}"
193
194    def __repr__(self) -> str:
195        return f"<CustomOpDef({self._qualname})>"
196
197    @contextmanager
198    def set_kernel_enabled(self, device_type: str, enabled: bool = True):
199        """
200        Disable or re-enable an already registered kernel for this custom operator.
201
202        If the kernel is already disabled/enabled, this is a no-op.
203
204        Note:
205            If a kernel is first disabled and then registered, it is disabled until enabled again.
206
207        Args:
208            device_type (str): The device type to disable/enable the kernel for.
209            disable (bool): Whether to disable or enable the kernel.
210
211        Example:
212            >>> inp = torch.randn(1)
213            >>>
214            >>> # define custom op `f`.
215            >>> @custom_op("mylib::f", mutates_args=())
216            >>> def f(x: Tensor) -> Tensor:
217            >>>     return torch.zeros(1)
218            >>>
219            >>> print(f(inp))  # tensor([0.]), default kernel
220            >>>
221            >>> @f.register_kernel("cpu")
222            >>> def _(x):
223            >>>     return torch.ones(1)
224            >>>
225            >>> print(f(inp))  # tensor([1.]), CPU kernel
226            >>>
227            >>> # temporarily disable the CPU kernel
228            >>> with f.set_kernel_enabled("cpu", enabled = False):
229            >>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled
230
231        """
232        action = "enable" if enabled else "disable"
233        originally_disabled = device_type in self._disabled_kernel
234        if device_type not in self._backend_fns:
235            log.warning(
236                "Attempted to %s kernel for %s but no kernel was registered for this device type.",
237                action,
238                device_type,
239            )
240
241        if not enabled:
242            if originally_disabled:
243                log.warning(
244                    "Attempted to disable kernel for %s but it was already disabled.",
245                    device_type,
246                )
247            else:
248                self._disabled_kernel.add(device_type)
249        else:  # enable the kernel
250            if not originally_disabled:
251                log.warning(
252                    "Attempted to enable kernel for  %s but it was already enabled.",
253                    device_type,
254                )
255            else:
256                self._disabled_kernel.remove(device_type)
257
258        try:
259            yield
260        finally:
261            # restore original state
262            if originally_disabled:
263                self._disabled_kernel.add(device_type)
264            else:
265                self._disabled_kernel.discard(device_type)
266
267    def register_kernel(
268        self, device_types: device_types_t, fn: Optional[Callable] = None, /
269    ) -> Callable:
270        """Register an implementation for a device type for this operator.
271
272        Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
273        This API may be used as a decorator.
274
275        Args:
276            fn (Callable): The function to register as the implementation for
277                the given device types.
278            device_types (str | Sequence[str]): The device device_types to register an impl to.
279
280        Examples::
281            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
282            >>> import torch
283            >>> from torch import Tensor
284            >>> from torch.library import custom_op
285            >>> import numpy as np
286            >>>
287            >>> # Create a custom op that works on cpu
288            >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
289            >>> def numpy_sin(x: Tensor) -> Tensor:
290            >>>     x_np = x.numpy()
291            >>>     y_np = np.sin(x_np)
292            >>>     return torch.from_numpy(y_np)
293            >>>
294            >>> # Add implementations for the cuda device
295            >>> @numpy_sin.register_kernel("cuda")
296            >>> def _(x):
297            >>>     x_np = x.cpu().numpy()
298            >>>     y_np = np.sin(x_np)
299            >>>     return torch.from_numpy(y_np).to(device=x.device)
300            >>>
301            >>> x_cpu = torch.randn(3)
302            >>> x_cuda = x_cpu.cuda()
303            >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
304            >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
305
306        """
307
308        def inner(fn):
309            if device_types is None or isinstance(device_types, str):
310                dtypes: List[Union[str, None]] = [device_types]
311            else:
312                dtypes = list(device_types)
313            for device_type in dtypes:
314                if device_type not in self._backend_fns:
315
316                    def backend_impl(*args, **kwargs):
317                        # Checks the assumption that outputs cannot alias
318                        # inputs or other outputs.
319                        storages = {
320                            id(tensor.untyped_storage())
321                            for tensor in iter_tensors(args, kwargs)
322                        }
323
324                        result = self._backend_fns[device_type](*args, **kwargs)
325
326                        tuple_result = result
327                        if not isinstance(result, tuple):
328                            tuple_result = (result,)
329                        for tensor in iter_tensors(tuple_result, {}):
330                            key = id(tensor.untyped_storage())
331                            if id(tensor.untyped_storage()) in storages:
332                                fn = self._backend_fns[device_type]
333                                module = inspect.getmodule(fn)
334                                raise RuntimeError(
335                                    f"{self._name} (with implementation in {module}): "
336                                    f"The output of this custom operator (1) must not "
337                                    f"also be an input to this custom operator and "
338                                    f"(2) may not alias any inputs to this custom operator "
339                                    f"or other returns. "
340                                    f"The most common way to trigger this error is if "
341                                    f"we have y = custom_op(x) and y and x are the same Tensor. "
342                                    f"Please instead return a clone of the offending output "
343                                    f"tensor(s) (e.g. return x.clone()) or refactor the custom "
344                                    f"operator to not return y."
345                                )
346                            storages.add(key)
347                        return result
348
349                    if device_type is None:
350                        self._lib.impl(
351                            self._name, backend_impl, "CompositeExplicitAutograd"
352                        )
353                    else:
354                        self._lib.impl(
355                            self._name,
356                            backend_impl,
357                            _C._dispatch_key_for_device(device_type),
358                        )
359
360                # Wrap function to choose between the default implementation or the device-specific
361                # implementation depending on if the kernel is disabled.
362                @torch._disable_dynamo
363                def wrapped_fn(*args, **kwargs):
364                    if device_type in self._disabled_kernel:
365                        return self._init_fn(*args, **kwargs)
366                    else:
367                        return fn(*args, **kwargs)
368
369                self._backend_fns[device_type] = wrapped_fn
370            return fn
371
372        if device_types is not None and not utils.has_tensor_arg(
373            self._opoverload._schema
374        ):
375            device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
376            if device_arg_index is None:
377                raise ValueError(
378                    "Functions without tensor inputs are required to have a `device: torch.device` argument"
379                )
380            self._register_backend_select_dispatcher(device_arg_index)
381
382        # See NOTE: [Supporting decorator and non-decorator usage]
383        if fn is None:
384            return inner
385        return inner(fn)
386
387    def register_fake(self, fn: Callable, /) -> Callable:
388        r"""Register a FakeTensor implementation for this custom op.
389
390        This is necessary to get the operator to work efficiently with torch.compile.
391
392        The Fake impl (sometimes also known as a meta kernel or abstract impl)
393        specifies the behavior of this operator on Tensors that carry no data.
394        Given some input Tensors with certain properties
395        (sizes/strides/storage_offset/device), it specifies what the properties of
396        the output Tensors are.
397
398        Please see :func:`torch.library.impl_abstract` for more details.
399
400        Args:
401            fn (Callable): The function to register as the FakeTensor
402                implementation.
403
404        Examples:
405            >>> import torch
406            >>> import numpy as np
407            >>> from torch import Tensor
408            >>>
409            >>> # Example 1: an operator without data-dependent output shape
410            >>> @torch.library.custom_op("mylib::linear", mutates_args=())
411            >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
412            >>>     return (x @ weight.t()) + bias
413            >>>
414            >>> @linear.register_fake
415            >>> def _(x, weight, bias):
416            >>>     assert x.dim() == 2
417            >>>     assert weight.dim() == 2
418            >>>     assert bias.dim() == 1
419            >>>     assert x.shape[1] == weight.shape[1]
420            >>>     assert weight.shape[0] == bias.shape[0]
421            >>>     assert x.device == weight.device
422            >>>     return x.new_empty(x.size(0), weight.size(0))
423            >>>
424            >>> x = torch.randn(2, 2)
425            >>> weight = torch.randn(2, 2)
426            >>> bias = torch.randn(2)
427            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
428            >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
429            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
430            >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
431            >>>
432            >>> # Example 2: an operator with data-dependent output shape
433            >>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
434            >>> def nonzero(x: Tensor) -> Tensor:
435            >>>     x_np = x.cpu().numpy()
436            >>>     res = np.stack(np.nonzero(x_np), axis=1)
437            >>>     return torch.tensor(res, device=x.device)
438            >>>
439            >>> @nonzero.register_fake
440            >>> def _(x):
441            >>>     # Number of nonzero-elements is data-dependent.
442            >>>     # Since we cannot peek at the data in an abstract impl,
443            >>>     # we use the ctx object to construct a new symint that
444            >>>     # represents the data-dependent size.
445            >>>     ctx = torch.library.get_ctx()
446            >>>     nnz = ctx.new_dynamic_size()
447            >>>     shape = [nnz, x.dim()]
448            >>>     result = x.new_empty(shape, dtype=torch.int64)
449            >>>     return result
450            >>>
451            >>> x = torch.tensor([0, 1, 2, 0, 0, 1])
452            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
453            >>> out = torch.compile(nonzero, fullgraph=True)(x)
454            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
455            >>> assert torch.allclose(out, x.nonzero())
456
457        """
458        self._abstract_fn = fn
459        return fn
460
461    def register_torch_dispatch(
462        self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
463    ) -> Callable:
464        r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
465
466        This allows for open registration to specify the behavior between the operator
467        and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
468        or the operator directly.
469
470        Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
471        """
472
473        def register(fn):
474            if torch_dispatch_class not in self._torch_dispatch_fns:
475
476                def inner(*args, **kwargs):
477                    return self._torch_dispatch_fns[torch_dispatch_class](
478                        *args, **kwargs
479                    )
480
481                self._lib._register_torch_dispatch_rule(
482                    self._name, torch_dispatch_class, inner
483                )
484            self._torch_dispatch_fns[torch_dispatch_class] = fn
485            return fn
486
487        if fn is None:
488            return register
489        else:
490            return register(fn)
491
492    def register_autograd(
493        self,
494        backward: Callable,
495        /,
496        *,
497        setup_context: Optional[Callable] = None,
498    ) -> None:
499        r"""Register a backward formula for this custom op.
500
501        In order for an operator to work with autograd, you need to register
502        a backward formula:
503        1. You must tell us how to compute gradients during the backward pass
504        by providing us a "backward" function.
505        2. If you need any values from the forward to compute gradients, you can
506        use `setup_context` to save values for backward.
507
508        ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
509        - ``grads`` is one or more gradients. The number of gradients matches
510        the number of outputs of the operator.
511        The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
512        :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
513        same as :meth:`torch.autograd.Function.backward`.
514
515        ``setup_context(ctx, inputs, output)`` runs during the forward pass.
516        Please save quantities needed for backward onto the ``ctx`` object via
517        either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
518        or assigning them as attributes of ``ctx``. If your custom op has
519        kwarg-only arguments, we expect the signature of ``setup_context``
520        to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
521
522        Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
523        they may not directly access :meth:`torch.Tensor.data_ptr` and they must
524        not depend on or mutate global state. If you need a non-traceable backward,
525        you can make it a separate custom_op that you call inside ``backward_fn``.
526
527        Examples:
528            >>> import torch
529            >>> import numpy as np
530            >>> from torch import Tensor
531            >>>
532            >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
533            >>> def numpy_sin(x: Tensor) -> Tensor:
534            >>>     x_np = x.cpu().numpy()
535            >>>     y_np = np.sin(x_np)
536            >>>     return torch.from_numpy(y_np).to(device=x.device)
537            >>>
538            >>> def setup_context(ctx, inputs, output) -> Tensor:
539            >>>     x, = inputs
540            >>>     ctx.save_for_backward(x)
541            >>>
542            >>> def backward(ctx, grad):
543            >>>     x, = ctx.saved_tensors
544            >>>     return grad * x.cos()
545            >>>
546            >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
547            >>>
548            >>> x = torch.randn(3, requires_grad=True)
549            >>> y = numpy_sin(x)
550            >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
551            >>> assert torch.allclose(grad_x, x.cos())
552            >>>
553            >>> # Example with a keyword-only arg
554            >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
555            >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
556            >>>     x_np = x.cpu().numpy()
557            >>>     y_np = x_np * val
558            >>>     return torch.from_numpy(y_np).to(device=x.device)
559            >>>
560            >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
561            >>>     ctx.val = keyword_only_inputs["val"]
562            >>>
563            >>> def backward(ctx, grad):
564            >>>     return grad * ctx.val
565            >>>
566            >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
567            >>>
568            >>> x = torch.randn(3, requires_grad=True)
569            >>> y = numpy_mul(x, val=3.14)
570            >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
571            >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
572
573        """
574        schema = self._opoverload._schema
575        if not utils.is_functional_schema(schema):
576            raise RuntimeError(
577                f"Cannot register autograd formula for non-functional operator "
578                f"{self} with schema {schema}. Please create "
579                f"a functional operator and register an autograd formula for that."
580            )
581
582        self._backward_fn = backward
583        self._setup_context_fn = setup_context
584
585    def _register_to_dispatcher(self) -> None:
586        lib = self._lib
587        schema_str = self._name + self._schema
588        cpp_schema = _C.parse_schema(schema_str)
589        if utils.has_kwarg_only_tensors(cpp_schema):
590            # If you want to support this, the progression is:
591            # - supporting kwarg-only Tensors that are non-differentiable
592            # - supporting kwarg-only Tensors (regardless of differentiability)
593            raise NotImplementedError(
594                f"custom_op with kwarg-only Tensor args. Please make your "
595                f"tensors not kwarg-only. Got: {schema_str}"
596            )
597
598        lib.define(
599            schema_str,
600            tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
601        )
602        self._opoverload = utils.lookup_op(self._qualname)
603
604        def fake_impl(*args, **kwargs):
605            if self._abstract_fn is None:
606                if utils.can_generate_trivial_fake_impl(self._opoverload):
607                    return None
608                raise RuntimeError(
609                    f"There was no fake impl registered for {self}. "
610                    f"This is necessary for torch.compile/export/fx tracing to work. "
611                    f"Please use `{self._init_fn.__name__}.register_fake` to add an "
612                    f"fake impl."
613                )
614            return self._abstract_fn(*args, **kwargs)
615
616        lib._register_fake(self._name, fake_impl, _stacklevel=4)
617
618        autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
619        lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
620
621        schema = self._opoverload._schema
622        if schema.is_mutable:
623
624            def adinplaceorview_impl(keyset, *args, **kwargs):
625                for arg, val in utils.zip_schema(schema, args, kwargs):
626                    if not arg.alias_info:
627                        continue
628                    if not arg.alias_info.is_write:
629                        continue
630                    if isinstance(val, Tensor):
631                        torch.autograd.graph.increment_version(val)
632                    elif isinstance(val, (tuple, list)):
633                        for v in val:
634                            if isinstance(v, Tensor):
635                                torch.autograd.graph.increment_version(v)
636                with _C._AutoDispatchBelowADInplaceOrView():
637                    return self._opoverload.redispatch(
638                        keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
639                    )
640
641            lib.impl(
642                self._name,
643                adinplaceorview_impl,
644                "ADInplaceOrView",
645                with_keyset=True,
646            )
647
648    def _register_backend_select_dispatcher(self, device_arg_index: int):
649        """
650        Switch on the device argument to select the correct backend to dispatch to.
651        """
652
653        def backend_select(keyset, *args, **kwargs):
654            device = args[device_arg_index].type
655            if device not in self._backend_fns:
656                raise RuntimeError(
657                    f"{self._name} does not have a kernel registered for {device}. "
658                    "Please use register_kernel to do so."
659                )
660            dispatch_key = _C._dispatch_key_for_device(device)
661            dispatch_key = getattr(_C.DispatchKey, dispatch_key)
662            return self._opoverload.redispatch(
663                _C.DispatchKeySet(dispatch_key), *args, **kwargs
664            )
665
666        self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
667
668    def __call__(self, *args, **kwargs):
669        return self._opoverload(*args, **kwargs)
670
671    def register_vmap(
672        self,
673        func: Optional[Callable] = None,
674    ):
675        r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
676
677        This API may be used as a decorator.
678
679        In order for an operator to work with :func:`torch.vmap`, you may need to register a
680        vmap implementation in the following signature:
681
682            ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
683
684        where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
685
686        It specifies how do we compute the batched version of ``op`` given inputs with an additional
687        dimension (specified by ``in_dims``).
688
689        For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
690        if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
691        specifying what dimension of the Tensor is being vmapped over.
692
693        ``info`` is a collection of additional metadata that may be helpful:
694        ``info.batch_size`` specifies the size of the dimension being vmapped over, while
695        ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
696
697        The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
698        ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
699        per output that specifies if the output has the vmapped dimension and what index it is in.
700
701        Examples:
702            >>> import torch
703            >>> import numpy as np
704            >>> from torch import Tensor
705            >>> from typing import Tuple
706            >>>
707            >>> def to_numpy(tensor):
708            >>>     return tensor.cpu().numpy()
709            >>>
710            >>> lib = torch.library.Library("mylib", "FRAGMENT")
711            >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
712            >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
713            >>>     x_np = to_numpy(x)
714            >>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
715            >>>     return torch.tensor(x_np ** 3, device=x.device), dx
716            >>>
717            >>> def numpy_cube_vmap(info, in_dims, x):
718            >>>     result = numpy_cube(x)
719            >>>     return result, (in_dims[0], in_dims[0])
720            >>>
721            >>> numpy_cube.register_vmap(numpy_cube_vmap)
722            >>>
723            >>> x = torch.randn(3)
724            >>> torch.vmap(numpy_cube)(x)
725            >>>
726            >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
727            >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
728            >>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
729            >>>
730            >>> @numpy_mul.register_vmap
731            >>> def numpy_mul_vmap(info, in_dims, x, y):
732            >>>     x_bdim, y_bdim = in_dims
733            >>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
734            >>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
735            >>>     result = x * y
736            >>>     result = result.movedim(-1, 0)
737            >>>     return result, 0
738            >>>
739            >>>
740            >>> x = torch.randn(3)
741            >>> y = torch.randn(3)
742            >>> torch.vmap(numpy_mul)(x, y)
743        """
744        from torch._functorch.autograd_function import custom_function_call_vmap_helper
745        from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
746
747        def register(func):
748            need_register = self._vmap_fn is None
749            self._vmap_fn = func
750
751            if need_register:
752
753                def wrapped_func(keyset, *args, **kwargs):
754                    interpreter = retrieve_current_functorch_interpreter()
755                    return custom_function_call_vmap_helper(
756                        interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
757                    )
758
759                self._lib.impl(
760                    self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
761                )
762
763        if func is None:
764            return register
765        else:
766            return register(func)
767
768
769# NOTE: [Supporting decorator and non-decorator usage]
770#
771# Some APIs may be both used as a decorator and not as a decorator.
772# For example:
773#
774# >>> def fn(x):
775# >>>     return x.sin()
776# >>>
777# >>> # Usage 1: not as a decorator
778# >>> numpy_sin.register_kernel("cuda", fn)
779# >>>
780# >>> # Usage 2: as a decorator
781# >>> @numpy_sin.register_kernel("cuda")
782# >>> def fn2(x):
783# >>>     return x.sin
784#
785# The way we support this is that `register_kernel` accepts an optional `fn`.
786# If `fn` is provided (Usage 1), then we know that the user is using it not
787# as a decorator.
788# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
789# decorator.
790
791
792OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
793OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
794
795
796def get_library_allowing_overwrite(
797    namespace: str, name: str
798) -> "torch.library.Library":
799    qualname = f"{namespace}::{name}"
800
801    if qualname in OPDEF_TO_LIB:
802        OPDEF_TO_LIB[qualname]._destroy()
803        del OPDEF_TO_LIB[qualname]
804
805    lib = torch.library.Library(namespace, "FRAGMENT")  # noqa: TOR901
806    OPDEF_TO_LIB[qualname] = lib
807    return lib
808
809
810def iter_tensors(
811    args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
812) -> Iterator[Tensor]:
813    def check(arg):
814        if isinstance(arg, Tensor):
815            yield arg
816        elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
817            yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
818
819    for arg in args:
820        yield from check(arg)
821    for kwarg in kwargs.values():
822        yield from check(kwarg)
823
824
825def _maybe_get_opdef(
826    op: Union[CustomOpDef, _ops.OpOverload, str]
827) -> Optional[CustomOpDef]:
828    if isinstance(op, CustomOpDef):
829        return op
830    if isinstance(op, _ops.OpOverload):
831        op = op._name
832    assert isinstance(op, str)
833    if op in OPDEFS:
834        return OPDEFS[op]
835    return None
836