1# mypy: allow-untyped-defs 2from typing import List, Optional, Tuple, Union 3 4import torch 5import torch.nn.functional as F 6from torch import SymInt, Tensor 7from torch._C import _add_docstr, _nested # type: ignore[attr-defined] 8 9from torch.types import _device as Device, _dtype as DType 10 11__all__ = [ 12 "to_padded_tensor", 13 "as_nested_tensor", 14 "nested_tensor", 15 "nested_tensor_from_jagged", 16 "narrow", 17 "masked_select", 18] 19 20# Nested Tensor constructor functions 21 22 23def as_nested_tensor( 24 ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], 25 dtype: Optional[DType] = None, 26 device: Optional[Device] = None, 27 layout=None 28) -> Tensor: 29 r""" 30 Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of 31 tensors. 32 33 If a nested tensor is passed, it will be returned directly unless the device / dtype / layout 34 differ. Note that converting device / dtype will result in a copy, while converting layout 35 is not currently supported by this function. 36 37 If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size. 38 A copy will be incurred if the passed device / dtype differ from those of the input OR if 39 the input is non-contiguous. Otherwise, the input's storage will be used directly. 40 41 If a tensor list is provided, tensors in the list are always copied during construction of 42 the nested tensor. 43 44 Args: 45 ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a 46 list / tuple of tensors with the same ndim 47 48 Keyword arguments: 49 dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. 50 Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. 51 device (:class:`torch.device`, optional): the desired device of returned nested tensor. 52 Default: if None, same :class:`torch.device` as leftmost tensor in the list 53 layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. 54 Only strided and jagged layouts are supported. Default: if None, the strided layout. 55 56 Example:: 57 58 >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) 59 >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) 60 >>> nt = torch.nested.as_nested_tensor([a, b]) 61 >>> nt.is_leaf 62 False 63 >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) 64 >>> nt.backward(fake_grad) 65 >>> a.grad 66 tensor([1., 1., 1.]) 67 >>> b.grad 68 tensor([0., 0., 0., 0., 0.]) 69 >>> c = torch.randn(3, 5, requires_grad=True) 70 >>> nt2 = torch.nested.as_nested_tensor(c) 71 """ 72 is_tensor_list = isinstance(ts, (list, tuple)) and all(isinstance(t, Tensor) for t in ts) 73 if not isinstance(ts, Tensor) and not is_tensor_list: 74 raise TypeError( 75 "as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors " 76 ) 77 # convert tuple -> list if needed 78 if is_tensor_list and not isinstance(ts, list): 79 ts = list(ts) 80 81 if isinstance(ts, Tensor) and ts.dim() < 2: 82 raise RuntimeError("as_nested_tensor(): Expected tensor argument to have dim() > 1") 83 84 if isinstance(ts, Tensor) and ts.is_nested: 85 if layout == ts.layout: 86 # return input directly or input copied to device / dtype 87 return ts.to(device=device, dtype=dtype) 88 else: 89 # TODO: Just use nt.to(layout=layout) when it exists. 90 raise RuntimeError( 91 "as_nested_tensor(): Converting between nested tensor layouts is not supported") 92 93 if layout is None: 94 layout = torch.strided 95 if layout == torch.strided: 96 if isinstance(ts, Tensor): 97 # contiguous() might be necessary to get flattened view. 98 # we could probably be more precise about when to do this as an optimization 99 buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype) 100 nested_sizes = torch.tensor([t.shape for t in ts]) 101 return torch._nested_view_from_buffer( 102 buffer, 103 nested_sizes, 104 *torch._nested_compute_contiguous_strides_offsets(nested_sizes)) 105 else: 106 assert isinstance(ts, list) 107 return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None) 108 elif layout == torch.jagged: 109 if isinstance(ts, Tensor): 110 if device is None: 111 device = ts.device 112 113 # contiguous() might be necessary to get flattened view. 114 # we could probably be more precise about when to do this as an optimization 115 values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype) 116 batch_size = ts.shape[0] 117 seq_len = ts.shape[1] 118 offsets = torch.arange(0, batch_size * seq_len + 1, seq_len, 119 device=device, dtype=torch.int64) 120 121 from torch.nested._internal.nested_tensor import nested_view_from_values_offsets 122 123 return nested_view_from_values_offsets( 124 values, offsets, min_seqlen=seq_len, max_seqlen=seq_len 125 ) 126 else: 127 from torch.nested._internal.nested_tensor import jagged_from_list 128 129 assert isinstance(ts, list) 130 nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype) 131 return nt 132 else: 133 raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") 134 135 136# Note: This not only adds doc strings for the nested ops, but 137# also connects the torch.nested Python namespace to the torch._C._nested builtins. 138 139to_padded_tensor = _add_docstr( 140 _nested.nested_to_padded_tensor, 141 r""" 142to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor 143 144Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor. 145The leading entries will be filled with the nested data, 146while the trailing entries will be padded. 147 148.. warning:: 149 150 :func:`to_padded_tensor` always copies the underlying data, 151 since the nested and the non-nested tensors differ in memory layout. 152 153Args: 154 padding (float): The padding value for the trailing entries. 155 156Keyword args: 157 output_size (Tuple[int]): The size of the output tensor. 158 If given, it must be large enough to contain all nested data; 159 else, will infer by taking the max size of each nested sub-tensor along each dimension. 160 out (Tensor, optional): the output tensor. 161 162Example:: 163 164 >>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) 165 nested_tensor([ 166 tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], 167 [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), 168 tensor([[-1.8546, -0.7194, -0.2918, -0.1846], 169 [ 0.2773, 0.8793, -0.5183, -0.6447], 170 [ 1.8009, 1.8468, -0.9832, -1.5272]]) 171 ]) 172 >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) 173 tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], 174 [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], 175 [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], 176 [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], 177 [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], 178 [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) 179 >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) 180 tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], 181 [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], 182 [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], 183 [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], 184 [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], 185 [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], 186 [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], 187 [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) 188 >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) 189 RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported. 190 191""", 192) 193 194def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor: 195 r""" 196Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see 197:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors. 198 199Args: 200 tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor, 201 where each element of the list has the same dimensionality. 202 203Keyword arguments: 204 dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. 205 Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. 206 layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. 207 Only strided and jagged layouts are supported. Default: if None, the strided layout. 208 device (:class:`torch.device`, optional): the desired device of returned nested tensor. 209 Default: if None, same :class:`torch.device` as leftmost tensor in the list 210 requires_grad (bool, optional): If autograd should record operations on the 211 returned nested tensor. Default: ``False``. 212 pin_memory (bool, optional): If set, returned nested tensor would be allocated in 213 the pinned memory. Works only for CPU tensors. Default: ``False``. 214 215Example:: 216 217 >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) 218 >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) 219 >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) 220 >>> nt.is_leaf 221 True 222 """ 223 if layout is None: 224 layout = torch.strided 225 if layout == torch.strided: 226 return _nested.nested_tensor( 227 tensor_list, 228 dtype=dtype, 229 device=device, 230 requires_grad=requires_grad, 231 pin_memory=pin_memory) 232 elif layout == torch.jagged: 233 # Need to wrap lists of scalars as tensors 234 list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list] 235 236 from torch.nested._internal.nested_tensor import jagged_from_list 237 238 with torch.no_grad(): 239 nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype) 240 241 nt.requires_grad_(requires_grad) 242 if pin_memory: 243 nt = nt.pin_memory() # type: ignore[assignment] 244 245 return nt 246 else: 247 raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") 248 249 250def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor: 251 r""" 252Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows 253similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor 254shows only the elements in the interval `[start, start+length)`. As nested representations 255allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length` 256can also be tensors of shape `tensor.shape[0]`. 257 258There's some differences depending on the layout you use for the nested tensor. If using strided layout, 259torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while 260jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular 261representation is really useful for representing kv-caches in Transformer models, as specialized 262SDPA kernels can deal with format easily, resulting in performance improvements. 263 264 265Args: 266 tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data 267 for the nested tensor if using the jagged layout or will be copied for the strided layout. 268 dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the 269 jagged layout, while strided supports all dim 270 start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation 271 length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op 272 273Keyword arguments: 274 layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. 275 Only strided and jagged layouts are supported. Default: if None, the strided layout. 276 277Example:: 278 279 >>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) 280 >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) 281 >>> narrow_base = torch.randn(5, 10, 20) 282 >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) 283 >>> nt_narrowed.is_contiguous() 284 False 285 """ 286 if not isinstance(start, (int, SymInt, Tensor)): 287 raise RuntimeError("start must be an integer or a tensor") 288 289 if not isinstance(length, (int, SymInt, Tensor)): 290 raise RuntimeError("length must be an integer or a tensor") 291 292 if layout == torch.strided: 293 if isinstance(start, Tensor) or isinstance(length, Tensor): 294 raise RuntimeError("start and length must be integers for the strided layout NT impl") 295 # TODO: switch to as_nested_tensor(tensor) when it is available 296 nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length) 297 elif layout == torch.jagged: 298 if dim != 1: 299 raise RuntimeError("jagged layout only supports dim=1") 300 301 from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths 302 303 if isinstance(start, (int, SymInt)): 304 start = torch.tensor([start], device=tensor.device, dtype=torch.int64) 305 306 if isinstance(length, (int, SymInt)): 307 length = torch.tensor([length], device=tensor.device, dtype=torch.int64) 308 309 nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length) 310 else: 311 raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}") 312 313 return nt 314 315 316def nested_tensor_from_jagged( 317 values: Tensor, 318 offsets: Optional[Tensor] = None, 319 lengths: Optional[Tensor] = None, 320 jagged_dim: Optional[int] = None, 321 min_seqlen: Optional[int] = None, 322 max_seqlen: Optional[int] = None, 323) -> Tensor: 324 r""" 325Constructs a jagged layout nested tensor from the given jagged components. The jagged layout 326consists of a required values buffer with the jagged dimension packed into a single dimension. 327The offsets / lengths metadata determines how this dimension is split into batch elements 328and are expected to be allocated on the same device as the values buffer. 329 330Expected metadata formats: 331 * offsets: Indices within the packed dimension splitting it into heterogeneously-sized 332 batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6 333 should be conceptually split into batch elements of length [2, 1, 3]. Note that both the 334 beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1). 335 * lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3] 336 indicates that a packed jagged dim of size 6 should be conceptually split into batch 337 elements of length [2, 1, 3]. 338 339Note that it can be useful to provide both offsets and lengths. This describes a nested tensor 340with "holes", where the offsets indicate the start position of each batch item and the length 341specifies the total number of elements (see example below). 342 343The returned jagged layout nested tensor will be a view of the input values tensor. 344 345Args: 346 values (:class:`torch.Tensor`): The underlying buffer in the shape of 347 (sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension, 348 with the offsets / lengths metadata used to distinguish batch elements. 349 offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. 350 lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. 351 jagged_dim (optional int): Indicates which dimension in values is the packed jagged 352 dimension. If None, this is set to dim=1 (i.e. the dimension immediately following 353 the batch dimension). Default: None 354 min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence 355 length for the returned nested tensor. This can be a useful alternative to computing 356 this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None 357 max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence 358 length for the returned nested tensor. This can be a useful alternative to computing 359 this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None 360 361Example:: 362 363 >>> values = torch.randn(12, 5) 364 >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) 365 >>> nt = nested_tensor_from_jagged(values, offsets) 366 >>> # 3D shape with the middle dimension jagged 367 >>> nt.shape 368 torch.Size([5, j2, 5]) 369 >>> # Length of each item in the batch: 370 >>> offsets.diff() 371 tensor([3, 2, 1, 4, 2]) 372 373 >>> values = torch.randn(6, 5) 374 >>> offsets = torch.tensor([0, 2, 3, 6]) 375 >>> lengths = torch.tensor([1, 1, 2]) 376 >>> # NT with holes 377 >>> nt = nested_tensor_from_jagged(values, offsets, lengths) 378 >>> a, b, c = nt.unbind() 379 >>> # Batch item 1 consists of indices [0, 1) 380 >>> torch.equal(a, values[0:1, :]) 381 True 382 >>> # Batch item 2 consists of indices [2, 3) 383 >>> torch.equal(b, values[2:3, :]) 384 True 385 >>> # Batch item 3 consists of indices [3, 5) 386 >>> torch.equal(c, values[3:5, :]) 387 True 388 """ 389 from torch.fx._symbolic_trace import is_fx_tracing 390 if is_fx_tracing(): 391 raise RuntimeError( 392 "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. " 393 "Use fx.wrap to wrap the function that calls nested_tensor_from_jagged." 394 ) 395 396 if offsets is None: 397 if lengths is None: 398 raise RuntimeError( 399 "nested_tensor_from_jagged(): At least one of offsets or lengths is required." 400 ) 401 else: 402 # TODO: Truly support offsets=None at some point? 403 # For now, just convert lengths -> offsets for kernel convenience 404 offsets = F.pad(lengths.cumsum(0), (1, 0)) 405 lengths = None 406 407 if jagged_dim is None: 408 jagged_dim = 1 409 410 from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths 411 412 return nested_view_from_values_offsets_lengths( 413 values, offsets, lengths, ragged_idx=jagged_dim, min_seqlen=min_seqlen, max_seqlen=max_seqlen) 414 415def masked_select(tensor: Tensor, mask: Tensor) -> Tensor: 416 r""" 417 Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor 418 will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is 419 represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor. 420 421 Args: 422 tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from. 423 mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input 424 425 Example:: 426 427 >>> tensor = torch.randn(3, 3) 428 >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]]) 429 >>> nt = torch.nested.masked_select(tensor, mask) 430 >>> nt.shape 431 torch.Size([3, j4]) 432 >>> # Length of each item in the batch: 433 >>> nt.offsets().diff() 434 tensor([1, 2, 1]) 435 436 >>> tensor = torch.randn(6, 5) 437 >>> mask = torch.tensor([False]) 438 >>> nt = torch.nested.masked_select(tensor, mask) 439 >>> nt.shape 440 torch.Size([6, j5]) 441 >>> # Length of each item in the batch: 442 >>> nt.offsets().diff() 443 tensor([0, 0, 0, 0, 0, 0]) 444 """ 445 if tensor.layout != torch.strided: 446 raise RuntimeError( 447 f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}" 448 ) 449 450 if mask.layout != torch.strided: 451 raise RuntimeError( 452 f"torch.nested.masked_select requires a strided mask, given: {mask.layout}" 453 ) 454 res_values = tensor.masked_select(mask) 455 expanded_mask = mask.expand(tensor.shape) 456 res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1) 457 458 from torch.nested._internal.nested_tensor import ( 459 nested_view_from_values_offsets, 460 ) 461 462 return nested_view_from_values_offsets( 463 values=res_values, 464 offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)), 465 ) 466