xref: /aosp_15_r20/external/pytorch/torch/_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import contextlib
4import ctypes
5import importlib
6import inspect
7import sys
8import types
9from typing import Any, Callable, Dict, List, Set, Type, Union
10
11import torch
12import torch.utils._pytree as pytree
13from torch import _utils_internal
14from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
15from torch._functorch.pyfunctorch import dispatch_functorch
16from torch.utils._python_dispatch import TorchDispatchMode
17
18
19# Query `hasattr` only once.
20_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
21
22
23@contextlib.contextmanager
24def dl_open_guard():
25    """
26    Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
27    shared library to load custom operators.
28    """
29    if not _SET_GLOBAL_FLAGS:
30        yield
31        return
32    old_flags = sys.getdlopenflags()
33    sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
34    try:
35        yield
36    finally:
37        sys.setdlopenflags(old_flags)
38
39
40class OperatorBase:
41    """
42    Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
43    (which represents Python-only operators that are unrepresentable in TorchScript).
44    """
45
46    def __init__(self):
47        # The dispatch cache precomputes a mapping of dispatch key that the
48        # dispatcher wants to dispatch to, to an actual implementation of the
49        # dispatch key.  Confusingly, the actual implementation could *also* be a
50        # dispatch key, but in this case, this refers to the C++ kernel that
51        # was registered to some dispatch key.  Aliases are permitted in the
52        # latter but not the former; for example, you might lookup the
53        # entry for AutogradCPU, and this maps you to the Autograd key for
54        # the generic autograd kernel that works for all devices.  Since this
55        # is the Python dispatcher, you can also put an arbitrary Python
56        # callable to call instead.  This handler gets precisely the
57        # args/kwargs that the operator was __call__'ed with.
58        # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
59        # for use with OpOverload; cache lookup is done entirely from C++
60        # for speed.
61        # TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
62        self._dispatch_cache: Dict[
63            DispatchKey, Union[DispatchKey, Callable[..., Any]]
64        ] = {}
65
66        # This table allows you to override the behavior of a particular
67        # dispatch key to call a custom Python function, rather than the
68        # ordinary C++ configured behavior.  This is the raison d'etre of
69        # Python dispatcher: to let you program the dispatcher from Python
70        # in case you need something unusual, and don't want to clobber
71        # the existing registrations using the Python operator registration
72        # API.
73        self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
74
75        # This table allows you to override the behavior of a particular
76        # operator for a particular TorchDispatchMode.  In practice,
77        # we are using this mostly for ProxyTensorMode.  Modes can be
78        # thought of as an open world extension of dispatch keys, so it
79        # makes sense that you should be able to register them, the same
80        # way you can register dispatch keys.
81        self.python_key_table: Dict[
82            Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
83        ] = {}
84
85        # This table allows you to override the behavior of functorch
86        # transformations.  NB: this currently only does something for
87        # HigherOrderOperator
88        self.functorch_table = {}
89
90    def __call__(self, *args, **kwargs):
91        raise NotImplementedError
92
93    def has_kernel_for_dispatch_key(self, k):
94        return k in self.py_kernels
95
96    def has_kernel_for_any_dispatch_key(self, ks):
97        for k in self.py_kernels:
98            if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
99                return True
100        return False
101
102    def py_impl(self, k):
103        def inner(fn):
104            if inspect.isclass(k) and (
105                issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
106            ):
107                assert k not in self.python_key_table
108                # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys?
109                self.python_key_table[k] = fn
110                self._dispatch_cache.clear()
111                return fn
112
113            if isinstance(k, torch._C._functorch.TransformType):
114                assert k not in self.functorch_table
115                self.functorch_table[k] = fn
116                return fn
117
118            assert isinstance(k, DispatchKey)
119            assert (
120                k != DispatchKey.Python
121            ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
122
123            if k in self.py_kernels:
124                raise RuntimeError(
125                    f"Trying to override a python impl for {k} on operator {self.name()}"
126                )
127            self.py_kernels[k] = fn
128            self._dispatch_cache.clear()
129            return fn
130
131        return inner
132
133    # Registers an implementation to all **3** variants of functionalization that we have:
134    # - DispatchKey.Functionalize
135    # - functorch.TransformType.Functionalize
136    # - FunctionalTensorMode
137    # Example:
138    #   @py_functionalize_impl
139    #   def functionalize_rule(ctx, inner_f, *args):
140    #       args_unwrapped = ctx.unwrap_tensors(args)
141    #       with ctx.redispatch_to_next():
142    #           out = ctx.functionalize(inner_f)(*args_unwrapped)
143    #           return ctx.wrap_tensors(out)
144    def py_functionalize_impl(self, fn):
145        from torch._subclasses.functional_tensor import (
146            CppFunctionalizeAPI as _CppFunctionalizeAPI,
147            FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
148            PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
149        )
150
151        # Construct our three flavors of functionalization,
152        # each of which have slightly different wrap/unwrap/redispatch policies
153        def functionalize_dk_fn(*args, **kwargs):
154            return fn(_CppFunctionalizeAPI(), *args, **kwargs)
155
156        def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
157            return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
158
159        def functionalize_functorch_fn(interpreter, *args, **kwargs):
160            return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
161
162        self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
163        self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
164            functionalize_dispatch_mode_fn
165        )
166        self.py_impl(torch._C._functorch.TransformType.Functionalize)(
167            functionalize_functorch_fn
168        )
169
170        return fn
171
172    def name(self):
173        raise NotImplementedError
174
175
176# Equivalent to computeDispatchTableEntryWithDebug
177def resolve_key(op: OperatorBase, k: DispatchKey):  # type: ignore[valid-type]
178    # 1. (Direct) operator registration
179    if op.has_kernel_for_dispatch_key(k):
180        return k
181    # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
182    cand = DispatchKey.CompositeExplicitAutogradNonFunctional
183    if (
184        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
185    ) and op.has_kernel_for_dispatch_key(cand):
186        return cand
187    # 2.2 Use CompositeExplicitAutograd kernel if available
188    cand = DispatchKey.CompositeExplicitAutograd
189    if (
190        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
191    ) and op.has_kernel_for_dispatch_key(cand):
192        return cand
193    has_backend_kernel = op.has_kernel_for_any_dispatch_key(
194        torch._C._dispatch_get_backend_keyset_from_autograd(k)
195    ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
196    # 2.3. Use CompositeImplicitAutograd kernel if available
197    cand = DispatchKey.CompositeImplicitAutogradNestedTensor
198    if (
199        (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
200        and op.has_kernel_for_dispatch_key(cand)
201        and not has_backend_kernel
202    ):
203        return cand
204    cand = DispatchKey.CompositeImplicitAutograd
205    if (
206        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
207    ) and op.has_kernel_for_dispatch_key(cand):
208        if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
209            torch._C._dispatch_autogradother_backends
210        ):
211            raise RuntimeError("ambiguous autogradother kernel")
212        elif not has_backend_kernel:
213            return cand
214    # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
215    cand = DispatchKey.Autograd
216    if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
217        return cand
218    # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
219    cand = DispatchKey.FuncTorchBatchedDecomposition
220    if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
221        return cand
222    # Backend fallback
223    if torch._C._dispatch_has_backend_fallback(k):
224        # The dispatch key itself will implicitly route to backend fallback.
225        # This is probably not great for the pure Python implementation.
226        return k
227    raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
228
229
230_higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
231
232_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
233    DispatchKey.PythonDispatcher,  # type: ignore[attr-defined]
234    DispatchKey.PythonTLSSnapshot,  # type: ignore[attr-defined]
235    DispatchKey.ADInplaceOrView,
236    DispatchKey.BackendSelect,
237    DispatchKey.AutocastCPU,  # type: ignore[attr-defined]
238    DispatchKey.AutocastCUDA,  # type: ignore[attr-defined]
239]
240
241
242class HigherOrderOperator(OperatorBase, abc.ABC):
243    # The HigherOrderOperator will appear as torch.ops.higher_order.{name}
244    #
245    # If you're creating a new HigherOrderOperator, please do not change the
246    # default. Adding operators to the global torch.ops namespace is a bad
247    # practice due to name collisions.
248    def __init__(self, name):
249        super().__init__()
250        if type(self) is HigherOrderOperator:
251            raise RuntimeError(
252                "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
253            )
254        self._name = name
255
256        # Make _OPNamespace not scream, this whole name based association needs a good hard look
257        self.__name__ = name
258        _higher_order_ops[name] = self
259        self._ns = "higher_order"
260        self.__module__ = "torch.ops.higher_order"
261
262        self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
263
264        for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
265            self.fallthrough(dispatch_key)
266
267        # [NOTE] We have to register pre-dispatch key implementation
268        # because sometimes HOP use aot-dispatch tracing to detect certaion
269        # mutations. This is problematic when we are functionalizing HOP
270        # during pre-dispatch because when the inner tracer starts, it will see
271        # that PreDispatch key is still active. In that case, we just redispatch
272        # it to next key. This is only safe to do when PreDispatch key stack has no
273        # active modes.
274
275    def py_impl(self, k):
276        if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
277            self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
278        return super().py_impl(k)
279
280    @property
281    def namespace(self):
282        return self._ns
283
284    def fallthrough(self, dispatch_key):
285        self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
286
287    # Use positional-only argument to avoid naming collide with custom ops arguments
288    # that are named "self".
289    def dispatch(self, /, dispatch_key, *args, **kwargs):
290        from torch.utils._python_dispatch import _get_current_dispatch_mode
291
292        if dispatch_key in self._dispatch_cache:
293            kernel = self._dispatch_cache[dispatch_key]
294            assert not isinstance(kernel, DispatchKey)
295            return kernel(*args, **kwargs)
296
297        if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
298            return dispatch_functorch(self, args, kwargs)
299
300        if dispatch_key == DispatchKey.Python:
301            # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
302            # in torch/csrc/utils/python_arg_parser.cpp
303
304            overloaded_args_list = []
305
306            def has_python_key(tensor):
307                return torch._C._dispatch_keys(tensor).has("Python")
308
309            def check_overloaded(arg):
310                if isinstance(arg, torch.Tensor) and has_python_key(arg):
311                    overloaded_args_list.append(arg)
312
313            for arg in (*args, *kwargs.values()):
314                check_overloaded(arg)
315                if isinstance(arg, (list, tuple)):
316                    for a in arg:
317                        check_overloaded(a)
318
319            overloaded_args = tuple(overloaded_args_list)
320            overloaded_types = tuple(type(arg) for arg in overloaded_args)
321
322            # Step 1: dispatch on any user TorchDispatchModes
323            from torch.utils._python_dispatch import _pop_mode_temporarily
324
325            curr_mode = _get_current_dispatch_mode()
326            if curr_mode is not None:
327                if type(curr_mode) in self.python_key_table:
328                    handler = self.python_key_table[type(curr_mode)]
329                    with _pop_mode_temporarily() as mode:
330                        # "natural" calling convention: (mode, *args, **kwargs)
331                        # TODO(rzou): we should support torch_dispatch calling convention too.
332                        result = handler(mode, *args, **kwargs)
333                else:
334                    raise NotImplementedError(
335                        f"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
336                        f"We recommend filing an issue."
337                    )
338                if result is not NotImplemented:
339                    return result
340
341            # Step 2: dispatch on any subclasses
342            for arg in overloaded_args:
343                subclass_type = type(arg)
344                if (
345                    subclass_type.__torch_dispatch__
346                    == torch._C._disabled_torch_dispatch_impl
347                ):
348                    continue
349                if subclass_type in self.python_key_table:
350                    handler = self.python_key_table[subclass_type]
351                    # "natural" calling convention: (*args, **kwargs)
352                    # TODO(rzou): we should support torch_dispatch calling convention too.
353                    result = handler(*args, **kwargs)
354                else:
355                    raise NotImplementedError(
356                        f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
357                        f"We recommend filing an issue."
358                    )
359                if result is not NotImplemented:
360                    return result
361
362            # All handlers returned NotImplemented
363            raise TypeError(
364                f"Multiple dispatch failed for {self._name}. There was no registered that "
365                f"did not return NotImplemented. Use HOP.py_impl to register some. "
366                f"Tried mode: {curr_mode}) and subclasses: "
367                f"{[type(a) for a in overloaded_args]}"
368            )
369
370        functionality_key = torch._C._to_functionality_key(dispatch_key)  # type: ignore[attr-defined]
371        if functionality_key == DispatchKey.PreDispatch:
372            from torch.utils._python_dispatch import _pop_mode_temporarily
373
374            # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
375            # calls inside of a mode.
376            if (
377                _len_torch_dispatch_stack_pre_dispatch() > 0
378            ) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
379                DispatchKey.Python
380            ):
381                curr_mode = _get_current_dispatch_mode_pre_dispatch()
382                assert (
383                    curr_mode is not None
384                ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
385                assert (
386                    type(curr_mode) in self.python_key_table
387                ), f"Current active mode {curr_mode} not registered"
388                handler = self.python_key_table[type(curr_mode)]
389                with _pop_mode_temporarily(functionality_key) as mode:
390                    return handler(mode, *args, **kwargs)
391
392        final_key = resolve_key(self, dispatch_key)
393
394        # This can current fail due to backend fallbacks.  You just have to
395        # register them by hand for HigherOrderOperator.
396        if final_key not in self.py_kernels:
397            raise NotImplementedError(
398                f"could not find kernel for HigherOrderOperator {self._name} "
399                f"at dispatch key {final_key} (resolved from {dispatch_key})"
400            )
401
402        # [NOTE] We shouldn't cache PreDispatch kernel here because depending
403        # on what modes are active, predispatch behaviour is different.
404        # Also we do same thing for normal ops:
405        # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
406        if dispatch_key != DispatchKey.PreDispatch:
407            self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
408        kernel = self.py_kernels[final_key]
409        # It's illegal to register DispatchKey to py_kernels, since there's no
410        # C++ kernel to call into
411        assert not isinstance(kernel, DispatchKey)
412        return kernel(*args, **kwargs)
413
414    @abc.abstractmethod
415    def __call__(self, /, *args, **kwargs):
416        # Dynamo already traces the body of HigherOrderOp beforehand when it
417        # so no need to trace into it.
418        from torch._dynamo import disable
419
420        @disable
421        def wrapper():
422            flat_args = _to_flat_tuple(args, kwargs)
423            if torch.overrides.has_torch_function(flat_args):
424                return torch.overrides.handle_torch_function(
425                    self, flat_args, *args, **kwargs
426                )
427
428            dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
429            return self.dispatch(
430                dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
431            )
432
433        return wrapper()
434
435    def __str__(self):
436        return f"{self.name()}"
437
438    def name(self):
439        return self._name
440
441
442def _to_flat_tuple(args, kwargs):
443    return pytree.arg_tree_leaves(*args, **kwargs)
444
445
446def _compute_keyset(args, kwargs, non_fallthrough_keys):
447    tensors = _get_tensors(args, kwargs)
448    return key_extractor(tensors, non_fallthrough_keys)
449
450
451def _get_tensors(args, kwargs):
452    flat_all = _to_flat_tuple(args, kwargs)
453    tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
454    return tuple(tensor_args)
455
456
457# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
458# at ATen/core/dispatch/DispatchKeyExtractor.h
459def key_extractor(tensors, key_mask):
460    key_set = torch._C._dispatch_tls_local_include_set()
461    for tensor in tensors:
462        key_set = key_set | torch._C._dispatch_keys(tensor)
463    key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
464    key_set = key_set & key_mask
465    return key_set
466
467
468# Mode stack for PreDispatchKey
469# it should always have three keys with
470# priority given to FunctionalTensorMode and
471# then ProxyTorchDispatchMode. It means that
472# slot 0 belongs to ProxyTorchDispatchMode and
473# slot 1 belongs to FunctionalTensorMode.
474#
475# SchemaCheckMode is separate from the other 2,
476# and is only valid when the stack is empty.
477# SchemaCheckMode is for testing purposes, and
478# is meant to run in eager mode on concrete inputs,
479# checking for incorrect schemas in regards to
480# aliasing or mutating ops.
481class _ModeStackStateForPreDispatch:
482    def __init__(self):
483        self.__infra_modes = [None, None]
484        self._schema_check_mode = None
485
486    def set(self, index, mode):
487        assert index < len(self.__infra_modes)
488        self.__infra_modes[index] = mode
489
490    def get(self, index):
491        assert index < len(self.__infra_modes)
492        return self.__infra_modes[index]
493
494    def count(self):
495        return len([i for i in self.__infra_modes if i is not None]) + int(
496            self._schema_check_mode is not None
497        )
498
499
500_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
501
502
503def unset_mode_pre_dispatch(mode_key, schema_check=False):
504    current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
505    assert mode_key is None or mode_key in (
506        torch._C._TorchDispatchModeKey.PROXY,
507        torch._C._TorchDispatchModeKey.FUNCTIONAL,
508    )
509    if schema_check:
510        assert mode_key is None
511
512    def _unset_mode():
513        if mode_key == torch._C._TorchDispatchModeKey.PROXY:
514            current_mode = current_mode_stack_pre_dispatch.get(0)
515            mode_stack_state_for_pre_dispatch().set(0, None)
516            return current_mode
517        elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL:
518            current_mode = current_mode_stack_pre_dispatch.get(1)
519            mode_stack_state_for_pre_dispatch().set(1, None)
520            return current_mode
521        else:
522            current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
523            mode_stack_state_for_pre_dispatch()._schema_check_mode = None
524            return current_mode
525
526    current_mode = _unset_mode()
527
528    new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
529    # When we are unsetting a mode, we need to check if there is
530    # active mode left on the PreDispatch key. If there is nothing
531    # active, we need to remove PreDispatch key from local dispatch include
532    # set.
533    if new_pre_dispatch_len == 0:
534        torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False)
535
536    return current_mode
537
538
539def _set_mode_pre_dispatch(mode):
540    from torch._subclasses.functional_tensor import FunctionalTensorMode
541    from torch._subclasses.schema_check_mode import SchemaCheckMode
542    from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
543
544    assert isinstance(
545        mode,
546        (
547            FunctionalTensorMode,
548            ProxyTorchDispatchMode,
549            SchemaCheckMode,
550        ),
551    )
552
553    previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
554    if isinstance(mode, SchemaCheckMode):
555        current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
556        if previous_mode_stack_len > 0:
557            raise AssertionError(
558                "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
559            )
560        mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
561    elif isinstance(mode, FunctionalTensorMode):
562        current_mode = mode_stack_state_for_pre_dispatch().get(1)
563        assert current_mode is None
564        mode_stack_state_for_pre_dispatch().set(1, mode)
565    else:
566        current_mode = mode_stack_state_for_pre_dispatch().get(0)
567        assert current_mode is None
568        mode_stack_state_for_pre_dispatch().set(0, mode)
569
570    # When we are setting a mode, we need to check if there is
571    # active mode left on the PreDispatch key. If there was nothing
572    # active before setting this mode, it means that PreDispatch key
573    # was turned off. So we need to turn it on again.
574    if previous_mode_stack_len == 0:
575        torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True)
576
577
578def _pop_mode_from_pre_dispatch():
579    mode_stack = mode_stack_state_for_pre_dispatch()
580    pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
581
582    if pre_dispatch_len == 0:
583        raise AssertionError("Trying to pop empty mode stack")
584
585    if mode_stack._schema_check_mode is not None:
586        return unset_mode_pre_dispatch(None, schema_check=True)
587    if mode_stack.get(1) is not None:
588        return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
589    if mode_stack.get(0) is not None:
590        return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
591
592
593def _len_torch_dispatch_stack_pre_dispatch():
594    return mode_stack_state_for_pre_dispatch().count()
595
596
597def _get_dispatch_mode_pre_dispatch(mode_key):
598    assert mode_key in (
599        torch._C._TorchDispatchModeKey.PROXY,
600        torch._C._TorchDispatchModeKey.FUNCTIONAL,
601    )
602    if mode_key == torch._C._TorchDispatchModeKey.PROXY:
603        return mode_stack_state_for_pre_dispatch().get(0)
604    else:
605        return mode_stack_state_for_pre_dispatch().get(1)
606
607
608def _get_current_dispatch_mode_pre_dispatch():
609    if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
610        return mode_stack_state_for_pre_dispatch()._schema_check_mode
611    else:
612        stack_len = mode_stack_state_for_pre_dispatch().count()
613        if stack_len == 2:
614            return mode_stack_state_for_pre_dispatch().get(1)
615        if stack_len == 1:
616            return (
617                mode_stack_state_for_pre_dispatch().get(1)
618                if mode_stack_state_for_pre_dispatch().get(1) is not None
619                else mode_stack_state_for_pre_dispatch().get(0)
620            )
621    return None
622
623
624def mode_stack_state_for_pre_dispatch():
625    global _mode_stack_state_for_pre_dispatch
626    return _mode_stack_state_for_pre_dispatch
627
628
629cached_ops: Set["OpOverload"] = set()
630
631
632def add_cached_op(op_overload):
633    global cached_ops
634    cached_ops.add(op_overload)
635
636
637def reset_cached_ops():
638    global cached_ops
639    cached_ops.clear()
640
641
642def get_cached_ops():
643    global cached_ops
644    return cached_ops
645
646
647# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
648# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
649class OpOverload(OperatorBase):
650    def __init__(self, overloadpacket, op, op_dk, schema, tags):
651        super().__init__()
652        self._op = op
653        self._op_dk = op_dk
654        self._schema = schema
655        self._overloadpacket = overloadpacket
656        self._tags = tags
657        self._overloadname = (
658            "default" if schema.overload_name == "" else schema.overload_name
659        )
660        self._name = self._schema.name
661        if schema.overload_name:
662            self._name += "." + schema.overload_name
663        self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
664        self.__module__ = overloadpacket.__module__
665        op.__module__ = overloadpacket.__module__
666        self.__qualname__ = self._name
667        self.__annotations__ = {}
668        # Only compute the OperatorHandle when we need it. Not all OpOverloads have
669        # OperatorHandles (the TorchScript ones don't...)
670        self._lazy_handle = None
671
672        # If the OpOverload was constructed from a Library.def in Python.
673        self._defined_in_python = self.__qualname__ in torch.library._defs
674
675        # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
676        is_write = None
677        for a in self._schema.arguments:
678            if a.alias_info is None:
679                continue
680            if is_write is None:
681                is_write = a.alias_info.is_write
682            else:
683                # We will conservatively call mixed mutable/non-mutable
684                # aliased inputs as NOT a view
685                is_write = a.alias_info.is_write or is_write
686        self.is_view = is_write is not None and not is_write
687
688    @property
689    def _namespace(self):
690        return self._schema.name.split("::")[0]
691
692    @property
693    def _opname(self):
694        return self._schema.name.split("::")[1]
695
696    @property
697    def _handle(self):
698        if self._lazy_handle is None:
699            self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
700                self._schema.name, self._schema.overload_name
701            )
702        return self._lazy_handle
703
704    # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
705    def __deepcopy__(self, memo=None):
706        return self
707
708    def __repr__(self):
709        return "<OpOverload(op='{}.{}', overload='{}')>".format(
710            *self._schema.name.split("::"), self._overloadname
711        )
712
713    # Use positional-only argument to avoid naming collision with aten ops arguments
714    # that are named "self". This way, all the aten ops can be called by kwargs.
715    def __call__(self, /, *args, **kwargs):
716        return self._op(*args, **kwargs)
717
718    # Use positional-only argument to avoid naming collision with aten ops arguments
719    # that are named "self". This way, all the aten ops can be called by kwargs.
720    def redispatch(self, /, keyset, *args, **kwargs):
721        return self._handle.redispatch_boxed(keyset, *args, **kwargs)
722
723    def __hash__(self):
724        return hash(self._op)
725
726    # `my_namespace.my_op_name.overload_name`
727    def __str__(self):
728        return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
729
730    def has_kernel_for_dispatch_key(self, k):
731        return super().has_kernel_for_dispatch_key(
732            k
733        ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
734
735    def has_kernel_for_any_dispatch_key(self, ks):
736        return torch._C._dispatch_has_kernel_for_any_dispatch_key(
737            self.name(), ks
738        ) or super().has_kernel_for_any_dispatch_key(ks)
739
740    @property
741    def namespace(self):
742        return self._schema.name.split("::")[0]
743
744    def _can_decompose(self):
745        dk = DispatchKey.CompositeImplicitAutograd
746        return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
747            self.name(), dk
748        )
749
750    def decompose(self, *args, **kwargs):
751        dk = DispatchKey.CompositeImplicitAutograd
752        if dk in self.py_kernels:
753            # NB: This branch is not too necessary anymore, because we can
754            # apply Python CompositeImplicitAutograd *before* tracing
755            # using Python dispatcher (also taking advantage of the autograd
756            # formula).  But it's included for completeness
757            return self.py_kernels[dk](*args, **kwargs)
758        elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
759            return self._op_dk(dk, *args, **kwargs)
760        else:
761            return NotImplemented
762
763    # Remove a dispatch key from the dispatch cache.  This will force it to get
764    # recomputed the next time.  Does nothing
765    # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
766    # calling _del_dispatch on that key is NOT sufficient to apply your change,
767    # because a single registration may affect MULTIPLE dispatch keys (e.g.,
768    # registering Autograd affects AutogradCPU).  del_dispatch is to be used
769    # only if you are specifically modifying how get_dispatch handles a
770    # particular input 'key'.
771    def _uncache_dispatch(self, key):
772        self._dispatch_cache.pop(key, None)
773
774    # This implements the pre-computation logic for the Python dispatcher.
775    def _get_dispatch(self, key):
776        # This is only called upon a cache miss
777        assert key not in self._dispatch_cache, f"{self} {key}"
778
779        if key == DispatchKey.Python:
780            if not isinstance(self, TorchBindOpOverload) and not self.python_key_table:
781                self._dispatch_cache[key] = key
782                add_cached_op(self)
783                return key
784
785            def handler(*args, **kwargs):
786                from torch.utils._python_dispatch import _get_current_dispatch_mode
787
788                # TODO: We also need to handle tensor subclasses here
789                # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
790                curr_mode = type(_get_current_dispatch_mode())
791                assert (
792                    curr_mode is not None
793                ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
794
795                if curr_mode not in self.python_key_table:
796                    if isinstance(self, TorchBindOpOverload):
797                        with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
798                            return torch._library.utils.handle_dispatch_mode(
799                                mode, self, *args, **kwargs
800                            )
801                    else:
802                        return self._op_dk(key, *args, **kwargs)
803
804                with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
805                    return self.python_key_table[curr_mode](mode, *args, **kwargs)
806
807            self._dispatch_cache[key] = handler
808            add_cached_op(self)
809            return handler
810
811        functionality_key = torch._C._to_functionality_key(key)  # type: ignore[attr-defined]
812        if functionality_key == DispatchKey.PreDispatch:
813            curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
814            # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
815            # calls inside of a mode.
816            if (
817                curr_stack_len > 0
818                and not torch._C._dispatch_tls_is_dispatch_key_excluded(
819                    DispatchKey.Python
820                )
821            ):
822
823                def handler(*args, **kwargs):
824                    @contextlib.contextmanager
825                    def _temporarily_pop_modes_from_pre_dispatch():
826                        top_mode = _pop_mode_from_pre_dispatch()
827                        try:
828                            yield top_mode
829                        finally:
830                            _set_mode_pre_dispatch(top_mode)
831
832                    with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
833                        return torch._library.utils.handle_dispatch_mode(
834                            curr_mode, self, *args, **kwargs
835                        )
836
837                # Note [Not Caching Per-Dispatch-Key Mode Handlers]
838                # Note that we're not caching this handler.  There isn't really a point, since the slow bit
839                # is the handler itself (in python).
840                # Also, not caching means that we don't have to reset the cache when any existing
841                # modes go out of scope (which in of itself takes time to loop through all operators).
842                return handler
843
844        final_key = resolve_key(self, key)
845
846        # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
847        cache_result = key != DispatchKey.PreDispatch
848
849        # TODO: We could potentially have lots of debugging wrappers against
850        # dispatch keys; design some general registration mechanism instead of
851        # having if statement for each of them
852        if key == DispatchKey.Functionalize:
853            import torch._dispatch.python as pydispatch
854
855            if pydispatch.CROSSREF_FUNCTIONALIZE:
856                handler = pydispatch.make_crossref_functionalize(self, final_key)
857                if cache_result:
858                    self._dispatch_cache[key] = handler
859                    add_cached_op(self)
860                return handler
861
862        r = self.py_kernels.get(final_key, final_key)
863        if cache_result:
864            self._dispatch_cache[key] = r
865            add_cached_op(self)
866        return r
867
868    def name(self):
869        return self._name
870
871    @property
872    def overloadpacket(self):
873        return self._overloadpacket
874
875    @property
876    def op(self):
877        return self._op
878
879    @property
880    def tags(self):
881        return self._tags
882
883    # TODO: add more methods to expose information about input and output arguments
884
885
886# TorchBindOpOverload are those custom ops which have at least one overload's
887# schema consists of torch.ScriptObject (i.e. custom class) input.
888# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
889# when its inputs contain FakeScriptObject in a similar way as higher order ops.
890class TorchBindOpOverload(OpOverload):
891    def _fallthrough_keys(self) -> List[DispatchKey]:
892        # TODO: we should be calling the fallback for these, but a fallthrough is almost close
893        # enough to the fallback in most cases that we care about.
894        _DEFAULT_FALLTHROUGH_KEYS = [
895            DispatchKey.Autograd,
896            DispatchKey.AutogradCPU,
897            DispatchKey.AutogradCUDA,
898            DispatchKey.ADInplaceOrView,
899            DispatchKey.BackendSelect,
900            DispatchKey.PythonTLSSnapshot,
901            DispatchKey.PythonDispatcher,
902        ]
903
904        def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
905            if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
906                return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
907                    self.name(), key
908                )
909
910            return (
911                key not in self.py_kernels
912                or self.py_kernels[key] is torch.library.fallthrough_kernel
913            )
914
915        return [
916            key
917            for key in _DEFAULT_FALLTHROUGH_KEYS
918            if _may_use_fallthrough_instead_of_fallback(key)
919        ]
920
921    @contextlib.contextmanager
922    def _register_as_effectful_op_temporarily(self):
923        from torch._higher_order_ops.effects import (
924            _EffectType,
925            _register_effectful_op,
926            SIDE_EFFECTS,
927        )
928
929        try:
930            if self not in SIDE_EFFECTS:
931                _register_effectful_op(self, _EffectType.ORDERED)
932            yield
933        finally:
934            if self in SIDE_EFFECTS:
935                del SIDE_EFFECTS[self]
936
937    # Use positional-only argument to avoid naming collision with aten ops arguments
938    # that are named "self". This way, all the aten ops can be called by kwargs.
939    def __call__(self, /, *args, **kwargs):
940        if _must_dispatch_in_python(args, kwargs):
941            # When any inputs are FakeScriptObject, we need to
942            # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
943            # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
944            #
945            # Note:
946            # 1. We only register the torchbind op temporarily as effectful op because we only want
947            #    the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
948            #    of the eagerly executing the op might change after tracing.
949            # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
950            #    cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
951            with self._register_as_effectful_op_temporarily():
952                return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
953        return self._op(*args, **kwargs)
954
955    def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
956        non_fallthrough_keys = torch._C._dispatch_keyset_full()
957        for key in fallthrough_keys:
958            non_fallthrough_keys = non_fallthrough_keys.remove(key)
959
960        dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
961        dispatch_key = dispatch_key_set.highestPriorityTypeId()
962
963        handler = (
964            self._get_dispatch(dispatch_key)
965            if dispatch_key not in self._dispatch_cache
966            else self._dispatch_cache[dispatch_key]
967        )
968
969        if isinstance(handler, DispatchKey):
970            # fallthrough keys can be registered at runtime via torch.library.impl
971            # so need to add it to fallthrough_keys and re-dispatch.
972            if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
973                self.name(), dispatch_key
974            ):
975                return self._dispatch_in_python(
976                    args, kwargs, fallthrough_keys + [dispatch_key]
977                )
978
979            raise RuntimeError(
980                f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
981                f" but no python implementation is found."
982                f" Please file an issue on this when you encounter this error."
983                f" This error can happen when you export or compile the model."
984                f" It can still happpen even if a C++ implementation for {dispatch_key}. "
985                f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
986                f" with a C++ implementation."
987            )
988
989        assert isinstance(handler, Callable)  # type: ignore[arg-type]
990        return handler(*args, **kwargs)
991
992
993def _must_dispatch_in_python(args, kwargs):
994    return pytree.tree_any(
995        lambda obj: isinstance(
996            obj, torch._library.fake_class_registry.FakeScriptObject
997        ),
998        (args, kwargs),
999    )
1000
1001
1002def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
1003    return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
1004
1005
1006# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
1007# You can obtain an OpOverload object through attribute query.
1008class OpOverloadPacket:
1009    def __init__(self, qualified_op_name, op_name, op, overload_names):
1010        # These attributes are accessible on the object through the properties
1011        # defined below but are immutable
1012        self._qualified_op_name = qualified_op_name
1013        self.__name__ = op_name
1014        self._op = op
1015        self._overload_names = overload_names
1016        self._dir = []
1017        self._has_torchbind_op_overload = any(
1018            _has_script_object_arg(schema) for schema in self._schemas.values()
1019        )
1020
1021    # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
1022    def __deepcopy__(self, memo=None):
1023        return self
1024
1025    def __repr__(self):
1026        return "<OpOverloadPacket(op='{}.{}')>".format(
1027            *self._qualified_op_name.split("::")
1028        )
1029
1030    def __hash__(self):
1031        return hash(self._op)
1032
1033    def __str__(self):
1034        return "{}.{}".format(*self._qualified_op_name.split("::"))
1035
1036    @property
1037    def op(self):
1038        return self._op
1039
1040    @property
1041    def _schemas(self):
1042        return {
1043            overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
1044            for overload_name in self._overload_names
1045        }
1046
1047    def __getattr__(self, key):
1048        # It is not a valid op_name when __file__ is passed in
1049        if key == "__file__":
1050            return "torch.ops"
1051
1052        # ensure that query for dunder attributes that does not exist on
1053        # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
1054        # `_get_operation_overload` (which is an expensive operation).
1055        # This is done to prevent any potential slowdown. This list can be extended
1056        # if there exists other attributes like `__name__` that only exist on self._op and not on the
1057        # opoverloadpacket.
1058        # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
1059        try:
1060            if key.startswith("__"):
1061                return getattr(self._op, key)
1062        except AttributeError:
1063            # for consistency because it seems weird to
1064            # throw an attribute error with a message containing
1065            # an object name different from the one the attribute
1066            # query was performed on.
1067            raise AttributeError(
1068                f"'{str(self)}' can't have an overload name beginning with '__' and the "
1069                f"underlying op {str(self._op)} has no attribute {key} either."
1070            ) from None
1071
1072        try:
1073            # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
1074            use_key = "" if key == "default" else key
1075            # TODO: disallow access to overloads registered by JIT
1076            op_dk_tags = torch._C._get_operation_overload(
1077                self._qualified_op_name, use_key
1078            )
1079            if op_dk_tags is None:
1080                raise AttributeError(
1081                    f"The underlying op of '{str(self)}' has no overload name '{key}'"
1082                )
1083
1084            op_, op_dk_, tags = op_dk_tags
1085            schema = torch._C._get_schema(self._qualified_op_name, use_key)
1086            overload = (
1087                OpOverload(self, op_, op_dk_, schema, tags)
1088                if not _has_script_object_arg(schema)
1089                else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
1090            )
1091            # cache the overload object
1092            setattr(self, key, overload)
1093            self._dir.append(key)
1094            return overload
1095        except RuntimeError:
1096            raise AttributeError(
1097                f"The underlying op of '{str(self)}' has no overload name '{key}'"
1098            ) from None
1099
1100    def __iter__(self):
1101        return iter(self._dir)
1102
1103    # Use positional-only argument to avoid naming collision with aten ops arguments
1104    # that are named "self". This way, all the aten ops can be called by kwargs.
1105    def __call__(self, /, *args, **kwargs):
1106        # overloading __call__ to ensure torch.ops.foo.bar()
1107        # is still callable from JIT
1108        # We save the function ptr as the `op` attribute on
1109        # OpOverloadPacket to access it here.
1110
1111        # Directly calling OverloadPacket goes into C++, which will check
1112        # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
1113        # intercept it here and call TorchBindOpverload instead.
1114        if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1115            return _call_overload_packet_from_python(self, args, kwargs)
1116        return self._op(*args, **(kwargs or {}))
1117
1118    # TODO: use this to make a __dir__
1119    def overloads(self):
1120        return [n if n else "default" for n in self._overload_names]
1121
1122
1123# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
1124# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
1125def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
1126    # Re-use the torch function handling logic in cpp
1127    torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
1128        op, *args, **kwargs
1129    )
1130
1131    if torch_function_called:
1132        return ret
1133
1134    # The following mirrors getOpWithStack.
1135    # In cpp, we do a schema matching for the arguments, and call ToIValue to
1136    # to check whether the arguments are valid. But need to do similar things here
1137    # and check the schema whether the FakeScriptObject is the corresponding fake class
1138    # of the actual class used in schema.
1139    exceptions = {}
1140    found_op = None
1141    for overload_name in op.overloads():
1142        op_overload = getattr(op, overload_name)
1143        try:
1144            _ = torch._C._check_schema_allow_fake_script_object(
1145                op_overload._schema, *args, **kwargs
1146            )
1147            found_op = op_overload
1148            break
1149        except RuntimeError as e:
1150            exceptions[overload_name] = e
1151
1152    if found_op:
1153        return found_op(*args, **kwargs)
1154
1155    err_msg = (
1156        f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
1157    )
1158    for i, (key, msg) in enumerate(exceptions.items()):
1159        err_msg += f"Overload name {key}:\n {msg}\n"
1160    raise RuntimeError(err_msg)
1161
1162
1163# Resolution of torch.fn is different from torch.ops.aten.fn
1164# torch.fn uses the Python argparser, matches with the
1165# appropriate schema, and calls into the unboxed version of the method
1166# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
1167# JIT creates a stack of all the overloads and then tries to match the
1168# correct one at runtime and always calls into the boxed version of the method
1169# Autograd codegen creates VariableType, TracerType,
1170# inplace or view type and python bindings.
1171# Aten codegen generates tensor methods for the tensor class.
1172
1173# _OpNamespace is a subclass of ModuleType because the torch script
1174# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
1175# to work from script, we need to ensure ops and foo are modules
1176
1177
1178class _OpNamespace(types.ModuleType):
1179    """
1180    An op namespace to dynamically bind Operators into Python.
1181
1182    Say a user has created a custom Operator called "my_namespace::my_op". To
1183    call this op, the user will write torch.ops.my_namespace.my_op(...).
1184    At startup, this operation will not yet be bound into Python. Instead, the
1185    following sequence of magic tricks will occur:
1186    1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
1187       on the `torch.ops` object, which will create a new `_OpNamespace`
1188       object called `my_namespace` and set it as an attribute on the `ops`
1189       object.
1190    2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
1191       the `my_namespace` object, which will retrieve the operation via
1192       `torch.get_operation`, a function bound from C++, and then in a similar
1193       fashion bind this new object onto the `my_namespace` object.
1194    3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
1195        and subsequent accesses will incur no further lookup (the namespace and
1196        operation will already exist).
1197    """
1198
1199    def __init__(self, name):
1200        super().__init__("torch.ops." + name)
1201        self.name = name
1202        self._dir = []
1203
1204    def __iter__(self):
1205        return iter(self._dir)
1206
1207    def __getattr__(self, op_name):
1208        # It is not a valid op_name when __file__ is passed in
1209        if op_name == "__file__":
1210            return "torch.ops"
1211        elif op_name in ["__origin__", "__self__"]:
1212            raise AttributeError(
1213                f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
1214            )
1215
1216        # Get the op `my_namespace::my_op` if available. This will also check
1217        # for overloads and raise an exception if there are more than one.
1218        namespace_name = self.name
1219        qualified_op_name = f"{namespace_name}::{op_name}"
1220        module_name = self.__module__ + "." + namespace_name
1221
1222        try:
1223            op, overload_names = _get_packet(qualified_op_name, module_name)
1224            if op is None:
1225                raise AttributeError(
1226                    f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1227                )
1228        except RuntimeError as e:
1229            # Turn this into AttributeError so getattr(obj, key, default)
1230            # works (this is called by TorchScript with __origin__)
1231            raise AttributeError(
1232                f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1233            ) from e
1234
1235        op.__module__ = module_name
1236        opoverloadpacket = OpOverloadPacket(
1237            qualified_op_name, op_name, op, overload_names
1238        )
1239        opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
1240        # cache the opoverloadpacket to ensure that each op corresponds to
1241        # a unique OpOverloadPacket object
1242        setattr(self, op_name, opoverloadpacket)
1243        self._dir.append(op_name)
1244        return opoverloadpacket
1245
1246
1247def _get_packet(qualname, op_module):
1248    op, overload_names = torch._C._jit_get_operation(qualname)
1249    if op is not None:
1250        # let the script frontend know that op is identical to the builtin op
1251        # with qualified_op_name
1252        torch.jit._builtins._register_builtin(op, qualname)
1253        op.__module__ = op_module
1254    return op, overload_names
1255
1256
1257def _refresh_packet(packet):
1258    op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
1259    assert op is not None
1260    packet._op = op
1261    packet._overload_names = overload_names
1262
1263
1264class _PyOpNamespace(_OpNamespace):
1265    def __init__(self, name, ops):
1266        super().__init__(name)
1267        self._ops = ops
1268
1269    def __getattr__(self, name):
1270        # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
1271        op = self._ops.get(name, None)
1272        if op is None:
1273            raise AttributeError(
1274                f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
1275            )
1276        setattr(self, name, op)
1277        return op
1278
1279
1280class _Ops(types.ModuleType):
1281    __file__ = "_ops.py"
1282
1283    def __init__(self):
1284        super().__init__("torch.ops")
1285        self.loaded_libraries = set()
1286        self._higher_order_op_namespace = _PyOpNamespace(
1287            "torch.ops.higher_order", _higher_order_ops
1288        )
1289        self._dir = []
1290
1291    def __getattr__(self, name):
1292        # Check if the name is a HigherOrderOperator
1293        if name == "higher_order":
1294            return self._higher_order_op_namespace
1295
1296        # Here we are creating `torch.ops.my_namespace`
1297        namespace = _OpNamespace(name)
1298        setattr(self, name, namespace)
1299        self._dir.append(name)
1300        return namespace
1301
1302    def __iter__(self):
1303        return iter(self._dir)
1304
1305    def import_module(self, module):
1306        """
1307        Imports a Python module that has torch.library registrations.
1308
1309        Generally, to extend PyTorch with custom operators, a user will
1310        create a Python module whose import triggers registration of
1311        the custom operators via a torch.ops.load_library call or a call
1312        to one or more torch.library.* APIs.
1313
1314        It is unexpected for Python modules to have side effects, so some
1315        linters and formatters will complain. Use this API to import Python
1316        modules that contain these torch.library side effects.
1317
1318        Args:
1319            module (str): The name of the Python module to import
1320
1321        """
1322        importlib.import_module(module)
1323
1324    def load_library(self, path):
1325        """
1326        Loads a shared library from the given path into the current process.
1327
1328        The library being loaded may run global initialization code to register
1329        custom operators with the PyTorch JIT runtime. This allows dynamically
1330        loading custom operators. For this, you should compile your operator
1331        and the static registration code into a shared library object, and then
1332        call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
1333        shared object.
1334
1335        After the library is loaded, it is added to the
1336        ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
1337        for the paths of all libraries loaded using this function.
1338
1339        Args:
1340            path (str): A path to a shared library to load.
1341        """
1342        if torch._running_with_deploy():
1343            return
1344
1345        path = _utils_internal.resolve_library_path(path)
1346        with dl_open_guard():
1347            # Import the shared library into the process, thus running its
1348            # static (global) initialization code in order to register custom
1349            # operators with the JIT.
1350            ctypes.CDLL(path)
1351        self.loaded_libraries.add(path)
1352
1353
1354# The ops "namespace"
1355ops: _Ops = _Ops()
1356