xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/ragged_getitem.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python-style indexing and slicing for RaggedTensors."""
16
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.ragged import ragged_gather_ops
27from tensorflow.python.ops.ragged import ragged_math_ops
28from tensorflow.python.ops.ragged import ragged_tensor
29from tensorflow.python.util import dispatch
30from tensorflow.python.util.tf_export import tf_export
31
32
33@tf_export("__operators__.ragged_getitem", v1=[])
34@dispatch.add_dispatch_support
35def ragged_tensor_getitem(rt_input, key):
36  """Returns the specified piece of this RaggedTensor.
37
38  Supports multidimensional indexing and slicing, with one restriction:
39  indexing into a ragged inner dimension is not allowed.  This case is
40  problematic because the indicated value may exist in some rows but not
41  others.  In such cases, it's not obvious whether we should (1) report an
42  IndexError; (2) use a default value; or (3) skip that value and return a
43  tensor with fewer rows than we started with.  Following the guiding
44  principles of Python ("In the face of ambiguity, refuse the temptation to
45  guess"), we simply disallow this operation.
46
47  Args:
48    rt_input: The RaggedTensor to slice.
49    key: Indicates which piece of the RaggedTensor to return, using standard
50      Python semantics (e.g., negative values index from the end).  `key`
51      may have any of the following types:
52
53      * `int` constant
54      * Scalar integer `Tensor`
55      * `slice` containing integer constants and/or scalar integer
56        `Tensor`s
57      * `Ellipsis`
58      * `tf.newaxis`
59      * `tuple` containing any of the above (for multidimensional indexing)
60
61  Returns:
62    A `Tensor` or `RaggedTensor` object.  Values that include at least one
63    ragged dimension are returned as `RaggedTensor`.  Values that include no
64    ragged dimensions are returned as `Tensor`.  See above for examples of
65    expressions that return `Tensor`s vs `RaggedTensor`s.
66
67  Raises:
68    ValueError: If `key` is out of bounds.
69    ValueError: If `key` is not supported.
70    TypeError: If the indices in `key` have an unsupported type.
71
72  Examples:
73
74  >>> # A 2-D ragged tensor with 1 ragged dimension.
75  >>> rt = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']])
76  >>> rt[0].numpy()                 # First row (1-D `Tensor`)
77  array([b'a', b'b', b'c'], dtype=object)
78  >>> rt[:3].to_list()              # First three rows (2-D RaggedTensor)
79  [[b'a', b'b', b'c'], [b'd', b'e'], [b'f']]
80  >>> rt[3, 0].numpy()              # 1st element of 4th row (scalar)
81  b'g'
82
83  >>> # A 3-D ragged tensor with 2 ragged dimensions.
84  >>> rt = tf.ragged.constant([[[1, 2, 3], [4]],
85  ...                          [[5], [], [6]],
86  ...                          [[7]],
87  ...                          [[8, 9], [10]]])
88  >>> rt[1].to_list()               # Second row (2-D RaggedTensor)
89  [[5], [], [6]]
90  >>> rt[3, 0].numpy()              # First element of fourth row (1-D Tensor)
91  array([8, 9], dtype=int32)
92  >>> rt[:, 1:3].to_list()          # Items 1-3 of each row (3-D RaggedTensor)
93  [[[4]], [[], [6]], [], [[10]]]
94  >>> rt[:, -1:].to_list()          # Last item of each row (3-D RaggedTensor)
95  [[[4]], [[6]], [[7]], [[10]]]
96  """
97  if not isinstance(rt_input, ragged_tensor.RaggedTensor):
98    raise TypeError("Ragged __getitem__ expects a ragged_tensor.")
99  scope_tensors = [rt_input] + list(_tensors_in_key_list(key))
100  if isinstance(key, (list, tuple)):
101    key = list(key)
102  else:
103    key = [key]
104  with ops.name_scope(None, "RaggedGetItem", scope_tensors):
105    return _ragged_getitem(rt_input, key)
106
107
108def _ragged_getitem(rt_input, key_list):
109  """Helper for indexing and slicing ragged tensors with __getitem__().
110
111  Extracts the specified piece of the `rt_input`.  See
112  `RaggedTensor.__getitem__` for examples and restrictions.
113
114  Args:
115    rt_input: The `RaggedTensor` from which a piece should be returned.
116    key_list: The list of keys specifying which piece to return. Each key
117      corresponds with a separate dimension.
118
119  Returns:
120    The indicated piece of rt_input.
121
122  Raises:
123    ValueError: If `key_list` is not supported.
124    TypeError: If any keys in `key_list` have an unsupported type.
125  """
126  if not key_list:
127    return rt_input
128  row_key = key_list[0]
129  inner_keys = key_list[1:]
130
131  if row_key is Ellipsis:
132    expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims)
133    return _ragged_getitem(rt_input, expanded_key_list)
134
135  # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor
136  # that puts all values in a single row.
137  if row_key is array_ops.newaxis:
138    inner_rt = _ragged_getitem(rt_input, inner_keys)
139    nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
140    if nsplits.value is not None:
141      nsplits = nsplits.value
142    else:
143      nsplits = array_ops.shape(inner_rt.row_splits,
144                                out_type=inner_rt.row_splits.dtype)[0]
145    return ragged_tensor.RaggedTensor.from_uniform_row_length(
146        inner_rt, nsplits - 1, nrows=1, validate=False)
147
148  # Slicing a range of rows: first slice the outer dimension, and then
149  # call `_ragged_getitem_inner_dimensions` to handle the inner keys.
150  if isinstance(row_key, slice):
151    sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
152    if rt_input.uniform_row_length is not None:
153      # If the inner dimension has uniform_row_length, then preserve it (by
154      # re-wrapping the values in a new RaggedTensor).  Note that the row
155      # length won't have changed, since we're slicing a range of rows (and not
156      # slicing the rows themselves).
157      sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
158          sliced_rt_input.values, rt_input.uniform_row_length,
159          nrows=sliced_rt_input.nrows())
160    return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
161
162  # Indexing a single row: slice values to get the indicated row, and then
163  # use a recursive call to __getitem__ to handle the inner keys.
164  else:
165    starts = rt_input.row_splits[:-1]
166    limits = rt_input.row_splits[1:]
167    if context.executing_eagerly():
168      # In python, __getitem__ should throw IndexError for out of bound
169      # indices. This will allow iteration run correctly as python will
170      # translate IndexError into StopIteration for next()/__next__().
171      # Below is an example:
172      #    import tensorflow as tf
173      #    r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]])
174      #    for elem in r:
175      #      print(elem)
176      # In non eager mode, the exception is thrown when session runs
177      # so we don't know if out of bound happens before.
178      # In eager mode, however, it is possible to find out when to
179      # throw out of bound IndexError.
180      # In the following row_key >= len(starts) is checked. In case of
181      # TypeError which happens when row_key is not an integer, the exception
182      # will simply be ignored as it will be processed later anyway.
183      try:
184        if int(row_key) >= len(starts):
185          raise IndexError("Row key {} out of bounds".format(row_key))
186      except (TypeError, ValueError):
187        pass
188    row = rt_input.values[starts[row_key]:limits[row_key]]
189    return row.__getitem__(inner_keys)
190
191
192def _slice_ragged_row_dimension(rt_input, row_key):
193  """Slice the outer dimension of `rt_input` according to the given `slice`.
194
195  Args:
196    rt_input: The `RaggedTensor` to slice.
197    row_key: The `slice` object that should be used to slice `rt_input`.
198
199  Returns:
200    A `RaggedTensor` containing the indicated slice of `rt_input`.
201  """
202  if row_key.start is None and row_key.stop is None and row_key.step is None:
203    return rt_input
204
205  # Use row_key to slice the starts & limits.
206  new_starts = rt_input.row_splits[:-1][row_key]
207  new_limits = rt_input.row_splits[1:][row_key]
208  zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype)
209
210  # If there's no slice step, then we can just select a single continuous
211  # span of `ragged.values(rt_input)`.
212  if row_key.step is None or row_key.step == 1:
213    # Construct the new splits.  If new_starts and new_limits are empty,
214    # then this reduces to [0].  Otherwise, this reduces to:
215    #   concat([[new_starts[0]], new_limits])
216    new_splits = array_ops.concat(
217        [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits],
218        axis=0)
219    values_start = new_splits[0]
220    values_limit = new_splits[-1]
221    return ragged_tensor.RaggedTensor.from_row_splits(
222        rt_input.values[values_start:values_limit], new_splits - values_start,
223        validate=False)
224
225  # If there is a slice step (aka a strided slice), then use ragged_gather to
226  # collect the necessary elements of `ragged.values(rt_input)`.
227  else:
228    return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1,
229                                                  rt_input.values)
230
231
232def _ragged_getitem_inner_dimensions(rt_input, key_list):
233  """Retrieve inner dimensions, keeping outermost dimension unchanged.
234
235  Args:
236    rt_input: The `RaggedTensor` or `Tensor` from which a piece should be
237      extracted.
238    key_list: The __getitem__ keys for slicing the inner dimensions.
239
240  Returns:
241    A `RaggedTensor`.
242
243  Raises:
244    ValueError: If key_list is not supported.
245  """
246  if not key_list:
247    return rt_input
248
249  if isinstance(rt_input, ops.Tensor):
250    return rt_input.__getitem__([slice(None, None, None)] + key_list)
251
252  column_key = key_list[0]
253  if column_key is Ellipsis:
254    expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims)
255    return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list)
256
257  # Adding a new axis to a ragged inner dimension: recursively get the inner
258  # dimensions of rt_input with key_list[1:], and then wrap the result in a
259  # RaggedTensor that puts each value in its own row.
260  if column_key is array_ops.newaxis:
261    inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
262    nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
263    if nsplits.value is not None:
264      nsplits = nsplits.value
265    else:
266      nsplits = array_ops.shape(inner_rt.row_splits,
267                                out_type=inner_rt.row_splits.dtype)[0]
268    return ragged_tensor.RaggedTensor.from_uniform_row_length(
269        inner_rt, 1, nrows=nsplits - 1, validate=False)
270
271  # Slicing a range of columns in a ragged inner dimension.  We use a
272  # recursive call to process the values, and then assemble a RaggedTensor
273  # with those values.
274  if isinstance(column_key, slice):
275    if (column_key.start is None and column_key.stop is None and
276        column_key.step is None):
277      # Trivial slice: recursively process all values, & splits is unchanged.
278      return rt_input.with_values(
279          _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:]))
280    else:
281      if not (isinstance(column_key.start, (ops.Tensor, int, type(None))) and
282              isinstance(column_key.stop, (ops.Tensor, int, type(None)))):
283        raise TypeError("slice offsets must be integers or None")
284
285      # Nontrivial slice: use ragged_gather to extract the indicated slice as
286      # a new RaggedTensor (inner_rt), and then recursively process its values.
287      starts = rt_input.row_splits[:-1]
288      limits = rt_input.row_splits[1:]
289      step = 1 if column_key.step is None else column_key.step
290      lower_bound = _if_ge_zero(step, lambda: starts, lambda: starts - 1)
291      upper_bound = _if_ge_zero(step, lambda: limits, lambda: limits - 1)
292      # inner_rt_starts[i] = index to start gathering for row i.
293      if column_key.start is None:
294        inner_rt_starts = _if_ge_zero(step, lambda: starts, lambda: limits - 1)
295      else:
296        start_offset = math_ops.cast(column_key.start, starts.dtype)
297        inner_rt_starts = _if_ge_zero(
298            column_key.start,
299            lambda: math_ops.minimum(starts + start_offset, upper_bound),
300            lambda: math_ops.maximum(limits + start_offset, lower_bound))
301      # inner_rt_limits[i] = index to stop gathering for row i.
302      if column_key.stop is None:
303        inner_rt_limits = _if_ge_zero(step, lambda: limits, lambda: starts - 1)
304      else:
305        stop_offset = math_ops.cast(column_key.stop, starts.dtype)
306        inner_rt_limits = _if_ge_zero(
307            column_key.stop,
308            lambda: math_ops.minimum(starts + stop_offset, upper_bound),
309            lambda: math_ops.maximum(limits + stop_offset, lower_bound))
310      inner_rt = _build_ragged_tensor_from_value_ranges(
311          inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
312      # If the row dimension is uniform, then calculate the new
313      # uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
314      if rt_input.uniform_row_length is not None:
315        new_row_length = _slice_length(rt_input.uniform_row_length, column_key)
316        inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
317            inner_rt.values, new_row_length, rt_input.nrows())
318      return inner_rt.with_values(
319          _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
320
321  # Indexing a single column in a ragged inner dimension: raise an Exception.
322  # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
323  # into a ragged inner dimension is problematic.
324  if rt_input.uniform_row_length is None:
325    raise ValueError("Cannot index into an inner ragged dimension.")
326
327  # Indexing a single column in a uniform inner dimension: check that the
328  # given index is in-bounds, and then use a strided slice over rt_input.values
329  # to take the indicated element from each row.
330  row_length = rt_input.uniform_row_length
331  column_key = math_ops.cast(column_key, row_length.dtype)
332  oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
333  oob_checks = [
334      check_ops.assert_greater_equal(
335          column_key, -row_length, message=oob_err_msg),
336      check_ops.assert_less(column_key, row_length, message=oob_err_msg),
337  ]
338  with ops.control_dependencies(oob_checks):
339    offset = _if_ge_zero(column_key, lambda: column_key,
340                         lambda: row_length + column_key)
341    sliced_rt = rt_input.values[offset::row_length]
342    return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
343
344
345def _slice_length(value_length, slice_key):
346  """Computes the number of elements in a slice of a value with a given length.
347
348  Returns the equivalent of: `len(range(value_length)[slice_key])`
349
350  Args:
351    value_length: Scalar int `Tensor`: the length of the value being sliced.
352    slice_key: A `slice` object used to slice elements from the value.
353
354  Returns:
355    The number of elements in the sliced value.
356  """
357  # Note: we could compute the slice length without creating a zeros tensor
358  # with some variant of (stop-start)//step, but doing so would require more
359  # ops (for checking bounds, handling negative indices, negative step sizes,
360  # etc); and we expect this to be an uncommon operation, so we use this
361  # simpler implementation.
362  zeros = array_ops.zeros(value_length, dtype=dtypes.bool)
363  return array_ops.size(zeros[slice_key], out_type=value_length.dtype)
364
365
366def _expand_ellipsis(key_list, num_remaining_dims):
367  """Expands the ellipsis at the start of `key_list`.
368
369  Assumes that the first element of `key_list` is Ellipsis.  This will either
370  remove the Ellipsis (if it corresponds to zero indices) or prepend a new
371  `slice(None, None, None)` (if it corresponds to more than zero indices).
372
373  Args:
374    key_list: The arguments to `__getitem__()`.
375    num_remaining_dims: The number of dimensions remaining.
376
377  Returns:
378    A copy of `key_list` with he ellipsis expanded.
379  Raises:
380    ValueError: If ragged_rank.shape.ndims is None
381    IndexError: If there are too many elements in `key_list`.
382  """
383  if num_remaining_dims is None:
384    raise ValueError("Ellipsis not supported for unknown shape RaggedTensors")
385  num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis)
386  if num_indices > num_remaining_dims + 1:
387    raise IndexError("Too many indices for RaggedTensor")
388  elif num_indices == num_remaining_dims + 1:
389    return key_list[1:]
390  else:
391    return [slice(None, None, None)] + key_list
392
393
394def _tensors_in_key_list(key_list):
395  """Generates all Tensors in the given slice spec."""
396  if isinstance(key_list, ops.Tensor):
397    yield key_list
398  if isinstance(key_list, (list, tuple)):
399    for v in key_list:
400      for tensor in _tensors_in_key_list(v):
401        yield tensor
402  if isinstance(key_list, slice):
403    for tensor in _tensors_in_key_list(key_list.start):
404      yield tensor
405    for tensor in _tensors_in_key_list(key_list.stop):
406      yield tensor
407    for tensor in _tensors_in_key_list(key_list.step):
408      yield tensor
409
410
411def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
412  """Returns a `RaggedTensor` containing the specified sequences of values.
413
414  Returns a RaggedTensor `output` where:
415
416  ```python
417  output.shape[0] = starts.shape[0]
418  output[i] = values[starts[i]:limits[i]:step]
419  ```
420
421  Requires that `starts.shape == limits.shape` and
422  `0 <= starts[i] <= limits[i] <= values.shape[0]`.
423
424  Args:
425    starts: 1D integer Tensor specifying the start indices for the sequences of
426      values to include.
427    limits: 1D integer Tensor specifying the limit indices for the sequences of
428      values to include.
429    step: Integer value specifying the step size for strided slices.
430    values: The set of values to select from.
431
432  Returns:
433    A `RaggedTensor`.
434
435  Raises:
436    ValueError: Until the prerequisite ops are checked in.
437  """
438  # Use `ragged_range` to get the index of each value we should include.
439  if step is None:
440    step = 1
441  step = ops.convert_to_tensor(step, name="step")
442  if step.dtype.is_integer:
443    step = math_ops.cast(step, starts.dtype)
444  else:
445    raise TypeError("slice strides must be integers or None")
446  value_indices = ragged_math_ops.range(starts, limits, step,
447                                        row_splits_dtype=starts.dtype)
448
449  # Use `ragged_gather` or `array_ops.gather` to collect the values.
450  if isinstance(values, ragged_tensor.RaggedTensor):
451    gathered_values = ragged_gather_ops.gather(
452        params=values, indices=value_indices.values)
453  else:
454    gathered_values = array_ops.gather(
455        params=values, indices=value_indices.values)
456
457  # Assemble the RaggedTensor from splits & values.
458  return value_indices.with_values(gathered_values)
459
460
461def _if_ge_zero(value, true_fn, false_fn):
462  """Returns `true_fn() if value >= 0 else false_fn()`."""
463  # If `value` is statically known, then don't use a control flow op.
464  if isinstance(value, ops.Tensor):
465    const_value = tensor_util.constant_value(value)
466    if const_value is None:
467      return control_flow_ops.cond(value >= 0, true_fn, false_fn)
468    else:
469      value = const_value
470  if value >= 0:
471    return true_fn()
472  else:
473    return false_fn()
474