1# mypy: allow-untyped-defs 2import abc 3import contextlib 4import ctypes 5import importlib 6import inspect 7import sys 8import types 9from typing import Any, Callable, Dict, List, Set, Type, Union 10 11import torch 12import torch.utils._pytree as pytree 13from torch import _utils_internal 14from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey 15from torch._functorch.pyfunctorch import dispatch_functorch 16from torch.utils._python_dispatch import TorchDispatchMode 17 18 19# Query `hasattr` only once. 20_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") 21 22 23@contextlib.contextmanager 24def dl_open_guard(): 25 """ 26 Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a 27 shared library to load custom operators. 28 """ 29 if not _SET_GLOBAL_FLAGS: 30 yield 31 return 32 old_flags = sys.getdlopenflags() 33 sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) 34 try: 35 yield 36 finally: 37 sys.setdlopenflags(old_flags) 38 39 40class OperatorBase: 41 """ 42 Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator 43 (which represents Python-only operators that are unrepresentable in TorchScript). 44 """ 45 46 def __init__(self): 47 # The dispatch cache precomputes a mapping of dispatch key that the 48 # dispatcher wants to dispatch to, to an actual implementation of the 49 # dispatch key. Confusingly, the actual implementation could *also* be a 50 # dispatch key, but in this case, this refers to the C++ kernel that 51 # was registered to some dispatch key. Aliases are permitted in the 52 # latter but not the former; for example, you might lookup the 53 # entry for AutogradCPU, and this maps you to the Autograd key for 54 # the generic autograd kernel that works for all devices. Since this 55 # is the Python dispatcher, you can also put an arbitrary Python 56 # callable to call instead. This handler gets precisely the 57 # args/kwargs that the operator was __call__'ed with. 58 # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp 59 # for use with OpOverload; cache lookup is done entirely from C++ 60 # for speed. 61 # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! 62 self._dispatch_cache: Dict[ 63 DispatchKey, Union[DispatchKey, Callable[..., Any]] 64 ] = {} 65 66 # This table allows you to override the behavior of a particular 67 # dispatch key to call a custom Python function, rather than the 68 # ordinary C++ configured behavior. This is the raison d'etre of 69 # Python dispatcher: to let you program the dispatcher from Python 70 # in case you need something unusual, and don't want to clobber 71 # the existing registrations using the Python operator registration 72 # API. 73 self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {} 74 75 # This table allows you to override the behavior of a particular 76 # operator for a particular TorchDispatchMode. In practice, 77 # we are using this mostly for ProxyTensorMode. Modes can be 78 # thought of as an open world extension of dispatch keys, so it 79 # makes sense that you should be able to register them, the same 80 # way you can register dispatch keys. 81 self.python_key_table: Dict[ 82 Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any] 83 ] = {} 84 85 # This table allows you to override the behavior of functorch 86 # transformations. NB: this currently only does something for 87 # HigherOrderOperator 88 self.functorch_table = {} 89 90 def __call__(self, *args, **kwargs): 91 raise NotImplementedError 92 93 def has_kernel_for_dispatch_key(self, k): 94 return k in self.py_kernels 95 96 def has_kernel_for_any_dispatch_key(self, ks): 97 for k in self.py_kernels: 98 if not torch._C._dispatch_is_alias_key(k) and ks.has(k): 99 return True 100 return False 101 102 def py_impl(self, k): 103 def inner(fn): 104 if inspect.isclass(k) and ( 105 issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) 106 ): 107 assert k not in self.python_key_table 108 # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys? 109 self.python_key_table[k] = fn 110 self._dispatch_cache.clear() 111 return fn 112 113 if isinstance(k, torch._C._functorch.TransformType): 114 assert k not in self.functorch_table 115 self.functorch_table[k] = fn 116 return fn 117 118 assert isinstance(k, DispatchKey) 119 assert ( 120 k != DispatchKey.Python 121 ), "Please register a mode for the torch._C.DispatchKey.Python key instead." 122 123 if k in self.py_kernels: 124 raise RuntimeError( 125 f"Trying to override a python impl for {k} on operator {self.name()}" 126 ) 127 self.py_kernels[k] = fn 128 self._dispatch_cache.clear() 129 return fn 130 131 return inner 132 133 # Registers an implementation to all **3** variants of functionalization that we have: 134 # - DispatchKey.Functionalize 135 # - functorch.TransformType.Functionalize 136 # - FunctionalTensorMode 137 # Example: 138 # @py_functionalize_impl 139 # def functionalize_rule(ctx, inner_f, *args): 140 # args_unwrapped = ctx.unwrap_tensors(args) 141 # with ctx.redispatch_to_next(): 142 # out = ctx.functionalize(inner_f)(*args_unwrapped) 143 # return ctx.wrap_tensors(out) 144 def py_functionalize_impl(self, fn): 145 from torch._subclasses.functional_tensor import ( 146 CppFunctionalizeAPI as _CppFunctionalizeAPI, 147 FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, 148 PythonFunctionalizeAPI as _PythonFunctionalizeAPI, 149 ) 150 151 # Construct our three flavors of functionalization, 152 # each of which have slightly different wrap/unwrap/redispatch policies 153 def functionalize_dk_fn(*args, **kwargs): 154 return fn(_CppFunctionalizeAPI(), *args, **kwargs) 155 156 def functionalize_dispatch_mode_fn(mode, *args, **kwargs): 157 return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) 158 159 def functionalize_functorch_fn(interpreter, *args, **kwargs): 160 return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) 161 162 self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) 163 self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( 164 functionalize_dispatch_mode_fn 165 ) 166 self.py_impl(torch._C._functorch.TransformType.Functionalize)( 167 functionalize_functorch_fn 168 ) 169 170 return fn 171 172 def name(self): 173 raise NotImplementedError 174 175 176# Equivalent to computeDispatchTableEntryWithDebug 177def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] 178 # 1. (Direct) operator registration 179 if op.has_kernel_for_dispatch_key(k): 180 return k 181 # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available 182 cand = DispatchKey.CompositeExplicitAutogradNonFunctional 183 if ( 184 k == DispatchKey.Undefined or is_included_in_alias(k, cand) 185 ) and op.has_kernel_for_dispatch_key(cand): 186 return cand 187 # 2.2 Use CompositeExplicitAutograd kernel if available 188 cand = DispatchKey.CompositeExplicitAutograd 189 if ( 190 k == DispatchKey.Undefined or is_included_in_alias(k, cand) 191 ) and op.has_kernel_for_dispatch_key(cand): 192 return cand 193 has_backend_kernel = op.has_kernel_for_any_dispatch_key( 194 torch._C._dispatch_get_backend_keyset_from_autograd(k) 195 ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) 196 # 2.3. Use CompositeImplicitAutograd kernel if available 197 cand = DispatchKey.CompositeImplicitAutogradNestedTensor 198 if ( 199 (k != DispatchKey.Undefined and is_included_in_alias(k, cand)) 200 and op.has_kernel_for_dispatch_key(cand) 201 and not has_backend_kernel 202 ): 203 return cand 204 cand = DispatchKey.CompositeImplicitAutograd 205 if ( 206 k == DispatchKey.Undefined or is_included_in_alias(k, cand) 207 ) and op.has_kernel_for_dispatch_key(cand): 208 if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( 209 torch._C._dispatch_autogradother_backends 210 ): 211 raise RuntimeError("ambiguous autogradother kernel") 212 elif not has_backend_kernel: 213 return cand 214 # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available 215 cand = DispatchKey.Autograd 216 if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): 217 return cand 218 # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available 219 cand = DispatchKey.FuncTorchBatchedDecomposition 220 if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): 221 return cand 222 # Backend fallback 223 if torch._C._dispatch_has_backend_fallback(k): 224 # The dispatch key itself will implicitly route to backend fallback. 225 # This is probably not great for the pure Python implementation. 226 return k 227 raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") 228 229 230_higher_order_ops: Dict[str, "HigherOrderOperator"] = {} 231 232_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ 233 DispatchKey.PythonDispatcher, # type: ignore[attr-defined] 234 DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] 235 DispatchKey.ADInplaceOrView, 236 DispatchKey.BackendSelect, 237 DispatchKey.AutocastCPU, # type: ignore[attr-defined] 238 DispatchKey.AutocastCUDA, # type: ignore[attr-defined] 239] 240 241 242class HigherOrderOperator(OperatorBase, abc.ABC): 243 # The HigherOrderOperator will appear as torch.ops.higher_order.{name} 244 # 245 # If you're creating a new HigherOrderOperator, please do not change the 246 # default. Adding operators to the global torch.ops namespace is a bad 247 # practice due to name collisions. 248 def __init__(self, name): 249 super().__init__() 250 if type(self) is HigherOrderOperator: 251 raise RuntimeError( 252 "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." 253 ) 254 self._name = name 255 256 # Make _OPNamespace not scream, this whole name based association needs a good hard look 257 self.__name__ = name 258 _higher_order_ops[name] = self 259 self._ns = "higher_order" 260 self.__module__ = "torch.ops.higher_order" 261 262 self.non_fallthrough_keys = torch._C._dispatch_keyset_full() 263 264 for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: 265 self.fallthrough(dispatch_key) 266 267 # [NOTE] We have to register pre-dispatch key implementation 268 # because sometimes HOP use aot-dispatch tracing to detect certaion 269 # mutations. This is problematic when we are functionalizing HOP 270 # during pre-dispatch because when the inner tracer starts, it will see 271 # that PreDispatch key is still active. In that case, we just redispatch 272 # it to next key. This is only safe to do when PreDispatch key stack has no 273 # active modes. 274 275 def py_impl(self, k): 276 if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): 277 self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) 278 return super().py_impl(k) 279 280 @property 281 def namespace(self): 282 return self._ns 283 284 def fallthrough(self, dispatch_key): 285 self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) 286 287 # Use positional-only argument to avoid naming collide with custom ops arguments 288 # that are named "self". 289 def dispatch(self, /, dispatch_key, *args, **kwargs): 290 from torch.utils._python_dispatch import _get_current_dispatch_mode 291 292 if dispatch_key in self._dispatch_cache: 293 kernel = self._dispatch_cache[dispatch_key] 294 assert not isinstance(kernel, DispatchKey) 295 return kernel(*args, **kwargs) 296 297 if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode: 298 return dispatch_functorch(self, args, kwargs) 299 300 if dispatch_key == DispatchKey.Python: 301 # Keep the following 1:1 with handle_torch_function_no_python_arg_parser 302 # in torch/csrc/utils/python_arg_parser.cpp 303 304 overloaded_args_list = [] 305 306 def has_python_key(tensor): 307 return torch._C._dispatch_keys(tensor).has("Python") 308 309 def check_overloaded(arg): 310 if isinstance(arg, torch.Tensor) and has_python_key(arg): 311 overloaded_args_list.append(arg) 312 313 for arg in (*args, *kwargs.values()): 314 check_overloaded(arg) 315 if isinstance(arg, (list, tuple)): 316 for a in arg: 317 check_overloaded(a) 318 319 overloaded_args = tuple(overloaded_args_list) 320 overloaded_types = tuple(type(arg) for arg in overloaded_args) 321 322 # Step 1: dispatch on any user TorchDispatchModes 323 from torch.utils._python_dispatch import _pop_mode_temporarily 324 325 curr_mode = _get_current_dispatch_mode() 326 if curr_mode is not None: 327 if type(curr_mode) in self.python_key_table: 328 handler = self.python_key_table[type(curr_mode)] 329 with _pop_mode_temporarily() as mode: 330 # "natural" calling convention: (mode, *args, **kwargs) 331 # TODO(rzou): we should support torch_dispatch calling convention too. 332 result = handler(mode, *args, **kwargs) 333 else: 334 raise NotImplementedError( 335 f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " 336 f"We recommend filing an issue." 337 ) 338 if result is not NotImplemented: 339 return result 340 341 # Step 2: dispatch on any subclasses 342 for arg in overloaded_args: 343 subclass_type = type(arg) 344 if ( 345 subclass_type.__torch_dispatch__ 346 == torch._C._disabled_torch_dispatch_impl 347 ): 348 continue 349 if subclass_type in self.python_key_table: 350 handler = self.python_key_table[subclass_type] 351 # "natural" calling convention: (*args, **kwargs) 352 # TODO(rzou): we should support torch_dispatch calling convention too. 353 result = handler(*args, **kwargs) 354 else: 355 raise NotImplementedError( 356 f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. " 357 f"We recommend filing an issue." 358 ) 359 if result is not NotImplemented: 360 return result 361 362 # All handlers returned NotImplemented 363 raise TypeError( 364 f"Multiple dispatch failed for {self._name}. There was no registered that " 365 f"did not return NotImplemented. Use HOP.py_impl to register some. " 366 f"Tried mode: {curr_mode}) and subclasses: " 367 f"{[type(a) for a in overloaded_args]}" 368 ) 369 370 functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] 371 if functionality_key == DispatchKey.PreDispatch: 372 from torch.utils._python_dispatch import _pop_mode_temporarily 373 374 # The check for Python in the exclude set is so we properly respect `with no_dispatch()` 375 # calls inside of a mode. 376 if ( 377 _len_torch_dispatch_stack_pre_dispatch() > 0 378 ) and not torch._C._dispatch_tls_is_dispatch_key_excluded( 379 DispatchKey.Python 380 ): 381 curr_mode = _get_current_dispatch_mode_pre_dispatch() 382 assert ( 383 curr_mode is not None 384 ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." 385 assert ( 386 type(curr_mode) in self.python_key_table 387 ), f"Current active mode {curr_mode} not registered" 388 handler = self.python_key_table[type(curr_mode)] 389 with _pop_mode_temporarily(functionality_key) as mode: 390 return handler(mode, *args, **kwargs) 391 392 final_key = resolve_key(self, dispatch_key) 393 394 # This can current fail due to backend fallbacks. You just have to 395 # register them by hand for HigherOrderOperator. 396 if final_key not in self.py_kernels: 397 raise NotImplementedError( 398 f"could not find kernel for HigherOrderOperator {self._name} " 399 f"at dispatch key {final_key} (resolved from {dispatch_key})" 400 ) 401 402 # [NOTE] We shouldn't cache PreDispatch kernel here because depending 403 # on what modes are active, predispatch behaviour is different. 404 # Also we do same thing for normal ops: 405 # See Note [Not Caching Per-Dispatch-Key Mode Handlers] 406 if dispatch_key != DispatchKey.PreDispatch: 407 self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] 408 kernel = self.py_kernels[final_key] 409 # It's illegal to register DispatchKey to py_kernels, since there's no 410 # C++ kernel to call into 411 assert not isinstance(kernel, DispatchKey) 412 return kernel(*args, **kwargs) 413 414 @abc.abstractmethod 415 def __call__(self, /, *args, **kwargs): 416 # Dynamo already traces the body of HigherOrderOp beforehand when it 417 # so no need to trace into it. 418 from torch._dynamo import disable 419 420 @disable 421 def wrapper(): 422 flat_args = _to_flat_tuple(args, kwargs) 423 if torch.overrides.has_torch_function(flat_args): 424 return torch.overrides.handle_torch_function( 425 self, flat_args, *args, **kwargs 426 ) 427 428 dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) 429 return self.dispatch( 430 dispatch_key_set.highestPriorityTypeId(), *args, **kwargs 431 ) 432 433 return wrapper() 434 435 def __str__(self): 436 return f"{self.name()}" 437 438 def name(self): 439 return self._name 440 441 442def _to_flat_tuple(args, kwargs): 443 return pytree.arg_tree_leaves(*args, **kwargs) 444 445 446def _compute_keyset(args, kwargs, non_fallthrough_keys): 447 tensors = _get_tensors(args, kwargs) 448 return key_extractor(tensors, non_fallthrough_keys) 449 450 451def _get_tensors(args, kwargs): 452 flat_all = _to_flat_tuple(args, kwargs) 453 tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] 454 return tuple(tensor_args) 455 456 457# Note - this should maintain identical impl to the C++ dispatcher key extraction logic 458# at ATen/core/dispatch/DispatchKeyExtractor.h 459def key_extractor(tensors, key_mask): 460 key_set = torch._C._dispatch_tls_local_include_set() 461 for tensor in tensors: 462 key_set = key_set | torch._C._dispatch_keys(tensor) 463 key_set = key_set - torch._C._dispatch_tls_local_exclude_set() 464 key_set = key_set & key_mask 465 return key_set 466 467 468# Mode stack for PreDispatchKey 469# it should always have three keys with 470# priority given to FunctionalTensorMode and 471# then ProxyTorchDispatchMode. It means that 472# slot 0 belongs to ProxyTorchDispatchMode and 473# slot 1 belongs to FunctionalTensorMode. 474# 475# SchemaCheckMode is separate from the other 2, 476# and is only valid when the stack is empty. 477# SchemaCheckMode is for testing purposes, and 478# is meant to run in eager mode on concrete inputs, 479# checking for incorrect schemas in regards to 480# aliasing or mutating ops. 481class _ModeStackStateForPreDispatch: 482 def __init__(self): 483 self.__infra_modes = [None, None] 484 self._schema_check_mode = None 485 486 def set(self, index, mode): 487 assert index < len(self.__infra_modes) 488 self.__infra_modes[index] = mode 489 490 def get(self, index): 491 assert index < len(self.__infra_modes) 492 return self.__infra_modes[index] 493 494 def count(self): 495 return len([i for i in self.__infra_modes if i is not None]) + int( 496 self._schema_check_mode is not None 497 ) 498 499 500_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() 501 502 503def unset_mode_pre_dispatch(mode_key, schema_check=False): 504 current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() 505 assert mode_key is None or mode_key in ( 506 torch._C._TorchDispatchModeKey.PROXY, 507 torch._C._TorchDispatchModeKey.FUNCTIONAL, 508 ) 509 if schema_check: 510 assert mode_key is None 511 512 def _unset_mode(): 513 if mode_key == torch._C._TorchDispatchModeKey.PROXY: 514 current_mode = current_mode_stack_pre_dispatch.get(0) 515 mode_stack_state_for_pre_dispatch().set(0, None) 516 return current_mode 517 elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL: 518 current_mode = current_mode_stack_pre_dispatch.get(1) 519 mode_stack_state_for_pre_dispatch().set(1, None) 520 return current_mode 521 else: 522 current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode 523 mode_stack_state_for_pre_dispatch()._schema_check_mode = None 524 return current_mode 525 526 current_mode = _unset_mode() 527 528 new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() 529 # When we are unsetting a mode, we need to check if there is 530 # active mode left on the PreDispatch key. If there is nothing 531 # active, we need to remove PreDispatch key from local dispatch include 532 # set. 533 if new_pre_dispatch_len == 0: 534 torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False) 535 536 return current_mode 537 538 539def _set_mode_pre_dispatch(mode): 540 from torch._subclasses.functional_tensor import FunctionalTensorMode 541 from torch._subclasses.schema_check_mode import SchemaCheckMode 542 from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode 543 544 assert isinstance( 545 mode, 546 ( 547 FunctionalTensorMode, 548 ProxyTorchDispatchMode, 549 SchemaCheckMode, 550 ), 551 ) 552 553 previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch() 554 if isinstance(mode, SchemaCheckMode): 555 current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode 556 if previous_mode_stack_len > 0: 557 raise AssertionError( 558 "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack" 559 ) 560 mode_stack_state_for_pre_dispatch()._schema_check_mode = mode 561 elif isinstance(mode, FunctionalTensorMode): 562 current_mode = mode_stack_state_for_pre_dispatch().get(1) 563 assert current_mode is None 564 mode_stack_state_for_pre_dispatch().set(1, mode) 565 else: 566 current_mode = mode_stack_state_for_pre_dispatch().get(0) 567 assert current_mode is None 568 mode_stack_state_for_pre_dispatch().set(0, mode) 569 570 # When we are setting a mode, we need to check if there is 571 # active mode left on the PreDispatch key. If there was nothing 572 # active before setting this mode, it means that PreDispatch key 573 # was turned off. So we need to turn it on again. 574 if previous_mode_stack_len == 0: 575 torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True) 576 577 578def _pop_mode_from_pre_dispatch(): 579 mode_stack = mode_stack_state_for_pre_dispatch() 580 pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() 581 582 if pre_dispatch_len == 0: 583 raise AssertionError("Trying to pop empty mode stack") 584 585 if mode_stack._schema_check_mode is not None: 586 return unset_mode_pre_dispatch(None, schema_check=True) 587 if mode_stack.get(1) is not None: 588 return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL) 589 if mode_stack.get(0) is not None: 590 return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) 591 592 593def _len_torch_dispatch_stack_pre_dispatch(): 594 return mode_stack_state_for_pre_dispatch().count() 595 596 597def _get_dispatch_mode_pre_dispatch(mode_key): 598 assert mode_key in ( 599 torch._C._TorchDispatchModeKey.PROXY, 600 torch._C._TorchDispatchModeKey.FUNCTIONAL, 601 ) 602 if mode_key == torch._C._TorchDispatchModeKey.PROXY: 603 return mode_stack_state_for_pre_dispatch().get(0) 604 else: 605 return mode_stack_state_for_pre_dispatch().get(1) 606 607 608def _get_current_dispatch_mode_pre_dispatch(): 609 if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None: 610 return mode_stack_state_for_pre_dispatch()._schema_check_mode 611 else: 612 stack_len = mode_stack_state_for_pre_dispatch().count() 613 if stack_len == 2: 614 return mode_stack_state_for_pre_dispatch().get(1) 615 if stack_len == 1: 616 return ( 617 mode_stack_state_for_pre_dispatch().get(1) 618 if mode_stack_state_for_pre_dispatch().get(1) is not None 619 else mode_stack_state_for_pre_dispatch().get(0) 620 ) 621 return None 622 623 624def mode_stack_state_for_pre_dispatch(): 625 global _mode_stack_state_for_pre_dispatch 626 return _mode_stack_state_for_pre_dispatch 627 628 629cached_ops: Set["OpOverload"] = set() 630 631 632def add_cached_op(op_overload): 633 global cached_ops 634 cached_ops.add(op_overload) 635 636 637def reset_cached_ops(): 638 global cached_ops 639 cached_ops.clear() 640 641 642def get_cached_ops(): 643 global cached_ops 644 return cached_ops 645 646 647# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. 648# You can obtain an OpOverload object through attribute query on OpOverloadPacket. 649class OpOverload(OperatorBase): 650 def __init__(self, overloadpacket, op, op_dk, schema, tags): 651 super().__init__() 652 self._op = op 653 self._op_dk = op_dk 654 self._schema = schema 655 self._overloadpacket = overloadpacket 656 self._tags = tags 657 self._overloadname = ( 658 "default" if schema.overload_name == "" else schema.overload_name 659 ) 660 self._name = self._schema.name 661 if schema.overload_name: 662 self._name += "." + schema.overload_name 663 self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" 664 self.__module__ = overloadpacket.__module__ 665 op.__module__ = overloadpacket.__module__ 666 self.__qualname__ = self._name 667 self.__annotations__ = {} 668 # Only compute the OperatorHandle when we need it. Not all OpOverloads have 669 # OperatorHandles (the TorchScript ones don't...) 670 self._lazy_handle = None 671 672 # If the OpOverload was constructed from a Library.def in Python. 673 self._defined_in_python = self.__qualname__ in torch.library._defs 674 675 # Logic replicated from aten/src/ATen/native/MathBitsFallback.h 676 is_write = None 677 for a in self._schema.arguments: 678 if a.alias_info is None: 679 continue 680 if is_write is None: 681 is_write = a.alias_info.is_write 682 else: 683 # We will conservatively call mixed mutable/non-mutable 684 # aliased inputs as NOT a view 685 is_write = a.alias_info.is_write or is_write 686 self.is_view = is_write is not None and not is_write 687 688 @property 689 def _namespace(self): 690 return self._schema.name.split("::")[0] 691 692 @property 693 def _opname(self): 694 return self._schema.name.split("::")[1] 695 696 @property 697 def _handle(self): 698 if self._lazy_handle is None: 699 self._lazy_handle = torch._C._dispatch_find_schema_or_throw( 700 self._schema.name, self._schema.overload_name 701 ) 702 return self._lazy_handle 703 704 # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. 705 def __deepcopy__(self, memo=None): 706 return self 707 708 def __repr__(self): 709 return "<OpOverload(op='{}.{}', overload='{}')>".format( 710 *self._schema.name.split("::"), self._overloadname 711 ) 712 713 # Use positional-only argument to avoid naming collision with aten ops arguments 714 # that are named "self". This way, all the aten ops can be called by kwargs. 715 def __call__(self, /, *args, **kwargs): 716 return self._op(*args, **kwargs) 717 718 # Use positional-only argument to avoid naming collision with aten ops arguments 719 # that are named "self". This way, all the aten ops can be called by kwargs. 720 def redispatch(self, /, keyset, *args, **kwargs): 721 return self._handle.redispatch_boxed(keyset, *args, **kwargs) 722 723 def __hash__(self): 724 return hash(self._op) 725 726 # `my_namespace.my_op_name.overload_name` 727 def __str__(self): 728 return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) 729 730 def has_kernel_for_dispatch_key(self, k): 731 return super().has_kernel_for_dispatch_key( 732 k 733 ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) 734 735 def has_kernel_for_any_dispatch_key(self, ks): 736 return torch._C._dispatch_has_kernel_for_any_dispatch_key( 737 self.name(), ks 738 ) or super().has_kernel_for_any_dispatch_key(ks) 739 740 @property 741 def namespace(self): 742 return self._schema.name.split("::")[0] 743 744 def _can_decompose(self): 745 dk = DispatchKey.CompositeImplicitAutograd 746 return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( 747 self.name(), dk 748 ) 749 750 def decompose(self, *args, **kwargs): 751 dk = DispatchKey.CompositeImplicitAutograd 752 if dk in self.py_kernels: 753 # NB: This branch is not too necessary anymore, because we can 754 # apply Python CompositeImplicitAutograd *before* tracing 755 # using Python dispatcher (also taking advantage of the autograd 756 # formula). But it's included for completeness 757 return self.py_kernels[dk](*args, **kwargs) 758 elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): 759 return self._op_dk(dk, *args, **kwargs) 760 else: 761 return NotImplemented 762 763 # Remove a dispatch key from the dispatch cache. This will force it to get 764 # recomputed the next time. Does nothing 765 # WARNING: if you register a dispatch key to py_kernels of an OpOverload, 766 # calling _del_dispatch on that key is NOT sufficient to apply your change, 767 # because a single registration may affect MULTIPLE dispatch keys (e.g., 768 # registering Autograd affects AutogradCPU). del_dispatch is to be used 769 # only if you are specifically modifying how get_dispatch handles a 770 # particular input 'key'. 771 def _uncache_dispatch(self, key): 772 self._dispatch_cache.pop(key, None) 773 774 # This implements the pre-computation logic for the Python dispatcher. 775 def _get_dispatch(self, key): 776 # This is only called upon a cache miss 777 assert key not in self._dispatch_cache, f"{self} {key}" 778 779 if key == DispatchKey.Python: 780 if not isinstance(self, TorchBindOpOverload) and not self.python_key_table: 781 self._dispatch_cache[key] = key 782 add_cached_op(self) 783 return key 784 785 def handler(*args, **kwargs): 786 from torch.utils._python_dispatch import _get_current_dispatch_mode 787 788 # TODO: We also need to handle tensor subclasses here 789 # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. 790 curr_mode = type(_get_current_dispatch_mode()) 791 assert ( 792 curr_mode is not None 793 ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." 794 795 if curr_mode not in self.python_key_table: 796 if isinstance(self, TorchBindOpOverload): 797 with torch.utils._python_dispatch._pop_mode_temporarily() as mode: 798 return torch._library.utils.handle_dispatch_mode( 799 mode, self, *args, **kwargs 800 ) 801 else: 802 return self._op_dk(key, *args, **kwargs) 803 804 with torch.utils._python_dispatch._pop_mode_temporarily() as mode: 805 return self.python_key_table[curr_mode](mode, *args, **kwargs) 806 807 self._dispatch_cache[key] = handler 808 add_cached_op(self) 809 return handler 810 811 functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] 812 if functionality_key == DispatchKey.PreDispatch: 813 curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() 814 # The check for Python in the exclude set is so we properly respect `with no_dispatch()` 815 # calls inside of a mode. 816 if ( 817 curr_stack_len > 0 818 and not torch._C._dispatch_tls_is_dispatch_key_excluded( 819 DispatchKey.Python 820 ) 821 ): 822 823 def handler(*args, **kwargs): 824 @contextlib.contextmanager 825 def _temporarily_pop_modes_from_pre_dispatch(): 826 top_mode = _pop_mode_from_pre_dispatch() 827 try: 828 yield top_mode 829 finally: 830 _set_mode_pre_dispatch(top_mode) 831 832 with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: 833 return torch._library.utils.handle_dispatch_mode( 834 curr_mode, self, *args, **kwargs 835 ) 836 837 # Note [Not Caching Per-Dispatch-Key Mode Handlers] 838 # Note that we're not caching this handler. There isn't really a point, since the slow bit 839 # is the handler itself (in python). 840 # Also, not caching means that we don't have to reset the cache when any existing 841 # modes go out of scope (which in of itself takes time to loop through all operators). 842 return handler 843 844 final_key = resolve_key(self, key) 845 846 # See Note [Not Caching Per-Dispatch-Key Mode Handlers] 847 cache_result = key != DispatchKey.PreDispatch 848 849 # TODO: We could potentially have lots of debugging wrappers against 850 # dispatch keys; design some general registration mechanism instead of 851 # having if statement for each of them 852 if key == DispatchKey.Functionalize: 853 import torch._dispatch.python as pydispatch 854 855 if pydispatch.CROSSREF_FUNCTIONALIZE: 856 handler = pydispatch.make_crossref_functionalize(self, final_key) 857 if cache_result: 858 self._dispatch_cache[key] = handler 859 add_cached_op(self) 860 return handler 861 862 r = self.py_kernels.get(final_key, final_key) 863 if cache_result: 864 self._dispatch_cache[key] = r 865 add_cached_op(self) 866 return r 867 868 def name(self): 869 return self._name 870 871 @property 872 def overloadpacket(self): 873 return self._overloadpacket 874 875 @property 876 def op(self): 877 return self._op 878 879 @property 880 def tags(self): 881 return self._tags 882 883 # TODO: add more methods to expose information about input and output arguments 884 885 886# TorchBindOpOverload are those custom ops which have at least one overload's 887# schema consists of torch.ScriptObject (i.e. custom class) input. 888# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python 889# when its inputs contain FakeScriptObject in a similar way as higher order ops. 890class TorchBindOpOverload(OpOverload): 891 def _fallthrough_keys(self) -> List[DispatchKey]: 892 # TODO: we should be calling the fallback for these, but a fallthrough is almost close 893 # enough to the fallback in most cases that we care about. 894 _DEFAULT_FALLTHROUGH_KEYS = [ 895 DispatchKey.Autograd, 896 DispatchKey.AutogradCPU, 897 DispatchKey.AutogradCUDA, 898 DispatchKey.ADInplaceOrView, 899 DispatchKey.BackendSelect, 900 DispatchKey.PythonTLSSnapshot, 901 DispatchKey.PythonDispatcher, 902 ] 903 904 def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): 905 if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): 906 return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( 907 self.name(), key 908 ) 909 910 return ( 911 key not in self.py_kernels 912 or self.py_kernels[key] is torch.library.fallthrough_kernel 913 ) 914 915 return [ 916 key 917 for key in _DEFAULT_FALLTHROUGH_KEYS 918 if _may_use_fallthrough_instead_of_fallback(key) 919 ] 920 921 @contextlib.contextmanager 922 def _register_as_effectful_op_temporarily(self): 923 from torch._higher_order_ops.effects import ( 924 _EffectType, 925 _register_effectful_op, 926 SIDE_EFFECTS, 927 ) 928 929 try: 930 if self not in SIDE_EFFECTS: 931 _register_effectful_op(self, _EffectType.ORDERED) 932 yield 933 finally: 934 if self in SIDE_EFFECTS: 935 del SIDE_EFFECTS[self] 936 937 # Use positional-only argument to avoid naming collision with aten ops arguments 938 # that are named "self". This way, all the aten ops can be called by kwargs. 939 def __call__(self, /, *args, **kwargs): 940 if _must_dispatch_in_python(args, kwargs): 941 # When any inputs are FakeScriptObject, we need to 942 # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher 943 # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. 944 # 945 # Note: 946 # 1. We only register the torchbind op temporarily as effectful op because we only want 947 # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior 948 # of the eagerly executing the op might change after tracing. 949 # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might 950 # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. 951 with self._register_as_effectful_op_temporarily(): 952 return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) 953 return self._op(*args, **kwargs) 954 955 def _dispatch_in_python(self, args, kwargs, fallthrough_keys): 956 non_fallthrough_keys = torch._C._dispatch_keyset_full() 957 for key in fallthrough_keys: 958 non_fallthrough_keys = non_fallthrough_keys.remove(key) 959 960 dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) 961 dispatch_key = dispatch_key_set.highestPriorityTypeId() 962 963 handler = ( 964 self._get_dispatch(dispatch_key) 965 if dispatch_key not in self._dispatch_cache 966 else self._dispatch_cache[dispatch_key] 967 ) 968 969 if isinstance(handler, DispatchKey): 970 # fallthrough keys can be registered at runtime via torch.library.impl 971 # so need to add it to fallthrough_keys and re-dispatch. 972 if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( 973 self.name(), dispatch_key 974 ): 975 return self._dispatch_in_python( 976 args, kwargs, fallthrough_keys + [dispatch_key] 977 ) 978 979 raise RuntimeError( 980 f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}." 981 f" but no python implementation is found." 982 f" Please file an issue on this when you encounter this error." 983 f" This error can happen when you export or compile the model." 984 f" It can still happpen even if a C++ implementation for {dispatch_key}. " 985 f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " 986 f" with a C++ implementation." 987 ) 988 989 assert isinstance(handler, Callable) # type: ignore[arg-type] 990 return handler(*args, **kwargs) 991 992 993def _must_dispatch_in_python(args, kwargs): 994 return pytree.tree_any( 995 lambda obj: isinstance( 996 obj, torch._library.fake_class_registry.FakeScriptObject 997 ), 998 (args, kwargs), 999 ) 1000 1001 1002def _has_script_object_arg(schema: torch.FunctionSchema) -> bool: 1003 return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments) 1004 1005 1006# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator 1007# You can obtain an OpOverload object through attribute query. 1008class OpOverloadPacket: 1009 def __init__(self, qualified_op_name, op_name, op, overload_names): 1010 # These attributes are accessible on the object through the properties 1011 # defined below but are immutable 1012 self._qualified_op_name = qualified_op_name 1013 self.__name__ = op_name 1014 self._op = op 1015 self._overload_names = overload_names 1016 self._dir = [] 1017 self._has_torchbind_op_overload = any( 1018 _has_script_object_arg(schema) for schema in self._schemas.values() 1019 ) 1020 1021 # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. 1022 def __deepcopy__(self, memo=None): 1023 return self 1024 1025 def __repr__(self): 1026 return "<OpOverloadPacket(op='{}.{}')>".format( 1027 *self._qualified_op_name.split("::") 1028 ) 1029 1030 def __hash__(self): 1031 return hash(self._op) 1032 1033 def __str__(self): 1034 return "{}.{}".format(*self._qualified_op_name.split("::")) 1035 1036 @property 1037 def op(self): 1038 return self._op 1039 1040 @property 1041 def _schemas(self): 1042 return { 1043 overload_name: torch._C._get_schema(self._qualified_op_name, overload_name) 1044 for overload_name in self._overload_names 1045 } 1046 1047 def __getattr__(self, key): 1048 # It is not a valid op_name when __file__ is passed in 1049 if key == "__file__": 1050 return "torch.ops" 1051 1052 # ensure that query for dunder attributes that does not exist on 1053 # opoverloadpacket but instead exists on the self._op object does not unnecessarily call 1054 # `_get_operation_overload` (which is an expensive operation). 1055 # This is done to prevent any potential slowdown. This list can be extended 1056 # if there exists other attributes like `__name__` that only exist on self._op and not on the 1057 # opoverloadpacket. 1058 # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' 1059 try: 1060 if key.startswith("__"): 1061 return getattr(self._op, key) 1062 except AttributeError: 1063 # for consistency because it seems weird to 1064 # throw an attribute error with a message containing 1065 # an object name different from the one the attribute 1066 # query was performed on. 1067 raise AttributeError( 1068 f"'{str(self)}' can't have an overload name beginning with '__' and the " 1069 f"underlying op {str(self._op)} has no attribute {key} either." 1070 ) from None 1071 1072 try: 1073 # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' 1074 use_key = "" if key == "default" else key 1075 # TODO: disallow access to overloads registered by JIT 1076 op_dk_tags = torch._C._get_operation_overload( 1077 self._qualified_op_name, use_key 1078 ) 1079 if op_dk_tags is None: 1080 raise AttributeError( 1081 f"The underlying op of '{str(self)}' has no overload name '{key}'" 1082 ) 1083 1084 op_, op_dk_, tags = op_dk_tags 1085 schema = torch._C._get_schema(self._qualified_op_name, use_key) 1086 overload = ( 1087 OpOverload(self, op_, op_dk_, schema, tags) 1088 if not _has_script_object_arg(schema) 1089 else TorchBindOpOverload(self, op_, op_dk_, schema, tags) 1090 ) 1091 # cache the overload object 1092 setattr(self, key, overload) 1093 self._dir.append(key) 1094 return overload 1095 except RuntimeError: 1096 raise AttributeError( 1097 f"The underlying op of '{str(self)}' has no overload name '{key}'" 1098 ) from None 1099 1100 def __iter__(self): 1101 return iter(self._dir) 1102 1103 # Use positional-only argument to avoid naming collision with aten ops arguments 1104 # that are named "self". This way, all the aten ops can be called by kwargs. 1105 def __call__(self, /, *args, **kwargs): 1106 # overloading __call__ to ensure torch.ops.foo.bar() 1107 # is still callable from JIT 1108 # We save the function ptr as the `op` attribute on 1109 # OpOverloadPacket to access it here. 1110 1111 # Directly calling OverloadPacket goes into C++, which will check 1112 # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we 1113 # intercept it here and call TorchBindOpverload instead. 1114 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): 1115 return _call_overload_packet_from_python(self, args, kwargs) 1116 return self._op(*args, **(kwargs or {})) 1117 1118 # TODO: use this to make a __dir__ 1119 def overloads(self): 1120 return [n if n else "default" for n in self._overload_names] 1121 1122 1123# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp 1124# _jit_get_operations, which calls _get_operation_for_overload_or_packet. 1125def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs): 1126 # Re-use the torch function handling logic in cpp 1127 torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( 1128 op, *args, **kwargs 1129 ) 1130 1131 if torch_function_called: 1132 return ret 1133 1134 # The following mirrors getOpWithStack. 1135 # In cpp, we do a schema matching for the arguments, and call ToIValue to 1136 # to check whether the arguments are valid. But need to do similar things here 1137 # and check the schema whether the FakeScriptObject is the corresponding fake class 1138 # of the actual class used in schema. 1139 exceptions = {} 1140 found_op = None 1141 for overload_name in op.overloads(): 1142 op_overload = getattr(op, overload_name) 1143 try: 1144 _ = torch._C._check_schema_allow_fake_script_object( 1145 op_overload._schema, *args, **kwargs 1146 ) 1147 found_op = op_overload 1148 break 1149 except RuntimeError as e: 1150 exceptions[overload_name] = e 1151 1152 if found_op: 1153 return found_op(*args, **kwargs) 1154 1155 err_msg = ( 1156 f"Fail to match any TorchBindOverload of {op} with following exceptions:\n" 1157 ) 1158 for i, (key, msg) in enumerate(exceptions.items()): 1159 err_msg += f"Overload name {key}:\n {msg}\n" 1160 raise RuntimeError(err_msg) 1161 1162 1163# Resolution of torch.fn is different from torch.ops.aten.fn 1164# torch.fn uses the Python argparser, matches with the 1165# appropriate schema, and calls into the unboxed version of the method 1166# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. 1167# JIT creates a stack of all the overloads and then tries to match the 1168# correct one at runtime and always calls into the boxed version of the method 1169# Autograd codegen creates VariableType, TracerType, 1170# inplace or view type and python bindings. 1171# Aten codegen generates tensor methods for the tensor class. 1172 1173# _OpNamespace is a subclass of ModuleType because the torch script 1174# allows attribute lookups on modules only. Since we want torch.ops.foo.bar() 1175# to work from script, we need to ensure ops and foo are modules 1176 1177 1178class _OpNamespace(types.ModuleType): 1179 """ 1180 An op namespace to dynamically bind Operators into Python. 1181 1182 Say a user has created a custom Operator called "my_namespace::my_op". To 1183 call this op, the user will write torch.ops.my_namespace.my_op(...). 1184 At startup, this operation will not yet be bound into Python. Instead, the 1185 following sequence of magic tricks will occur: 1186 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method 1187 on the `torch.ops` object, which will create a new `_OpNamespace` 1188 object called `my_namespace` and set it as an attribute on the `ops` 1189 object. 1190 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on 1191 the `my_namespace` object, which will retrieve the operation via 1192 `torch.get_operation`, a function bound from C++, and then in a similar 1193 fashion bind this new object onto the `my_namespace` object. 1194 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation 1195 and subsequent accesses will incur no further lookup (the namespace and 1196 operation will already exist). 1197 """ 1198 1199 def __init__(self, name): 1200 super().__init__("torch.ops." + name) 1201 self.name = name 1202 self._dir = [] 1203 1204 def __iter__(self): 1205 return iter(self._dir) 1206 1207 def __getattr__(self, op_name): 1208 # It is not a valid op_name when __file__ is passed in 1209 if op_name == "__file__": 1210 return "torch.ops" 1211 elif op_name in ["__origin__", "__self__"]: 1212 raise AttributeError( 1213 f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" 1214 ) 1215 1216 # Get the op `my_namespace::my_op` if available. This will also check 1217 # for overloads and raise an exception if there are more than one. 1218 namespace_name = self.name 1219 qualified_op_name = f"{namespace_name}::{op_name}" 1220 module_name = self.__module__ + "." + namespace_name 1221 1222 try: 1223 op, overload_names = _get_packet(qualified_op_name, module_name) 1224 if op is None: 1225 raise AttributeError( 1226 f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" 1227 ) 1228 except RuntimeError as e: 1229 # Turn this into AttributeError so getattr(obj, key, default) 1230 # works (this is called by TorchScript with __origin__) 1231 raise AttributeError( 1232 f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" 1233 ) from e 1234 1235 op.__module__ = module_name 1236 opoverloadpacket = OpOverloadPacket( 1237 qualified_op_name, op_name, op, overload_names 1238 ) 1239 opoverloadpacket.__module__ = self.__module__ + "." + namespace_name 1240 # cache the opoverloadpacket to ensure that each op corresponds to 1241 # a unique OpOverloadPacket object 1242 setattr(self, op_name, opoverloadpacket) 1243 self._dir.append(op_name) 1244 return opoverloadpacket 1245 1246 1247def _get_packet(qualname, op_module): 1248 op, overload_names = torch._C._jit_get_operation(qualname) 1249 if op is not None: 1250 # let the script frontend know that op is identical to the builtin op 1251 # with qualified_op_name 1252 torch.jit._builtins._register_builtin(op, qualname) 1253 op.__module__ = op_module 1254 return op, overload_names 1255 1256 1257def _refresh_packet(packet): 1258 op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) 1259 assert op is not None 1260 packet._op = op 1261 packet._overload_names = overload_names 1262 1263 1264class _PyOpNamespace(_OpNamespace): 1265 def __init__(self, name, ops): 1266 super().__init__(name) 1267 self._ops = ops 1268 1269 def __getattr__(self, name): 1270 # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. 1271 op = self._ops.get(name, None) 1272 if op is None: 1273 raise AttributeError( 1274 f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" 1275 ) 1276 setattr(self, name, op) 1277 return op 1278 1279 1280class _Ops(types.ModuleType): 1281 __file__ = "_ops.py" 1282 1283 def __init__(self): 1284 super().__init__("torch.ops") 1285 self.loaded_libraries = set() 1286 self._higher_order_op_namespace = _PyOpNamespace( 1287 "torch.ops.higher_order", _higher_order_ops 1288 ) 1289 self._dir = [] 1290 1291 def __getattr__(self, name): 1292 # Check if the name is a HigherOrderOperator 1293 if name == "higher_order": 1294 return self._higher_order_op_namespace 1295 1296 # Here we are creating `torch.ops.my_namespace` 1297 namespace = _OpNamespace(name) 1298 setattr(self, name, namespace) 1299 self._dir.append(name) 1300 return namespace 1301 1302 def __iter__(self): 1303 return iter(self._dir) 1304 1305 def import_module(self, module): 1306 """ 1307 Imports a Python module that has torch.library registrations. 1308 1309 Generally, to extend PyTorch with custom operators, a user will 1310 create a Python module whose import triggers registration of 1311 the custom operators via a torch.ops.load_library call or a call 1312 to one or more torch.library.* APIs. 1313 1314 It is unexpected for Python modules to have side effects, so some 1315 linters and formatters will complain. Use this API to import Python 1316 modules that contain these torch.library side effects. 1317 1318 Args: 1319 module (str): The name of the Python module to import 1320 1321 """ 1322 importlib.import_module(module) 1323 1324 def load_library(self, path): 1325 """ 1326 Loads a shared library from the given path into the current process. 1327 1328 The library being loaded may run global initialization code to register 1329 custom operators with the PyTorch JIT runtime. This allows dynamically 1330 loading custom operators. For this, you should compile your operator 1331 and the static registration code into a shared library object, and then 1332 call ``torch.ops.load_library('path/to/libcustom.so')`` to load the 1333 shared object. 1334 1335 After the library is loaded, it is added to the 1336 ``torch.ops.loaded_libraries`` attribute, a set that may be inspected 1337 for the paths of all libraries loaded using this function. 1338 1339 Args: 1340 path (str): A path to a shared library to load. 1341 """ 1342 if torch._running_with_deploy(): 1343 return 1344 1345 path = _utils_internal.resolve_library_path(path) 1346 with dl_open_guard(): 1347 # Import the shared library into the process, thus running its 1348 # static (global) initialization code in order to register custom 1349 # operators with the JIT. 1350 ctypes.CDLL(path) 1351 self.loaded_libraries.add(path) 1352 1353 1354# The ops "namespace" 1355ops: _Ops = _Ops() 1356