1# mypy: allow-untyped-defs 2import functools 3from contextlib import nullcontext 4from typing import Any, Callable, Dict, Optional, Sequence 5 6import torch 7import torch._decomp 8import torch._prims 9import torch._refs 10import torch._refs.nn 11import torch._refs.nn.functional 12import torch._refs.special 13import torch.overrides 14from torch._prims_common import torch_function_passthrough 15 16 17@functools.lru_cache(None) 18def torch_to_refs_map(): 19 """ 20 Mapping of torch API functions to torch._refs functions. 21 E.g. torch_to_refs_map()[torch.add] == torch._refs.add 22 """ 23 modules = [ 24 (torch, torch._refs), 25 (torch.nn, torch._refs.nn), 26 (torch.nn.functional, torch._refs.nn.functional), 27 (torch.special, torch._refs.special), 28 (torch.fft, torch._refs.fft), 29 (torch.linalg, torch._refs.linalg), 30 ] 31 r: Dict[Any, Any] = { 32 torch.Tensor.__invert__: torch._refs.bitwise_not, 33 torch.Tensor.__xor__: torch._refs.bitwise_xor, 34 torch.Tensor.__and__: torch._refs.bitwise_and, 35 torch.Tensor.__or__: torch._refs.bitwise_or, 36 torch.Tensor.__eq__: torch._refs.eq, 37 torch.Tensor.__rsub__: torch._refs.rsub, 38 torch.Tensor.__rtruediv__: torch._refs.rtruediv, 39 torch.Tensor.__floordiv__: torch._refs.floor_divide, 40 torch.Tensor.__rfloordiv__: torch._refs.rfloordiv, 41 torch.Tensor.__pow__: torch._refs.pow, 42 torch.Tensor.__rpow__: torch._refs.rpow, 43 torch.Tensor.new_empty: torch._refs.new_empty, 44 torch.Tensor.new_full: torch._refs.new_full, 45 torch.Tensor.new_zeros: torch._refs.new_zeros, 46 torch.Tensor.new_ones: torch._refs.new_ones, 47 torch.Tensor.fill_: torch._refs.fill_, 48 torch.Tensor.zero_: torch._refs.zero_, 49 torch.Tensor.to: torch._refs.to, 50 torch.Tensor.sum_to_size: torch._refs.sum_to_size, 51 # TODO: Should these methods be mapped some other way? 52 torch.Tensor.copy_: torch._prims.copy_to, 53 torch.Tensor.resize: torch._prims.resize, 54 } 55 for mod_torch, mod_refs in modules: 56 for s in mod_refs.__all__: # type: ignore[attr-defined] 57 r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) 58 59 # Support remapping torch.Tensor.foo to _refs.foo 60 for s in dir(torch.Tensor): 61 if s in torch._refs.__all__: 62 r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s) 63 64 # Support conversions 65 for s in torch._refs._conversions.__all__: 66 tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) 67 r[tensor_attr] = torch._refs._conversions.__dict__.get(s) 68 69 return r 70 71 72@functools.lru_cache(None) 73def all_prims(): 74 """ 75 Set of all prim functions, e.g., torch._prims.add in all_prims() 76 """ 77 return {torch._prims.__dict__.get(s) for s in torch._prims.__all__} 78 79 80class TorchRefsMode(torch.overrides.TorchFunctionMode): 81 """ 82 Switches the interpretation of torch.* functions and Tensor methods to 83 use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.) 84 85 >>> # xdoctest: +SKIP 86 >>> with TorchRefsMode(): 87 ... torch.add(x, y) # calls torch._refs.add(x, y) 88 89 By default, this context manager will fall back on the torch.* if the 90 ref does not exist; set strict=True to error if this occurs. 91 If the ref exists we still would like to fall back on the torch.* sometimes, 92 this behavior can be customized by passing a function to should_fallback_fn. 93 """ 94 95 def __init__( 96 self, 97 strict=False, 98 should_fallback_fn=lambda *_: False, 99 prims_mode_cls=nullcontext, 100 ): 101 self.strict = strict 102 self.should_fallback_fn = should_fallback_fn 103 self.prims_mode_cls = prims_mode_cls 104 105 def __torch_function__( 106 self, 107 orig_func: Callable, 108 types: Sequence, 109 args: Sequence[Any] = (), 110 kwargs: Optional[Dict] = None, 111 ): 112 if kwargs is None: 113 kwargs = {} 114 # For primitive operations, run them as is without interception 115 # Unless we are in prims_mode, in which case we want to use nvprims 116 if orig_func in torch_function_passthrough or orig_func in all_prims(): 117 with self.prims_mode_cls(): 118 return orig_func(*args, **kwargs) 119 mapping = torch_to_refs_map() 120 func = mapping.get(orig_func, None) 121 122 # For torch.ops.aten.*, use registered decompositions from torch._decomp 123 # torch._decomp.decomposition_table provides a mapping from 124 # torch.ops.aten.* to torch._refs or torch._decomp.decompositions 125 # implementations. 126 # There're other ways to implement this functionality, 127 # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417 128 if func is None and isinstance(orig_func, torch._ops.OpOverload): 129 func = torch._decomp.decomposition_table.get(orig_func, None) 130 elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket): 131 default = getattr(orig_func, "default", None) 132 if default is not None: 133 func = torch._decomp.decomposition_table.get(default, None) 134 135 if func is not None: 136 # If the ref exists query whether we should use it or not 137 if self.should_fallback_fn(self, orig_func, func, args, kwargs): 138 return orig_func(*args, **kwargs) 139 # torch calls inside func should be interpreted as refs calls 140 with self: 141 return func(*args, **kwargs) 142 if self.strict: 143 raise RuntimeError( 144 f"no _refs support for {torch.overrides.resolve_name(orig_func)}" 145 ) 146 return orig_func(*args, **kwargs) 147