xref: /aosp_15_r20/external/pytorch/torch/nested/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Optional, Tuple, Union
3
4import torch
5import torch.nn.functional as F
6from torch import SymInt, Tensor
7from torch._C import _add_docstr, _nested  # type: ignore[attr-defined]
8
9from torch.types import _device as Device, _dtype as DType
10
11__all__ = [
12    "to_padded_tensor",
13    "as_nested_tensor",
14    "nested_tensor",
15    "nested_tensor_from_jagged",
16    "narrow",
17    "masked_select",
18]
19
20# Nested Tensor constructor functions
21
22
23def as_nested_tensor(
24    ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
25    dtype: Optional[DType] = None,
26    device: Optional[Device] = None,
27    layout=None
28) -> Tensor:
29    r"""
30    Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of
31    tensors.
32
33    If a nested tensor is passed, it will be returned directly unless the device / dtype / layout
34    differ. Note that converting device / dtype will result in a copy, while converting layout
35    is not currently supported by this function.
36
37    If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size.
38    A copy will be incurred if the passed device / dtype differ from those of the input OR if
39    the input is non-contiguous. Otherwise, the input's storage will be used directly.
40
41    If a tensor list is provided, tensors in the list are always copied during construction of
42    the nested tensor.
43
44    Args:
45        ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a
46            list / tuple of tensors with the same ndim
47
48    Keyword arguments:
49        dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
50            Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
51        device (:class:`torch.device`, optional): the desired device of returned nested tensor.
52            Default: if None, same :class:`torch.device` as leftmost tensor in the list
53        layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
54            Only strided and jagged layouts are supported. Default: if None, the strided layout.
55
56    Example::
57
58        >>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
59        >>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
60        >>> nt = torch.nested.as_nested_tensor([a, b])
61        >>> nt.is_leaf
62        False
63        >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
64        >>> nt.backward(fake_grad)
65        >>> a.grad
66        tensor([1., 1., 1.])
67        >>> b.grad
68        tensor([0., 0., 0., 0., 0.])
69        >>> c = torch.randn(3, 5, requires_grad=True)
70        >>> nt2 = torch.nested.as_nested_tensor(c)
71    """
72    is_tensor_list = isinstance(ts, (list, tuple)) and all(isinstance(t, Tensor) for t in ts)
73    if not isinstance(ts, Tensor) and not is_tensor_list:
74        raise TypeError(
75            "as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors "
76        )
77    # convert tuple -> list if needed
78    if is_tensor_list and not isinstance(ts, list):
79        ts = list(ts)
80
81    if isinstance(ts, Tensor) and ts.dim() < 2:
82        raise RuntimeError("as_nested_tensor(): Expected tensor argument to have dim() > 1")
83
84    if isinstance(ts, Tensor) and ts.is_nested:
85        if layout == ts.layout:
86            # return input directly or input copied to device / dtype
87            return ts.to(device=device, dtype=dtype)
88        else:
89            # TODO: Just use nt.to(layout=layout) when it exists.
90            raise RuntimeError(
91                "as_nested_tensor(): Converting between nested tensor layouts is not supported")
92
93    if layout is None:
94        layout = torch.strided
95    if layout == torch.strided:
96        if isinstance(ts, Tensor):
97            # contiguous() might be necessary to get flattened view.
98            # we could probably be more precise about when to do this as an optimization
99            buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype)
100            nested_sizes = torch.tensor([t.shape for t in ts])
101            return torch._nested_view_from_buffer(
102                buffer,
103                nested_sizes,
104                *torch._nested_compute_contiguous_strides_offsets(nested_sizes))
105        else:
106            assert isinstance(ts, list)
107            return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)
108    elif layout == torch.jagged:
109        if isinstance(ts, Tensor):
110            if device is None:
111                device = ts.device
112
113            # contiguous() might be necessary to get flattened view.
114            # we could probably be more precise about when to do this as an optimization
115            values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype)
116            batch_size = ts.shape[0]
117            seq_len = ts.shape[1]
118            offsets = torch.arange(0, batch_size * seq_len + 1, seq_len,
119                                   device=device, dtype=torch.int64)
120
121            from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
122
123            return nested_view_from_values_offsets(
124                values, offsets, min_seqlen=seq_len, max_seqlen=seq_len
125            )
126        else:
127            from torch.nested._internal.nested_tensor import jagged_from_list
128
129            assert isinstance(ts, list)
130            nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype)
131            return nt
132    else:
133        raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
134
135
136# Note: This not only adds doc strings for the nested ops, but
137# also connects the torch.nested Python namespace to the torch._C._nested builtins.
138
139to_padded_tensor = _add_docstr(
140    _nested.nested_to_padded_tensor,
141    r"""
142to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
143
144Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor.
145The leading entries will be filled with the nested data,
146while the trailing entries will be padded.
147
148.. warning::
149
150    :func:`to_padded_tensor` always copies the underlying data,
151    since the nested and the non-nested tensors differ in memory layout.
152
153Args:
154    padding (float): The padding value for the trailing entries.
155
156Keyword args:
157    output_size (Tuple[int]): The size of the output tensor.
158                              If given, it must be large enough to contain all nested data;
159                              else, will infer by taking the max size of each nested sub-tensor along each dimension.
160    out (Tensor, optional): the output tensor.
161
162Example::
163
164    >>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
165    nested_tensor([
166      tensor([[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
167              [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995]]),
168      tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
169              [ 0.2773,  0.8793, -0.5183, -0.6447],
170              [ 1.8009,  1.8468, -0.9832, -1.5272]])
171    ])
172    >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
173    tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
174             [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995],
175             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
176            [[-1.8546, -0.7194, -0.2918, -0.1846,  0.0000],
177             [ 0.2773,  0.8793, -0.5183, -0.6447,  0.0000],
178             [ 1.8009,  1.8468, -0.9832, -1.5272,  0.0000]]])
179    >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
180    tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276,  1.0000],
181             [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995,  1.0000],
182             [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
183             [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]],
184            [[-1.8546, -0.7194, -0.2918, -0.1846,  1.0000,  1.0000],
185             [ 0.2773,  0.8793, -0.5183, -0.6447,  1.0000,  1.0000],
186             [ 1.8009,  1.8468, -0.9832, -1.5272,  1.0000,  1.0000],
187             [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]]])
188    >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
189    RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
190
191""",
192)
193
194def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor:
195    r"""
196Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see
197:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
198
199Args:
200    tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor,
201    where each element of the list has the same dimensionality.
202
203Keyword arguments:
204    dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
205        Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
206    layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
207        Only strided and jagged layouts are supported. Default: if None, the strided layout.
208    device (:class:`torch.device`, optional): the desired device of returned nested tensor.
209        Default: if None, same :class:`torch.device` as leftmost tensor in the list
210    requires_grad (bool, optional): If autograd should record operations on the
211        returned nested tensor. Default: ``False``.
212    pin_memory (bool, optional): If set, returned nested tensor would be allocated in
213        the pinned memory. Works only for CPU tensors. Default: ``False``.
214
215Example::
216
217    >>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
218    >>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
219    >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
220    >>> nt.is_leaf
221    True
222    """
223    if layout is None:
224        layout = torch.strided
225    if layout == torch.strided:
226        return _nested.nested_tensor(
227            tensor_list,
228            dtype=dtype,
229            device=device,
230            requires_grad=requires_grad,
231            pin_memory=pin_memory)
232    elif layout == torch.jagged:
233        # Need to wrap lists of scalars as tensors
234        list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list]
235
236        from torch.nested._internal.nested_tensor import jagged_from_list
237
238        with torch.no_grad():
239            nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
240
241        nt.requires_grad_(requires_grad)
242        if pin_memory:
243            nt = nt.pin_memory()  # type: ignore[assignment]
244
245        return nt
246    else:
247        raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
248
249
250def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
251    r"""
252Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
253similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
254shows only the elements in the interval `[start, start+length)`. As nested representations
255allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
256can also be tensors of shape `tensor.shape[0]`.
257
258There's some differences depending on the layout you use for the nested tensor. If using strided layout,
259torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
260jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
261representation is really useful for representing kv-caches in Transformer models, as specialized
262SDPA kernels can deal with format easily, resulting in performance improvements.
263
264
265Args:
266    tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
267        for the nested tensor if using the jagged layout or will be copied for the strided layout.
268    dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
269        jagged layout, while strided supports all dim
270    start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
271    length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
272
273Keyword arguments:
274    layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
275        Only strided and jagged layouts are supported. Default: if None, the strided layout.
276
277Example::
278
279    >>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
280    >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
281    >>> narrow_base = torch.randn(5, 10, 20)
282    >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
283    >>> nt_narrowed.is_contiguous()
284    False
285    """
286    if not isinstance(start, (int, SymInt, Tensor)):
287        raise RuntimeError("start must be an integer or a tensor")
288
289    if not isinstance(length, (int, SymInt, Tensor)):
290        raise RuntimeError("length must be an integer or a tensor")
291
292    if layout == torch.strided:
293        if isinstance(start, Tensor) or isinstance(length, Tensor):
294            raise RuntimeError("start and length must be integers for the strided layout NT impl")
295        # TODO: switch to as_nested_tensor(tensor) when it is available
296        nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
297    elif layout == torch.jagged:
298        if dim != 1:
299            raise RuntimeError("jagged layout only supports dim=1")
300
301        from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
302
303        if isinstance(start, (int, SymInt)):
304            start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
305
306        if isinstance(length, (int, SymInt)):
307            length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
308
309        nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
310    else:
311        raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
312
313    return nt
314
315
316def nested_tensor_from_jagged(
317    values: Tensor,
318    offsets: Optional[Tensor] = None,
319    lengths: Optional[Tensor] = None,
320    jagged_dim: Optional[int] = None,
321    min_seqlen: Optional[int] = None,
322    max_seqlen: Optional[int] = None,
323) -> Tensor:
324    r"""
325Constructs a jagged layout nested tensor from the given jagged components. The jagged layout
326consists of a required values buffer with the jagged dimension packed into a single dimension.
327The offsets / lengths metadata determines how this dimension is split into batch elements
328and are expected to be allocated on the same device as the values buffer.
329
330Expected metadata formats:
331    * offsets: Indices within the packed dimension splitting it into heterogeneously-sized
332      batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6
333      should be conceptually split into batch elements of length [2, 1, 3]. Note that both the
334      beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1).
335    * lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3]
336      indicates that a packed jagged dim of size 6 should be conceptually split into batch
337      elements of length [2, 1, 3].
338
339Note that it can be useful to provide both offsets and lengths. This describes a nested tensor
340with "holes", where the offsets indicate the start position of each batch item and the length
341specifies the total number of elements (see example below).
342
343The returned jagged layout nested tensor will be a view of the input values tensor.
344
345Args:
346    values (:class:`torch.Tensor`): The underlying buffer in the shape of
347        (sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension,
348        with the offsets / lengths metadata used to distinguish batch elements.
349    offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1.
350    lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B.
351    jagged_dim (optional int): Indicates which dimension in values is the packed jagged
352        dimension. If None, this is set to dim=1 (i.e. the dimension immediately following
353        the batch dimension). Default: None
354    min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence
355        length for the returned nested tensor. This can be a useful alternative to computing
356        this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
357    max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence
358        length for the returned nested tensor. This can be a useful alternative to computing
359        this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
360
361Example::
362
363    >>> values = torch.randn(12, 5)
364    >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12])
365    >>> nt = nested_tensor_from_jagged(values, offsets)
366    >>> # 3D shape with the middle dimension jagged
367    >>> nt.shape
368    torch.Size([5, j2, 5])
369    >>> # Length of each item in the batch:
370    >>> offsets.diff()
371    tensor([3, 2, 1, 4, 2])
372
373    >>> values = torch.randn(6, 5)
374    >>> offsets = torch.tensor([0, 2, 3, 6])
375    >>> lengths = torch.tensor([1, 1, 2])
376    >>> # NT with holes
377    >>> nt = nested_tensor_from_jagged(values, offsets, lengths)
378    >>> a, b, c = nt.unbind()
379    >>> # Batch item 1 consists of indices [0, 1)
380    >>> torch.equal(a, values[0:1, :])
381    True
382    >>> # Batch item 2 consists of indices [2, 3)
383    >>> torch.equal(b, values[2:3, :])
384    True
385    >>> # Batch item 3 consists of indices [3, 5)
386    >>> torch.equal(c, values[3:5, :])
387    True
388    """
389    from torch.fx._symbolic_trace import is_fx_tracing
390    if is_fx_tracing():
391        raise RuntimeError(
392            "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. "
393            "Use fx.wrap to wrap the function that calls nested_tensor_from_jagged."
394        )
395
396    if offsets is None:
397        if lengths is None:
398            raise RuntimeError(
399                "nested_tensor_from_jagged(): At least one of offsets or lengths is required."
400            )
401        else:
402            # TODO: Truly support offsets=None at some point?
403            # For now, just convert lengths -> offsets for kernel convenience
404            offsets = F.pad(lengths.cumsum(0), (1, 0))
405            lengths = None
406
407    if jagged_dim is None:
408        jagged_dim = 1
409
410    from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths
411
412    return nested_view_from_values_offsets_lengths(
413        values, offsets, lengths, ragged_idx=jagged_dim, min_seqlen=min_seqlen, max_seqlen=max_seqlen)
414
415def masked_select(tensor: Tensor, mask: Tensor) -> Tensor:
416    r"""
417    Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor
418    will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is
419    represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor.
420
421    Args:
422    tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from.
423    mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input
424
425    Example::
426
427    >>> tensor = torch.randn(3, 3)
428    >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
429    >>> nt = torch.nested.masked_select(tensor, mask)
430    >>> nt.shape
431    torch.Size([3, j4])
432    >>> # Length of each item in the batch:
433    >>> nt.offsets().diff()
434    tensor([1, 2, 1])
435
436    >>> tensor = torch.randn(6, 5)
437    >>> mask = torch.tensor([False])
438    >>> nt = torch.nested.masked_select(tensor, mask)
439    >>> nt.shape
440    torch.Size([6, j5])
441    >>> # Length of each item in the batch:
442    >>> nt.offsets().diff()
443    tensor([0, 0, 0, 0, 0, 0])
444    """
445    if tensor.layout != torch.strided:
446        raise RuntimeError(
447            f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}"
448        )
449
450    if mask.layout != torch.strided:
451        raise RuntimeError(
452            f"torch.nested.masked_select requires a strided mask, given: {mask.layout}"
453        )
454    res_values = tensor.masked_select(mask)
455    expanded_mask = mask.expand(tensor.shape)
456    res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1)
457
458    from torch.nested._internal.nested_tensor import (
459        nested_view_from_values_offsets,
460    )
461
462    return nested_view_from_values_offsets(
463        values=res_values,
464        offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)),
465    )
466