1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import inspect 4import logging 5import weakref 6from contextlib import contextmanager 7from typing import ( 8 Any, 9 Callable, 10 Dict, 11 Iterable, 12 Iterator, 13 List, 14 Optional, 15 Sequence, 16 Set, 17 Tuple, 18 Union, 19) 20 21import torch 22from torch import _C, _ops, Tensor 23from torch.utils._exposed_in import exposed_in 24 25from . import autograd, utils 26 27 28device_types_t = Optional[Union[str, Sequence[str]]] 29log = logging.getLogger(__name__) 30 31 32@exposed_in("torch.library") 33def custom_op( 34 name: str, 35 fn: Optional[Callable] = None, 36 /, 37 *, 38 mutates_args: Union[str, Iterable[str]], 39 device_types: device_types_t = None, 40 schema: Optional[str] = None, 41) -> Callable: 42 """Wraps a function into custom operator. 43 44 Reasons why you may want to create a custom op include: 45 - Wrapping a third-party library or custom kernel to work with PyTorch 46 subsystems like Autograd. 47 - Preventing torch.compile/export/FX tracing from peeking inside your function. 48 49 This API is used as a decorator around a function (please see examples). 50 The provided function must have type hints; these are needed to interface 51 with PyTorch's various subsystems. 52 53 Args: 54 name (str): A name for the custom op that looks like "{namespace}::{name}", 55 e.g. "mylib::my_linear". The name is used as the op's stable identifier 56 in PyTorch subsystems (e.g. torch.export, FX graphs). 57 To avoid name collisions, please use your project name as the namespace; 58 e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. 59 mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. 60 This MUST be accurate, otherwise, the behavior is undefined. If "unknown", 61 it pessimistically assumes that all inputs to the operator are being mutated. 62 device_types (None | str | Sequence[str]): The device type(s) the function 63 is valid for. If no device type is provided, then the function 64 is used as the default implementation for all device types. 65 Examples: "cpu", "cuda". 66 When registering a device-specific implementation for an operator that accepts no Tensors, 67 we require the operator to have a "device: torch.device argument". 68 schema (None | str): A schema string for the operator. If None 69 (recommended) we'll infer a schema for the operator from its type 70 annotations. We recommend letting us infer a schema unless you 71 have a specific reason not to. 72 Example: "(Tensor x, int y) -> (Tensor, Tensor)". 73 74 .. note:: 75 We recommend not passing in a ``schema`` arg and instead letting us infer 76 it from the type annotations. It is error-prone to write your own schema. 77 You may wish to provide your own schema if our interpretation of 78 the type annotation is not what you want. 79 For more info on how to write a schema string, see 80 `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_ 81 82 Examples:: 83 >>> import torch 84 >>> from torch import Tensor 85 >>> from torch.library import custom_op 86 >>> import numpy as np 87 >>> 88 >>> @custom_op("mylib::numpy_sin", mutates_args=()) 89 >>> def numpy_sin(x: Tensor) -> Tensor: 90 >>> x_np = x.cpu().numpy() 91 >>> y_np = np.sin(x_np) 92 >>> return torch.from_numpy(y_np).to(device=x.device) 93 >>> 94 >>> x = torch.randn(3) 95 >>> y = numpy_sin(x) 96 >>> assert torch.allclose(y, x.sin()) 97 >>> 98 >>> # Example of a custom op that only works for one device type. 99 >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") 100 >>> def numpy_sin_cpu(x: Tensor) -> Tensor: 101 >>> x_np = x.numpy() 102 >>> y_np = np.sin(x_np) 103 >>> return torch.from_numpy(y_np) 104 >>> 105 >>> x = torch.randn(3) 106 >>> y = numpy_sin_cpu(x) 107 >>> assert torch.allclose(y, x.sin()) 108 >>> 109 >>> # Example of a custom op that mutates an input 110 >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") 111 >>> def numpy_sin_inplace(x: Tensor) -> None: 112 >>> x_np = x.numpy() 113 >>> np.sin(x_np, out=x_np) 114 >>> 115 >>> x = torch.randn(3) 116 >>> expected = x.sin() 117 >>> numpy_sin_inplace(x) 118 >>> assert torch.allclose(x, expected) 119 >>> 120 >>> # Example of a factory function 121 >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") 122 >>> def bar(device: torch.device) -> Tensor: 123 >>> return torch.ones(3) 124 >>> 125 >>> bar("cpu") 126 127 """ 128 129 def inner(fn): 130 import torch 131 132 if schema is None: 133 schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) 134 else: 135 schema_str = schema 136 137 namespace, opname = name.split("::") 138 result = CustomOpDef(namespace, opname, schema_str, fn) 139 if schema is not None: 140 # Check that schema's alias annotations match those of `mutates_args`. 141 expected = set() 142 for arg in result._opoverload._schema.arguments: 143 if arg.alias_info is not None and arg.alias_info.is_write: 144 expected.add(arg.name) 145 if expected != set(mutates_args): 146 raise ValueError( 147 f"Attempted to create a custom op with `mutates_args={mutates_args}` " 148 f"and `schema={schema}. The schema suggests that the op mutates {expected}" 149 f"which is different from what was provided to us in `mutates_args`. " 150 f"Please make these consistent." 151 ) 152 result.register_kernel(device_types)(fn) 153 return result 154 155 if fn is None: 156 return inner 157 return inner(fn) 158 159 160class CustomOpDef: 161 """CustomOpDef is a wrapper around a function that turns it into a custom op. 162 163 It has various methods for registering additional behavior for this 164 custom op. 165 166 You should not instantiate CustomOpDef directly; instead, use the 167 :func:`torch.library.custom_op` API. 168 """ 169 170 def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None: 171 # Fields used to interface with the PyTorch dispatcher 172 self._namespace = namespace 173 self._name = name 174 self._schema = schema 175 176 self._init_fn = fn 177 178 self._backend_fns: Dict[Union[str, None], Callable] = {} 179 self._abstract_fn: Optional[Callable] = None 180 self._setup_context_fn: Optional[Callable] = None 181 self._backward_fn: Optional[Callable] = None 182 self._torch_dispatch_fns: Dict[type, Callable] = {} 183 self._vmap_fn: Optional[Callable] = None 184 185 self._lib = get_library_allowing_overwrite(self._namespace, self._name) 186 self._register_to_dispatcher() 187 self._disabled_kernel: Set = set() 188 OPDEFS[self._qualname] = self 189 190 @property 191 def _qualname(self) -> str: 192 return f"{self._namespace}::{self._name}" 193 194 def __repr__(self) -> str: 195 return f"<CustomOpDef({self._qualname})>" 196 197 @contextmanager 198 def set_kernel_enabled(self, device_type: str, enabled: bool = True): 199 """ 200 Disable or re-enable an already registered kernel for this custom operator. 201 202 If the kernel is already disabled/enabled, this is a no-op. 203 204 Note: 205 If a kernel is first disabled and then registered, it is disabled until enabled again. 206 207 Args: 208 device_type (str): The device type to disable/enable the kernel for. 209 disable (bool): Whether to disable or enable the kernel. 210 211 Example: 212 >>> inp = torch.randn(1) 213 >>> 214 >>> # define custom op `f`. 215 >>> @custom_op("mylib::f", mutates_args=()) 216 >>> def f(x: Tensor) -> Tensor: 217 >>> return torch.zeros(1) 218 >>> 219 >>> print(f(inp)) # tensor([0.]), default kernel 220 >>> 221 >>> @f.register_kernel("cpu") 222 >>> def _(x): 223 >>> return torch.ones(1) 224 >>> 225 >>> print(f(inp)) # tensor([1.]), CPU kernel 226 >>> 227 >>> # temporarily disable the CPU kernel 228 >>> with f.set_kernel_enabled("cpu", enabled = False): 229 >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled 230 231 """ 232 action = "enable" if enabled else "disable" 233 originally_disabled = device_type in self._disabled_kernel 234 if device_type not in self._backend_fns: 235 log.warning( 236 "Attempted to %s kernel for %s but no kernel was registered for this device type.", 237 action, 238 device_type, 239 ) 240 241 if not enabled: 242 if originally_disabled: 243 log.warning( 244 "Attempted to disable kernel for %s but it was already disabled.", 245 device_type, 246 ) 247 else: 248 self._disabled_kernel.add(device_type) 249 else: # enable the kernel 250 if not originally_disabled: 251 log.warning( 252 "Attempted to enable kernel for %s but it was already enabled.", 253 device_type, 254 ) 255 else: 256 self._disabled_kernel.remove(device_type) 257 258 try: 259 yield 260 finally: 261 # restore original state 262 if originally_disabled: 263 self._disabled_kernel.add(device_type) 264 else: 265 self._disabled_kernel.discard(device_type) 266 267 def register_kernel( 268 self, device_types: device_types_t, fn: Optional[Callable] = None, / 269 ) -> Callable: 270 """Register an implementation for a device type for this operator. 271 272 Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". 273 This API may be used as a decorator. 274 275 Args: 276 fn (Callable): The function to register as the implementation for 277 the given device types. 278 device_types (str | Sequence[str]): The device device_types to register an impl to. 279 280 Examples:: 281 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 282 >>> import torch 283 >>> from torch import Tensor 284 >>> from torch.library import custom_op 285 >>> import numpy as np 286 >>> 287 >>> # Create a custom op that works on cpu 288 >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") 289 >>> def numpy_sin(x: Tensor) -> Tensor: 290 >>> x_np = x.numpy() 291 >>> y_np = np.sin(x_np) 292 >>> return torch.from_numpy(y_np) 293 >>> 294 >>> # Add implementations for the cuda device 295 >>> @numpy_sin.register_kernel("cuda") 296 >>> def _(x): 297 >>> x_np = x.cpu().numpy() 298 >>> y_np = np.sin(x_np) 299 >>> return torch.from_numpy(y_np).to(device=x.device) 300 >>> 301 >>> x_cpu = torch.randn(3) 302 >>> x_cuda = x_cpu.cuda() 303 >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) 304 >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) 305 306 """ 307 308 def inner(fn): 309 if device_types is None or isinstance(device_types, str): 310 dtypes: List[Union[str, None]] = [device_types] 311 else: 312 dtypes = list(device_types) 313 for device_type in dtypes: 314 if device_type not in self._backend_fns: 315 316 def backend_impl(*args, **kwargs): 317 # Checks the assumption that outputs cannot alias 318 # inputs or other outputs. 319 storages = { 320 id(tensor.untyped_storage()) 321 for tensor in iter_tensors(args, kwargs) 322 } 323 324 result = self._backend_fns[device_type](*args, **kwargs) 325 326 tuple_result = result 327 if not isinstance(result, tuple): 328 tuple_result = (result,) 329 for tensor in iter_tensors(tuple_result, {}): 330 key = id(tensor.untyped_storage()) 331 if id(tensor.untyped_storage()) in storages: 332 fn = self._backend_fns[device_type] 333 module = inspect.getmodule(fn) 334 raise RuntimeError( 335 f"{self._name} (with implementation in {module}): " 336 f"The output of this custom operator (1) must not " 337 f"also be an input to this custom operator and " 338 f"(2) may not alias any inputs to this custom operator " 339 f"or other returns. " 340 f"The most common way to trigger this error is if " 341 f"we have y = custom_op(x) and y and x are the same Tensor. " 342 f"Please instead return a clone of the offending output " 343 f"tensor(s) (e.g. return x.clone()) or refactor the custom " 344 f"operator to not return y." 345 ) 346 storages.add(key) 347 return result 348 349 if device_type is None: 350 self._lib.impl( 351 self._name, backend_impl, "CompositeExplicitAutograd" 352 ) 353 else: 354 self._lib.impl( 355 self._name, 356 backend_impl, 357 _C._dispatch_key_for_device(device_type), 358 ) 359 360 # Wrap function to choose between the default implementation or the device-specific 361 # implementation depending on if the kernel is disabled. 362 @torch._disable_dynamo 363 def wrapped_fn(*args, **kwargs): 364 if device_type in self._disabled_kernel: 365 return self._init_fn(*args, **kwargs) 366 else: 367 return fn(*args, **kwargs) 368 369 self._backend_fns[device_type] = wrapped_fn 370 return fn 371 372 if device_types is not None and not utils.has_tensor_arg( 373 self._opoverload._schema 374 ): 375 device_arg_index = utils.get_device_arg_index(self._opoverload._schema) 376 if device_arg_index is None: 377 raise ValueError( 378 "Functions without tensor inputs are required to have a `device: torch.device` argument" 379 ) 380 self._register_backend_select_dispatcher(device_arg_index) 381 382 # See NOTE: [Supporting decorator and non-decorator usage] 383 if fn is None: 384 return inner 385 return inner(fn) 386 387 def register_fake(self, fn: Callable, /) -> Callable: 388 r"""Register a FakeTensor implementation for this custom op. 389 390 This is necessary to get the operator to work efficiently with torch.compile. 391 392 The Fake impl (sometimes also known as a meta kernel or abstract impl) 393 specifies the behavior of this operator on Tensors that carry no data. 394 Given some input Tensors with certain properties 395 (sizes/strides/storage_offset/device), it specifies what the properties of 396 the output Tensors are. 397 398 Please see :func:`torch.library.impl_abstract` for more details. 399 400 Args: 401 fn (Callable): The function to register as the FakeTensor 402 implementation. 403 404 Examples: 405 >>> import torch 406 >>> import numpy as np 407 >>> from torch import Tensor 408 >>> 409 >>> # Example 1: an operator without data-dependent output shape 410 >>> @torch.library.custom_op("mylib::linear", mutates_args=()) 411 >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 412 >>> return (x @ weight.t()) + bias 413 >>> 414 >>> @linear.register_fake 415 >>> def _(x, weight, bias): 416 >>> assert x.dim() == 2 417 >>> assert weight.dim() == 2 418 >>> assert bias.dim() == 1 419 >>> assert x.shape[1] == weight.shape[1] 420 >>> assert weight.shape[0] == bias.shape[0] 421 >>> assert x.device == weight.device 422 >>> return x.new_empty(x.size(0), weight.size(0)) 423 >>> 424 >>> x = torch.randn(2, 2) 425 >>> weight = torch.randn(2, 2) 426 >>> bias = torch.randn(2) 427 >>> # xdoctest: +SKIP("Requires Python <= 3.11") 428 >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) 429 >>> # xdoctest: +SKIP("Requires Python <= 3.11") 430 >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) 431 >>> 432 >>> # Example 2: an operator with data-dependent output shape 433 >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) 434 >>> def nonzero(x: Tensor) -> Tensor: 435 >>> x_np = x.cpu().numpy() 436 >>> res = np.stack(np.nonzero(x_np), axis=1) 437 >>> return torch.tensor(res, device=x.device) 438 >>> 439 >>> @nonzero.register_fake 440 >>> def _(x): 441 >>> # Number of nonzero-elements is data-dependent. 442 >>> # Since we cannot peek at the data in an abstract impl, 443 >>> # we use the ctx object to construct a new symint that 444 >>> # represents the data-dependent size. 445 >>> ctx = torch.library.get_ctx() 446 >>> nnz = ctx.new_dynamic_size() 447 >>> shape = [nnz, x.dim()] 448 >>> result = x.new_empty(shape, dtype=torch.int64) 449 >>> return result 450 >>> 451 >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) 452 >>> # xdoctest: +SKIP("Requires Python <= 3.11") 453 >>> out = torch.compile(nonzero, fullgraph=True)(x) 454 >>> # xdoctest: +SKIP("Requires Python <= 3.11") 455 >>> assert torch.allclose(out, x.nonzero()) 456 457 """ 458 self._abstract_fn = fn 459 return fn 460 461 def register_torch_dispatch( 462 self, torch_dispatch_class: Any, fn: Optional[Callable] = None, / 463 ) -> Callable: 464 r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. 465 466 This allows for open registration to specify the behavior between the operator 467 and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` 468 or the operator directly. 469 470 Please see :func:`torch.library.register_torch_dispatch` for examples and more details. 471 """ 472 473 def register(fn): 474 if torch_dispatch_class not in self._torch_dispatch_fns: 475 476 def inner(*args, **kwargs): 477 return self._torch_dispatch_fns[torch_dispatch_class]( 478 *args, **kwargs 479 ) 480 481 self._lib._register_torch_dispatch_rule( 482 self._name, torch_dispatch_class, inner 483 ) 484 self._torch_dispatch_fns[torch_dispatch_class] = fn 485 return fn 486 487 if fn is None: 488 return register 489 else: 490 return register(fn) 491 492 def register_autograd( 493 self, 494 backward: Callable, 495 /, 496 *, 497 setup_context: Optional[Callable] = None, 498 ) -> None: 499 r"""Register a backward formula for this custom op. 500 501 In order for an operator to work with autograd, you need to register 502 a backward formula: 503 1. You must tell us how to compute gradients during the backward pass 504 by providing us a "backward" function. 505 2. If you need any values from the forward to compute gradients, you can 506 use `setup_context` to save values for backward. 507 508 ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``: 509 - ``grads`` is one or more gradients. The number of gradients matches 510 the number of outputs of the operator. 511 The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by 512 :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the 513 same as :meth:`torch.autograd.Function.backward`. 514 515 ``setup_context(ctx, inputs, output)`` runs during the forward pass. 516 Please save quantities needed for backward onto the ``ctx`` object via 517 either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` 518 or assigning them as attributes of ``ctx``. If your custom op has 519 kwarg-only arguments, we expect the signature of ``setup_context`` 520 to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. 521 522 Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, 523 they may not directly access :meth:`torch.Tensor.data_ptr` and they must 524 not depend on or mutate global state. If you need a non-traceable backward, 525 you can make it a separate custom_op that you call inside ``backward_fn``. 526 527 Examples: 528 >>> import torch 529 >>> import numpy as np 530 >>> from torch import Tensor 531 >>> 532 >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) 533 >>> def numpy_sin(x: Tensor) -> Tensor: 534 >>> x_np = x.cpu().numpy() 535 >>> y_np = np.sin(x_np) 536 >>> return torch.from_numpy(y_np).to(device=x.device) 537 >>> 538 >>> def setup_context(ctx, inputs, output) -> Tensor: 539 >>> x, = inputs 540 >>> ctx.save_for_backward(x) 541 >>> 542 >>> def backward(ctx, grad): 543 >>> x, = ctx.saved_tensors 544 >>> return grad * x.cos() 545 >>> 546 >>> numpy_sin.register_autograd(backward, setup_context=setup_context) 547 >>> 548 >>> x = torch.randn(3, requires_grad=True) 549 >>> y = numpy_sin(x) 550 >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) 551 >>> assert torch.allclose(grad_x, x.cos()) 552 >>> 553 >>> # Example with a keyword-only arg 554 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 555 >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: 556 >>> x_np = x.cpu().numpy() 557 >>> y_np = x_np * val 558 >>> return torch.from_numpy(y_np).to(device=x.device) 559 >>> 560 >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: 561 >>> ctx.val = keyword_only_inputs["val"] 562 >>> 563 >>> def backward(ctx, grad): 564 >>> return grad * ctx.val 565 >>> 566 >>> numpy_mul.register_autograd(backward, setup_context=setup_context) 567 >>> 568 >>> x = torch.randn(3, requires_grad=True) 569 >>> y = numpy_mul(x, val=3.14) 570 >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) 571 >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) 572 573 """ 574 schema = self._opoverload._schema 575 if not utils.is_functional_schema(schema): 576 raise RuntimeError( 577 f"Cannot register autograd formula for non-functional operator " 578 f"{self} with schema {schema}. Please create " 579 f"a functional operator and register an autograd formula for that." 580 ) 581 582 self._backward_fn = backward 583 self._setup_context_fn = setup_context 584 585 def _register_to_dispatcher(self) -> None: 586 lib = self._lib 587 schema_str = self._name + self._schema 588 cpp_schema = _C.parse_schema(schema_str) 589 if utils.has_kwarg_only_tensors(cpp_schema): 590 # If you want to support this, the progression is: 591 # - supporting kwarg-only Tensors that are non-differentiable 592 # - supporting kwarg-only Tensors (regardless of differentiability) 593 raise NotImplementedError( 594 f"custom_op with kwarg-only Tensor args. Please make your " 595 f"tensors not kwarg-only. Got: {schema_str}" 596 ) 597 598 lib.define( 599 schema_str, 600 tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order], 601 ) 602 self._opoverload = utils.lookup_op(self._qualname) 603 604 def fake_impl(*args, **kwargs): 605 if self._abstract_fn is None: 606 if utils.can_generate_trivial_fake_impl(self._opoverload): 607 return None 608 raise RuntimeError( 609 f"There was no fake impl registered for {self}. " 610 f"This is necessary for torch.compile/export/fx tracing to work. " 611 f"Please use `{self._init_fn.__name__}.register_fake` to add an " 612 f"fake impl." 613 ) 614 return self._abstract_fn(*args, **kwargs) 615 616 lib._register_fake(self._name, fake_impl, _stacklevel=4) 617 618 autograd_impl = autograd.make_autograd_impl(self._opoverload, self) 619 lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) 620 621 schema = self._opoverload._schema 622 if schema.is_mutable: 623 624 def adinplaceorview_impl(keyset, *args, **kwargs): 625 for arg, val in utils.zip_schema(schema, args, kwargs): 626 if not arg.alias_info: 627 continue 628 if not arg.alias_info.is_write: 629 continue 630 if isinstance(val, Tensor): 631 torch.autograd.graph.increment_version(val) 632 elif isinstance(val, (tuple, list)): 633 for v in val: 634 if isinstance(v, Tensor): 635 torch.autograd.graph.increment_version(v) 636 with _C._AutoDispatchBelowADInplaceOrView(): 637 return self._opoverload.redispatch( 638 keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs 639 ) 640 641 lib.impl( 642 self._name, 643 adinplaceorview_impl, 644 "ADInplaceOrView", 645 with_keyset=True, 646 ) 647 648 def _register_backend_select_dispatcher(self, device_arg_index: int): 649 """ 650 Switch on the device argument to select the correct backend to dispatch to. 651 """ 652 653 def backend_select(keyset, *args, **kwargs): 654 device = args[device_arg_index].type 655 if device not in self._backend_fns: 656 raise RuntimeError( 657 f"{self._name} does not have a kernel registered for {device}. " 658 "Please use register_kernel to do so." 659 ) 660 dispatch_key = _C._dispatch_key_for_device(device) 661 dispatch_key = getattr(_C.DispatchKey, dispatch_key) 662 return self._opoverload.redispatch( 663 _C.DispatchKeySet(dispatch_key), *args, **kwargs 664 ) 665 666 self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True) 667 668 def __call__(self, *args, **kwargs): 669 return self._opoverload(*args, **kwargs) 670 671 def register_vmap( 672 self, 673 func: Optional[Callable] = None, 674 ): 675 r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. 676 677 This API may be used as a decorator. 678 679 In order for an operator to work with :func:`torch.vmap`, you may need to register a 680 vmap implementation in the following signature: 681 682 ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, 683 684 where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. 685 686 It specifies how do we compute the batched version of ``op`` given inputs with an additional 687 dimension (specified by ``in_dims``). 688 689 For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` 690 if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer 691 specifying what dimension of the Tensor is being vmapped over. 692 693 ``info`` is a collection of additional metadata that may be helpful: 694 ``info.batch_size`` specifies the size of the dimension being vmapped over, while 695 ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. 696 697 The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, 698 ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` 699 per output that specifies if the output has the vmapped dimension and what index it is in. 700 701 Examples: 702 >>> import torch 703 >>> import numpy as np 704 >>> from torch import Tensor 705 >>> from typing import Tuple 706 >>> 707 >>> def to_numpy(tensor): 708 >>> return tensor.cpu().numpy() 709 >>> 710 >>> lib = torch.library.Library("mylib", "FRAGMENT") 711 >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) 712 >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: 713 >>> x_np = to_numpy(x) 714 >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) 715 >>> return torch.tensor(x_np ** 3, device=x.device), dx 716 >>> 717 >>> def numpy_cube_vmap(info, in_dims, x): 718 >>> result = numpy_cube(x) 719 >>> return result, (in_dims[0], in_dims[0]) 720 >>> 721 >>> numpy_cube.register_vmap(numpy_cube_vmap) 722 >>> 723 >>> x = torch.randn(3) 724 >>> torch.vmap(numpy_cube)(x) 725 >>> 726 >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 727 >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: 728 >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) 729 >>> 730 >>> @numpy_mul.register_vmap 731 >>> def numpy_mul_vmap(info, in_dims, x, y): 732 >>> x_bdim, y_bdim = in_dims 733 >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 734 >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 735 >>> result = x * y 736 >>> result = result.movedim(-1, 0) 737 >>> return result, 0 738 >>> 739 >>> 740 >>> x = torch.randn(3) 741 >>> y = torch.randn(3) 742 >>> torch.vmap(numpy_mul)(x, y) 743 """ 744 from torch._functorch.autograd_function import custom_function_call_vmap_helper 745 from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter 746 747 def register(func): 748 need_register = self._vmap_fn is None 749 self._vmap_fn = func 750 751 if need_register: 752 753 def wrapped_func(keyset, *args, **kwargs): 754 interpreter = retrieve_current_functorch_interpreter() 755 return custom_function_call_vmap_helper( 756 interpreter, self._vmap_fn, self._opoverload, *args, **kwargs 757 ) 758 759 self._lib.impl( 760 self._name, wrapped_func, "FuncTorchBatched", with_keyset=True 761 ) 762 763 if func is None: 764 return register 765 else: 766 return register(func) 767 768 769# NOTE: [Supporting decorator and non-decorator usage] 770# 771# Some APIs may be both used as a decorator and not as a decorator. 772# For example: 773# 774# >>> def fn(x): 775# >>> return x.sin() 776# >>> 777# >>> # Usage 1: not as a decorator 778# >>> numpy_sin.register_kernel("cuda", fn) 779# >>> 780# >>> # Usage 2: as a decorator 781# >>> @numpy_sin.register_kernel("cuda") 782# >>> def fn2(x): 783# >>> return x.sin 784# 785# The way we support this is that `register_kernel` accepts an optional `fn`. 786# If `fn` is provided (Usage 1), then we know that the user is using it not 787# as a decorator. 788# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a 789# decorator. 790 791 792OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {} 793OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 794 795 796def get_library_allowing_overwrite( 797 namespace: str, name: str 798) -> "torch.library.Library": 799 qualname = f"{namespace}::{name}" 800 801 if qualname in OPDEF_TO_LIB: 802 OPDEF_TO_LIB[qualname]._destroy() 803 del OPDEF_TO_LIB[qualname] 804 805 lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 806 OPDEF_TO_LIB[qualname] = lib 807 return lib 808 809 810def iter_tensors( 811 args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1 812) -> Iterator[Tensor]: 813 def check(arg): 814 if isinstance(arg, Tensor): 815 yield arg 816 elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): 817 yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) 818 819 for arg in args: 820 yield from check(arg) 821 for kwarg in kwargs.values(): 822 yield from check(kwarg) 823 824 825def _maybe_get_opdef( 826 op: Union[CustomOpDef, _ops.OpOverload, str] 827) -> Optional[CustomOpDef]: 828 if isinstance(op, CustomOpDef): 829 return op 830 if isinstance(op, _ops.OpOverload): 831 op = op._name 832 assert isinstance(op, str) 833 if op in OPDEFS: 834 return OPDEFS[op] 835 return None 836