1# mypy: allow-untyped-defs 2import contextlib 3import functools 4import inspect 5import re 6import sys 7import traceback 8import weakref 9from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union 10from typing_extensions import deprecated 11 12import torch 13import torch._library as _library 14from torch._library.custom_ops import ( 15 _maybe_get_opdef, 16 custom_op, 17 CustomOpDef, 18 device_types_t, 19) 20from torch._library.infer_schema import infer_schema # noqa: F401 21from torch._ops import OpOverload 22 23 24__all__ = [ 25 "Library", 26 "impl", 27 "define", 28 "fallthrough_kernel", 29 "impl_abstract", 30 "register_fake", 31 "register_torch_dispatch", 32 "register_vmap", 33 "get_ctx", 34 "custom_op", 35 "infer_schema", 36] 37 38# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered 39# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. 40# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid 41# libraries calling into kernels not intended to be called. 42_impls: Set[str] = set() 43_defs: Set[str] = set() 44 45# prim is reserved by TorchScript interpreter 46_reserved_namespaces = ["prim"] 47 48 49def fallthrough_kernel(): 50 """ 51 A dummy function to pass to ``Library.impl`` in order to register a fallthrough. 52 """ 53 raise NotImplementedError("fallthrough_kernel() should never be called.") 54 55 56class Library: 57 """ 58 A class to create libraries that can be used to register new operators or 59 override operators in existing libraries from Python. 60 A user can optionally pass in a dispatch keyname if they only want to register 61 kernels corresponding to only one specific dispatch key. 62 63 To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". 64 To create a new library (with name ns) to register new operators, set the kind to "DEF". 65 To create a fragment of a possibly existing library to register operators (and bypass 66 the limitation that there is only one library for a given namespace), set the kind to 67 "FRAGMENT". 68 69 Args: 70 ns: library name 71 kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" 72 dispatch_key: PyTorch dispatch key (default: "") 73 """ 74 75 def __init__(self, ns, kind, dispatch_key=""): 76 if kind not in ("IMPL", "DEF", "FRAGMENT"): 77 raise ValueError("Unsupported kind: ", kind) 78 79 if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): 80 raise ValueError( 81 ns, 82 " is a reserved namespace. Please try creating a library with another name.", 83 ) 84 85 frame = traceback.extract_stack(limit=3)[0] 86 filename, lineno = frame.filename, frame.lineno 87 self.m: Optional[Any] = torch._C._dispatch_library( 88 kind, ns, dispatch_key, filename, lineno 89 ) 90 self.ns = ns 91 self._op_defs: Set[str] = set() 92 self._op_impls: Set[str] = set() 93 self._registration_handles: List[torch._library.utils.RegistrationHandle] = [] 94 self.kind = kind 95 self.dispatch_key = dispatch_key 96 # Use a finalizer to setup the "destructor" instead of __del__. 97 # Python __del__ can lead to weird things (globals and locals may already 98 # be gone when __del__ actually gets called!). finalizers help the 99 # situation because it lets us capture references and keeps them alive 100 weakref.finalize( 101 self, 102 _del_library, 103 _impls, 104 self._op_impls, 105 _defs, 106 self._op_defs, 107 self._registration_handles, 108 ) 109 110 def __repr__(self): 111 return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" 112 113 def define(self, schema, alias_analysis="", *, tags=()): 114 r"""Defines a new operator and its semantics in the ns namespace. 115 116 Args: 117 schema: function schema to define a new operator. 118 alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be 119 inferred from the schema (default behavior) or not ("CONSERVATIVE"). 120 tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this 121 operator. Tagging an operator changes the operator's behavior 122 under various PyTorch subsystems; please read the docs for the 123 torch.Tag carefully before applying it. 124 125 Returns: 126 name of the operator as inferred from the schema. 127 128 Example:: 129 >>> my_lib = Library("mylib", "DEF") 130 >>> my_lib.define("sum(Tensor self) -> Tensor") 131 """ 132 # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid 133 # AliasAnalysis type in C++ 134 if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: 135 raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") 136 assert self.m is not None 137 if isinstance(tags, torch.Tag): 138 tags = (tags,) 139 140 name = schema.split("(")[0] 141 packet_name = name.split(".")[0] if "." in name else name 142 has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( 143 getattr(torch.ops, self.ns), packet_name 144 ) 145 146 result = self.m.define(schema, alias_analysis, tuple(tags)) 147 name = schema.split("(")[0] 148 qualname = self.ns + "::" + name 149 150 # If the OpOverloadPacket exists already, then this means we're adding a 151 # new OpOverload for it. Refresh the packet to include the new OpOverload. 152 if has_preexisting_packet: 153 ns = getattr(torch.ops, self.ns) 154 packet = getattr(ns, packet_name) 155 torch._ops._refresh_packet(packet) 156 157 self._op_defs.add(qualname) 158 _defs.add(qualname) 159 return result 160 161 def _register_fake(self, op_name, fn, _stacklevel=1): 162 r"""Registers the fake impl for an operator defined in the library.""" 163 source = torch._library.utils.get_source(_stacklevel + 1) 164 frame = sys._getframe(_stacklevel) 165 caller_module = inspect.getmodule(frame) 166 # Can be none if you call register_fake from somewhere there isn't a module 167 # (e.g. __main__) 168 caller_module_name = None if caller_module is None else caller_module.__name__ 169 170 # TODO(rzou): We're gonna need to stage this change with torchvision, 171 # since torchvision is github first. 172 if caller_module_name is not None and caller_module_name.startswith( 173 "torchvision." 174 ): 175 caller_module_name = None 176 177 qualname = f"{self.ns}::{op_name}" 178 entry = torch._library.simple_registry.singleton.find(qualname) 179 if caller_module_name is not None: 180 func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) 181 else: 182 func_to_register = fn 183 184 handle = entry.fake_impl.register(func_to_register, source) 185 self._registration_handles.append(handle) 186 187 def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): 188 r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. 189 190 This allows for open registration to specify the behavior between the operator 191 and the torch_dispatch_class without needing to modify the torch_dispatch_class 192 or the operator directly. 193 194 The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a 195 TorchDispatchMode. 196 197 If it is a Tensor subclass, we expect fn to have the following signature: 198 (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any 199 200 If it is a TorchDispatchMode, we expect fn to have the following signature: 201 (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any 202 """ 203 qualname = f"{self.ns}::{op_name}" 204 entry = torch._library.simple_registry.singleton.find(qualname) 205 handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) 206 self._registration_handles.append(handle) 207 208 def _impl_with_aoti_compile(self, op_name, dispatch_key=""): 209 r"""Register the operator to use the AOTI-compiled implementation. 210 211 Args: 212 op_name: operator name (along with the overload) or OpOverload object. 213 dispatch_key: dispatch key that the input function should be registered for. By default, it uses 214 the dispatch key that the library was created with. 215 216 Example:: 217 >>> my_lib = Library("aten", "IMPL") 218 >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") 219 """ 220 if dispatch_key == "": 221 dispatch_key = self.dispatch_key 222 assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) 223 224 if isinstance(op_name, str): 225 name = op_name 226 elif isinstance(op_name, OpOverload): 227 name = op_name._schema.name 228 overload_name = op_name._schema.overload_name 229 if overload_name != "": 230 name = name + "." + overload_name 231 else: 232 raise RuntimeError( 233 "_impl_with_aoti_compile should be passed either a name or an OpOverload object " 234 "as the first argument" 235 ) 236 237 key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key 238 if key in _impls: 239 # TODO: in future, add more info about where the existing function is registered (this info is 240 # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) 241 raise RuntimeError( 242 "This is not allowed since there's already a kernel registered from python overriding {}" 243 "'s behavior for {} dispatch key and {} namespace.".format( 244 name.split("::")[-1], dispatch_key, self.ns 245 ) 246 ) 247 248 assert self.m is not None 249 impl_fn: Callable = self.m.impl_with_aoti_compile 250 impl_fn(self.ns, name.split("::")[-1], dispatch_key) 251 252 _impls.add(key) 253 self._op_impls.add(key) 254 255 def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False): 256 r"""Registers the function implementation for an operator defined in the library. 257 258 Args: 259 op_name: operator name (along with the overload) or OpOverload object. 260 fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` 261 to register a fallthrough. 262 dispatch_key: dispatch key that the input function should be registered for. By default, it uses 263 the dispatch key that the library was created with. 264 with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument 265 to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. 266 267 Example:: 268 >>> my_lib = Library("aten", "IMPL") 269 >>> def div_cpu(self, other): 270 >>> return self * (1 / other) 271 >>> my_lib.impl("div.Tensor", div_cpu, "CPU") 272 """ 273 if not callable(fn): 274 raise TypeError( 275 f"Input function is required to be a callable but found type {type(fn)}" 276 ) 277 if dispatch_key == "": 278 dispatch_key = self.dispatch_key 279 280 if isinstance(op_name, str): 281 name = op_name 282 elif isinstance(op_name, OpOverload): 283 name = op_name._schema.name 284 overload_name = op_name._schema.overload_name 285 if overload_name != "": 286 name = name + "." + overload_name 287 else: 288 raise RuntimeError( 289 "impl should be passed either a name or an OpOverload object as the first argument" 290 ) 291 292 key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key 293 if key in _impls: 294 # TODO: in future, add more info about where the existing function is registered (this info is 295 # today already returned by the C++ warning when impl is called but we error out before that) 296 raise RuntimeError( 297 "This is not allowed since there's already a kernel registered from python overriding {}" 298 "'s behavior for {} dispatch key and {} namespace.".format( 299 name.split("::")[-1], dispatch_key, self.ns 300 ) 301 ) 302 303 if dispatch_key == "Meta": 304 dispatcher_op_name = name 305 if "::" not in dispatcher_op_name: 306 dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" 307 308 # Internally, we shouldn't be registering meta kernels for any operators that 309 # have CompositeImplicitAutograd kernels. 310 # Instead, we should be letting those decompositions run, and writing meta kernels 311 # only for the base operators. 312 if torch._C._dispatch_has_kernel_for_dispatch_key( 313 dispatcher_op_name, "CompositeImplicitAutograd" 314 ): 315 raise RuntimeError( 316 f"We should not register a meta kernel directly to the operator '{name}'," 317 " because it has a CompositeImplicitAutograd kernel in core." 318 " Instead we should let the operator decompose, and ensure that we have meta kernels" 319 " for the base ops that it decomposes into." 320 ) 321 322 assert self.m is not None 323 self.m.impl( 324 name, 325 dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", 326 fn, 327 with_keyset, 328 ) 329 330 _impls.add(key) 331 self._op_impls.add(key) 332 333 def fallback(self, fn, dispatch_key="", *, with_keyset=False): 334 r"""Registers the function implementation as the fallback for the given key. 335 336 This function only works for a library with global namespace ("_"). 337 338 Args: 339 fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` 340 to register a fallthrough. 341 dispatch_key: dispatch key that the input function should be registered for. By default, it uses 342 the dispatch key that the library was created with. 343 with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument 344 to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. 345 346 Example:: 347 >>> my_lib = Library("_", "IMPL") 348 >>> def fallback_kernel(op, *args, **kwargs): 349 >>> # Handle all autocast ops generically 350 >>> # ... 351 >>> my_lib.fallback(fallback_kernel, "Autocast") 352 """ 353 if dispatch_key == "": 354 dispatch_key = self.dispatch_key 355 356 if self.ns != "_": 357 raise RuntimeError( 358 f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" 359 ) 360 361 assert dispatch_key != "" 362 assert self.m is not None 363 364 self.m.fallback(dispatch_key, fn, with_keyset) 365 366 def _destroy(self): 367 if self.m is not None: 368 self.m.reset() 369 self.m = None 370 for handle in self._registration_handles: 371 handle.destroy() 372 self._registration_handles.clear() 373 global _impls 374 _impls -= self._op_impls 375 for name in self._op_defs: 376 # Delete the cached torch.ops.ns.foo if it was registered. 377 # Otherwise, accessing it leads to a segfault. 378 # It's possible that we only registered an overload in this Library 379 # and another library owns an alive overload. 380 # That's OK - the next time torch.ops.ns.foo gets called, it'll be 381 # recomputed to point at the right collection of overloads. 382 ns, name_with_overload = name.split("::") 383 name = name_with_overload.split(".")[0] 384 if not hasattr(torch.ops, ns): 385 continue 386 namespace = getattr(torch.ops, ns) 387 if not hasattr(namespace, name): 388 continue 389 delattr(namespace, name) 390 391 392def _del_library( 393 captured_impls, 394 op_impls, 395 captured_defs, 396 op_defs, 397 registration_handles, 398): 399 captured_impls -= op_impls 400 captured_defs -= op_defs 401 for handle in registration_handles: 402 handle.destroy() 403 404 405@contextlib.contextmanager 406def _scoped_library(*args, **kwargs): 407 try: 408 lib = Library(*args, **kwargs) 409 yield lib 410 finally: 411 lib._destroy() 412 413 414_keep_alive: List[Library] = [] 415 416 417NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") 418 419 420@functools.singledispatch 421def define(qualname, schema, *, lib=None, tags=()): 422 r"""Defines a new operator. 423 424 In PyTorch, defining an op (short for "operator") is a two step-process: 425 - we need to define the op (by providing an operator name and schema) 426 - we need to implement behavior for how the operator interacts with 427 various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. 428 429 This entrypoint defines the custom operator (the first step) 430 you must then perform the second step by calling various 431 ``impl_*`` APIs, like :func:`torch.library.impl` or 432 :func:`torch.library.register_fake`. 433 434 Args: 435 qualname (str): The qualified name for the operator. Should be 436 a string that looks like "namespace::name", e.g. "aten::sin". 437 Operators in PyTorch need a namespace to 438 avoid name collisions; a given operator may only be created once. 439 If you are writing a Python library, we recommend the namespace to 440 be the name of your top-level module. 441 schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" 442 for an op that accepts one Tensor and returns one Tensor. It does 443 not contain the operator name (that is passed in ``qualname``). 444 lib (Optional[Library]): If provided, the lifetime of this operator 445 will be tied to the lifetime of the Library object. 446 tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this 447 operator. Tagging an operator changes the operator's behavior 448 under various PyTorch subsystems; please read the docs for the 449 torch.Tag carefully before applying it. 450 451 Example:: 452 >>> import torch 453 >>> import numpy as np 454 >>> 455 >>> # Define the operator 456 >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") 457 >>> 458 >>> # Add implementations for the operator 459 >>> @torch.library.impl("mylib::sin", "cpu") 460 >>> def f(x): 461 >>> return torch.from_numpy(np.sin(x.numpy())) 462 >>> 463 >>> # Call the new operator from torch.ops. 464 >>> x = torch.randn(3) 465 >>> y = torch.ops.mylib.sin(x) 466 >>> assert torch.allclose(y, x.sin()) 467 468 """ 469 if not isinstance(qualname, str): 470 raise ValueError( 471 f"define(qualname, schema): expected qualname " 472 f"to be instance of str, got {type(qualname)}" 473 ) 474 namespace, name = torch._library.utils.parse_namespace(qualname) 475 if lib is None: 476 lib = Library(namespace, "FRAGMENT") 477 _keep_alive.append(lib) 478 if not NAMELESS_SCHEMA.fullmatch(schema): 479 raise ValueError( 480 f"define(qualname, schema, ...): expected schema " 481 f'to look like e.g. "(Tensor x) -> Tensor" but ' 482 f'got "{schema}"' 483 ) 484 lib.define(name + schema, alias_analysis="", tags=tags) 485 486 487@define.register 488def _(lib: Library, schema, alias_analysis=""): 489 """The old torch.library.define. 490 We're keeping this around for BC reasons 491 """ 492 493 def wrap(f): 494 name = lib.define(schema, alias_analysis) 495 lib.impl(name, f) 496 return f 497 498 return wrap 499 500 501@functools.singledispatch 502def impl(qualname, types, func=None, *, lib=None): 503 """Register an implementation for a device type for this operator. 504 505 You may pass "default" for ``types`` to register this implementation as the 506 default implementation for ALL device types. 507 Please only use this if the implementation truly supports all device types; 508 for example, this is true if it is a composition of built-in PyTorch operators. 509 510 Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". 511 512 Args: 513 qualname (str): Should be a string that looks like "namespace::operator_name". 514 types (str | Sequence[str]): The device types to register an impl to. 515 lib (Optional[Library]): If provided, the lifetime of this registration 516 will be tied to the lifetime of the Library object. 517 518 Examples: 519 >>> import torch 520 >>> import numpy as np 521 >>> 522 >>> # Define the operator 523 >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") 524 >>> 525 >>> # Add implementations for the cpu device 526 >>> @torch.library.impl("mylib::mysin", "cpu") 527 >>> def f(x): 528 >>> return torch.from_numpy(np.sin(x.numpy())) 529 >>> 530 >>> x = torch.randn(3) 531 >>> y = torch.ops.mylib.mysin(x) 532 >>> assert torch.allclose(y, x.sin()) 533 """ 534 return _impl(qualname, types, func, lib=lib, disable_dynamo=False) 535 536 537def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False): 538 if isinstance(types, str): 539 types = (types,) 540 keys = set({}) 541 for typ in types: 542 is_dispatch_key = torch._C._parse_dispatch_key(typ) 543 if is_dispatch_key: 544 # We also support passing a DispatchKey to impl. Please prefer using 545 # the higher-level torch.library APIs and only pass DispatchKey to 546 # torch.library.impl with caution (or even better, don't use this 547 # option and file an issue on GitHub for what you need). 548 # We don't advertise this to users because 549 # it is very easy to shoot yourself in the foot. 550 keys.add(typ) 551 else: 552 keys.add(_device_type_to_key(typ)) 553 554 def register(func): 555 namespace, _ = torch._library.utils.parse_namespace(qualname) 556 557 if lib is None: 558 use_lib = Library(namespace, "FRAGMENT") 559 _keep_alive.append(use_lib) 560 else: 561 use_lib = lib 562 if disable_dynamo: 563 564 @torch._disable_dynamo 565 def func_no_dynamo(*args, **kwargs): 566 return func(*args, **kwargs) 567 568 for key in keys: 569 use_lib.impl(qualname, func_no_dynamo, key) 570 else: 571 for key in keys: 572 use_lib.impl(qualname, func, key) 573 574 if func is None: 575 return register 576 else: 577 register(func) 578 579 580def _device_type_to_key(device_type: str) -> str: 581 if device_type == "default": 582 # This is technically not correct, because although all device_type 583 # DispatchKeys are included in CompositeExplicitAutograd, 584 # not everything in CompositeExplicitAutograd is associated with a 585 # device_type. I don't really care that much about the difference. 586 return "CompositeExplicitAutograd" 587 return torch._C._dispatch_key_for_device(device_type) 588 589 590@impl.register 591def _(lib: Library, name, dispatch_key=""): 592 """Legacy torch.library.impl API. Kept around for BC""" 593 594 def wrap(f): 595 lib.impl(name, f, dispatch_key) 596 return f 597 598 return wrap 599 600 601@deprecated( 602 "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " 603 "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", 604 category=FutureWarning, 605) 606def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): 607 r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. 608 Please use that instead. 609 """ 610 if func is not None: 611 _stacklevel = _stacklevel + 1 612 return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) 613 614 615_op_identifier = Union[ 616 str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" 617] 618 619 620def register_kernel( 621 op: _op_identifier, 622 device_types: device_types_t, 623 func: Optional[Callable] = None, 624 /, 625 *, 626 lib: Optional[Library] = None, 627): 628 """Register an implementation for a device type for this operator. 629 630 Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". 631 This API may be used as a decorator. 632 633 Args: 634 fn (Callable): The function to register as the implementation for 635 the given device types. 636 device_types (None | str | Sequence[str]): The device_types to register an impl to. 637 If None, we will register to all device types -- please only use 638 this option if your implementation is truly device-type-agnostic. 639 640 Examples:: 641 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 642 >>> import torch 643 >>> from torch import Tensor 644 >>> from torch.library import custom_op 645 >>> import numpy as np 646 >>> 647 >>> # Create a custom op that works on cpu 648 >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") 649 >>> def numpy_sin(x: Tensor) -> Tensor: 650 >>> x_np = x.numpy() 651 >>> y_np = np.sin(x_np) 652 >>> return torch.from_numpy(y_np) 653 >>> 654 >>> # Add implementations for the cuda device 655 >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") 656 >>> def _(x): 657 >>> x_np = x.cpu().numpy() 658 >>> y_np = np.sin(x_np) 659 >>> return torch.from_numpy(y_np).to(device=x.device) 660 >>> 661 >>> x_cpu = torch.randn(3) 662 >>> x_cuda = x_cpu.cuda() 663 >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) 664 >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) 665 666 """ 667 668 if not isinstance( 669 op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 670 ): 671 raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}") 672 if isinstance(op, torch._ops.OpOverload): 673 op = op._name 674 opdef = _maybe_get_opdef(op) 675 if opdef is not None: 676 return opdef.register_kernel(device_types, func) 677 assert isinstance(op, str) 678 if device_types is None: 679 device_types = "CompositeExplicitAutograd" 680 681 return _impl(op, device_types, func, lib=lib, disable_dynamo=True) 682 683 684def register_fake( 685 op: _op_identifier, 686 func: Optional[Callable] = None, 687 /, 688 *, 689 lib: Optional[Library] = None, 690 _stacklevel: int = 1, 691): 692 r"""Register a FakeTensor implementation ("fake impl") for this operator. 693 694 Also sometimes known as a "meta kernel", "abstract impl". 695 696 An "FakeTensor implementation" specifies the behavior of this operator on 697 Tensors that carry no data ("FakeTensor"). Given some input Tensors with 698 certain properties (sizes/strides/storage_offset/device), it specifies 699 what the properties of the output Tensors are. 700 701 The FakeTensor implementation has the same signature as the operator. 702 It is run for both FakeTensors and meta tensors. To write a FakeTensor 703 implementation, assume that all Tensor inputs to the operator are 704 regular CPU/CUDA/Meta tensors, but they do not have storage, and 705 you are trying to return regular CPU/CUDA/Meta tensor(s) as output. 706 The FakeTensor implementation must consist of only PyTorch operations 707 (and may not directly access the storage or data of any input or 708 intermediate Tensors). 709 710 This API may be used as a decorator (see examples). 711 712 For a detailed guide on custom ops, please see 713 https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html 714 715 Examples: 716 >>> import torch 717 >>> import numpy as np 718 >>> from torch import Tensor 719 >>> 720 >>> # Example 1: an operator without data-dependent output shape 721 >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) 722 >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 723 >>> raise NotImplementedError("Implementation goes here") 724 >>> 725 >>> @torch.library.register_fake("mylib::custom_linear") 726 >>> def _(x, weight, bias): 727 >>> assert x.dim() == 2 728 >>> assert weight.dim() == 2 729 >>> assert bias.dim() == 1 730 >>> assert x.shape[1] == weight.shape[1] 731 >>> assert weight.shape[0] == bias.shape[0] 732 >>> assert x.device == weight.device 733 >>> 734 >>> return (x @ weight.t()) + bias 735 >>> 736 >>> with torch._subclasses.fake_tensor.FakeTensorMode(): 737 >>> x = torch.randn(2, 3) 738 >>> w = torch.randn(3, 3) 739 >>> b = torch.randn(3) 740 >>> y = torch.ops.mylib.custom_linear(x, w, b) 741 >>> 742 >>> assert y.shape == (2, 3) 743 >>> 744 >>> # Example 2: an operator with data-dependent output shape 745 >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) 746 >>> def custom_nonzero(x: Tensor) -> Tensor: 747 >>> x_np = x.numpy(force=True) 748 >>> res = np.stack(np.nonzero(x_np), axis=1) 749 >>> return torch.tensor(res, device=x.device) 750 >>> 751 >>> @torch.library.register_fake("mylib::custom_nonzero") 752 >>> def _(x): 753 >>> # Number of nonzero-elements is data-dependent. 754 >>> # Since we cannot peek at the data in an fake impl, 755 >>> # we use the ctx object to construct a new symint that 756 >>> # represents the data-dependent size. 757 >>> ctx = torch.library.get_ctx() 758 >>> nnz = ctx.new_dynamic_size() 759 >>> shape = [nnz, x.dim()] 760 >>> result = x.new_empty(shape, dtype=torch.int64) 761 >>> return result 762 >>> 763 >>> from torch.fx.experimental.proxy_tensor import make_fx 764 >>> 765 >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) 766 >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) 767 >>> trace.print_readable() 768 >>> 769 >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) 770 771 """ 772 if not isinstance( 773 op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 774 ): 775 raise ValueError("register_fake(op): got unexpected type for op: {type(op)}") 776 if isinstance(op, torch._ops.OpOverload): 777 op = op._name 778 opdef = _maybe_get_opdef(op) 779 if opdef is not None: 780 if func is None: 781 return opdef.register_fake 782 else: 783 return opdef.register_fake(func) 784 assert isinstance(op, str) 785 786 stacklevel = _stacklevel 787 788 def register(func): 789 namespace, op_name = torch._library.utils.parse_namespace(op) 790 if lib is None: 791 use_lib = Library(namespace, "FRAGMENT") 792 _keep_alive.append(use_lib) 793 else: 794 use_lib = lib 795 use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) 796 return func 797 798 if func is None: 799 return register 800 else: 801 stacklevel += 1 802 return register(func) 803 804 805def register_autograd( 806 op: _op_identifier, 807 backward: Callable, 808 /, 809 *, 810 setup_context: Optional[Callable] = None, 811 lib=None, 812) -> None: 813 r"""Register a backward formula for this custom op. 814 815 In order for an operator to work with autograd, you need to register 816 a backward formula: 817 1. You must tell us how to compute gradients during the backward pass 818 by providing us a "backward" function. 819 2. If you need any values from the forward to compute gradients, you can 820 use `setup_context` to save values for backward. 821 822 ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: 823 - ``grads`` is one or more gradients. The number of gradients matches 824 the number of outputs of the operator. 825 The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by 826 :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the 827 same as :meth:`torch.autograd.Function.backward`. 828 829 ``setup_context(ctx, inputs, output)`` runs during the forward pass. 830 Please save quantities needed for backward onto the ``ctx`` object via 831 either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` 832 or assigning them as attributes of ``ctx``. If your custom op has 833 kwarg-only arguments, we expect the signature of ``setup_context`` 834 to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. 835 836 Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, 837 they may not directly access :meth:`torch.Tensor.data_ptr` and they must 838 not depend on or mutate global state. If you need a non-traceable backward, 839 you can make it a separate custom_op that you call inside ``backward_fn``. 840 841 Examples: 842 >>> import torch 843 >>> import numpy as np 844 >>> from torch import Tensor 845 >>> 846 >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) 847 >>> def numpy_sin(x: Tensor) -> Tensor: 848 >>> x_np = x.cpu().numpy() 849 >>> y_np = np.sin(x_np) 850 >>> return torch.from_numpy(y_np).to(device=x.device) 851 >>> 852 >>> def setup_context(ctx, inputs, output) -> Tensor: 853 >>> x, = inputs 854 >>> ctx.save_for_backward(x) 855 >>> 856 >>> def backward(ctx, grad): 857 >>> x, = ctx.saved_tensors 858 >>> return grad * x.cos() 859 >>> 860 >>> torch.library.register_autograd( 861 ... "mylib::numpy_sin", backward, setup_context=setup_context 862 ... ) 863 >>> 864 >>> x = torch.randn(3, requires_grad=True) 865 >>> y = numpy_sin(x) 866 >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 867 >>> assert torch.allclose(grad_x, x.cos()) 868 >>> 869 >>> # Example with a keyword-only arg 870 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 871 >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: 872 >>> x_np = x.cpu().numpy() 873 >>> y_np = x_np * val 874 >>> return torch.from_numpy(y_np).to(device=x.device) 875 >>> 876 >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: 877 >>> ctx.val = keyword_only_inputs["val"] 878 >>> 879 >>> def backward(ctx, grad): 880 >>> return grad * ctx.val 881 >>> 882 >>> torch.library.register_autograd( 883 ... "mylib::numpy_mul", backward, setup_context=setup_context 884 ... ) 885 >>> 886 >>> x = torch.randn(3, requires_grad=True) 887 >>> y = numpy_mul(x, val=3.14) 888 >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 889 >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) 890 891 """ 892 if not isinstance( 893 op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 894 ): 895 raise ValueError( 896 f"register_autograd(op): got unexpected type for op: {type(op)}" 897 ) 898 if isinstance(op, torch._ops.OpOverload): 899 op = op._name 900 opdef = _maybe_get_opdef(op) 901 if opdef is not None: 902 opdef.register_autograd(backward, setup_context=setup_context) 903 return 904 905 assert isinstance(op, str) 906 qualname = op 907 op = torch._library.utils.lookup_op(qualname) 908 schema = op._schema 909 if not _library.utils.is_functional_schema(schema): 910 raise RuntimeError( 911 f"Cannot register autograd formula for non-functional operator " 912 f"{op} with schema {schema}. Please create " 913 f"a functional operator and register an autograd formula for that." 914 ) 915 if _library.utils.has_kwarg_only_tensors(schema): 916 raise NotImplementedError( 917 f"register_autograd with kwarg-only Tensor args. In the original " 918 f"definition of the op, please make your tensors not kwarg-only. " 919 f"Got: {schema}" 920 ) 921 922 info = _library.autograd.Info(backward, setup_context) 923 autograd_kernel = _library.autograd.make_autograd_impl(op, info) 924 namespace, opname = torch._library.utils.parse_namespace(qualname) 925 if lib is None: 926 lib = Library(namespace, "FRAGMENT") 927 _keep_alive.append(lib) 928 lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) 929 930 931def register_torch_dispatch( 932 op: _op_identifier, 933 torch_dispatch_class: Any, 934 func: Optional[Callable] = None, 935 /, 936 *, 937 lib: Optional[Library] = None, 938): 939 r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. 940 941 This allows for open registration to specify the behavior between the operator 942 and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` 943 or the operator directly. 944 945 The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a 946 TorchDispatchMode. 947 948 If it is a Tensor subclass, we expect ``func`` to have the following signature: 949 ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` 950 951 If it is a TorchDispatchMode, we expect ``func`` to have the following signature: 952 ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` 953 954 ``args`` and ``kwargs`` will have been normalized the same way they are 955 in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). 956 957 Examples: 958 959 >>> import torch 960 >>> 961 >>> @torch.library.custom_op("mylib::foo", mutates_args={}) 962 >>> def foo(x: torch.Tensor) -> torch.Tensor: 963 >>> return x.clone() 964 >>> 965 >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 966 >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): 967 >>> return func(*args, **kwargs) 968 >>> 969 >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) 970 >>> def _(mode, func, types, args, kwargs): 971 >>> x, = args 972 >>> return x + 1 973 >>> 974 >>> x = torch.randn(3) 975 >>> y = foo(x) 976 >>> assert torch.allclose(y, x) 977 >>> 978 >>> with MyMode(): 979 >>> y = foo(x) 980 >>> assert torch.allclose(y, x + 1) 981 982 """ 983 if not isinstance( 984 op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 985 ): 986 raise ValueError( 987 "register_torch_dispatch(op): got unexpected type for op: {type(op)}" 988 ) 989 if isinstance(op, torch._ops.OpOverload): 990 op = op._name 991 opdef = _maybe_get_opdef(op) 992 if opdef is not None: 993 return opdef.register_torch_dispatch(torch_dispatch_class, func) 994 assert isinstance(op, str) 995 996 def register(func): 997 namespace, op_name = torch._library.utils.parse_namespace(op) 998 if lib is None: 999 use_lib = Library(namespace, "FRAGMENT") 1000 _keep_alive.append(use_lib) 1001 else: 1002 use_lib = lib 1003 use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) 1004 return func 1005 1006 if func is None: 1007 return register 1008 else: 1009 return register(func) 1010 1011 1012def register_vmap( 1013 op: _op_identifier, 1014 func: Optional[Callable] = None, 1015 /, 1016 *, 1017 lib=None, 1018): 1019 r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. 1020 1021 This API may be used as a decorator (see examples). 1022 1023 In order for an operator to work with :func:`torch.vmap`, you may need to register a 1024 vmap implementation in the following signature: 1025 1026 ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, 1027 1028 where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. 1029 We do not support kwarg-only Tensor args. 1030 1031 It specifies how do we compute the batched version of ``op`` given inputs with an additional 1032 dimension (specified by ``in_dims``). 1033 1034 For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` 1035 if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer 1036 specifying what dimension of the Tensor is being vmapped over. 1037 1038 ``info`` is a collection of additional metadata that may be helpful: 1039 ``info.batch_size`` specifies the size of the dimension being vmapped over, while 1040 ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. 1041 1042 The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, 1043 ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` 1044 per output that specifies if the output has the vmapped dimension and what index it is in. 1045 1046 Examples: 1047 >>> import torch 1048 >>> import numpy as np 1049 >>> from torch import Tensor 1050 >>> from typing import Tuple 1051 >>> 1052 >>> def to_numpy(tensor): 1053 >>> return tensor.cpu().numpy() 1054 >>> 1055 >>> lib = torch.library.Library("mylib", "FRAGMENT") 1056 >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) 1057 >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: 1058 >>> x_np = to_numpy(x) 1059 >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) 1060 >>> return torch.tensor(x_np ** 3, device=x.device), dx 1061 >>> 1062 >>> def numpy_cube_vmap(info, in_dims, x): 1063 >>> result = numpy_cube(x) 1064 >>> return result, (in_dims[0], in_dims[0]) 1065 >>> 1066 >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) 1067 >>> 1068 >>> x = torch.randn(3) 1069 >>> torch.vmap(numpy_cube)(x) 1070 >>> 1071 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 1072 >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: 1073 >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) 1074 >>> 1075 >>> @torch.library.register_vmap("mylib::numpy_mul") 1076 >>> def numpy_mul_vmap(info, in_dims, x, y): 1077 >>> x_bdim, y_bdim = in_dims 1078 >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 1079 >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 1080 >>> result = x * y 1081 >>> result = result.movedim(-1, 0) 1082 >>> return result, 0 1083 >>> 1084 >>> 1085 >>> x = torch.randn(3) 1086 >>> y = torch.randn(3) 1087 >>> torch.vmap(numpy_mul)(x, y) 1088 1089 .. note:: 1090 The vmap function should aim to preserve the semantics of the entire custom operator. 1091 That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. 1092 1093 If your custom operator has any custom behavior in the backward pass, please 1094 keep this in mind. 1095 1096 """ 1097 if not isinstance( 1098 op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 1099 ): 1100 raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}") 1101 if isinstance(op, torch._ops.OpOverload): 1102 op = op._name 1103 opdef = _maybe_get_opdef(op) 1104 if opdef is not None: 1105 return opdef.register_vmap(func) 1106 assert isinstance(op, str) 1107 qualname = op 1108 op = torch._library.utils.lookup_op(qualname) 1109 schema = op._schema 1110 if _library.utils.has_kwarg_only_tensors(schema): 1111 raise NotImplementedError( 1112 f"register_vmap with kwarg-only Tensor args. In the original " 1113 f"definition of the op, please make your tensors not kwarg-only. " 1114 f"Got: {schema}" 1115 ) 1116 1117 def register(func): 1118 nonlocal op, lib 1119 1120 namespace, opname = torch._library.utils.parse_namespace(qualname) 1121 if lib is None: 1122 lib = Library(namespace, "FRAGMENT") 1123 _keep_alive.append(lib) 1124 1125 from torch._functorch.autograd_function import custom_function_call_vmap_helper 1126 from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter 1127 1128 def wrapped_func(keyset, *args, **kwargs): 1129 interpreter = retrieve_current_functorch_interpreter() 1130 return custom_function_call_vmap_helper( 1131 interpreter, func, op, *args, **kwargs 1132 ) 1133 1134 lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) 1135 1136 if func is None: 1137 return register 1138 else: 1139 return register(func) 1140 1141 1142# If the op was defined in C++, then we want to make sure there was an 1143# m.set_python_module(module, ...) call and that the module is the 1144# same as the module that called torch.library.register_fake. 1145def _check_pystubs_once(func, qualname, actual_module_name): 1146 checked = False 1147 1148 def inner(*args, **kwargs): 1149 nonlocal checked 1150 if checked: 1151 return func(*args, **kwargs) 1152 1153 op = torch._library.utils.lookup_op(qualname) 1154 if op._defined_in_python: 1155 checked = True 1156 return func(*args, **kwargs) 1157 1158 maybe_pystub = torch._C._dispatch_pystub( 1159 op._schema.name, op._schema.overload_name 1160 ) 1161 if maybe_pystub is None: 1162 if torch._library.utils.requires_set_python_module(): 1163 namespace = op.namespace 1164 cpp_filename = op._handle.debug() 1165 raise RuntimeError( 1166 f"Operator '{qualname}' was defined in C++ and has a Python " 1167 f"fake impl. In this situation, we require there to also be a " 1168 f'companion C++ `m.set_python_module("{actual_module_name}")` ' 1169 f"call, but we could not find one. Please add that to " 1170 f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " 1171 f"operator was registered in ({cpp_filename})" 1172 ) 1173 else: 1174 pystub_module = maybe_pystub[0] 1175 if actual_module_name != pystub_module: 1176 cpp_filename = op._handle.debug() 1177 raise RuntimeError( 1178 f"Operator '{qualname}' specified that its python fake impl " 1179 f"is in the Python module '{pystub_module}' but it was actually found " 1180 f"in '{actual_module_name}'. Please either move the fake impl " 1181 f"or correct the m.set_python_module call ({cpp_filename})" 1182 ) 1183 checked = True 1184 return func(*args, **kwargs) 1185 1186 return inner 1187 1188 1189# NOTE [ctx inside the fake implementation] 1190# If a user has an operator with data-dependent output shape, then when writing 1191# a fake implementation they must query the current ctx and use methods on the 1192# ctx to construct a new unbacked symint. 1193# 1194# This is done via us setting the global_ctx_getter function every time a fake 1195# implementation is invoked. 1196def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": 1197 """get_ctx() returns the current AbstractImplCtx object. 1198 1199 Calling ``get_ctx()`` is only valid inside of an fake impl 1200 (see :func:`torch.library.register_fake` for more usage details. 1201 """ 1202 return torch._library.fake_impl.global_ctx_getter() 1203 1204 1205_OPCHECK_DEFAULT_UTILS = ( 1206 "test_schema", 1207 "test_autograd_registration", 1208 "test_faketensor", 1209 "test_aot_dispatch_dynamic", 1210) 1211 1212 1213def opcheck( 1214 op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], 1215 args: Tuple[Any, ...], 1216 kwargs: Optional[Dict[str, Any]] = None, 1217 *, 1218 test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, 1219 raise_exception: bool = True, 1220) -> Dict[str, str]: 1221 """Given an operator and some sample arguments, tests if the operator is 1222 registered correctly. 1223 1224 That is, when you use the torch.library/TORCH_LIBRARY APIs to create a 1225 custom op, you specified metadata (e.g. mutability info) about the custom op 1226 and these APIs require that the functions you pass them satisfy certain 1227 properties (e.g. no data pointer access in the fake/meta/abstract kernel) 1228 ``opcheck`` tests these metadata and properties. 1229 1230 Concretely, we test the following: 1231 1232 - test_schema: If the schema matches the implementation of 1233 the operator. For example: if the schema specifies a Tensor is mutated, 1234 then we check the implementation mutates the Tensor. If the schema 1235 specifies that we return a new Tensor, then we check that the 1236 implementation returns a new Tensor (instead of an existing one or 1237 a view of an existing one). 1238 - test_autograd_registration: If the operator supports training 1239 (autograd): we check that its autograd formula is registered via 1240 torch.library.register_autograd or a manual registration to one 1241 or more DispatchKey::Autograd keys. Any other DispatchKey-based 1242 registrations may lead to undefined behavior. 1243 - test_faketensor: If the operator has a FakeTensor kernel 1244 (and if it is correct). The FakeTensor kernel is necessary ( 1245 but not sufficient) for the operator to work with PyTorch compilation 1246 APIs (torch.compile/export/FX). We check that a FakeTensor kernel 1247 (also sometimes known as a meta kernel) was registered for the 1248 operator and that it is correct. This test takes the result of 1249 running the operator on real tensors and the result of running 1250 the operator on FakeTensors and checks that they have the same 1251 Tensor metadata (sizes/strides/dtype/device/etc). 1252 - test_aot_dispatch_dynamic: If the operator has correct behavior 1253 with PyTorch compilation APIs (torch.compile/export/FX). 1254 This checks that the outputs (and gradients, if applicable) are the 1255 same under eager-mode PyTorch and torch.compile. 1256 This test is a superset of ``test_faketensor`` and is an e2e test; 1257 other things it tests are that the operator supports 1258 functionalization and that the backward pass (if it exists) also 1259 supports FakeTensor and functionalization. 1260 1261 For best results, please call ``opcheck`` multiple times with a 1262 representative set of inputs. If your operator supports 1263 autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; 1264 if your operator supports multiple devices (e.g. CPU and CUDA), please 1265 use ``opcheck`` with inputs on all supported devices. 1266 1267 Args: 1268 op: The operator. Must either be a function decorated with 1269 :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket 1270 found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) 1271 args: The args to the operator 1272 kwargs: The kwargs to the operator 1273 test_utils: Tests that we should run. Default: all of them. 1274 Example: ("test_schema", "test_faketensor") 1275 raise_exception: If we should raise an exception on the first 1276 error. If False, we will return a dict with information 1277 on if each test passed or not. 1278 1279 .. warning:: 1280 1281 opcheck and :func:`torch.autograd.gradcheck` test different things; 1282 opcheck tests if your usage of torch.library APIs is correct while 1283 :func:`torch.autograd.gradcheck` tests if your autograd formula is 1284 mathematically correct. Use both to test custom ops that support 1285 gradient computation. 1286 1287 Example: 1288 1289 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 1290 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 1291 >>> def numpy_add(x: Tensor, y: float) -> Tensor: 1292 >>> x_np = x.numpy(force=True) 1293 >>> z_np = x_np + y 1294 >>> return torch.from_numpy(z_np).to(x.device) 1295 >>> 1296 >>> @numpy_sin.register_fake 1297 >>> def _(x, y): 1298 >>> return torch.empty_like(x) 1299 >>> 1300 >>> def setup_context(ctx, inputs, output): 1301 >>> y, = inputs 1302 >>> ctx.y = y 1303 >>> 1304 >>> def backward(ctx, grad): 1305 >>> return grad * ctx.y, None 1306 >>> 1307 >>> numpy_sin.register_autograd(backward, setup_context=setup_context) 1308 >>> 1309 >>> sample_inputs = [ 1310 >>> (torch.randn(3), 3.14), 1311 >>> (torch.randn(2, 3, device='cuda'), 2.718), 1312 >>> (torch.randn(1, 10, requires_grad=True), 1.234), 1313 >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), 1314 >>> ] 1315 >>> 1316 >>> for args in sample_inputs: 1317 >>> torch.library.opcheck(foo, args) 1318 1319 """ 1320 import torch.testing._internal.optests as optests 1321 1322 return optests.opcheck( 1323 op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception 1324 ) 1325