xref: /aosp_15_r20/external/pytorch/torch/library.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4import inspect
5import re
6import sys
7import traceback
8import weakref
9from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
10from typing_extensions import deprecated
11
12import torch
13import torch._library as _library
14from torch._library.custom_ops import (
15    _maybe_get_opdef,
16    custom_op,
17    CustomOpDef,
18    device_types_t,
19)
20from torch._library.infer_schema import infer_schema  # noqa: F401
21from torch._ops import OpOverload
22
23
24__all__ = [
25    "Library",
26    "impl",
27    "define",
28    "fallthrough_kernel",
29    "impl_abstract",
30    "register_fake",
31    "register_torch_dispatch",
32    "register_vmap",
33    "get_ctx",
34    "custom_op",
35    "infer_schema",
36]
37
38# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
39# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
40# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
41# libraries calling into kernels not intended to be called.
42_impls: Set[str] = set()
43_defs: Set[str] = set()
44
45# prim is reserved by TorchScript interpreter
46_reserved_namespaces = ["prim"]
47
48
49def fallthrough_kernel():
50    """
51    A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
52    """
53    raise NotImplementedError("fallthrough_kernel() should never be called.")
54
55
56class Library:
57    """
58    A class to create libraries that can be used to register new operators or
59    override operators in existing libraries from Python.
60    A user can optionally pass in a dispatch keyname if they only want to register
61    kernels corresponding to only one specific dispatch key.
62
63    To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
64    To create a new library (with name ns) to register new operators, set the kind to "DEF".
65    To create a fragment of a possibly existing library to register operators (and bypass
66    the limitation that there is only one library for a given namespace), set the kind to
67    "FRAGMENT".
68
69    Args:
70        ns: library name
71        kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
72        dispatch_key: PyTorch dispatch key (default: "")
73    """
74
75    def __init__(self, ns, kind, dispatch_key=""):
76        if kind not in ("IMPL", "DEF", "FRAGMENT"):
77            raise ValueError("Unsupported kind: ", kind)
78
79        if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
80            raise ValueError(
81                ns,
82                " is a reserved namespace. Please try creating a library with another name.",
83            )
84
85        frame = traceback.extract_stack(limit=3)[0]
86        filename, lineno = frame.filename, frame.lineno
87        self.m: Optional[Any] = torch._C._dispatch_library(
88            kind, ns, dispatch_key, filename, lineno
89        )
90        self.ns = ns
91        self._op_defs: Set[str] = set()
92        self._op_impls: Set[str] = set()
93        self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
94        self.kind = kind
95        self.dispatch_key = dispatch_key
96        # Use a finalizer to setup the "destructor" instead of __del__.
97        # Python __del__ can lead to weird things (globals and locals may already
98        # be gone when __del__ actually gets called!). finalizers help the
99        # situation because it lets us capture references and keeps them alive
100        weakref.finalize(
101            self,
102            _del_library,
103            _impls,
104            self._op_impls,
105            _defs,
106            self._op_defs,
107            self._registration_handles,
108        )
109
110    def __repr__(self):
111        return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
112
113    def define(self, schema, alias_analysis="", *, tags=()):
114        r"""Defines a new operator and its semantics in the ns namespace.
115
116        Args:
117            schema: function schema to define a new operator.
118            alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
119                                       inferred from the schema (default behavior) or not ("CONSERVATIVE").
120            tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
121                                       operator. Tagging an operator changes the operator's behavior
122                                       under various PyTorch subsystems; please read the docs for the
123                                       torch.Tag carefully before applying it.
124
125        Returns:
126            name of the operator as inferred from the schema.
127
128        Example::
129            >>> my_lib = Library("mylib", "DEF")
130            >>> my_lib.define("sum(Tensor self) -> Tensor")
131        """
132        # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
133        # AliasAnalysis type in C++
134        if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
135            raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
136        assert self.m is not None
137        if isinstance(tags, torch.Tag):
138            tags = (tags,)
139
140        name = schema.split("(")[0]
141        packet_name = name.split(".")[0] if "." in name else name
142        has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
143            getattr(torch.ops, self.ns), packet_name
144        )
145
146        result = self.m.define(schema, alias_analysis, tuple(tags))
147        name = schema.split("(")[0]
148        qualname = self.ns + "::" + name
149
150        # If the OpOverloadPacket exists already, then this means we're adding a
151        # new OpOverload for it. Refresh the packet to include the new OpOverload.
152        if has_preexisting_packet:
153            ns = getattr(torch.ops, self.ns)
154            packet = getattr(ns, packet_name)
155            torch._ops._refresh_packet(packet)
156
157        self._op_defs.add(qualname)
158        _defs.add(qualname)
159        return result
160
161    def _register_fake(self, op_name, fn, _stacklevel=1):
162        r"""Registers the fake impl for an operator defined in the library."""
163        source = torch._library.utils.get_source(_stacklevel + 1)
164        frame = sys._getframe(_stacklevel)
165        caller_module = inspect.getmodule(frame)
166        # Can be none if you call register_fake from somewhere there isn't a module
167        # (e.g. __main__)
168        caller_module_name = None if caller_module is None else caller_module.__name__
169
170        # TODO(rzou): We're gonna need to stage this change with torchvision,
171        # since torchvision is github first.
172        if caller_module_name is not None and caller_module_name.startswith(
173            "torchvision."
174        ):
175            caller_module_name = None
176
177        qualname = f"{self.ns}::{op_name}"
178        entry = torch._library.simple_registry.singleton.find(qualname)
179        if caller_module_name is not None:
180            func_to_register = _check_pystubs_once(fn, qualname, caller_module_name)
181        else:
182            func_to_register = fn
183
184        handle = entry.fake_impl.register(func_to_register, source)
185        self._registration_handles.append(handle)
186
187    def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn):
188        r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class.
189
190        This allows for open registration to specify the behavior between the operator
191        and the torch_dispatch_class without needing to modify the torch_dispatch_class
192        or the operator directly.
193
194        The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a
195        TorchDispatchMode.
196
197        If it is a Tensor subclass, we expect fn to have the following signature:
198        (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
199
200        If it is a TorchDispatchMode, we expect fn to have the following signature:
201        (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
202        """
203        qualname = f"{self.ns}::{op_name}"
204        entry = torch._library.simple_registry.singleton.find(qualname)
205        handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn)
206        self._registration_handles.append(handle)
207
208    def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
209        r"""Register the operator to use the AOTI-compiled implementation.
210
211        Args:
212            op_name: operator name (along with the overload) or OpOverload object.
213            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
214                          the dispatch key that the library was created with.
215
216        Example::
217            >>> my_lib = Library("aten", "IMPL")
218            >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
219        """
220        if dispatch_key == "":
221            dispatch_key = self.dispatch_key
222        assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
223
224        if isinstance(op_name, str):
225            name = op_name
226        elif isinstance(op_name, OpOverload):
227            name = op_name._schema.name
228            overload_name = op_name._schema.overload_name
229            if overload_name != "":
230                name = name + "." + overload_name
231        else:
232            raise RuntimeError(
233                "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
234                "as the first argument"
235            )
236
237        key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
238        if key in _impls:
239            # TODO: in future, add more info about where the existing function is registered (this info is
240            # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
241            raise RuntimeError(
242                "This is not allowed since there's already a kernel registered from python overriding {}"
243                "'s behavior for {} dispatch key and {} namespace.".format(
244                    name.split("::")[-1], dispatch_key, self.ns
245                )
246            )
247
248        assert self.m is not None
249        impl_fn: Callable = self.m.impl_with_aoti_compile
250        impl_fn(self.ns, name.split("::")[-1], dispatch_key)
251
252        _impls.add(key)
253        self._op_impls.add(key)
254
255    def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
256        r"""Registers the function implementation for an operator defined in the library.
257
258        Args:
259            op_name: operator name (along with the overload) or OpOverload object.
260            fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
261                to register a fallthrough.
262            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
263                          the dispatch key that the library was created with.
264            with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
265                         to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
266
267        Example::
268            >>> my_lib = Library("aten", "IMPL")
269            >>> def div_cpu(self, other):
270            >>>     return self * (1 / other)
271            >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
272        """
273        if not callable(fn):
274            raise TypeError(
275                f"Input function is required to be a callable but found type {type(fn)}"
276            )
277        if dispatch_key == "":
278            dispatch_key = self.dispatch_key
279
280        if isinstance(op_name, str):
281            name = op_name
282        elif isinstance(op_name, OpOverload):
283            name = op_name._schema.name
284            overload_name = op_name._schema.overload_name
285            if overload_name != "":
286                name = name + "." + overload_name
287        else:
288            raise RuntimeError(
289                "impl should be passed either a name or an OpOverload object as the first argument"
290            )
291
292        key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
293        if key in _impls:
294            # TODO: in future, add more info about where the existing function is registered (this info is
295            # today already returned by the C++ warning when impl is called but we error out before that)
296            raise RuntimeError(
297                "This is not allowed since there's already a kernel registered from python overriding {}"
298                "'s behavior for {} dispatch key and {} namespace.".format(
299                    name.split("::")[-1], dispatch_key, self.ns
300                )
301            )
302
303        if dispatch_key == "Meta":
304            dispatcher_op_name = name
305            if "::" not in dispatcher_op_name:
306                dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
307
308            # Internally, we shouldn't be registering meta kernels for any operators that
309            # have CompositeImplicitAutograd kernels.
310            # Instead, we should be letting those decompositions run, and writing meta kernels
311            # only for the base operators.
312            if torch._C._dispatch_has_kernel_for_dispatch_key(
313                dispatcher_op_name, "CompositeImplicitAutograd"
314            ):
315                raise RuntimeError(
316                    f"We should not register a meta kernel directly to the operator '{name}',"
317                    " because it has a CompositeImplicitAutograd kernel in core."
318                    " Instead we should let the operator decompose, and ensure that we have meta kernels"
319                    " for the base ops that it decomposes into."
320                )
321
322        assert self.m is not None
323        self.m.impl(
324            name,
325            dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
326            fn,
327            with_keyset,
328        )
329
330        _impls.add(key)
331        self._op_impls.add(key)
332
333    def fallback(self, fn, dispatch_key="", *, with_keyset=False):
334        r"""Registers the function implementation as the fallback for the given key.
335
336        This function only works for a library with global namespace ("_").
337
338        Args:
339            fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
340                to register a fallthrough.
341            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
342                          the dispatch key that the library was created with.
343            with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
344                         to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
345
346        Example::
347            >>> my_lib = Library("_", "IMPL")
348            >>> def fallback_kernel(op, *args, **kwargs):
349            >>>     # Handle all autocast ops generically
350            >>>     # ...
351            >>> my_lib.fallback(fallback_kernel, "Autocast")
352        """
353        if dispatch_key == "":
354            dispatch_key = self.dispatch_key
355
356        if self.ns != "_":
357            raise RuntimeError(
358                f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
359            )
360
361        assert dispatch_key != ""
362        assert self.m is not None
363
364        self.m.fallback(dispatch_key, fn, with_keyset)
365
366    def _destroy(self):
367        if self.m is not None:
368            self.m.reset()
369        self.m = None
370        for handle in self._registration_handles:
371            handle.destroy()
372        self._registration_handles.clear()
373        global _impls
374        _impls -= self._op_impls
375        for name in self._op_defs:
376            # Delete the cached torch.ops.ns.foo if it was registered.
377            # Otherwise, accessing it leads to a segfault.
378            # It's possible that we only registered an overload in this Library
379            # and another library owns an alive overload.
380            # That's OK - the next time torch.ops.ns.foo gets called, it'll be
381            # recomputed to point at the right collection of overloads.
382            ns, name_with_overload = name.split("::")
383            name = name_with_overload.split(".")[0]
384            if not hasattr(torch.ops, ns):
385                continue
386            namespace = getattr(torch.ops, ns)
387            if not hasattr(namespace, name):
388                continue
389            delattr(namespace, name)
390
391
392def _del_library(
393    captured_impls,
394    op_impls,
395    captured_defs,
396    op_defs,
397    registration_handles,
398):
399    captured_impls -= op_impls
400    captured_defs -= op_defs
401    for handle in registration_handles:
402        handle.destroy()
403
404
405@contextlib.contextmanager
406def _scoped_library(*args, **kwargs):
407    try:
408        lib = Library(*args, **kwargs)
409        yield lib
410    finally:
411        lib._destroy()
412
413
414_keep_alive: List[Library] = []
415
416
417NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
418
419
420@functools.singledispatch
421def define(qualname, schema, *, lib=None, tags=()):
422    r"""Defines a new operator.
423
424    In PyTorch, defining an op (short for "operator") is a two step-process:
425    - we need to define the op (by providing an operator name and schema)
426    - we need to implement behavior for how the operator interacts with
427    various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
428
429    This entrypoint defines the custom operator (the first step)
430    you must then perform the second step by calling various
431    ``impl_*`` APIs, like :func:`torch.library.impl` or
432    :func:`torch.library.register_fake`.
433
434    Args:
435        qualname (str): The qualified name for the operator. Should be
436            a string that looks like "namespace::name", e.g. "aten::sin".
437            Operators in PyTorch need a namespace to
438            avoid name collisions; a given operator may only be created once.
439            If you are writing a Python library, we recommend the namespace to
440            be the name of your top-level module.
441        schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor"
442            for an op that accepts one Tensor and returns one Tensor. It does
443            not contain the operator name (that is passed in ``qualname``).
444        lib (Optional[Library]): If provided, the lifetime of this operator
445            will be tied to the lifetime of the Library object.
446        tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
447            operator. Tagging an operator changes the operator's behavior
448            under various PyTorch subsystems; please read the docs for the
449            torch.Tag carefully before applying it.
450
451    Example::
452        >>> import torch
453        >>> import numpy as np
454        >>>
455        >>> # Define the operator
456        >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
457        >>>
458        >>> # Add implementations for the operator
459        >>> @torch.library.impl("mylib::sin", "cpu")
460        >>> def f(x):
461        >>>     return torch.from_numpy(np.sin(x.numpy()))
462        >>>
463        >>> # Call the new operator from torch.ops.
464        >>> x = torch.randn(3)
465        >>> y = torch.ops.mylib.sin(x)
466        >>> assert torch.allclose(y, x.sin())
467
468    """
469    if not isinstance(qualname, str):
470        raise ValueError(
471            f"define(qualname, schema): expected qualname "
472            f"to be instance of str, got {type(qualname)}"
473        )
474    namespace, name = torch._library.utils.parse_namespace(qualname)
475    if lib is None:
476        lib = Library(namespace, "FRAGMENT")
477        _keep_alive.append(lib)
478    if not NAMELESS_SCHEMA.fullmatch(schema):
479        raise ValueError(
480            f"define(qualname, schema, ...): expected schema "
481            f'to look like e.g. "(Tensor x) -> Tensor" but '
482            f'got "{schema}"'
483        )
484    lib.define(name + schema, alias_analysis="", tags=tags)
485
486
487@define.register
488def _(lib: Library, schema, alias_analysis=""):
489    """The old torch.library.define.
490    We're keeping this around for BC reasons
491    """
492
493    def wrap(f):
494        name = lib.define(schema, alias_analysis)
495        lib.impl(name, f)
496        return f
497
498    return wrap
499
500
501@functools.singledispatch
502def impl(qualname, types, func=None, *, lib=None):
503    """Register an implementation for a device type for this operator.
504
505    You may pass "default" for ``types`` to register this implementation as the
506    default implementation for ALL device types.
507    Please only use this if the implementation truly supports all device types;
508    for example, this is true if it is a composition of built-in PyTorch operators.
509
510    Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
511
512    Args:
513        qualname (str): Should be a string that looks like "namespace::operator_name".
514        types (str | Sequence[str]): The device types to register an impl to.
515        lib (Optional[Library]): If provided, the lifetime of this registration
516            will be tied to the lifetime of the Library object.
517
518    Examples:
519        >>> import torch
520        >>> import numpy as np
521        >>>
522        >>> # Define the operator
523        >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
524        >>>
525        >>> # Add implementations for the cpu device
526        >>> @torch.library.impl("mylib::mysin", "cpu")
527        >>> def f(x):
528        >>>     return torch.from_numpy(np.sin(x.numpy()))
529        >>>
530        >>> x = torch.randn(3)
531        >>> y = torch.ops.mylib.mysin(x)
532        >>> assert torch.allclose(y, x.sin())
533    """
534    return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
535
536
537def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
538    if isinstance(types, str):
539        types = (types,)
540    keys = set({})
541    for typ in types:
542        is_dispatch_key = torch._C._parse_dispatch_key(typ)
543        if is_dispatch_key:
544            # We also support passing a DispatchKey to impl. Please prefer using
545            # the higher-level torch.library APIs and only pass DispatchKey to
546            # torch.library.impl with caution (or even better, don't use this
547            # option and file an issue on GitHub for what you need).
548            # We don't advertise this to users because
549            # it is very easy to shoot yourself in the foot.
550            keys.add(typ)
551        else:
552            keys.add(_device_type_to_key(typ))
553
554    def register(func):
555        namespace, _ = torch._library.utils.parse_namespace(qualname)
556
557        if lib is None:
558            use_lib = Library(namespace, "FRAGMENT")
559            _keep_alive.append(use_lib)
560        else:
561            use_lib = lib
562        if disable_dynamo:
563
564            @torch._disable_dynamo
565            def func_no_dynamo(*args, **kwargs):
566                return func(*args, **kwargs)
567
568            for key in keys:
569                use_lib.impl(qualname, func_no_dynamo, key)
570        else:
571            for key in keys:
572                use_lib.impl(qualname, func, key)
573
574    if func is None:
575        return register
576    else:
577        register(func)
578
579
580def _device_type_to_key(device_type: str) -> str:
581    if device_type == "default":
582        # This is technically not correct, because although all device_type
583        # DispatchKeys are included in CompositeExplicitAutograd,
584        # not everything in CompositeExplicitAutograd is associated with a
585        # device_type. I don't really care that much about the difference.
586        return "CompositeExplicitAutograd"
587    return torch._C._dispatch_key_for_device(device_type)
588
589
590@impl.register
591def _(lib: Library, name, dispatch_key=""):
592    """Legacy torch.library.impl API. Kept around for BC"""
593
594    def wrap(f):
595        lib.impl(name, f, dispatch_key)
596        return f
597
598    return wrap
599
600
601@deprecated(
602    "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
603    "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
604    category=FutureWarning,
605)
606def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
607    r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
608    Please use that instead.
609    """
610    if func is not None:
611        _stacklevel = _stacklevel + 1
612    return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
613
614
615_op_identifier = Union[
616    str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
617]
618
619
620def register_kernel(
621    op: _op_identifier,
622    device_types: device_types_t,
623    func: Optional[Callable] = None,
624    /,
625    *,
626    lib: Optional[Library] = None,
627):
628    """Register an implementation for a device type for this operator.
629
630    Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
631    This API may be used as a decorator.
632
633    Args:
634        fn (Callable): The function to register as the implementation for
635            the given device types.
636        device_types (None | str | Sequence[str]): The device_types to register an impl to.
637            If None, we will register to all device types -- please only use
638            this option if your implementation is truly device-type-agnostic.
639
640    Examples::
641        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
642        >>> import torch
643        >>> from torch import Tensor
644        >>> from torch.library import custom_op
645        >>> import numpy as np
646        >>>
647        >>> # Create a custom op that works on cpu
648        >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
649        >>> def numpy_sin(x: Tensor) -> Tensor:
650        >>>     x_np = x.numpy()
651        >>>     y_np = np.sin(x_np)
652        >>>     return torch.from_numpy(y_np)
653        >>>
654        >>> # Add implementations for the cuda device
655        >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
656        >>> def _(x):
657        >>>     x_np = x.cpu().numpy()
658        >>>     y_np = np.sin(x_np)
659        >>>     return torch.from_numpy(y_np).to(device=x.device)
660        >>>
661        >>> x_cpu = torch.randn(3)
662        >>> x_cuda = x_cpu.cuda()
663        >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
664        >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
665
666    """
667
668    if not isinstance(
669        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
670    ):
671        raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
672    if isinstance(op, torch._ops.OpOverload):
673        op = op._name
674    opdef = _maybe_get_opdef(op)
675    if opdef is not None:
676        return opdef.register_kernel(device_types, func)
677    assert isinstance(op, str)
678    if device_types is None:
679        device_types = "CompositeExplicitAutograd"
680
681    return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
682
683
684def register_fake(
685    op: _op_identifier,
686    func: Optional[Callable] = None,
687    /,
688    *,
689    lib: Optional[Library] = None,
690    _stacklevel: int = 1,
691):
692    r"""Register a FakeTensor implementation ("fake impl") for this operator.
693
694    Also sometimes known as a "meta kernel", "abstract impl".
695
696    An "FakeTensor implementation" specifies the behavior of this operator on
697    Tensors that carry no data ("FakeTensor"). Given some input Tensors with
698    certain properties (sizes/strides/storage_offset/device), it specifies
699    what the properties of the output Tensors are.
700
701    The FakeTensor implementation has the same signature as the operator.
702    It is run for both FakeTensors and meta tensors. To write a FakeTensor
703    implementation, assume that all Tensor inputs to the operator are
704    regular CPU/CUDA/Meta tensors, but they do not have storage, and
705    you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
706    The FakeTensor implementation must consist of only PyTorch operations
707    (and may not directly access the storage or data of any input or
708    intermediate Tensors).
709
710    This API may be used as a decorator (see examples).
711
712    For a detailed guide on custom ops, please see
713    https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
714
715    Examples:
716        >>> import torch
717        >>> import numpy as np
718        >>> from torch import Tensor
719        >>>
720        >>> # Example 1: an operator without data-dependent output shape
721        >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
722        >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
723        >>>     raise NotImplementedError("Implementation goes here")
724        >>>
725        >>> @torch.library.register_fake("mylib::custom_linear")
726        >>> def _(x, weight, bias):
727        >>>     assert x.dim() == 2
728        >>>     assert weight.dim() == 2
729        >>>     assert bias.dim() == 1
730        >>>     assert x.shape[1] == weight.shape[1]
731        >>>     assert weight.shape[0] == bias.shape[0]
732        >>>     assert x.device == weight.device
733        >>>
734        >>>     return (x @ weight.t()) + bias
735        >>>
736        >>> with torch._subclasses.fake_tensor.FakeTensorMode():
737        >>>     x = torch.randn(2, 3)
738        >>>     w = torch.randn(3, 3)
739        >>>     b = torch.randn(3)
740        >>>     y = torch.ops.mylib.custom_linear(x, w, b)
741        >>>
742        >>> assert y.shape == (2, 3)
743        >>>
744        >>> # Example 2: an operator with data-dependent output shape
745        >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
746        >>> def custom_nonzero(x: Tensor) -> Tensor:
747        >>>     x_np = x.numpy(force=True)
748        >>>     res = np.stack(np.nonzero(x_np), axis=1)
749        >>>     return torch.tensor(res, device=x.device)
750        >>>
751        >>> @torch.library.register_fake("mylib::custom_nonzero")
752        >>> def _(x):
753        >>> # Number of nonzero-elements is data-dependent.
754        >>> # Since we cannot peek at the data in an fake impl,
755        >>> # we use the ctx object to construct a new symint that
756        >>> # represents the data-dependent size.
757        >>>     ctx = torch.library.get_ctx()
758        >>>     nnz = ctx.new_dynamic_size()
759        >>>     shape = [nnz, x.dim()]
760        >>>     result = x.new_empty(shape, dtype=torch.int64)
761        >>>     return result
762        >>>
763        >>> from torch.fx.experimental.proxy_tensor import make_fx
764        >>>
765        >>> x = torch.tensor([0, 1, 2, 3, 4, 0])
766        >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
767        >>> trace.print_readable()
768        >>>
769        >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
770
771    """
772    if not isinstance(
773        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
774    ):
775        raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
776    if isinstance(op, torch._ops.OpOverload):
777        op = op._name
778    opdef = _maybe_get_opdef(op)
779    if opdef is not None:
780        if func is None:
781            return opdef.register_fake
782        else:
783            return opdef.register_fake(func)
784    assert isinstance(op, str)
785
786    stacklevel = _stacklevel
787
788    def register(func):
789        namespace, op_name = torch._library.utils.parse_namespace(op)
790        if lib is None:
791            use_lib = Library(namespace, "FRAGMENT")
792            _keep_alive.append(use_lib)
793        else:
794            use_lib = lib
795        use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)
796        return func
797
798    if func is None:
799        return register
800    else:
801        stacklevel += 1
802        return register(func)
803
804
805def register_autograd(
806    op: _op_identifier,
807    backward: Callable,
808    /,
809    *,
810    setup_context: Optional[Callable] = None,
811    lib=None,
812) -> None:
813    r"""Register a backward formula for this custom op.
814
815    In order for an operator to work with autograd, you need to register
816    a backward formula:
817    1. You must tell us how to compute gradients during the backward pass
818    by providing us a "backward" function.
819    2. If you need any values from the forward to compute gradients, you can
820    use `setup_context` to save values for backward.
821
822    ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``:
823    - ``grads`` is one or more gradients. The number of gradients matches
824    the number of outputs of the operator.
825    The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
826    :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
827    same as :meth:`torch.autograd.Function.backward`.
828
829    ``setup_context(ctx, inputs, output)`` runs during the forward pass.
830    Please save quantities needed for backward onto the ``ctx`` object via
831    either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
832    or assigning them as attributes of ``ctx``. If your custom op has
833    kwarg-only arguments, we expect the signature of ``setup_context``
834    to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
835
836    Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
837    they may not directly access :meth:`torch.Tensor.data_ptr` and they must
838    not depend on or mutate global state. If you need a non-traceable backward,
839    you can make it a separate custom_op that you call inside ``backward_fn``.
840
841    Examples:
842        >>> import torch
843        >>> import numpy as np
844        >>> from torch import Tensor
845        >>>
846        >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
847        >>> def numpy_sin(x: Tensor) -> Tensor:
848        >>>     x_np = x.cpu().numpy()
849        >>>     y_np = np.sin(x_np)
850        >>>     return torch.from_numpy(y_np).to(device=x.device)
851        >>>
852        >>> def setup_context(ctx, inputs, output) -> Tensor:
853        >>>     x, = inputs
854        >>>     ctx.save_for_backward(x)
855        >>>
856        >>> def backward(ctx, grad):
857        >>>     x, = ctx.saved_tensors
858        >>>     return grad * x.cos()
859        >>>
860        >>> torch.library.register_autograd(
861        ...     "mylib::numpy_sin", backward, setup_context=setup_context
862        ... )
863        >>>
864        >>> x = torch.randn(3, requires_grad=True)
865        >>> y = numpy_sin(x)
866        >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
867        >>> assert torch.allclose(grad_x, x.cos())
868        >>>
869        >>> # Example with a keyword-only arg
870        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
871        >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
872        >>>     x_np = x.cpu().numpy()
873        >>>     y_np = x_np * val
874        >>>     return torch.from_numpy(y_np).to(device=x.device)
875        >>>
876        >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
877        >>>     ctx.val = keyword_only_inputs["val"]
878        >>>
879        >>> def backward(ctx, grad):
880        >>>     return grad * ctx.val
881        >>>
882        >>> torch.library.register_autograd(
883        ...     "mylib::numpy_mul", backward, setup_context=setup_context
884        ... )
885        >>>
886        >>> x = torch.randn(3, requires_grad=True)
887        >>> y = numpy_mul(x, val=3.14)
888        >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
889        >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
890
891    """
892    if not isinstance(
893        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
894    ):
895        raise ValueError(
896            f"register_autograd(op): got unexpected type for op: {type(op)}"
897        )
898    if isinstance(op, torch._ops.OpOverload):
899        op = op._name
900    opdef = _maybe_get_opdef(op)
901    if opdef is not None:
902        opdef.register_autograd(backward, setup_context=setup_context)
903        return
904
905    assert isinstance(op, str)
906    qualname = op
907    op = torch._library.utils.lookup_op(qualname)
908    schema = op._schema
909    if not _library.utils.is_functional_schema(schema):
910        raise RuntimeError(
911            f"Cannot register autograd formula for non-functional operator "
912            f"{op} with schema {schema}. Please create "
913            f"a functional operator and register an autograd formula for that."
914        )
915    if _library.utils.has_kwarg_only_tensors(schema):
916        raise NotImplementedError(
917            f"register_autograd with kwarg-only Tensor args. In the original "
918            f"definition of the op, please make your tensors not kwarg-only. "
919            f"Got: {schema}"
920        )
921
922    info = _library.autograd.Info(backward, setup_context)
923    autograd_kernel = _library.autograd.make_autograd_impl(op, info)
924    namespace, opname = torch._library.utils.parse_namespace(qualname)
925    if lib is None:
926        lib = Library(namespace, "FRAGMENT")
927        _keep_alive.append(lib)
928    lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
929
930
931def register_torch_dispatch(
932    op: _op_identifier,
933    torch_dispatch_class: Any,
934    func: Optional[Callable] = None,
935    /,
936    *,
937    lib: Optional[Library] = None,
938):
939    r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
940
941    This allows for open registration to specify the behavior between the operator
942    and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
943    or the operator directly.
944
945    The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
946    TorchDispatchMode.
947
948    If it is a Tensor subclass, we expect ``func`` to have the following signature:
949    ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
950
951    If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
952    ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
953
954    ``args`` and ``kwargs`` will have been normalized the same way they are
955    in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
956
957    Examples:
958
959        >>> import torch
960        >>>
961        >>> @torch.library.custom_op("mylib::foo", mutates_args={})
962        >>> def foo(x: torch.Tensor) -> torch.Tensor:
963        >>>     return x.clone()
964        >>>
965        >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
966        >>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
967        >>>         return func(*args, **kwargs)
968        >>>
969        >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
970        >>> def _(mode, func, types, args, kwargs):
971        >>>     x, = args
972        >>>     return x + 1
973        >>>
974        >>> x = torch.randn(3)
975        >>> y = foo(x)
976        >>> assert torch.allclose(y, x)
977        >>>
978        >>> with MyMode():
979        >>>     y = foo(x)
980        >>> assert torch.allclose(y, x + 1)
981
982    """
983    if not isinstance(
984        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
985    ):
986        raise ValueError(
987            "register_torch_dispatch(op): got unexpected type for op: {type(op)}"
988        )
989    if isinstance(op, torch._ops.OpOverload):
990        op = op._name
991    opdef = _maybe_get_opdef(op)
992    if opdef is not None:
993        return opdef.register_torch_dispatch(torch_dispatch_class, func)
994    assert isinstance(op, str)
995
996    def register(func):
997        namespace, op_name = torch._library.utils.parse_namespace(op)
998        if lib is None:
999            use_lib = Library(namespace, "FRAGMENT")
1000            _keep_alive.append(use_lib)
1001        else:
1002            use_lib = lib
1003        use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
1004        return func
1005
1006    if func is None:
1007        return register
1008    else:
1009        return register(func)
1010
1011
1012def register_vmap(
1013    op: _op_identifier,
1014    func: Optional[Callable] = None,
1015    /,
1016    *,
1017    lib=None,
1018):
1019    r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
1020
1021    This API may be used as a decorator (see examples).
1022
1023    In order for an operator to work with :func:`torch.vmap`, you may need to register a
1024    vmap implementation in the following signature:
1025
1026        ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
1027
1028    where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
1029    We do not support kwarg-only Tensor args.
1030
1031    It specifies how do we compute the batched version of ``op`` given inputs with an additional
1032    dimension (specified by ``in_dims``).
1033
1034    For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
1035    if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
1036    specifying what dimension of the Tensor is being vmapped over.
1037
1038    ``info`` is a collection of additional metadata that may be helpful:
1039    ``info.batch_size`` specifies the size of the dimension being vmapped over, while
1040    ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
1041
1042    The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
1043    ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
1044    per output that specifies if the output has the vmapped dimension and what index it is in.
1045
1046    Examples:
1047        >>> import torch
1048        >>> import numpy as np
1049        >>> from torch import Tensor
1050        >>> from typing import Tuple
1051        >>>
1052        >>> def to_numpy(tensor):
1053        >>>     return tensor.cpu().numpy()
1054        >>>
1055        >>> lib = torch.library.Library("mylib", "FRAGMENT")
1056        >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
1057        >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
1058        >>>     x_np = to_numpy(x)
1059        >>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
1060        >>>     return torch.tensor(x_np ** 3, device=x.device), dx
1061        >>>
1062        >>> def numpy_cube_vmap(info, in_dims, x):
1063        >>>     result = numpy_cube(x)
1064        >>>     return result, (in_dims[0], in_dims[0])
1065        >>>
1066        >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
1067        >>>
1068        >>> x = torch.randn(3)
1069        >>> torch.vmap(numpy_cube)(x)
1070        >>>
1071        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
1072        >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
1073        >>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
1074        >>>
1075        >>> @torch.library.register_vmap("mylib::numpy_mul")
1076        >>> def numpy_mul_vmap(info, in_dims, x, y):
1077        >>>     x_bdim, y_bdim = in_dims
1078        >>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
1079        >>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
1080        >>>     result = x * y
1081        >>>     result = result.movedim(-1, 0)
1082        >>>     return result, 0
1083        >>>
1084        >>>
1085        >>> x = torch.randn(3)
1086        >>> y = torch.randn(3)
1087        >>> torch.vmap(numpy_mul)(x, y)
1088
1089    .. note::
1090        The vmap function should aim to preserve the semantics of the entire custom operator.
1091        That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
1092
1093        If your custom operator has any custom behavior in the backward pass, please
1094        keep this in mind.
1095
1096    """
1097    if not isinstance(
1098        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
1099    ):
1100        raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}")
1101    if isinstance(op, torch._ops.OpOverload):
1102        op = op._name
1103    opdef = _maybe_get_opdef(op)
1104    if opdef is not None:
1105        return opdef.register_vmap(func)
1106    assert isinstance(op, str)
1107    qualname = op
1108    op = torch._library.utils.lookup_op(qualname)
1109    schema = op._schema
1110    if _library.utils.has_kwarg_only_tensors(schema):
1111        raise NotImplementedError(
1112            f"register_vmap with kwarg-only Tensor args. In the original "
1113            f"definition of the op, please make your tensors not kwarg-only. "
1114            f"Got: {schema}"
1115        )
1116
1117    def register(func):
1118        nonlocal op, lib
1119
1120        namespace, opname = torch._library.utils.parse_namespace(qualname)
1121        if lib is None:
1122            lib = Library(namespace, "FRAGMENT")
1123            _keep_alive.append(lib)
1124
1125        from torch._functorch.autograd_function import custom_function_call_vmap_helper
1126        from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
1127
1128        def wrapped_func(keyset, *args, **kwargs):
1129            interpreter = retrieve_current_functorch_interpreter()
1130            return custom_function_call_vmap_helper(
1131                interpreter, func, op, *args, **kwargs
1132            )
1133
1134        lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
1135
1136    if func is None:
1137        return register
1138    else:
1139        return register(func)
1140
1141
1142# If the op was defined in C++, then we want to make sure there was an
1143# m.set_python_module(module, ...) call and that the module is the
1144# same as the module that called torch.library.register_fake.
1145def _check_pystubs_once(func, qualname, actual_module_name):
1146    checked = False
1147
1148    def inner(*args, **kwargs):
1149        nonlocal checked
1150        if checked:
1151            return func(*args, **kwargs)
1152
1153        op = torch._library.utils.lookup_op(qualname)
1154        if op._defined_in_python:
1155            checked = True
1156            return func(*args, **kwargs)
1157
1158        maybe_pystub = torch._C._dispatch_pystub(
1159            op._schema.name, op._schema.overload_name
1160        )
1161        if maybe_pystub is None:
1162            if torch._library.utils.requires_set_python_module():
1163                namespace = op.namespace
1164                cpp_filename = op._handle.debug()
1165                raise RuntimeError(
1166                    f"Operator '{qualname}' was defined in C++ and has a Python "
1167                    f"fake impl. In this situation, we require there to also be a "
1168                    f'companion C++ `m.set_python_module("{actual_module_name}")` '
1169                    f"call, but we could not find one. Please add that to "
1170                    f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
1171                    f"operator was registered in ({cpp_filename})"
1172                )
1173        else:
1174            pystub_module = maybe_pystub[0]
1175            if actual_module_name != pystub_module:
1176                cpp_filename = op._handle.debug()
1177                raise RuntimeError(
1178                    f"Operator '{qualname}' specified that its python fake impl "
1179                    f"is in the Python module '{pystub_module}' but it was actually found "
1180                    f"in '{actual_module_name}'. Please either move the fake impl "
1181                    f"or correct the m.set_python_module call ({cpp_filename})"
1182                )
1183        checked = True
1184        return func(*args, **kwargs)
1185
1186    return inner
1187
1188
1189# NOTE [ctx inside the fake implementation]
1190# If a user has an operator with data-dependent output shape, then when writing
1191# a fake implementation they must query the current ctx and use methods on the
1192# ctx to construct a new unbacked symint.
1193#
1194# This is done via us setting the global_ctx_getter function every time a fake
1195# implementation is invoked.
1196def get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
1197    """get_ctx() returns the current AbstractImplCtx object.
1198
1199    Calling ``get_ctx()`` is only valid inside of an fake impl
1200    (see :func:`torch.library.register_fake` for more usage details.
1201    """
1202    return torch._library.fake_impl.global_ctx_getter()
1203
1204
1205_OPCHECK_DEFAULT_UTILS = (
1206    "test_schema",
1207    "test_autograd_registration",
1208    "test_faketensor",
1209    "test_aot_dispatch_dynamic",
1210)
1211
1212
1213def opcheck(
1214    op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
1215    args: Tuple[Any, ...],
1216    kwargs: Optional[Dict[str, Any]] = None,
1217    *,
1218    test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
1219    raise_exception: bool = True,
1220) -> Dict[str, str]:
1221    """Given an operator and some sample arguments, tests if the operator is
1222    registered correctly.
1223
1224    That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
1225    custom op, you specified metadata (e.g. mutability info) about the custom op
1226    and these APIs require that the functions you pass them satisfy certain
1227    properties (e.g. no data pointer access in the fake/meta/abstract kernel)
1228    ``opcheck`` tests these metadata and properties.
1229
1230    Concretely, we test the following:
1231
1232    - test_schema: If the schema matches the implementation of
1233      the operator. For example: if the schema specifies a Tensor is mutated,
1234      then we check the implementation mutates the Tensor. If the schema
1235      specifies that we return a new Tensor, then we check that the
1236      implementation returns a new Tensor (instead of an existing one or
1237      a view of an existing one).
1238    - test_autograd_registration: If the operator supports training
1239      (autograd): we check that its autograd formula is registered via
1240      torch.library.register_autograd or a manual registration to one
1241      or more DispatchKey::Autograd keys. Any other DispatchKey-based
1242      registrations may lead to undefined behavior.
1243    - test_faketensor: If the operator has a FakeTensor kernel
1244      (and if it is correct). The FakeTensor kernel is necessary (
1245      but not sufficient) for the operator to work with PyTorch compilation
1246      APIs (torch.compile/export/FX). We check that a FakeTensor kernel
1247      (also sometimes known as a meta kernel) was registered for the
1248      operator and that it is correct. This test takes the result of
1249      running the operator on real tensors and the result of running
1250      the operator on FakeTensors and checks that they have the same
1251      Tensor metadata (sizes/strides/dtype/device/etc).
1252    - test_aot_dispatch_dynamic: If the operator has correct behavior
1253      with PyTorch compilation APIs (torch.compile/export/FX).
1254      This checks that the outputs (and gradients, if applicable) are the
1255      same under eager-mode PyTorch and torch.compile.
1256      This test is a superset of ``test_faketensor`` and is an e2e test;
1257      other things it tests are that the operator supports
1258      functionalization and that the backward pass (if it exists) also
1259      supports FakeTensor and functionalization.
1260
1261    For best results, please call ``opcheck`` multiple times with a
1262    representative set of inputs. If your operator supports
1263    autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
1264    if your operator supports multiple devices (e.g. CPU and CUDA), please
1265    use ``opcheck`` with inputs on all supported devices.
1266
1267    Args:
1268        op: The operator. Must either be a function decorated with
1269            :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
1270            found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
1271        args: The args to the operator
1272        kwargs: The kwargs to the operator
1273        test_utils: Tests that we should run. Default: all of them.
1274            Example: ("test_schema", "test_faketensor")
1275        raise_exception: If we should raise an exception on the first
1276            error. If False, we will return a dict with information
1277            on if each test passed or not.
1278
1279    .. warning::
1280
1281        opcheck and :func:`torch.autograd.gradcheck` test different things;
1282        opcheck tests if your usage of torch.library APIs is correct while
1283        :func:`torch.autograd.gradcheck` tests if your autograd formula is
1284        mathematically correct. Use both to test custom ops that support
1285        gradient computation.
1286
1287    Example:
1288
1289        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1290        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
1291        >>> def numpy_add(x: Tensor, y: float) -> Tensor:
1292        >>>     x_np = x.numpy(force=True)
1293        >>>     z_np = x_np + y
1294        >>>     return torch.from_numpy(z_np).to(x.device)
1295        >>>
1296        >>> @numpy_sin.register_fake
1297        >>> def _(x, y):
1298        >>>     return torch.empty_like(x)
1299        >>>
1300        >>> def setup_context(ctx, inputs, output):
1301        >>>     y, = inputs
1302        >>>     ctx.y = y
1303        >>>
1304        >>> def backward(ctx, grad):
1305        >>>     return grad * ctx.y, None
1306        >>>
1307        >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
1308        >>>
1309        >>> sample_inputs = [
1310        >>>     (torch.randn(3), 3.14),
1311        >>>     (torch.randn(2, 3, device='cuda'), 2.718),
1312        >>>     (torch.randn(1, 10, requires_grad=True), 1.234),
1313        >>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
1314        >>> ]
1315        >>>
1316        >>> for args in sample_inputs:
1317        >>>     torch.library.opcheck(foo, args)
1318
1319    """
1320    import torch.testing._internal.optests as optests
1321
1322    return optests.opcheck(
1323        op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
1324    )
1325