1# mypy: allow-untyped-defs 2from typing import * # noqa: F403 3from typing import Tuple 4 5import torch 6from torch._C import DispatchKey, DispatchKeySet 7from torch._prims_common import is_expandable_to 8from torch.utils.weak import WeakTensorKeyDictionary 9 10 11_tensor_id_counter = 0 12_tensor_symint_registry = WeakTensorKeyDictionary() 13 14 15def get_tensor_symint(tensor, *, coeff=1): 16 from torch._subclasses.fake_tensor import FakeTensor 17 from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor 18 19 # NB: Only FakeTensor is associated with a memo 20 tensor = mb_unwrap_functional_tensor(tensor) 21 if isinstance(tensor, FakeTensor): 22 return tensor.get_nested_int(coeff=coeff) 23 24 global _tensor_id_counter 25 26 tensor_symint = _tensor_symint_registry.get(tensor) 27 if tensor_symint is None: 28 tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff) 29 _tensor_id_counter += 1 30 _tensor_symint_registry[tensor] = tensor_symint 31 return tensor_symint 32 33 34# SDPA metadata; max / min seqlens are needed for e.g. flash 35def _get_sdpa_extreme_seqlen(func, tensor): 36 return int(func(tensor).item()) 37 38 39def _store_val_in_tensor(val) -> torch.Tensor: 40 # hack to get dynamic shapes support: store in a (val, 0) shaped tensor 41 return torch.zeros(val, 0) 42 43 44def _load_val_from_tensor(t: torch.Tensor): 45 return t.shape[0] 46 47 48class NestedTensor(torch.Tensor): 49 _values: torch.Tensor # type: ignore[assignment] 50 _offsets: torch.Tensor 51 _lengths: Optional[torch.Tensor] 52 # NOTE [ Nested ints for ragged sizes and strides ] 53 # 54 # Jagged layout tensors are tensors that represent a n-dim tensor with a 55 # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g., 56 # a jagged tensor with outer shape [B, x, D] is represented internally by a 57 # tensor with shape [sum(x), D] where we introduce what we call a nested int 58 # denoted as "x" here (but sometimes denoted with "*" to 59 # represent the ragged dimension, and sum(x) represents the dim of the inner 60 # tensor or equivalently the sum of all the sizes of the constituent 61 # tensors' varying lengths. 62 # 63 # We also use nested ints to represent the strides of this tensor. 64 # For example, a jagged tensor with shape [B, x, D] can be strided in two 65 # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D 66 _size: Tuple[int, ...] 67 _strides: Tuple[int, ...] 68 # Indicates that the nth dimension is ragged 69 _ragged_idx: int 70 _metadata_cache: Dict[str, Any] 71 72 @staticmethod 73 def __new__( 74 cls, 75 values, 76 offsets, 77 *, 78 lengths=None, 79 **kwargs, 80 ): 81 ks = DispatchKeySet(DispatchKey.NestedTensor) 82 ks = ks.add(DispatchKey.AutogradNestedTensor) 83 84 # Only support jagged for now. 85 assert offsets is not None 86 assert offsets.ndim == 1 87 assert not isinstance(values, NestedTensor) 88 assert values.device == offsets.device 89 90 # Query cache for the symint associated with offsets or lengths 91 # (create a new one if needed). 92 ragged_source = offsets if lengths is None else lengths 93 ragged_size = get_tensor_symint(ragged_source, coeff=1) 94 _ragged_idx = kwargs.get("_ragged_idx", 1) 95 B = offsets.shape[0] - 1 96 if lengths is not None: 97 assert B == lengths.shape[0] 98 99 # subtract 1 to convert to values dim space 100 r = _ragged_idx - 1 101 _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :]) 102 stride = values.stride() 103 _strides = (ragged_size * stride[r], *stride) 104 105 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 106 cls, 107 _size, 108 _strides, 109 0, 110 torch.contiguous_format, 111 values.dtype, 112 torch.jagged, 113 values.device, 114 False, 115 kwargs.get("requires_grad", False), 116 "sizes", 117 False, 118 True, # dispatch_layout 119 ks, 120 # don't try to calculate storage based on non-zero size 121 storage_size=values.untyped_storage().size(), 122 ) 123 r._ragged_idx = _ragged_idx 124 r._size = _size 125 r._strides = _strides 126 127 return r 128 129 def __init__(self, values, offsets, *, lengths=None, **kwargs): 130 super().__init__() 131 132 self._values = values 133 self._offsets = offsets 134 self._lengths = lengths 135 136 # holds properties that are computed lazily 137 self._metadata_cache = kwargs.get("_metadata_cache") or {} 138 139 # collapsed ragged dim must always be dynamic 140 torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) 141 torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) 142 143 # min / max sequence length should be dynamic if present 144 max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) 145 if max_seqlen_tensor is not None: 146 torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) 147 min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) 148 if min_seqlen_tensor is not None: 149 torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) 150 151 def values(self): 152 # dispatch to get proper view relationship 153 return torch._nested_get_values(self) # type: ignore[attr-defined] 154 155 def offsets(self): 156 return self._offsets 157 158 def lengths(self): 159 return self._lengths 160 161 # Private accessor functions for min / max sequence length. They're 162 # purposefully not @properties because those don't work with PT2 (yet). 163 # These compute / cache if not present. 164 # TODO: Revisit this when @properties are better supported by PT2. I think the ideal 165 # state would be to have public @properties for min / max sequence length that compile 166 # (including setters). 167 def _get_max_seqlen(self): 168 max_seqlen_tensor = self._max_seqlen_tensor 169 if max_seqlen_tensor is None: 170 # compute & cache 171 max_val = _get_sdpa_extreme_seqlen( 172 torch.max, 173 self._offsets.diff() if self._lengths is None else self._lengths, 174 ) 175 max_seqlen_tensor = _store_val_in_tensor(max_val) 176 self._metadata_cache["max_seqlen"] = max_seqlen_tensor 177 return _load_val_from_tensor(max_seqlen_tensor) 178 179 def _get_min_seqlen(self): 180 min_seqlen_tensor = self._min_seqlen_tensor 181 if min_seqlen_tensor is None: 182 # compute & cache 183 min_val = _get_sdpa_extreme_seqlen( 184 torch.min, 185 self._offsets.diff() if self._lengths is None else self._lengths, 186 ) 187 min_seqlen_tensor = _store_val_in_tensor(min_val) 188 self._metadata_cache["min_seqlen"] = min_seqlen_tensor 189 return _load_val_from_tensor(min_seqlen_tensor) 190 191 # Private accessors used for treating min / max seqlen as inner tensors for 192 # flatten / unflatten. These must be properties to work with the traceable wrapper 193 # subclass logic. These do not compute / cache if not present. 194 @property 195 def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: 196 return self._metadata_cache.get("max_seqlen", None) 197 198 @property 199 def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: 200 return self._metadata_cache.get("min_seqlen", None) 201 202 # These are old private @property accessors that are kept around for internal BC 203 # reasons. TODO: Remove these! 204 @property 205 def _max_seqlen(self): 206 return self._get_max_seqlen() 207 208 @property 209 def _min_seqlen(self): 210 return self._get_min_seqlen() 211 212 def __repr__(self): 213 # We should implement this in torch/_tensor_str.py instead 214 grad_fn_str = ( 215 f", requires_grad={self.requires_grad}" if self.requires_grad else "" 216 ) 217 if self.grad_fn: 218 grad_fn_str = f", grad_fn={self.grad_fn}" 219 return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})" 220 221 def __reduce_ex__(self, proto): 222 state = torch._utils._get_obj_state(self) 223 224 # SymNodes are not serializable 225 assert "_size" in state and "_strides" in state 226 state = dict(state) 227 del state["_size"] 228 del state["_strides"] 229 230 # TODO: Update this to handle the other inner tensors 231 func = NestedTensor 232 args = (self._values, self._offsets) 233 return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) 234 235 def __tensor_flatten__(self): 236 ctx = { 237 "requires_grad": self.requires_grad, 238 "ragged_idx": self._ragged_idx, 239 } 240 inner_tensors = ["_values", "_offsets"] 241 if self._lengths is not None: 242 inner_tensors.append("_lengths") 243 if self._min_seqlen_tensor is not None: 244 inner_tensors.append("_min_seqlen_tensor") 245 if self._max_seqlen_tensor is not None: 246 inner_tensors.append("_max_seqlen_tensor") 247 return inner_tensors, ctx 248 249 @staticmethod 250 def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): 251 from torch._subclasses.fake_tensor import FakeTensor 252 253 # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] 254 assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 255 values = inner_tensors["_values"] 256 offsets = inner_tensors["_offsets"] 257 lengths = inner_tensors.get("_lengths", None) 258 min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) 259 max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) 260 261 metadata_cache = {} 262 if min_seqlen_tensor is not None: 263 metadata_cache["min_seqlen"] = min_seqlen_tensor 264 if max_seqlen_tensor is not None: 265 metadata_cache["max_seqlen"] = max_seqlen_tensor 266 ragged_idx = meta["ragged_idx"] 267 268 # Alternatively, we could make it the caller's responsibility to 269 # cache it. But this heuristic seems simple enough. 270 ragged_source = offsets if lengths is None else lengths 271 if isinstance(ragged_source, FakeTensor): 272 ragged_size = outer_size[ragged_idx] 273 ragged_source.nested_int_memo = ragged_size 274 275 return NestedTensor( 276 values, 277 offsets=offsets, 278 lengths=lengths, 279 requires_grad=meta["requires_grad"], 280 _ragged_idx=ragged_idx, 281 _metadata_cache=metadata_cache, 282 ) 283 284 @classmethod 285 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 286 kwargs = {} if kwargs is None else kwargs 287 288 # Lazy import to avoid circular dependency 289 from .ops import lookup_jagged 290 291 fn = lookup_jagged(func, *args, **kwargs) 292 if fn is not None: 293 return fn(*args, **kwargs) 294 295 raise NotImplementedError(func) 296 297 @classmethod 298 def __torch_function__(cls, func, types, args=(), kwargs=None): 299 if kwargs is None: 300 kwargs = {} 301 302 from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify 303 304 from .ops import jagged_torch_function 305 306 # This should be removed after 307 # https://github.com/pytorch/pytorch/pull/125941/ lands 308 with maybe_enable_thunkify(): 309 try: 310 return jagged_torch_function(func, *args, **kwargs) 311 except NotImplementedError: 312 pass 313 with torch._C.DisableTorchFunctionSubclass(): 314 return func(*args, **kwargs) 315 316 317# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them! 318# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the 319# internal BC period has passed. 320 321 322# Not actually a view! 323class ViewBufferFromNested(torch.autograd.Function): 324 @staticmethod 325 def forward(ctx, x: NestedTensor): # type: ignore[override] 326 ctx.save_for_backward(x.offsets()) 327 ctx.metadata_cache = x._metadata_cache 328 ctx.ragged_idx = x._ragged_idx 329 return x._values 330 331 @staticmethod 332 def backward(ctx, gO: torch.Tensor): # type: ignore[override] 333 (offsets,) = ctx.saved_tensors 334 return NestedTensor( 335 gO, 336 offsets=offsets, 337 _metadata_cache=ctx.metadata_cache, 338 _ragged_idx=ctx.ragged_idx, 339 ) 340 341 342# Not actually a view! 343class ViewNestedFromBuffer(torch.autograd.Function): 344 @staticmethod 345 def forward( 346 ctx, 347 values: torch.Tensor, 348 offsets: torch.Tensor, 349 metadata_cache: Optional[Dict[str, Any]] = None, 350 ): # type: ignore[override] 351 # maintain BC with this usages of this where the seqlens are stuffed 352 # directly into the metadata cache as non-Tensors / ints 353 if metadata_cache is not None: 354 min_seqlen = metadata_cache.get("min_seqlen", None) 355 max_seqlen = metadata_cache.get("max_seqlen", None) 356 if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): 357 metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) 358 if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): 359 metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) 360 return NestedTensor( 361 values.detach(), 362 offsets=offsets, 363 _metadata_cache=metadata_cache, 364 ) 365 366 @staticmethod 367 def backward(ctx, gO: NestedTensor): # type: ignore[override] 368 return gO._values, None, None 369 370 371def buffer_from_jagged(jagged): 372 return ViewBufferFromNested.apply(jagged) 373 374 375# Need to make it obvious that users should be passing in offsets 376def jagged_from_list( 377 tensors: List[torch.Tensor], 378 offsets: Optional[torch.Tensor], 379 dtype=None, 380 device=None, 381) -> Tuple[NestedTensor, torch.Tensor]: 382 """Constructs a NestedTensor backed by jagged layout from a list of tensors""" 383 384 if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401 385 raise RuntimeError( 386 "When constructing a nested tensor, all tensors in list must have the same dtype" 387 ) 388 if not len(set(t.device for t in tensors)) == 1: # noqa: C401 389 raise RuntimeError( 390 "When constructing a nested tensor, all tensors in list must be on the same device" 391 ) 392 393 # Check that the NT is representable by the jagged layout. 394 # Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only 395 # raggedness allowed is for the single dim immediately adjacent to the batch dim. 396 sizes = [t.shape for t in tensors] 397 non_first_sizes = [s[1:] for s in sizes] 398 at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes) 399 if not at_most_first_ragged: 400 raise RuntimeError( 401 "Cannot represent given tensor list as a nested tensor with the jagged layout. " 402 "Note that the jagged layout only represents shapes of the form " 403 "(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged." 404 ) 405 406 # Set properties appropriately. 407 values = torch.cat(tensors, dim=0) 408 to_kwargs = {} 409 if device is not None: 410 to_kwargs["device"] = device 411 if dtype is not None: 412 to_kwargs["dtype"] = dtype 413 values = values.to(**to_kwargs) 414 415 # Calculate jagged offsets if not provided. 416 if offsets is None: 417 # Jagged layout specifies that offsets are stored as int64 on the same device as values. 418 # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating 419 # an extra leaf tensor during the forward, potentially resolving compatibility issues. 420 offsets = torch.cat( 421 [ 422 torch.zeros(1, dtype=torch.int64, device=values.device), 423 torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0), 424 ] 425 ) 426 427 # compute this now since it's easy 428 min_seqlen = min(t.shape[0] for t in tensors) 429 max_seqlen = max(t.shape[0] for t in tensors) 430 ret_nt = nested_view_from_values_offsets( 431 values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen 432 ) 433 return (ret_nt, offsets) # type: ignore[return-value] 434 435 436def jagged_from_tensor_and_lengths( 437 tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor 438) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]: 439 """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths""" 440 batch_size = tensor.shape[0] 441 if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to( 442 lengths.shape, (batch_size,) 443 ): 444 start_list = starts.expand(batch_size) 445 length_list = lengths.expand(batch_size) 446 else: 447 raise RuntimeError( 448 "When constructing a jagged nested tensor using narrow(), " 449 "your start and length must be Tensors that broadcast to input.shape[0]" 450 ) 451 452 # Calculate jagged offsets 453 assert ( 454 len(tensor.shape) >= 2 455 ), "tensor must at least be 2D for the nested narrow op to work" 456 max_seq_len = tensor.shape[1] 457 offset_lengths = max_seq_len * torch.arange( 458 0, batch_size, dtype=torch.int64, device=tensor.device 459 ) 460 # Jagged layout specifies that offsets are stored as int64 on the same device as values. 461 offsets = torch.cat( 462 [ 463 start_list + offset_lengths, 464 (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0), 465 ] 466 ) 467 468 # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy) 469 if len(tensor.shape) > 2: 470 values = tensor.view(-1, *tensor.shape[2:]) 471 else: 472 values = tensor.view(-1) 473 474 # Check if offsets and lengths make it possibly contiguous and return a regular NT 475 is_contiguous = True 476 orig_dim = tensor.shape[1] 477 if torch.any(length_list[1:-1].ne(orig_dim)): 478 is_contiguous = False 479 if torch.any(offsets[1:-2].diff().ne(orig_dim)): 480 is_contiguous = False 481 if offsets[0] + length_list[0] != orig_dim: 482 is_contiguous = False 483 484 actual_max_seqlen = int(torch.max(lengths).item()) 485 min_seqlen = int(torch.min(lengths).item()) 486 487 if is_contiguous: 488 ret_nt = nested_view_from_values_offsets( 489 values[offsets[0] : offsets[-1]], 490 offsets - offsets[0], 491 min_seqlen=min_seqlen, 492 max_seqlen=actual_max_seqlen, 493 ) 494 else: 495 ret_nt = nested_view_from_values_offsets_lengths( 496 values, 497 offsets, 498 length_list, 499 min_seqlen=min_seqlen, 500 max_seqlen=actual_max_seqlen, 501 ) 502 503 return (ret_nt, offsets, None if is_contiguous else length_list) 504 505 506# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked 507# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be 508# 0/1 because the dummy can be fake-ified and we want to avoid specializing. 509# This arg is otherwise unused. 510_dummy_instance: Optional[torch.Tensor] = None 511 512 513def _nt_view_dummy() -> torch.Tensor: 514 global _dummy_instance 515 if _dummy_instance is None: 516 _dummy_instance = NestedTensor( 517 values=torch.zeros(3, 3, device="meta"), 518 offsets=torch.zeros(3, device="meta", dtype=torch.int64), 519 ).detach() 520 return _dummy_instance 521 522 523def nested_view_from_values_offsets( 524 values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None 525): 526 min_seqlen_tensor = None 527 if min_seqlen is not None: 528 min_seqlen_tensor = _store_val_in_tensor(min_seqlen) 529 530 max_seqlen_tensor = None 531 if max_seqlen is not None: 532 max_seqlen_tensor = _store_val_in_tensor(max_seqlen) 533 534 return torch._nested_view_from_jagged( # type: ignore[attr-defined] 535 values, 536 offsets, 537 _nt_view_dummy(), 538 None, 539 ragged_idx, 540 min_seqlen_tensor, 541 max_seqlen_tensor, 542 ) # type: ignore[return-value] 543 544 545def nested_view_from_values_offsets_lengths( 546 values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None 547): 548 min_seqlen_tensor = None 549 if min_seqlen is not None: 550 min_seqlen_tensor = _store_val_in_tensor(min_seqlen) 551 552 max_seqlen_tensor = None 553 if max_seqlen is not None: 554 max_seqlen_tensor = _store_val_in_tensor(max_seqlen) 555 556 return torch._nested_view_from_jagged( # type: ignore[attr-defined] 557 values, 558 offsets, 559 _nt_view_dummy(), 560 lengths, 561 ragged_idx, 562 min_seqlen_tensor, 563 max_seqlen_tensor, 564 ) # type: ignore[return-value] 565