xref: /aosp_15_r20/external/pytorch/torch/_inductor/jagged_lowerings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import List, Optional, Tuple, Union
4
5import sympy
6
7import torch
8
9from .ir import Pointwise, TensorBox
10from .lowering import fallback_handler, is_integer_type, register_lowering
11from .virtualized import ops
12
13
14# pyre-ignore[2,3]
15def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len):
16    # jagged_len + 1 is used as the upper bound,
17    # because the last sequence length may be zero
18    begin_idx = ops.indirect_indexing(
19        offsets_loader([batch_idx]),
20        jagged_len + 1,
21    )
22    end_idx = offsets_loader([batch_idx + 1])
23    jagged_idx = begin_idx + seq_idx
24    return jagged_idx, end_idx
25
26
27def get_inverse_offsets(
28    offsets: TensorBox,
29    jagged_len: Union[int, sympy.Expr],
30    realize: bool = True,
31) -> TensorBox:
32    """
33    Returns "inverse_offsets" - the inverse of the offsets array.
34    offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor).
35    inverse_offsets maps jagged index to batch index.
36
37    e.g. for offsets [0, 3, 4, 9, 10] this will return
38    inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3]
39
40    For the given offsets, the computed inverse_offsets are cached
41    on the first call and reused in the further calls.
42    """
43
44    if hasattr(offsets, "inverse_offsets"):
45        # inverse_offsets are already computed
46        # for these offsets: can reuse
47        return offsets.inverse_offsets
48
49    # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise
50    # kernels, i.e. we need to realize it before using. In other words, we need
51    # offsets to be in global memory so that we can binary search over the
52    # entire tensor
53    offsets.realize()
54    device: torch.device = offsets.get_device()
55    dtype: torch.dtype = offsets.get_dtype()
56
57    # pyre-ignore[2,3]
58    def inner_fn(index):
59        idx = index[0]
60        bucket = ops.bucketize(
61            values=ops.index_expr(idx, dtype),
62            offsets_name=offsets.get_name(),
63            offsets_size=offsets.get_size()[0],
64            indexing_dtype=dtype,
65            right=True,
66        )
67        # ops.bucketize above returns 1-based bucket indices,
68        # but we need 0-based, hence we subtract 1 from batch
69        return bucket - 1
70
71    inverse_offsets = Pointwise.create(
72        device=device,
73        dtype=dtype,
74        inner_fn=inner_fn,
75        ranges=[jagged_len],
76    )
77
78    if realize:
79        # "freeze" the node so that it doesn't get inlined downstream.
80        inverse_offsets.realize()
81
82    # cache inverse_offsets for further reuse
83    offsets.inverse_offsets = inverse_offsets  # type: ignore[attr-defined]
84
85    return inverse_offsets
86
87
88def jagged_idx_to_dense_idx(
89    jagged_idx,  # pyre-ignore[2]
90    inverse_offsets_loader,  # pyre-ignore[2]
91    offsets_loader,  # pyre-ignore[2]
92    batch_size: Union[int, sympy.Expr],
93    max_seq_len: Union[int, sympy.Expr],
94    offsets_dtype: torch.dtype,
95) -> Tuple[sympy.Expr, sympy.Expr]:
96    batch_idx = ops.indirect_indexing(
97        inverse_offsets_loader([jagged_idx]),
98        batch_size + 1,
99    )
100    batch_start = offsets_loader([batch_idx])
101    seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start
102    # check=False because there may be sequences longer than max_seq_len
103    seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False)
104    return batch_idx, seq_idx
105
106
107def register_jagged_ops():
108    # pyre-ignore[56]
109    @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
110    def _jagged_to_padded_dense_forward(
111        jagged_values: TensorBox,
112        jagged_offsets: List[TensorBox],
113        max_lengths: List[int],  # list of ints/SymInts
114        padding_value: float = 0.0,
115    ) -> TensorBox:
116        device = jagged_values.get_device()
117        dtype = jagged_values.get_dtype()
118
119        jagged_values_size = jagged_values.get_size()
120
121        # only handle the common case of a single jagged dimension
122        if (
123            len(jagged_offsets) != 1
124            or device.type != "cuda"
125            or device != jagged_offsets[0].get_device()
126            or len(jagged_values_size) != 2
127            or len(jagged_offsets[0].get_size()) != 1
128            or len(max_lengths) != len(jagged_offsets)
129            or not is_integer_type(jagged_offsets[0])
130        ):
131            return fallback_handler(
132                torch.ops.aten._jagged_to_padded_dense_forward.default,
133                add_to_fallback_set=False,
134            )(
135                jagged_values,
136                jagged_offsets,
137                max_lengths,
138                padding_value,
139            )
140
141        offsets: TensorBox = jagged_offsets[0]
142        offsets_len = offsets.get_size()[0]
143        offsets_dtype = offsets.get_dtype()
144        batch_size = offsets_len - 1
145        max_seq_len = max_lengths[0]
146        embedding_len = jagged_values_size[1]
147        jagged_len = jagged_values_size[0]
148
149        output_size = [batch_size, max_seq_len, embedding_len]
150
151        values_loader = jagged_values.make_loader()
152        offsets_loader = offsets.make_loader()
153
154        # pyre-ignore[2,3,53]
155        def inner_fn(index):
156            # dense tensor size: [B, N, D]
157            batch_idx, seq_idx, emb_idx = index
158            jagged_idx, end_idx = dense_idx_to_jagged_idx(
159                batch_idx=batch_idx,
160                seq_idx=seq_idx,
161                offsets_loader=offsets_loader,
162                jagged_len=jagged_len,
163            )
164            return ops.masked(
165                ops.lt(
166                    ops.index_expr(jagged_idx, offsets_dtype),
167                    end_idx,
168                ),
169                lambda: values_loader([jagged_idx, emb_idx]),
170                padding_value,
171            )
172
173        return Pointwise.create(
174            device=device,
175            dtype=dtype,
176            inner_fn=inner_fn,
177            ranges=output_size,
178        )
179
180    def _dense_to_jagged_forward_impl(
181        fallback_op,  # pyre-ignore[2]
182        dense: TensorBox,
183        jagged_offsets: List[TensorBox],
184        jagged_len: Optional[int] = None,
185    ) -> TensorBox:
186        device = dense.get_device()
187        dtype = dense.get_dtype()
188
189        dense_size = dense.get_size()
190
191        # only handle the common case of a single jagged dimension
192        if (
193            len(jagged_offsets) != 1
194            or device.type != "cuda"
195            or device != jagged_offsets[0].get_device()
196            or len(jagged_offsets[0].get_size()) != 1
197            or len(dense_size) != 3
198            or jagged_len is None
199            or not is_integer_type(jagged_offsets[0])
200        ):
201            return fallback_handler(fallback_op, add_to_fallback_set=False)(
202                dense,
203                jagged_offsets,
204                jagged_len,
205            )
206
207        offsets: TensorBox = jagged_offsets[0]
208        offsets_dtype = offsets.get_dtype()
209        batch_size = dense_size[0]
210        max_seq_len = dense_size[1]
211        embedding_len = dense_size[-1]
212
213        output_size = [jagged_len, embedding_len]
214
215        dense_loader = dense.make_loader()
216        offsets_loader = offsets.make_loader()
217
218        inverse_offsets = get_inverse_offsets(
219            offsets=offsets,
220            jagged_len=jagged_len,
221        )
222        inverse_offsets_loader = inverse_offsets.make_loader()
223
224        # pyre-ignore[2,3,53]
225        def inner_fn(index):
226            # jagged tensor size: [sum_B(N_B), D]
227            jagged_idx, emb_idx = index
228            batch_idx, seq_idx = jagged_idx_to_dense_idx(
229                jagged_idx=jagged_idx,
230                offsets_loader=offsets_loader,
231                inverse_offsets_loader=inverse_offsets_loader,
232                batch_size=batch_size,
233                max_seq_len=max_seq_len,
234                offsets_dtype=offsets_dtype,
235            )
236            return ops.masked(
237                ops.lt(
238                    ops.index_expr(seq_idx, offsets_dtype),
239                    ops.index_expr(max_seq_len, offsets_dtype),
240                ),
241                lambda: dense_loader([batch_idx, seq_idx, emb_idx]),
242                0.0,  # jagged sequence longer than max_seq_len
243            )
244
245        return Pointwise.create(
246            device=device,
247            dtype=dtype,
248            inner_fn=inner_fn,
249            ranges=output_size,
250        )
251
252    # pyre-ignore[56]
253    @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
254    def _dense_to_jagged_forward(
255        dense: TensorBox,
256        jagged_offsets: List[TensorBox],
257        jagged_len: Optional[int] = None,
258    ) -> TensorBox:
259        return _dense_to_jagged_forward_impl(
260            fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default,
261            dense=dense,
262            jagged_offsets=jagged_offsets,
263            jagged_len=jagged_len,
264        )
265