1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python-style indexing and slicing for RaggedTensors.""" 16 17from tensorflow.python.eager import context 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import tensor_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import check_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops.ragged import ragged_gather_ops 27from tensorflow.python.ops.ragged import ragged_math_ops 28from tensorflow.python.ops.ragged import ragged_tensor 29from tensorflow.python.util import dispatch 30from tensorflow.python.util.tf_export import tf_export 31 32 33@tf_export("__operators__.ragged_getitem", v1=[]) 34@dispatch.add_dispatch_support 35def ragged_tensor_getitem(rt_input, key): 36 """Returns the specified piece of this RaggedTensor. 37 38 Supports multidimensional indexing and slicing, with one restriction: 39 indexing into a ragged inner dimension is not allowed. This case is 40 problematic because the indicated value may exist in some rows but not 41 others. In such cases, it's not obvious whether we should (1) report an 42 IndexError; (2) use a default value; or (3) skip that value and return a 43 tensor with fewer rows than we started with. Following the guiding 44 principles of Python ("In the face of ambiguity, refuse the temptation to 45 guess"), we simply disallow this operation. 46 47 Args: 48 rt_input: The RaggedTensor to slice. 49 key: Indicates which piece of the RaggedTensor to return, using standard 50 Python semantics (e.g., negative values index from the end). `key` 51 may have any of the following types: 52 53 * `int` constant 54 * Scalar integer `Tensor` 55 * `slice` containing integer constants and/or scalar integer 56 `Tensor`s 57 * `Ellipsis` 58 * `tf.newaxis` 59 * `tuple` containing any of the above (for multidimensional indexing) 60 61 Returns: 62 A `Tensor` or `RaggedTensor` object. Values that include at least one 63 ragged dimension are returned as `RaggedTensor`. Values that include no 64 ragged dimensions are returned as `Tensor`. See above for examples of 65 expressions that return `Tensor`s vs `RaggedTensor`s. 66 67 Raises: 68 ValueError: If `key` is out of bounds. 69 ValueError: If `key` is not supported. 70 TypeError: If the indices in `key` have an unsupported type. 71 72 Examples: 73 74 >>> # A 2-D ragged tensor with 1 ragged dimension. 75 >>> rt = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']]) 76 >>> rt[0].numpy() # First row (1-D `Tensor`) 77 array([b'a', b'b', b'c'], dtype=object) 78 >>> rt[:3].to_list() # First three rows (2-D RaggedTensor) 79 [[b'a', b'b', b'c'], [b'd', b'e'], [b'f']] 80 >>> rt[3, 0].numpy() # 1st element of 4th row (scalar) 81 b'g' 82 83 >>> # A 3-D ragged tensor with 2 ragged dimensions. 84 >>> rt = tf.ragged.constant([[[1, 2, 3], [4]], 85 ... [[5], [], [6]], 86 ... [[7]], 87 ... [[8, 9], [10]]]) 88 >>> rt[1].to_list() # Second row (2-D RaggedTensor) 89 [[5], [], [6]] 90 >>> rt[3, 0].numpy() # First element of fourth row (1-D Tensor) 91 array([8, 9], dtype=int32) 92 >>> rt[:, 1:3].to_list() # Items 1-3 of each row (3-D RaggedTensor) 93 [[[4]], [[], [6]], [], [[10]]] 94 >>> rt[:, -1:].to_list() # Last item of each row (3-D RaggedTensor) 95 [[[4]], [[6]], [[7]], [[10]]] 96 """ 97 if not isinstance(rt_input, ragged_tensor.RaggedTensor): 98 raise TypeError("Ragged __getitem__ expects a ragged_tensor.") 99 scope_tensors = [rt_input] + list(_tensors_in_key_list(key)) 100 if isinstance(key, (list, tuple)): 101 key = list(key) 102 else: 103 key = [key] 104 with ops.name_scope(None, "RaggedGetItem", scope_tensors): 105 return _ragged_getitem(rt_input, key) 106 107 108def _ragged_getitem(rt_input, key_list): 109 """Helper for indexing and slicing ragged tensors with __getitem__(). 110 111 Extracts the specified piece of the `rt_input`. See 112 `RaggedTensor.__getitem__` for examples and restrictions. 113 114 Args: 115 rt_input: The `RaggedTensor` from which a piece should be returned. 116 key_list: The list of keys specifying which piece to return. Each key 117 corresponds with a separate dimension. 118 119 Returns: 120 The indicated piece of rt_input. 121 122 Raises: 123 ValueError: If `key_list` is not supported. 124 TypeError: If any keys in `key_list` have an unsupported type. 125 """ 126 if not key_list: 127 return rt_input 128 row_key = key_list[0] 129 inner_keys = key_list[1:] 130 131 if row_key is Ellipsis: 132 expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims) 133 return _ragged_getitem(rt_input, expanded_key_list) 134 135 # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor 136 # that puts all values in a single row. 137 if row_key is array_ops.newaxis: 138 inner_rt = _ragged_getitem(rt_input, inner_keys) 139 nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0) 140 if nsplits.value is not None: 141 nsplits = nsplits.value 142 else: 143 nsplits = array_ops.shape(inner_rt.row_splits, 144 out_type=inner_rt.row_splits.dtype)[0] 145 return ragged_tensor.RaggedTensor.from_uniform_row_length( 146 inner_rt, nsplits - 1, nrows=1, validate=False) 147 148 # Slicing a range of rows: first slice the outer dimension, and then 149 # call `_ragged_getitem_inner_dimensions` to handle the inner keys. 150 if isinstance(row_key, slice): 151 sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key) 152 if rt_input.uniform_row_length is not None: 153 # If the inner dimension has uniform_row_length, then preserve it (by 154 # re-wrapping the values in a new RaggedTensor). Note that the row 155 # length won't have changed, since we're slicing a range of rows (and not 156 # slicing the rows themselves). 157 sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length( 158 sliced_rt_input.values, rt_input.uniform_row_length, 159 nrows=sliced_rt_input.nrows()) 160 return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys) 161 162 # Indexing a single row: slice values to get the indicated row, and then 163 # use a recursive call to __getitem__ to handle the inner keys. 164 else: 165 starts = rt_input.row_splits[:-1] 166 limits = rt_input.row_splits[1:] 167 if context.executing_eagerly(): 168 # In python, __getitem__ should throw IndexError for out of bound 169 # indices. This will allow iteration run correctly as python will 170 # translate IndexError into StopIteration for next()/__next__(). 171 # Below is an example: 172 # import tensorflow as tf 173 # r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]]) 174 # for elem in r: 175 # print(elem) 176 # In non eager mode, the exception is thrown when session runs 177 # so we don't know if out of bound happens before. 178 # In eager mode, however, it is possible to find out when to 179 # throw out of bound IndexError. 180 # In the following row_key >= len(starts) is checked. In case of 181 # TypeError which happens when row_key is not an integer, the exception 182 # will simply be ignored as it will be processed later anyway. 183 try: 184 if int(row_key) >= len(starts): 185 raise IndexError("Row key {} out of bounds".format(row_key)) 186 except (TypeError, ValueError): 187 pass 188 row = rt_input.values[starts[row_key]:limits[row_key]] 189 return row.__getitem__(inner_keys) 190 191 192def _slice_ragged_row_dimension(rt_input, row_key): 193 """Slice the outer dimension of `rt_input` according to the given `slice`. 194 195 Args: 196 rt_input: The `RaggedTensor` to slice. 197 row_key: The `slice` object that should be used to slice `rt_input`. 198 199 Returns: 200 A `RaggedTensor` containing the indicated slice of `rt_input`. 201 """ 202 if row_key.start is None and row_key.stop is None and row_key.step is None: 203 return rt_input 204 205 # Use row_key to slice the starts & limits. 206 new_starts = rt_input.row_splits[:-1][row_key] 207 new_limits = rt_input.row_splits[1:][row_key] 208 zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype) 209 210 # If there's no slice step, then we can just select a single continuous 211 # span of `ragged.values(rt_input)`. 212 if row_key.step is None or row_key.step == 1: 213 # Construct the new splits. If new_starts and new_limits are empty, 214 # then this reduces to [0]. Otherwise, this reduces to: 215 # concat([[new_starts[0]], new_limits]) 216 new_splits = array_ops.concat( 217 [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits], 218 axis=0) 219 values_start = new_splits[0] 220 values_limit = new_splits[-1] 221 return ragged_tensor.RaggedTensor.from_row_splits( 222 rt_input.values[values_start:values_limit], new_splits - values_start, 223 validate=False) 224 225 # If there is a slice step (aka a strided slice), then use ragged_gather to 226 # collect the necessary elements of `ragged.values(rt_input)`. 227 else: 228 return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1, 229 rt_input.values) 230 231 232def _ragged_getitem_inner_dimensions(rt_input, key_list): 233 """Retrieve inner dimensions, keeping outermost dimension unchanged. 234 235 Args: 236 rt_input: The `RaggedTensor` or `Tensor` from which a piece should be 237 extracted. 238 key_list: The __getitem__ keys for slicing the inner dimensions. 239 240 Returns: 241 A `RaggedTensor`. 242 243 Raises: 244 ValueError: If key_list is not supported. 245 """ 246 if not key_list: 247 return rt_input 248 249 if isinstance(rt_input, ops.Tensor): 250 return rt_input.__getitem__([slice(None, None, None)] + key_list) 251 252 column_key = key_list[0] 253 if column_key is Ellipsis: 254 expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims) 255 return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list) 256 257 # Adding a new axis to a ragged inner dimension: recursively get the inner 258 # dimensions of rt_input with key_list[1:], and then wrap the result in a 259 # RaggedTensor that puts each value in its own row. 260 if column_key is array_ops.newaxis: 261 inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:]) 262 nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0) 263 if nsplits.value is not None: 264 nsplits = nsplits.value 265 else: 266 nsplits = array_ops.shape(inner_rt.row_splits, 267 out_type=inner_rt.row_splits.dtype)[0] 268 return ragged_tensor.RaggedTensor.from_uniform_row_length( 269 inner_rt, 1, nrows=nsplits - 1, validate=False) 270 271 # Slicing a range of columns in a ragged inner dimension. We use a 272 # recursive call to process the values, and then assemble a RaggedTensor 273 # with those values. 274 if isinstance(column_key, slice): 275 if (column_key.start is None and column_key.stop is None and 276 column_key.step is None): 277 # Trivial slice: recursively process all values, & splits is unchanged. 278 return rt_input.with_values( 279 _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:])) 280 else: 281 if not (isinstance(column_key.start, (ops.Tensor, int, type(None))) and 282 isinstance(column_key.stop, (ops.Tensor, int, type(None)))): 283 raise TypeError("slice offsets must be integers or None") 284 285 # Nontrivial slice: use ragged_gather to extract the indicated slice as 286 # a new RaggedTensor (inner_rt), and then recursively process its values. 287 starts = rt_input.row_splits[:-1] 288 limits = rt_input.row_splits[1:] 289 step = 1 if column_key.step is None else column_key.step 290 lower_bound = _if_ge_zero(step, lambda: starts, lambda: starts - 1) 291 upper_bound = _if_ge_zero(step, lambda: limits, lambda: limits - 1) 292 # inner_rt_starts[i] = index to start gathering for row i. 293 if column_key.start is None: 294 inner_rt_starts = _if_ge_zero(step, lambda: starts, lambda: limits - 1) 295 else: 296 start_offset = math_ops.cast(column_key.start, starts.dtype) 297 inner_rt_starts = _if_ge_zero( 298 column_key.start, 299 lambda: math_ops.minimum(starts + start_offset, upper_bound), 300 lambda: math_ops.maximum(limits + start_offset, lower_bound)) 301 # inner_rt_limits[i] = index to stop gathering for row i. 302 if column_key.stop is None: 303 inner_rt_limits = _if_ge_zero(step, lambda: limits, lambda: starts - 1) 304 else: 305 stop_offset = math_ops.cast(column_key.stop, starts.dtype) 306 inner_rt_limits = _if_ge_zero( 307 column_key.stop, 308 lambda: math_ops.minimum(starts + stop_offset, upper_bound), 309 lambda: math_ops.maximum(limits + stop_offset, lower_bound)) 310 inner_rt = _build_ragged_tensor_from_value_ranges( 311 inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values) 312 # If the row dimension is uniform, then calculate the new 313 # uniform_row_length, and rebuild inner_rt using that uniform_row_lengths. 314 if rt_input.uniform_row_length is not None: 315 new_row_length = _slice_length(rt_input.uniform_row_length, column_key) 316 inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length( 317 inner_rt.values, new_row_length, rt_input.nrows()) 318 return inner_rt.with_values( 319 _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:])) 320 321 # Indexing a single column in a ragged inner dimension: raise an Exception. 322 # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing 323 # into a ragged inner dimension is problematic. 324 if rt_input.uniform_row_length is None: 325 raise ValueError("Cannot index into an inner ragged dimension.") 326 327 # Indexing a single column in a uniform inner dimension: check that the 328 # given index is in-bounds, and then use a strided slice over rt_input.values 329 # to take the indicated element from each row. 330 row_length = rt_input.uniform_row_length 331 column_key = math_ops.cast(column_key, row_length.dtype) 332 oob_err_msg = "Index out of bounds when indexing into a ragged tensor" 333 oob_checks = [ 334 check_ops.assert_greater_equal( 335 column_key, -row_length, message=oob_err_msg), 336 check_ops.assert_less(column_key, row_length, message=oob_err_msg), 337 ] 338 with ops.control_dependencies(oob_checks): 339 offset = _if_ge_zero(column_key, lambda: column_key, 340 lambda: row_length + column_key) 341 sliced_rt = rt_input.values[offset::row_length] 342 return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:]) 343 344 345def _slice_length(value_length, slice_key): 346 """Computes the number of elements in a slice of a value with a given length. 347 348 Returns the equivalent of: `len(range(value_length)[slice_key])` 349 350 Args: 351 value_length: Scalar int `Tensor`: the length of the value being sliced. 352 slice_key: A `slice` object used to slice elements from the value. 353 354 Returns: 355 The number of elements in the sliced value. 356 """ 357 # Note: we could compute the slice length without creating a zeros tensor 358 # with some variant of (stop-start)//step, but doing so would require more 359 # ops (for checking bounds, handling negative indices, negative step sizes, 360 # etc); and we expect this to be an uncommon operation, so we use this 361 # simpler implementation. 362 zeros = array_ops.zeros(value_length, dtype=dtypes.bool) 363 return array_ops.size(zeros[slice_key], out_type=value_length.dtype) 364 365 366def _expand_ellipsis(key_list, num_remaining_dims): 367 """Expands the ellipsis at the start of `key_list`. 368 369 Assumes that the first element of `key_list` is Ellipsis. This will either 370 remove the Ellipsis (if it corresponds to zero indices) or prepend a new 371 `slice(None, None, None)` (if it corresponds to more than zero indices). 372 373 Args: 374 key_list: The arguments to `__getitem__()`. 375 num_remaining_dims: The number of dimensions remaining. 376 377 Returns: 378 A copy of `key_list` with he ellipsis expanded. 379 Raises: 380 ValueError: If ragged_rank.shape.ndims is None 381 IndexError: If there are too many elements in `key_list`. 382 """ 383 if num_remaining_dims is None: 384 raise ValueError("Ellipsis not supported for unknown shape RaggedTensors") 385 num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis) 386 if num_indices > num_remaining_dims + 1: 387 raise IndexError("Too many indices for RaggedTensor") 388 elif num_indices == num_remaining_dims + 1: 389 return key_list[1:] 390 else: 391 return [slice(None, None, None)] + key_list 392 393 394def _tensors_in_key_list(key_list): 395 """Generates all Tensors in the given slice spec.""" 396 if isinstance(key_list, ops.Tensor): 397 yield key_list 398 if isinstance(key_list, (list, tuple)): 399 for v in key_list: 400 for tensor in _tensors_in_key_list(v): 401 yield tensor 402 if isinstance(key_list, slice): 403 for tensor in _tensors_in_key_list(key_list.start): 404 yield tensor 405 for tensor in _tensors_in_key_list(key_list.stop): 406 yield tensor 407 for tensor in _tensors_in_key_list(key_list.step): 408 yield tensor 409 410 411def _build_ragged_tensor_from_value_ranges(starts, limits, step, values): 412 """Returns a `RaggedTensor` containing the specified sequences of values. 413 414 Returns a RaggedTensor `output` where: 415 416 ```python 417 output.shape[0] = starts.shape[0] 418 output[i] = values[starts[i]:limits[i]:step] 419 ``` 420 421 Requires that `starts.shape == limits.shape` and 422 `0 <= starts[i] <= limits[i] <= values.shape[0]`. 423 424 Args: 425 starts: 1D integer Tensor specifying the start indices for the sequences of 426 values to include. 427 limits: 1D integer Tensor specifying the limit indices for the sequences of 428 values to include. 429 step: Integer value specifying the step size for strided slices. 430 values: The set of values to select from. 431 432 Returns: 433 A `RaggedTensor`. 434 435 Raises: 436 ValueError: Until the prerequisite ops are checked in. 437 """ 438 # Use `ragged_range` to get the index of each value we should include. 439 if step is None: 440 step = 1 441 step = ops.convert_to_tensor(step, name="step") 442 if step.dtype.is_integer: 443 step = math_ops.cast(step, starts.dtype) 444 else: 445 raise TypeError("slice strides must be integers or None") 446 value_indices = ragged_math_ops.range(starts, limits, step, 447 row_splits_dtype=starts.dtype) 448 449 # Use `ragged_gather` or `array_ops.gather` to collect the values. 450 if isinstance(values, ragged_tensor.RaggedTensor): 451 gathered_values = ragged_gather_ops.gather( 452 params=values, indices=value_indices.values) 453 else: 454 gathered_values = array_ops.gather( 455 params=values, indices=value_indices.values) 456 457 # Assemble the RaggedTensor from splits & values. 458 return value_indices.with_values(gathered_values) 459 460 461def _if_ge_zero(value, true_fn, false_fn): 462 """Returns `true_fn() if value >= 0 else false_fn()`.""" 463 # If `value` is statically known, then don't use a control flow op. 464 if isinstance(value, ops.Tensor): 465 const_value = tensor_util.constant_value(value) 466 if const_value is None: 467 return control_flow_ops.cond(value >= 0, true_fn, false_fn) 468 else: 469 value = const_value 470 if value >= 0: 471 return true_fn() 472 else: 473 return false_fn() 474