1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import List, Optional, Tuple, Union 4 5import sympy 6 7import torch 8 9from .ir import Pointwise, TensorBox 10from .lowering import fallback_handler, is_integer_type, register_lowering 11from .virtualized import ops 12 13 14# pyre-ignore[2,3] 15def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len): 16 # jagged_len + 1 is used as the upper bound, 17 # because the last sequence length may be zero 18 begin_idx = ops.indirect_indexing( 19 offsets_loader([batch_idx]), 20 jagged_len + 1, 21 ) 22 end_idx = offsets_loader([batch_idx + 1]) 23 jagged_idx = begin_idx + seq_idx 24 return jagged_idx, end_idx 25 26 27def get_inverse_offsets( 28 offsets: TensorBox, 29 jagged_len: Union[int, sympy.Expr], 30 realize: bool = True, 31) -> TensorBox: 32 """ 33 Returns "inverse_offsets" - the inverse of the offsets array. 34 offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). 35 inverse_offsets maps jagged index to batch index. 36 37 e.g. for offsets [0, 3, 4, 9, 10] this will return 38 inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3] 39 40 For the given offsets, the computed inverse_offsets are cached 41 on the first call and reused in the further calls. 42 """ 43 44 if hasattr(offsets, "inverse_offsets"): 45 # inverse_offsets are already computed 46 # for these offsets: can reuse 47 return offsets.inverse_offsets 48 49 # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise 50 # kernels, i.e. we need to realize it before using. In other words, we need 51 # offsets to be in global memory so that we can binary search over the 52 # entire tensor 53 offsets.realize() 54 device: torch.device = offsets.get_device() 55 dtype: torch.dtype = offsets.get_dtype() 56 57 # pyre-ignore[2,3] 58 def inner_fn(index): 59 idx = index[0] 60 bucket = ops.bucketize( 61 values=ops.index_expr(idx, dtype), 62 offsets_name=offsets.get_name(), 63 offsets_size=offsets.get_size()[0], 64 indexing_dtype=dtype, 65 right=True, 66 ) 67 # ops.bucketize above returns 1-based bucket indices, 68 # but we need 0-based, hence we subtract 1 from batch 69 return bucket - 1 70 71 inverse_offsets = Pointwise.create( 72 device=device, 73 dtype=dtype, 74 inner_fn=inner_fn, 75 ranges=[jagged_len], 76 ) 77 78 if realize: 79 # "freeze" the node so that it doesn't get inlined downstream. 80 inverse_offsets.realize() 81 82 # cache inverse_offsets for further reuse 83 offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined] 84 85 return inverse_offsets 86 87 88def jagged_idx_to_dense_idx( 89 jagged_idx, # pyre-ignore[2] 90 inverse_offsets_loader, # pyre-ignore[2] 91 offsets_loader, # pyre-ignore[2] 92 batch_size: Union[int, sympy.Expr], 93 max_seq_len: Union[int, sympy.Expr], 94 offsets_dtype: torch.dtype, 95) -> Tuple[sympy.Expr, sympy.Expr]: 96 batch_idx = ops.indirect_indexing( 97 inverse_offsets_loader([jagged_idx]), 98 batch_size + 1, 99 ) 100 batch_start = offsets_loader([batch_idx]) 101 seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start 102 # check=False because there may be sequences longer than max_seq_len 103 seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False) 104 return batch_idx, seq_idx 105 106 107def register_jagged_ops(): 108 # pyre-ignore[56] 109 @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) 110 def _jagged_to_padded_dense_forward( 111 jagged_values: TensorBox, 112 jagged_offsets: List[TensorBox], 113 max_lengths: List[int], # list of ints/SymInts 114 padding_value: float = 0.0, 115 ) -> TensorBox: 116 device = jagged_values.get_device() 117 dtype = jagged_values.get_dtype() 118 119 jagged_values_size = jagged_values.get_size() 120 121 # only handle the common case of a single jagged dimension 122 if ( 123 len(jagged_offsets) != 1 124 or device.type != "cuda" 125 or device != jagged_offsets[0].get_device() 126 or len(jagged_values_size) != 2 127 or len(jagged_offsets[0].get_size()) != 1 128 or len(max_lengths) != len(jagged_offsets) 129 or not is_integer_type(jagged_offsets[0]) 130 ): 131 return fallback_handler( 132 torch.ops.aten._jagged_to_padded_dense_forward.default, 133 add_to_fallback_set=False, 134 )( 135 jagged_values, 136 jagged_offsets, 137 max_lengths, 138 padding_value, 139 ) 140 141 offsets: TensorBox = jagged_offsets[0] 142 offsets_len = offsets.get_size()[0] 143 offsets_dtype = offsets.get_dtype() 144 batch_size = offsets_len - 1 145 max_seq_len = max_lengths[0] 146 embedding_len = jagged_values_size[1] 147 jagged_len = jagged_values_size[0] 148 149 output_size = [batch_size, max_seq_len, embedding_len] 150 151 values_loader = jagged_values.make_loader() 152 offsets_loader = offsets.make_loader() 153 154 # pyre-ignore[2,3,53] 155 def inner_fn(index): 156 # dense tensor size: [B, N, D] 157 batch_idx, seq_idx, emb_idx = index 158 jagged_idx, end_idx = dense_idx_to_jagged_idx( 159 batch_idx=batch_idx, 160 seq_idx=seq_idx, 161 offsets_loader=offsets_loader, 162 jagged_len=jagged_len, 163 ) 164 return ops.masked( 165 ops.lt( 166 ops.index_expr(jagged_idx, offsets_dtype), 167 end_idx, 168 ), 169 lambda: values_loader([jagged_idx, emb_idx]), 170 padding_value, 171 ) 172 173 return Pointwise.create( 174 device=device, 175 dtype=dtype, 176 inner_fn=inner_fn, 177 ranges=output_size, 178 ) 179 180 def _dense_to_jagged_forward_impl( 181 fallback_op, # pyre-ignore[2] 182 dense: TensorBox, 183 jagged_offsets: List[TensorBox], 184 jagged_len: Optional[int] = None, 185 ) -> TensorBox: 186 device = dense.get_device() 187 dtype = dense.get_dtype() 188 189 dense_size = dense.get_size() 190 191 # only handle the common case of a single jagged dimension 192 if ( 193 len(jagged_offsets) != 1 194 or device.type != "cuda" 195 or device != jagged_offsets[0].get_device() 196 or len(jagged_offsets[0].get_size()) != 1 197 or len(dense_size) != 3 198 or jagged_len is None 199 or not is_integer_type(jagged_offsets[0]) 200 ): 201 return fallback_handler(fallback_op, add_to_fallback_set=False)( 202 dense, 203 jagged_offsets, 204 jagged_len, 205 ) 206 207 offsets: TensorBox = jagged_offsets[0] 208 offsets_dtype = offsets.get_dtype() 209 batch_size = dense_size[0] 210 max_seq_len = dense_size[1] 211 embedding_len = dense_size[-1] 212 213 output_size = [jagged_len, embedding_len] 214 215 dense_loader = dense.make_loader() 216 offsets_loader = offsets.make_loader() 217 218 inverse_offsets = get_inverse_offsets( 219 offsets=offsets, 220 jagged_len=jagged_len, 221 ) 222 inverse_offsets_loader = inverse_offsets.make_loader() 223 224 # pyre-ignore[2,3,53] 225 def inner_fn(index): 226 # jagged tensor size: [sum_B(N_B), D] 227 jagged_idx, emb_idx = index 228 batch_idx, seq_idx = jagged_idx_to_dense_idx( 229 jagged_idx=jagged_idx, 230 offsets_loader=offsets_loader, 231 inverse_offsets_loader=inverse_offsets_loader, 232 batch_size=batch_size, 233 max_seq_len=max_seq_len, 234 offsets_dtype=offsets_dtype, 235 ) 236 return ops.masked( 237 ops.lt( 238 ops.index_expr(seq_idx, offsets_dtype), 239 ops.index_expr(max_seq_len, offsets_dtype), 240 ), 241 lambda: dense_loader([batch_idx, seq_idx, emb_idx]), 242 0.0, # jagged sequence longer than max_seq_len 243 ) 244 245 return Pointwise.create( 246 device=device, 247 dtype=dtype, 248 inner_fn=inner_fn, 249 ranges=output_size, 250 ) 251 252 # pyre-ignore[56] 253 @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) 254 def _dense_to_jagged_forward( 255 dense: TensorBox, 256 jagged_offsets: List[TensorBox], 257 jagged_len: Optional[int] = None, 258 ) -> TensorBox: 259 return _dense_to_jagged_forward_impl( 260 fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, 261 dense=dense, 262 jagged_offsets=jagged_offsets, 263 jagged_len=jagged_len, 264 ) 265