xref: /aosp_15_r20/external/pytorch/torch/nested/_internal/nested_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import *  # noqa: F403
3from typing import Tuple
4
5import torch
6from torch._C import DispatchKey, DispatchKeySet
7from torch._prims_common import is_expandable_to
8from torch.utils.weak import WeakTensorKeyDictionary
9
10
11_tensor_id_counter = 0
12_tensor_symint_registry = WeakTensorKeyDictionary()
13
14
15def get_tensor_symint(tensor, *, coeff=1):
16    from torch._subclasses.fake_tensor import FakeTensor
17    from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
18
19    # NB: Only FakeTensor is associated with a memo
20    tensor = mb_unwrap_functional_tensor(tensor)
21    if isinstance(tensor, FakeTensor):
22        return tensor.get_nested_int(coeff=coeff)
23
24    global _tensor_id_counter
25
26    tensor_symint = _tensor_symint_registry.get(tensor)
27    if tensor_symint is None:
28        tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
29        _tensor_id_counter += 1
30        _tensor_symint_registry[tensor] = tensor_symint
31    return tensor_symint
32
33
34# SDPA metadata; max / min seqlens are needed for e.g. flash
35def _get_sdpa_extreme_seqlen(func, tensor):
36    return int(func(tensor).item())
37
38
39def _store_val_in_tensor(val) -> torch.Tensor:
40    # hack to get dynamic shapes support: store in a (val, 0) shaped tensor
41    return torch.zeros(val, 0)
42
43
44def _load_val_from_tensor(t: torch.Tensor):
45    return t.shape[0]
46
47
48class NestedTensor(torch.Tensor):
49    _values: torch.Tensor  # type: ignore[assignment]
50    _offsets: torch.Tensor
51    _lengths: Optional[torch.Tensor]
52    # NOTE [ Nested ints for ragged sizes and strides ]
53    #
54    # Jagged layout tensors are tensors that represent a n-dim tensor with a
55    # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
56    # a jagged tensor with outer shape [B, x, D] is represented internally by a
57    # tensor with shape [sum(x), D] where we introduce what we call a nested int
58    # denoted as "x" here (but sometimes denoted with "*" to
59    # represent the ragged dimension, and sum(x) represents the dim of the inner
60    # tensor or equivalently the sum of all the sizes of the constituent
61    # tensors' varying lengths.
62    #
63    # We also use nested ints to represent the strides of this tensor.
64    # For example, a jagged tensor with shape [B, x, D] can be strided in two
65    # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
66    _size: Tuple[int, ...]
67    _strides: Tuple[int, ...]
68    # Indicates that the nth dimension is ragged
69    _ragged_idx: int
70    _metadata_cache: Dict[str, Any]
71
72    @staticmethod
73    def __new__(
74        cls,
75        values,
76        offsets,
77        *,
78        lengths=None,
79        **kwargs,
80    ):
81        ks = DispatchKeySet(DispatchKey.NestedTensor)
82        ks = ks.add(DispatchKey.AutogradNestedTensor)
83
84        # Only support jagged for now.
85        assert offsets is not None
86        assert offsets.ndim == 1
87        assert not isinstance(values, NestedTensor)
88        assert values.device == offsets.device
89
90        # Query cache for the symint associated with offsets or lengths
91        # (create a new one if needed).
92        ragged_source = offsets if lengths is None else lengths
93        ragged_size = get_tensor_symint(ragged_source, coeff=1)
94        _ragged_idx = kwargs.get("_ragged_idx", 1)
95        B = offsets.shape[0] - 1
96        if lengths is not None:
97            assert B == lengths.shape[0]
98
99        # subtract 1 to convert to values dim space
100        r = _ragged_idx - 1
101        _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
102        stride = values.stride()
103        _strides = (ragged_size * stride[r], *stride)
104
105        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
106            cls,
107            _size,
108            _strides,
109            0,
110            torch.contiguous_format,
111            values.dtype,
112            torch.jagged,
113            values.device,
114            False,
115            kwargs.get("requires_grad", False),
116            "sizes",
117            False,
118            True,  # dispatch_layout
119            ks,
120            # don't try to calculate storage based on non-zero size
121            storage_size=values.untyped_storage().size(),
122        )
123        r._ragged_idx = _ragged_idx
124        r._size = _size
125        r._strides = _strides
126
127        return r
128
129    def __init__(self, values, offsets, *, lengths=None, **kwargs):
130        super().__init__()
131
132        self._values = values
133        self._offsets = offsets
134        self._lengths = lengths
135
136        # holds properties that are computed lazily
137        self._metadata_cache = kwargs.get("_metadata_cache") or {}
138
139        # collapsed ragged dim must always be dynamic
140        torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
141        torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
142
143        # min / max sequence length should be dynamic if present
144        max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
145        if max_seqlen_tensor is not None:
146            torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
147        min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
148        if min_seqlen_tensor is not None:
149            torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
150
151    def values(self):
152        # dispatch to get proper view relationship
153        return torch._nested_get_values(self)  # type: ignore[attr-defined]
154
155    def offsets(self):
156        return self._offsets
157
158    def lengths(self):
159        return self._lengths
160
161    # Private accessor functions for min / max sequence length. They're
162    # purposefully not @properties because those don't work with PT2 (yet).
163    # These compute / cache if not present.
164    # TODO: Revisit this when @properties are better supported by PT2. I think the ideal
165    # state would be to have public @properties for min / max sequence length that compile
166    # (including setters).
167    def _get_max_seqlen(self):
168        max_seqlen_tensor = self._max_seqlen_tensor
169        if max_seqlen_tensor is None:
170            # compute & cache
171            max_val = _get_sdpa_extreme_seqlen(
172                torch.max,
173                self._offsets.diff() if self._lengths is None else self._lengths,
174            )
175            max_seqlen_tensor = _store_val_in_tensor(max_val)
176            self._metadata_cache["max_seqlen"] = max_seqlen_tensor
177        return _load_val_from_tensor(max_seqlen_tensor)
178
179    def _get_min_seqlen(self):
180        min_seqlen_tensor = self._min_seqlen_tensor
181        if min_seqlen_tensor is None:
182            # compute & cache
183            min_val = _get_sdpa_extreme_seqlen(
184                torch.min,
185                self._offsets.diff() if self._lengths is None else self._lengths,
186            )
187            min_seqlen_tensor = _store_val_in_tensor(min_val)
188            self._metadata_cache["min_seqlen"] = min_seqlen_tensor
189        return _load_val_from_tensor(min_seqlen_tensor)
190
191    # Private accessors used for treating min / max seqlen as inner tensors for
192    # flatten / unflatten. These must be properties to work with the traceable wrapper
193    # subclass logic. These do not compute / cache if not present.
194    @property
195    def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
196        return self._metadata_cache.get("max_seqlen", None)
197
198    @property
199    def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
200        return self._metadata_cache.get("min_seqlen", None)
201
202    # These are old private @property accessors that are kept around for internal BC
203    # reasons. TODO: Remove these!
204    @property
205    def _max_seqlen(self):
206        return self._get_max_seqlen()
207
208    @property
209    def _min_seqlen(self):
210        return self._get_min_seqlen()
211
212    def __repr__(self):
213        # We should implement this in torch/_tensor_str.py instead
214        grad_fn_str = (
215            f", requires_grad={self.requires_grad}" if self.requires_grad else ""
216        )
217        if self.grad_fn:
218            grad_fn_str = f", grad_fn={self.grad_fn}"
219        return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
220
221    def __reduce_ex__(self, proto):
222        state = torch._utils._get_obj_state(self)
223
224        # SymNodes are not serializable
225        assert "_size" in state and "_strides" in state
226        state = dict(state)
227        del state["_size"]
228        del state["_strides"]
229
230        # TODO: Update this to handle the other inner tensors
231        func = NestedTensor
232        args = (self._values, self._offsets)
233        return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
234
235    def __tensor_flatten__(self):
236        ctx = {
237            "requires_grad": self.requires_grad,
238            "ragged_idx": self._ragged_idx,
239        }
240        inner_tensors = ["_values", "_offsets"]
241        if self._lengths is not None:
242            inner_tensors.append("_lengths")
243        if self._min_seqlen_tensor is not None:
244            inner_tensors.append("_min_seqlen_tensor")
245        if self._max_seqlen_tensor is not None:
246            inner_tensors.append("_max_seqlen_tensor")
247        return inner_tensors, ctx
248
249    @staticmethod
250    def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
251        from torch._subclasses.fake_tensor import FakeTensor
252
253        # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
254        assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
255        values = inner_tensors["_values"]
256        offsets = inner_tensors["_offsets"]
257        lengths = inner_tensors.get("_lengths", None)
258        min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
259        max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
260
261        metadata_cache = {}
262        if min_seqlen_tensor is not None:
263            metadata_cache["min_seqlen"] = min_seqlen_tensor
264        if max_seqlen_tensor is not None:
265            metadata_cache["max_seqlen"] = max_seqlen_tensor
266        ragged_idx = meta["ragged_idx"]
267
268        # Alternatively, we could make it the caller's responsibility to
269        # cache it. But this heuristic seems simple enough.
270        ragged_source = offsets if lengths is None else lengths
271        if isinstance(ragged_source, FakeTensor):
272            ragged_size = outer_size[ragged_idx]
273            ragged_source.nested_int_memo = ragged_size
274
275        return NestedTensor(
276            values,
277            offsets=offsets,
278            lengths=lengths,
279            requires_grad=meta["requires_grad"],
280            _ragged_idx=ragged_idx,
281            _metadata_cache=metadata_cache,
282        )
283
284    @classmethod
285    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
286        kwargs = {} if kwargs is None else kwargs
287
288        # Lazy import to avoid circular dependency
289        from .ops import lookup_jagged
290
291        fn = lookup_jagged(func, *args, **kwargs)
292        if fn is not None:
293            return fn(*args, **kwargs)
294
295        raise NotImplementedError(func)
296
297    @classmethod
298    def __torch_function__(cls, func, types, args=(), kwargs=None):
299        if kwargs is None:
300            kwargs = {}
301
302        from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
303
304        from .ops import jagged_torch_function
305
306        # This should be removed after
307        # https://github.com/pytorch/pytorch/pull/125941/ lands
308        with maybe_enable_thunkify():
309            try:
310                return jagged_torch_function(func, *args, **kwargs)
311            except NotImplementedError:
312                pass
313            with torch._C.DisableTorchFunctionSubclass():
314                return func(*args, **kwargs)
315
316
317# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
318# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
319# internal BC period has passed.
320
321
322# Not actually a view!
323class ViewBufferFromNested(torch.autograd.Function):
324    @staticmethod
325    def forward(ctx, x: NestedTensor):  # type: ignore[override]
326        ctx.save_for_backward(x.offsets())
327        ctx.metadata_cache = x._metadata_cache
328        ctx.ragged_idx = x._ragged_idx
329        return x._values
330
331    @staticmethod
332    def backward(ctx, gO: torch.Tensor):  # type: ignore[override]
333        (offsets,) = ctx.saved_tensors
334        return NestedTensor(
335            gO,
336            offsets=offsets,
337            _metadata_cache=ctx.metadata_cache,
338            _ragged_idx=ctx.ragged_idx,
339        )
340
341
342# Not actually a view!
343class ViewNestedFromBuffer(torch.autograd.Function):
344    @staticmethod
345    def forward(
346        ctx,
347        values: torch.Tensor,
348        offsets: torch.Tensor,
349        metadata_cache: Optional[Dict[str, Any]] = None,
350    ):  # type: ignore[override]
351        # maintain BC with this usages of this where the seqlens are stuffed
352        # directly into the metadata cache as non-Tensors / ints
353        if metadata_cache is not None:
354            min_seqlen = metadata_cache.get("min_seqlen", None)
355            max_seqlen = metadata_cache.get("max_seqlen", None)
356            if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
357                metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
358            if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
359                metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
360        return NestedTensor(
361            values.detach(),
362            offsets=offsets,
363            _metadata_cache=metadata_cache,
364        )
365
366    @staticmethod
367    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
368        return gO._values, None, None
369
370
371def buffer_from_jagged(jagged):
372    return ViewBufferFromNested.apply(jagged)
373
374
375# Need to make it obvious that users should be passing in offsets
376def jagged_from_list(
377    tensors: List[torch.Tensor],
378    offsets: Optional[torch.Tensor],
379    dtype=None,
380    device=None,
381) -> Tuple[NestedTensor, torch.Tensor]:
382    """Constructs a NestedTensor backed by jagged layout from a list of tensors"""
383
384    if not len(set(t.dtype for t in tensors)) == 1:  # noqa: C401
385        raise RuntimeError(
386            "When constructing a nested tensor, all tensors in list must have the same dtype"
387        )
388    if not len(set(t.device for t in tensors)) == 1:  # noqa: C401
389        raise RuntimeError(
390            "When constructing a nested tensor, all tensors in list must be on the same device"
391        )
392
393    # Check that the NT is representable by the jagged layout.
394    # Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
395    # raggedness allowed is for the single dim immediately adjacent to the batch dim.
396    sizes = [t.shape for t in tensors]
397    non_first_sizes = [s[1:] for s in sizes]
398    at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
399    if not at_most_first_ragged:
400        raise RuntimeError(
401            "Cannot represent given tensor list as a nested tensor with the jagged layout. "
402            "Note that the jagged layout only represents shapes of the form "
403            "(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
404        )
405
406    # Set properties appropriately.
407    values = torch.cat(tensors, dim=0)
408    to_kwargs = {}
409    if device is not None:
410        to_kwargs["device"] = device
411    if dtype is not None:
412        to_kwargs["dtype"] = dtype
413    values = values.to(**to_kwargs)
414
415    # Calculate jagged offsets if not provided.
416    if offsets is None:
417        # Jagged layout specifies that offsets are stored as int64 on the same device as values.
418        # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
419        # an extra leaf tensor during the forward, potentially resolving compatibility issues.
420        offsets = torch.cat(
421            [
422                torch.zeros(1, dtype=torch.int64, device=values.device),
423                torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
424            ]
425        )
426
427    # compute this now since it's easy
428    min_seqlen = min(t.shape[0] for t in tensors)
429    max_seqlen = max(t.shape[0] for t in tensors)
430    ret_nt = nested_view_from_values_offsets(
431        values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
432    )
433    return (ret_nt, offsets)  # type: ignore[return-value]
434
435
436def jagged_from_tensor_and_lengths(
437    tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
438) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
439    """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
440    batch_size = tensor.shape[0]
441    if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
442        lengths.shape, (batch_size,)
443    ):
444        start_list = starts.expand(batch_size)
445        length_list = lengths.expand(batch_size)
446    else:
447        raise RuntimeError(
448            "When constructing a jagged nested tensor using narrow(), "
449            "your start and length must be Tensors that broadcast to input.shape[0]"
450        )
451
452    # Calculate jagged offsets
453    assert (
454        len(tensor.shape) >= 2
455    ), "tensor must at least be 2D for the nested narrow op to work"
456    max_seq_len = tensor.shape[1]
457    offset_lengths = max_seq_len * torch.arange(
458        0, batch_size, dtype=torch.int64, device=tensor.device
459    )
460    # Jagged layout specifies that offsets are stored as int64 on the same device as values.
461    offsets = torch.cat(
462        [
463            start_list + offset_lengths,
464            (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
465        ]
466    )
467
468    # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
469    if len(tensor.shape) > 2:
470        values = tensor.view(-1, *tensor.shape[2:])
471    else:
472        values = tensor.view(-1)
473
474    # Check if offsets and lengths make it possibly contiguous and return a regular NT
475    is_contiguous = True
476    orig_dim = tensor.shape[1]
477    if torch.any(length_list[1:-1].ne(orig_dim)):
478        is_contiguous = False
479    if torch.any(offsets[1:-2].diff().ne(orig_dim)):
480        is_contiguous = False
481    if offsets[0] + length_list[0] != orig_dim:
482        is_contiguous = False
483
484    actual_max_seqlen = int(torch.max(lengths).item())
485    min_seqlen = int(torch.min(lengths).item())
486
487    if is_contiguous:
488        ret_nt = nested_view_from_values_offsets(
489            values[offsets[0] : offsets[-1]],
490            offsets - offsets[0],
491            min_seqlen=min_seqlen,
492            max_seqlen=actual_max_seqlen,
493        )
494    else:
495        ret_nt = nested_view_from_values_offsets_lengths(
496            values,
497            offsets,
498            length_list,
499            min_seqlen=min_seqlen,
500            max_seqlen=actual_max_seqlen,
501        )
502
503    return (ret_nt, offsets, None if is_contiguous else length_list)
504
505
506# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
507# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
508# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
509# This arg is otherwise unused.
510_dummy_instance: Optional[torch.Tensor] = None
511
512
513def _nt_view_dummy() -> torch.Tensor:
514    global _dummy_instance
515    if _dummy_instance is None:
516        _dummy_instance = NestedTensor(
517            values=torch.zeros(3, 3, device="meta"),
518            offsets=torch.zeros(3, device="meta", dtype=torch.int64),
519        ).detach()
520    return _dummy_instance
521
522
523def nested_view_from_values_offsets(
524    values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
525):
526    min_seqlen_tensor = None
527    if min_seqlen is not None:
528        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
529
530    max_seqlen_tensor = None
531    if max_seqlen is not None:
532        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
533
534    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
535        values,
536        offsets,
537        _nt_view_dummy(),
538        None,
539        ragged_idx,
540        min_seqlen_tensor,
541        max_seqlen_tensor,
542    )  # type: ignore[return-value]
543
544
545def nested_view_from_values_offsets_lengths(
546    values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
547):
548    min_seqlen_tensor = None
549    if min_seqlen is not None:
550        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
551
552    max_seqlen_tensor = None
553    if max_seqlen is not None:
554        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
555
556    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
557        values,
558        offsets,
559        _nt_view_dummy(),
560        lengths,
561        ragged_idx,
562        min_seqlen_tensor,
563        max_seqlen_tensor,
564    )  # type: ignore[return-value]
565