1# mypy: ignore-errors 2 3import torch 4from copy import deepcopy 5from torch.utils._pytree import tree_map 6import torch.utils._pytree as pytree 7 8 9# TODO: Move LoggingTensor here. 10from torch.testing._internal.logging_tensor import LoggingTensor 11 12 13# Base class for wrapper-style tensors. 14class WrapperTensor(torch.Tensor): 15 @staticmethod 16 def __new__(cls, *args, **kwargs): 17 t, kwargs = cls.get_wrapper_properties(*args, **kwargs) 18 if "size" not in kwargs: 19 size = t.size() 20 else: 21 size = kwargs["size"] 22 del kwargs["size"] 23 if "dtype" not in kwargs: 24 kwargs["dtype"] = t.dtype 25 if "layout" not in kwargs: 26 kwargs["layout"] = t.layout 27 if "device" not in kwargs: 28 kwargs["device"] = t.device 29 if "requires_grad" not in kwargs: 30 kwargs["requires_grad"] = False 31 # Ignore memory_format and pin memory for now as I don't know how to 32 # safely access them on a Tensor (if possible??) 33 34 wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) 35 wrapper._validate_methods() 36 return wrapper 37 38 @classmethod 39 def get_wrapper_properties(cls, *args, **kwargs): 40 # Should return both an example Tensor and a dictionary of kwargs 41 # to override any of that example Tensor's properly. 42 # This is very similar to the `t.new_*(args)` API 43 raise NotImplementedError("You need to implement get_wrapper_properties") 44 45 def _validate_methods(self): 46 # Skip this if not in debug mode? 47 # Changing these on the python side is wrong as it would not be properly reflected 48 # on the c++ side 49 # This doesn't catch attributes set in the __init__ 50 forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"] 51 for el in forbidden_overrides: 52 if getattr(self.__class__, el) is not getattr(torch.Tensor, el): 53 raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the " 54 f"property {el} but this is not allowed as such change would " 55 "not be reflected to c++ callers.") 56 57 58class DiagTensorBelow(WrapperTensor): 59 @classmethod 60 def get_wrapper_properties(cls, diag, requires_grad=False): 61 assert diag.ndim == 1 62 return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad} 63 64 def __init__(self, diag, requires_grad=False): 65 self.diag = diag 66 67 handled_ops = {} 68 69 @classmethod 70 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 71 if not all(issubclass(cls, t) for t in types): 72 return NotImplemented 73 74 # For everything else, call the handler: 75 fn = cls.handled_ops.get(func.__name__, None) 76 if fn: 77 return fn(*args, **(kwargs or {})) 78 else: 79 # Note that here, because we don't need to provide the autograd formulas 80 # we can have a default "fallback" that creates a plain Tensor based 81 # on the diag elements and calls the func again. 82 83 def unwrap(e): 84 return e.diag.diag() if isinstance(e, DiagTensorBelow) else e 85 86 def wrap(e): 87 if isinstance(e, torch.Tensor) and e.ndim == 1: 88 return DiagTensorBelow(e) 89 if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero(): 90 return DiagTensorBelow(e.diag()) 91 return e 92 93 rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) 94 return rs 95 96 def __repr__(self): 97 return super().__repr__(tensor_contents=f"diag={self.diag}") 98 99 100class SparseTensor(WrapperTensor): 101 @classmethod 102 def get_wrapper_properties(cls, size, values, indices, requires_grad=False): 103 assert values.device == indices.device 104 return values, {"size": size, "requires_grad": requires_grad} 105 106 def __init__(self, size, values, indices, requires_grad=False): 107 self.values = values 108 self.indices = indices 109 110 def __repr__(self): 111 return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}") 112 113 def sparse_to_dense(self): 114 res = torch.zeros(self.size(), dtype=self.values.dtype) 115 res[self.indices.unbind(1)] = self.values 116 return res 117 118 @staticmethod 119 def from_dense(t): 120 indices = t.nonzero() 121 values = t[indices.unbind(1)] 122 return SparseTensor(t.size(), values, indices) 123 124 @classmethod 125 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 126 func_name = f"{func.__module__}.{func.__name__}" 127 128 res = cls._try_call_special_impl(func_name, args, kwargs) 129 if res is not NotImplemented: 130 return res 131 132 # Otherwise, use a default implementation that construct dense 133 # tensors and use that to compute values 134 def unwrap(e): 135 return e.sparse_to_dense() if isinstance(e, SparseTensor) else e 136 137 # Wrap back all Tensors into our custom class 138 def wrap(e): 139 # Check for zeros and use that to get indices 140 return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e 141 142 rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) 143 return rs 144 145 # To show how things happen later 146 def __rmul__(self, other): 147 return super().__rmul__(other) 148 149 _SPECIAL_IMPLS = {} 150 151 @classmethod 152 def _try_call_special_impl(cls, func, args, kwargs): 153 if func not in cls._SPECIAL_IMPLS: 154 return NotImplemented 155 return cls._SPECIAL_IMPLS[func](args, kwargs) 156 157 158# Example non-wrapper subclass that stores extra state. 159class NonWrapperTensor(torch.Tensor): 160 def __new__(cls, data): 161 t = torch.Tensor._make_subclass(cls, data) 162 t.extra_state = { 163 'last_func_called': None 164 } 165 return t 166 167 @classmethod 168 def __torch_function__(cls, func, types, args=(), kwargs=None): 169 result = super().__torch_function__(func, types, args, kwargs) 170 171 if isinstance(result, cls): 172 # Do something with the extra state. For the example here, just store the name of the 173 # last function called (skip for deepcopy so the copy has the same extra state). 174 if func is torch.Tensor.__deepcopy__: 175 result.extra_state = deepcopy(args[0].extra_state) 176 else: 177 result.extra_state = { 178 'last_func_called': func.__name__, 179 } 180 181 return result 182 183 # new_empty() must be defined for deepcopy to work 184 def new_empty(self, shape): 185 return type(self)(torch.empty(shape)) 186 187 188# Class used to store info about subclass tensors used in testing. 189class SubclassInfo: 190 191 __slots__ = ['name', 'create_fn', 'closed_under_ops'] 192 193 def __init__(self, name, create_fn, closed_under_ops=True): 194 self.name = name 195 self.create_fn = create_fn # create_fn(shape) -> tensor instance 196 self.closed_under_ops = closed_under_ops 197 198 199subclass_db = { 200 torch.Tensor: SubclassInfo( 201 'base_tensor', create_fn=torch.randn 202 ), 203 NonWrapperTensor: SubclassInfo( 204 'non_wrapper_tensor', 205 create_fn=lambda shape: NonWrapperTensor(torch.randn(shape)) 206 ), 207 LoggingTensor: SubclassInfo( 208 'logging_tensor', 209 create_fn=lambda shape: LoggingTensor(torch.randn(shape)) 210 ), 211 SparseTensor: SubclassInfo( 212 'sparse_tensor', 213 create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu()) 214 ), 215 DiagTensorBelow: SubclassInfo( 216 'diag_tensor_below', 217 create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)), 218 closed_under_ops=False # sparse semantics 219 ), 220} 221 222class SubclassWithTensorFactory(torch.Tensor): 223 @staticmethod 224 def __new__(cls, src): 225 shape = src.shape 226 kwargs = {} 227 kwargs["strides"] = src.stride() 228 kwargs["storage_offset"] = src.storage_offset() 229 kwargs["device"] = src.device 230 kwargs["layout"] = src.layout 231 kwargs["requires_grad"] = src.requires_grad 232 kwargs["dtype"] = src.dtype 233 out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 234 return out 235 236 def __init__(self, src): 237 self.src = src 238 239 def __repr__(self): 240 return f"{self.__class__.__name__}" 241 242 def __tensor_flatten__(self): 243 return ["src"], None 244 245 @classmethod 246 def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride): 247 src = inner_tensors["src"] 248 return cls(src) 249 250 @classmethod 251 def __torch_dispatch__(cls, func, types, args, kwargs): 252 if kwargs is None: 253 kwargs = {} 254 255 def _fn(x): 256 return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src 257 258 _args = pytree.tree_map_only(cls, _fn, args) 259 _kwargs = pytree.tree_map_only(cls, _fn, kwargs) 260 261 _out = func(*_args, **_kwargs) 262 263 _out_flat, _out_spec = pytree.tree_flatten(_out) 264 265 out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat] 266 return pytree.tree_unflatten(out_flat, _out_spec) 267