xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/ragged_math_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17import functools
18import typing
19
20import numpy as np
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import map_fn
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops.ragged import ragged_functional_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.ops.ragged import segment_id_ops
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.tf_export import tf_export
37
38
39#===============================================================================
40# ragged.range
41#===============================================================================
42# pylint: disable=redefined-builtin
43@tf_export('ragged.range')
44@dispatch.add_dispatch_support
45def range(starts,
46          limits=None,
47          deltas=1,
48          dtype=None,
49          name=None,
50          row_splits_dtype=dtypes.int64):
51  """Returns a `RaggedTensor` containing the specified sequences of numbers.
52
53  Each row of the returned `RaggedTensor` contains a single sequence:
54
55  ```python
56  ragged.range(starts, limits, deltas)[i] ==
57      tf.range(starts[i], limits[i], deltas[i])
58  ```
59
60  If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
61  empty list.  Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
62  `output[i]` will be an empty list.  This behavior is consistent with the
63  Python `range` function, but differs from the `tf.range` op, which returns
64  an error for these cases.
65
66  Examples:
67
68  >>> tf.ragged.range([3, 5, 2]).to_list()
69  [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
70  >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list()
71  [[0, 1, 2], [], [8, 9, 10, 11]]
72  >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list()
73  [[0, 2], [], [8, 10]]
74
75  The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
76  The vector inputs must all have the same size.  Scalar inputs are broadcast
77  to match the size of the vector inputs.
78
79  Args:
80    starts: Vector or scalar `Tensor`.  Specifies the first entry for each range
81      if `limits` is not `None`; otherwise, specifies the range limits, and the
82      first entries default to `0`.
83    limits: Vector or scalar `Tensor`.  Specifies the exclusive upper limits for
84      each range.
85    deltas: Vector or scalar `Tensor`.  Specifies the increment for each range.
86      Defaults to `1`.
87    dtype: The type of the elements of the resulting tensor.  If not specified,
88      then a value is chosen based on the other args.
89    name: A name for the operation.
90    row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
91      tensor.  One of `tf.int32` or `tf.int64`.
92
93  Returns:
94    A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
95  """
96  row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
97  if limits is None:
98    starts, limits = 0, starts
99
100  with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
101    starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
102    limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
103    deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
104
105    # infer dtype if not explicitly provided
106    if dtype is None:
107      starts, limits, deltas = _infer_matching_dtype(
108          [starts, limits, deltas],
109          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
110
111    result = gen_ragged_math_ops.ragged_range(
112        starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
113    return ragged_tensor.RaggedTensor.from_row_splits(
114        result.rt_dense_values, result.rt_nested_splits, validate=False)
115
116
117def _infer_matching_dtype(tensors, dtype_hierarchy):
118  """Infers a matching dtype for tensors, and casts them to that dtype."""
119  assert all(t.dtype in dtype_hierarchy for t in tensors)
120  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
121  return [math_ops.cast(t, inferred_dtype) for t in tensors]
122
123
124ops.no_gradient('RaggedRange')
125
126#===============================================================================
127# ragged_segment_<AGGREGATE>
128#===============================================================================
129
130# Docstring template used for the raggged_segment_<AGGREGATE> ops.
131_RAGGED_SEGMENT_DOCSTRING = """\
132Computes the %(combination)s along segments of a RaggedTensor.
133
134  Returns a RaggedTensor `output` with `num_segments` rows, where the row
135  `output[i]` is formed by taking the %(combination)s of all rows of `data`
136  whose corresponding `segment_id` is `i`.
137
138  The length of the row `output[i]` will be the maximum of the lengths of
139  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
140  rows correspond to a given segment ID, then the output row for that segment
141  ID will be empty.
142
143  Args:
144    data: A `RaggedTensor` containing the values to combine.
145    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
146      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
147      Must be greater than or equal to zero, and less than `num_segments`.
148      `segment_ids` is not required to be sorted.
149    num_segments: An `int32` or `int64` scalar specifying the number of
150      distinct segment ids.
151    name: A name prefix for the returned tensor (optional).
152  Returns:
153    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
154    has the same dtype as `data`, and its shape is
155    `[num_segments] + data.shape[segment_ids.rank:]`.
156  Raises:
157    ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
158"""
159
160
161def _ragged_segment_aggregate(unsorted_segment_op,
162                              data,
163                              segment_ids,
164                              num_segments,
165                              separator=None,
166                              name=None):
167  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
168
169  Returns a RaggedTensor `output` with `num_segments` rows, where the row
170  `output[i]` is formed by combining all rows of `data` whose corresponding
171  `segment_id` is `i`.  The values in each row are combined using
172  `unsorted_segment_op`.
173
174  The length of the row `output[i]` will be the maximum of the lengths of
175  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
176  rows correspond to a given segment ID, then the output row for that segment
177  ID will be empty.
178
179  Args:
180    unsorted_segment_op: The tensorflow `op` that should be used to combine
181      values in each row.  Must have the same signature and basic behavior as
182      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
183    data: A `RaggedTensor` containing the values to be combined.
184    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
185      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
186      `segment_ids` is not required to be sorted.
187    num_segments: An `int32` or `int64` scalar.
188    separator: An optional string. Defaults to None. The separator to use when
189      joining. Only used for string types.
190    name: A name prefix for the returned tensor (optional).
191
192  Returns:
193    A `RaggedTensor` containing the aggregated values.  The returned tensor
194    has the same dtype as `data`, and its shape is
195    `[num_segments] + data.shape[segment_ids.rank:]`.
196  Raises:
197    ValueError: If segment_ids.shape is not a prefix of data.shape.
198  """
199  if not (ragged_tensor.is_ragged(data) or
200          ragged_tensor.is_ragged(segment_ids)):
201    if separator is not None:
202      # It uses unsorted_segment_join.
203      return unsorted_segment_op(data, segment_ids, num_segments, separator,
204                                 name)
205    else:
206      return unsorted_segment_op(data, segment_ids, num_segments, name)
207
208  with ops.name_scope(name, 'RaggedSegment',
209                      [data, segment_ids, num_segments]) as name:
210    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
211    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
212        segment_ids, name='segment_ids')
213    data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
214    if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
215      raise ValueError('segment_ids must have dtype int32 or int64.')
216
217    if ragged_tensor.is_ragged(segment_ids):
218      if not ragged_tensor.is_ragged(data):
219        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
220                         'but segment_ids is ragged and data is not.')
221      check_splits = check_ops.assert_equal(
222          segment_ids.row_splits,
223          data.row_splits,
224          message='segment_ids.shape must be a prefix of data.shape')
225      with ops.control_dependencies([check_splits]):
226        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
227                                         segment_ids.values, num_segments,
228                                         separator)
229
230    # Find the length of each row in data.  (shape=[data_nrows])
231    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
232
233    # Find the length that each output row will have.  The length of the row
234    # corresponding to segment `id` is `max(data_row_lengths[i])` where
235    # `segment_ids[i]=id`.  (shape=[output_nrows])
236    output_row_lengths = math_ops.maximum(
237        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
238                                      num_segments), 0)
239
240    # Build the splits tensor for the output RaggedTensor.
241    output_splits = array_ops.concat([
242        array_ops.zeros([1], output_row_lengths.dtype),
243        math_ops.cumsum(output_row_lengths)
244    ],
245                                     axis=0)
246
247    # For each row in `data`, find the start & limit position where that row's
248    # values will be aggregated in output.values.
249    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
250    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
251
252    # For each value in `data.values`, find the position where it will
253    # aggregated in `output.values`.
254    # Get the target output values index for each data values index.
255    data_val_to_out_val_index = range(data_row_to_out_row_start,
256                                      data_row_to_out_row_limit).values
257
258    # Recursively aggregate the values.
259    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
260                                              data_val_to_out_val_index,
261                                              output_splits[-1], separator)
262    return ragged_tensor.RaggedTensor.from_row_splits(
263        output_values, output_splits, validate=False)
264
265
266@dispatch.dispatch_for_api(math_ops.unsorted_segment_sum)
267def segment_sum(data: ragged_tensor.RaggedOrDense,
268                segment_ids: ragged_tensor.RaggedOrDense,
269                num_segments,
270                name=None):
271  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
272  return _ragged_segment_aggregate(
273      math_ops.unsorted_segment_sum,
274      data=data,
275      segment_ids=segment_ids,
276      num_segments=num_segments,
277      name=(name or 'RaggedSegmentSum'))
278
279
280@dispatch.dispatch_for_api(math_ops.unsorted_segment_prod)
281def segment_prod(data: ragged_tensor.RaggedOrDense,
282                 segment_ids: ragged_tensor.RaggedOrDense,
283                 num_segments,
284                 name=None):
285  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
286  return _ragged_segment_aggregate(
287      math_ops.unsorted_segment_prod,
288      data=data,
289      segment_ids=segment_ids,
290      num_segments=num_segments,
291      name=(name or 'RaggedSegmentProd'))
292
293
294@dispatch.dispatch_for_api(math_ops.unsorted_segment_min)
295def segment_min(data: ragged_tensor.RaggedOrDense,
296                segment_ids: ragged_tensor.RaggedOrDense,
297                num_segments,
298                name=None):
299  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
300  return _ragged_segment_aggregate(
301      math_ops.unsorted_segment_min,
302      data=data,
303      segment_ids=segment_ids,
304      num_segments=num_segments,
305      name=(name or 'RaggedSegmentMin'))
306
307
308@dispatch.dispatch_for_api(math_ops.unsorted_segment_max)
309def segment_max(data: ragged_tensor.RaggedOrDense,
310                segment_ids: ragged_tensor.RaggedOrDense,
311                num_segments,
312                name=None):
313  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
314  return _ragged_segment_aggregate(
315      math_ops.unsorted_segment_max,
316      data=data,
317      segment_ids=segment_ids,
318      num_segments=num_segments,
319      name=(name or 'RaggedSegmentMax'))
320
321
322@dispatch.dispatch_for_api(math_ops.unsorted_segment_mean)
323def segment_mean(data: ragged_tensor.RaggedOrDense,
324                 segment_ids: ragged_tensor.RaggedOrDense,
325                 num_segments,
326                 name=None):
327  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
328  with ops.name_scope(name, 'RaggedSegmentMean',
329                      [data, segment_ids, num_segments]):
330    total = segment_sum(data, segment_ids, num_segments)
331    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
332        array_ops.ones_like(data.flat_values),
333        data.nested_row_splits,
334        validate=False)
335    count = segment_sum(ones, segment_ids, num_segments)
336    if ragged_tensor.is_ragged(total):
337      return total.with_flat_values(total.flat_values / count.flat_values)
338    else:
339      return total / count
340
341
342@dispatch.dispatch_for_api(math_ops.unsorted_segment_sqrt_n)
343def segment_sqrt_n(data: ragged_tensor.RaggedOrDense,
344                   segment_ids: ragged_tensor.RaggedOrDense,
345                   num_segments,
346                   name=None):
347  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
348  with ops.name_scope(name, 'RaggedSegmentSqrtN',
349                      [data, segment_ids, num_segments]):
350    total = segment_sum(data, segment_ids, num_segments)
351    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
352        array_ops.ones_like(data.flat_values),
353        data.nested_row_splits,
354        validate=False)
355    count = segment_sum(ones, segment_ids, num_segments)
356    if ragged_tensor.is_ragged(total):
357      return total.with_flat_values(total.flat_values /
358                                    math_ops.sqrt(count.flat_values))
359    else:
360      return total / math_ops.sqrt(count)
361
362
363def _set_ragged_segment_docstring(func, combination, combined):
364  func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
365      combination=combination, combined=combined)
366
367
368_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
369_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
370_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
371_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
372_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
373_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
374                              'summed')
375
376#===============================================================================
377# ragged_reduce_<AGGREGATE>
378#===============================================================================
379
380# Docstring template used for ragged_reduce_<AGGREGATE> ops.
381_RAGGED_REDUCE_DOCSTRING = """\
382Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
383
384  Reduces `input_tensor` along the dimensions given in `axis` by taking the
385  %(combination)s of values.  If a reduced dimension has no elements for
386  some index, then the value for that index will be %(default)s.
387
388  The rank of the tensor is reduced by `1` for each entry in `axis`.  If
389  `axis` is not specified, then all dimensions are reduced, and a scalar
390  value is returned.
391  Args:
392    input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
393    axis: The dimensions to reduce.  May be `None` (to reduce all axes), an
394      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
395      a given set of axes), or a `Tensor` with a constant value.  Must be in
396      the range `[0, input_tensor.rank]`.
397    name: A name prefix for the returned tensor (optional).
398  Returns:
399    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
400    has the same dtype as `data`, and its shape is given by removing the
401    dimensions specified in `axis` from `input_tensor.shape`.  The `ragged_rank`
402    of the returned tensor is given by substracting any ragged dimensions
403    specified in `axis` from `input_tensor.ragged_rank`.
404  Raises:
405    ValueError: If `axis` contains a `Tensor` whose value is not constant.
406  ####Example:
407    %(example)s
408"""
409_RAGGED_REDUCE_SUM_EXAMPLE = """
410    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
411    >>> tf.reduce_sum(rt, axis=0).numpy()  # = [3+1+9+2, 1+5+6, 4]
412    array([15, 12, 4], dtype=int32)
413    >>> tf.reduce_sum(rt, axis=1).numpy()  # = [3+1+4, 1+5, 9, 2+6]
414    array([8, 6, 9, 8], dtype=int32)
415"""
416_RAGGED_REDUCE_PROD_EXAMPLE = """
417    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
418    >>> tf.reduce_prod(rt, axis=0).numpy()  # = [3*1*9*2, 1*5*6, 4]
419    array([54, 30, 4], dtype=int32)
420    >>> tf.reduce_prod(rt, axis=1).numpy()  # = [3*1*4, 1*5, 9, 2*6]
421    array([12, 5, 9, 12], dtype=int32)
422"""
423_RAGGED_REDUCE_MIN_EXAMPLE = """
424    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
425    >>> tf.reduce_min(rt, axis=0).numpy()
426    array([1, 1, 4], dtype=int32)
427    >>> tf.reduce_min(rt, axis=1).numpy()
428    array([1, 1, 9, 2], dtype=int32)
429"""
430_RAGGED_REDUCE_MAX_EXAMPLE = """
431    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
432    >>> tf.reduce_max(rt, axis=0).numpy()
433    array([9, 6, 4], dtype=int32)
434    >>> tf.reduce_max(rt, axis=1).numpy()
435    array([4, 5, 9, 6], dtype=int32)
436"""
437_RAGGED_REDUCE_MEAN_EXAMPLE = """
438    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
439    >>> tf.reduce_mean(rt, axis=0).numpy()
440    array([3.75, 4.  , 4. ])
441    >>> tf.reduce_mean(rt, axis=1).numpy()
442    array([2.66666667, 3.  , 9.  , 4.  ])
443"""
444_RAGGED_REDUCE_VARIANCE_EXAMPLE = """
445    >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]],
446    ...                         dtype=tf.float64)
447    >>> tf.math.reduce_variance(rt, axis=0).numpy()
448    array([1.25, 0., 0.])
449    >>> tf.math.reduce_variance(rt, axis=1).numpy()
450    array([2., 0.25, 0., 2.25])
451"""
452_RAGGED_REDUCE_STD_EXAMPLE = """
453    >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]],
454    ...                         dtype=tf.float64)
455    >>> tf.math.reduce_std(rt, axis=0).numpy()
456    array([1.11803399, 0.47140452])
457    >>> tf.math.reduce_std(rt, axis=1).numpy()
458    array([0.5, 0.5, 0., 1.5])
459"""
460_RAGGED_REDUCE_ALL_EXAMPLE = """
461    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
462    >>> tf.reduce_all(rt, axis=0).numpy()
463    array([False,  True, False,  True])
464    >>> tf.reduce_all(rt, axis=1).numpy()
465    array([ True, False, False])
466"""
467_RAGGED_REDUCE_ANY_EXAMPLE = """
468    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
469    >>> tf.reduce_any(rt, axis=0).numpy()
470    array([ True,  True, False,  True])
471    >>> tf.reduce_any(rt, axis=1).numpy()
472    array([ True,  True,  True])
473"""
474
475
476def ragged_reduce_aggregate(reduce_op,
477                            unsorted_segment_op,
478                            rt_input,
479                            axis,
480                            keepdims,
481                            separator=None,
482                            name=None):
483  """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
484
485  Reduces `rt_input` along the dimensions given in `axis`.  The rank of the
486  tensor is reduced by 1 for each entry in `axis`.  If `axis` is not specified,
487  then all dimensions are reduced, and a scalar value is returned.
488
489  This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
490  if not, then reducing multiple axes will return incorrect results.  (In
491  particular, reducing multiple axes is currently implemented by reducing the
492  axes one at a time.)
493
494  Args:
495    reduce_op: The tensorflow `op` that should be used to reduce values in
496      uniform dimensions.  Must have the same signature and basic behavior as
497      `reduce_sum`, `reduce_max`, etc.
498    unsorted_segment_op: The tensorflow `op` that should be used to combine
499      values in ragged dimensions.  Must have the same signature and basic
500      behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
501    rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
502    axis: The axis or axes to reduce.  May be `None` (to reduce all axes), an
503      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
504      given set of axes), or a `Tensor` with a constant value.  Must be in the
505      range `[0, rt_input.rank)`.
506    keepdims: If true, retains reduced dimensions with length 1.
507    separator: An optional string. Defaults to None. The separator to use when
508      joining. The separator must not be set for non-string data types. (i.e. if
509      separator is not None then it uses string ops)
510    name: A name prefix for the returned tensor (optional).
511
512  Returns:
513    A `RaggedTensor` containing the reduced values.  The returned tensor
514    has the same dtype as `data`, and its shape is given by removing the
515    dimensions specified in `axis` from `rt_input.shape`.  The `ragged_rank`
516    of the returned tensor is given by substracting any ragged dimensions
517    specified in `axis` from `rt_input.ragged_rank`.
518  Raises:
519    ValueError: If `axis` contains a `Tensor` whose value is not constant.
520  """
521  if not ragged_tensor.is_ragged(rt_input):
522    if separator is None:
523      return reduce_op(rt_input, axis, keepdims=keepdims, name=name)
524    else:
525      # When separator is not None, We infer that dtype is string and
526      # reduce_join will be called.
527      return reduce_op(
528          rt_input, axis, keepdims=keepdims, name=name, separator=separator)
529
530  if isinstance(axis, ops.Tensor):
531    axis = tensor_util.constant_value(axis)
532    if axis is None:
533      raise ValueError('axis must be known at graph construction time.')
534    if isinstance(axis, np.ndarray):
535      axis = axis.tolist()
536
537  # When reducing all axes, just ignore splits & reduce the inner values.
538  if axis is None:
539    result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name)
540    if keepdims:
541      # Expand the result to the input number of dimensions.
542      for _ in rt_input.shape[1:]:
543        result = array_ops.expand_dims(result, axis=0)
544    return result
545
546  with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
547    if isinstance(axis, (tuple, list)):
548      if not axis:
549        return rt_input
550      elif len(axis) == 1:
551        axis = axis[0]
552      else:
553        # When reducing multiple axes, as we reduce one at a time (see below),
554        # the negative axis has to be converted to positive at the first run
555        # as the sort with negative axis will have different orders.
556        # See GitHub issue 27497.
557        axis = [
558            array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
559                                        'rank(input_tensor)')
560            for i, a in enumerate(axis)
561        ]
562        # When reducing multiple axes, just reduce one at a time.  This is less
563        # efficient, and only works for associative ops.  (In particular, it
564        # does not work for reduce_mean.)  However, reducing multiple axes at
565        # once will probably require a nontrivial c++ op.
566        axis = sorted(axis)
567        inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
568                                                rt_input, axis[-1], keepdims,
569                                                separator)
570        return ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
571                                       inner_reduced, axis[:-1], keepdims,
572                                       separator)
573
574    rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
575        rt_input, name='rt_input')
576
577    axis = array_ops.get_positive_axis(
578        axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
579
580    if axis == 0:
581      # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
582      row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
583      num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
584      segment_ids = range(row_lengths).values
585      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
586                                         segment_ids, num_segments, separator)
587      if keepdims:
588        result = array_ops.expand_dims(result, axis=0)
589      return result
590    elif axis == 1:
591      # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
592      num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
593      segment_ids = segment_id_ops.row_splits_to_segment_ids(
594          rt_input.row_splits)
595      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
596                                         segment_ids, num_segments, separator)
597      if keepdims:
598        result = array_ops.expand_dims(result, axis=1)
599      return result
600    else:
601      # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
602      #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
603      return rt_input.with_values(
604          ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
605                                  rt_input.values, axis - 1, keepdims,
606                                  separator))
607
608
609@dispatch.dispatch_for_api(math_ops.reduce_sum)
610def reduce_sum(input_tensor: ragged_tensor.Ragged,
611               axis=None,
612               keepdims=None,
613               name=None):
614  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
615
616  return ragged_reduce_aggregate(
617      reduce_op=math_ops.reduce_sum,
618      unsorted_segment_op=math_ops.unsorted_segment_sum,
619      rt_input=input_tensor,
620      axis=axis,
621      keepdims=keepdims,
622      name=(name or 'RaggedReduceSum'))
623
624
625@dispatch.dispatch_for_api(math_ops.reduce_prod)
626def reduce_prod(input_tensor: ragged_tensor.Ragged,
627                axis=None,
628                keepdims=None,
629                name=None):
630  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
631  return ragged_reduce_aggregate(
632      reduce_op=math_ops.reduce_prod,
633      unsorted_segment_op=math_ops.unsorted_segment_prod,
634      rt_input=input_tensor,
635      axis=axis,
636      keepdims=keepdims,
637      name=(name or 'RaggedReduceProd'))
638
639
640@dispatch.dispatch_for_api(math_ops.reduce_min)
641def reduce_min(input_tensor: ragged_tensor.Ragged,
642               axis=None,
643               keepdims=None,
644               name=None):
645  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
646  return ragged_reduce_aggregate(
647      reduce_op=math_ops.reduce_min,
648      unsorted_segment_op=math_ops.unsorted_segment_min,
649      rt_input=input_tensor,
650      axis=axis,
651      keepdims=keepdims,
652      name=(name or 'RaggedReduceMin'))
653
654
655@dispatch.dispatch_for_api(math_ops.reduce_max)
656def reduce_max(input_tensor: ragged_tensor.Ragged,
657               axis=None,
658               keepdims=None,
659               name=None):
660  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
661  return ragged_reduce_aggregate(
662      reduce_op=math_ops.reduce_max,
663      unsorted_segment_op=math_ops.unsorted_segment_max,
664      rt_input=input_tensor,
665      axis=axis,
666      keepdims=keepdims,
667      name=(name or 'RaggedReduceMax'))
668
669
670@dispatch.dispatch_for_api(math_ops.reduce_mean)
671def reduce_mean(input_tensor: ragged_tensor.Ragged,
672                axis=None,
673                keepdims=None,
674                name=None):
675  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
676  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
677    total = reduce_sum(input_tensor, axis, keepdims)
678    if ragged_tensor.is_ragged(input_tensor):
679      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
680          array_ops.ones_like(input_tensor.flat_values),
681          input_tensor.nested_row_splits,
682          validate=False)
683    else:
684      ones = array_ops.ones_like(input_tensor)
685    count = reduce_sum(ones, axis, keepdims)
686    if ragged_tensor.is_ragged(total):
687      return ragged_tensor.RaggedTensor.from_nested_row_splits(
688          total.flat_values / count.flat_values,
689          total.nested_row_splits,
690          validate=False)
691    else:
692      return total / count
693
694
695@dispatch.dispatch_for_api(math_ops.reduce_variance)
696def reduce_variance(input_tensor: ragged_tensor.Ragged,
697                    axis=None,
698                    keepdims=False,
699                    name=None):
700  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
701  with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]):
702    input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
703        input_tensor, name='input_tensor')
704    if input_tensor.dtype.is_complex:
705      raise ValueError(
706          'reduce_variance is not supported for RaggedTensors with complex dtypes.'
707      )
708    square_of_input = math_ops.square(input_tensor)
709    mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims)
710    mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims)
711    square_of_mean = math_ops.square(mean)
712    # Note: the above method of computing variance is not numerically stable,
713    # and can result in negative variances.  Here we clip to >= 0.
714    return math_ops.maximum(mean_of_square - square_of_mean, 0)
715
716
717@dispatch.dispatch_for_api(math_ops.reduce_std)
718def reduce_std(input_tensor: ragged_tensor.Ragged,
719               axis=None,
720               keepdims=False,
721               name=None):
722  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
723  with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]):
724    variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims)
725    return math_ops.sqrt(variance)
726
727
728def _cast(input_tensor, dtype):
729  return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
730                                               dtype)
731
732
733@dispatch.dispatch_for_api(math_ops.reduce_all)
734def reduce_all(input_tensor: ragged_tensor.Ragged,
735               axis=None,
736               keepdims=None,
737               name=None):
738  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
739  with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
740    return _cast(
741        reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
742        dtypes.bool)
743
744
745@dispatch.dispatch_for_api(math_ops.reduce_any)
746def reduce_any(input_tensor: ragged_tensor.Ragged,
747               axis=None,
748               keepdims=None,
749               name=None):
750  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
751  with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
752    return _cast(
753        reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
754        dtypes.bool)
755
756
757def _set_ragged_reduce_docstring(func, combination, combined, default, example):
758  func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
759      combination=combination,
760      combined=combined,
761      default=default,
762      example=example)
763
764
765_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
766                             _RAGGED_REDUCE_SUM_EXAMPLE)
767_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
768                             _RAGGED_REDUCE_PROD_EXAMPLE)
769_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
770                             '`input_tensor.dtype.min`',
771                             _RAGGED_REDUCE_MIN_EXAMPLE)
772_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
773                             '`input_tensor.dtype.max`',
774                             _RAGGED_REDUCE_MAX_EXAMPLE)
775_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
776                             _RAGGED_REDUCE_MEAN_EXAMPLE)
777_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN',
778                             _RAGGED_REDUCE_VARIANCE_EXAMPLE)
779_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN',
780                             _RAGGED_REDUCE_STD_EXAMPLE)
781_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
782                             _RAGGED_REDUCE_ALL_EXAMPLE)
783_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
784                             _RAGGED_REDUCE_ANY_EXAMPLE)
785
786
787#===============================================================================
788# ragged.matmul
789#===============================================================================
790@dispatch.dispatch_for_api(math_ops.matmul)
791def matmul(a: ragged_tensor.RaggedOrDense,
792           b: ragged_tensor.RaggedOrDense,
793           transpose_a=False,
794           transpose_b=False,
795           adjoint_a=False,
796           adjoint_b=False,
797           a_is_sparse=False,
798           b_is_sparse=False,
799           output_type=None,
800           name=None):
801  """Multiplies matrix `a` by matrix `b`.
802
803  If all transpose or adjoint attributes are `False` then:
804
805  ```
806  output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j.
807  ```
808
809  The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2`
810  dimensions are batch dimensions.  The inputs must have the same dtype.  See
811  `tf.matmul` for more information.
812
813  Args:
814    a: `tf.Tensor` or `RaggedTensor` with `rank > 1`.
815    b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`.
816    transpose_a: If `True`, `a` is transposed before multiplication.
817    transpose_b: If `True`, `b` is transposed before multiplication.
818    adjoint_a: If `True`, `a` is conjugated & transposed before multiplication.
819    adjoint_b: If `True`, `b` is conjugated & transposed before multiplication.
820    a_is_sparse: If `True`, optimize assuming `a` is mostly zero.
821    b_is_sparse: If `True`, optimize assuming `b` is mostly zero.
822    output_type: The output datatype (optional).
823    name: Name for the operation (optional).
824
825  Returns:
826    A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where
827    each inner-most matrix is the product of the corresponding matrices in `a`
828    and `b`.
829  """
830  if transpose_a and adjoint_a:
831    raise ValueError('Only one of transpose_a and adjoint_a can be True.')
832  if transpose_b and adjoint_b:
833    raise ValueError('Only one of transpose_b and adjoint_b can be True.')
834
835  kwargs = dict(
836      transpose_a=transpose_a,
837      transpose_b=transpose_b,
838      adjoint_a=adjoint_a,
839      adjoint_b=adjoint_b,
840      a_is_sparse=a_is_sparse,
841      b_is_sparse=b_is_sparse,
842      output_type=output_type)
843
844  with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name:
845    a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a')
846    b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b')
847
848    a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor)
849    b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor)
850    if not (a_is_ragged or b_is_ragged):
851      return math_ops.matmul(a, b, **kwargs)
852
853    if a.dtype != b.dtype:
854      raise ValueError('`a` and `b` must have the same dtype.')
855
856    # TODO(edloper): Support broadcasting inputs.  (Broadcast support is not
857    # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul,
858    # but it is supported by the op.)
859
860    # Find the rank of the input tensors.
861    if a.shape.rank is None:
862      if b.shape.rank is None:
863        raise ValueError('matmul requires at least one input to have known '
864                         'rank if either input is ragged.')
865      rank = b.shape.rank
866    else:
867      if b.shape.rank is not None and a.shape.rank != b.shape.rank:
868        raise ValueError('`a` and `b` must have the same rank.')
869      rank = a.shape.rank
870
871    # At least one of `a` and `b` is ragged; and ragged tensors always have
872    # rank>=2.
873    if rank < 2:
874      # This can happen if e.g. `a` is a 1D dense tensor and `b` is a
875      # ragged tensor with unknown rank.  Since ragged tensors always have
876      # `rank>=2`, this implies that `a` and `b` have different ranks.
877      raise ValueError('`a` and `b` must have the same rank.')
878
879    # Rank>3: We have multiple batch dimensions.  Merge them into a single
880    # batch dimension, recursively call `matmul`, and then restore the original
881    # batch dimension (using a.row_splits).
882    if rank > 3:
883      shape_err = 'Batch dimensions of `a` and `b` do not have the same size.'
884      if not a_is_ragged:
885        a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1)
886      if not b_is_ragged:
887        b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1)
888      with ops.control_dependencies([
889          check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err)
890      ]):
891        flat_result = matmul(a.values, b.values, **kwargs)
892        return a.with_values(flat_result)
893
894    if rank == 2:
895      return _matmul_2d(a, b, **kwargs)
896
897    assert rank == 3  # I.e., we have a single batch dimension.
898
899    a_ragged_rank = a.ragged_rank if a_is_ragged else 0
900    if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a):
901      # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute
902      # the result with a single dense `matmul`.
903      return _matmul_3d_with_batch_dim_folding(a, b, **kwargs)
904    else:
905      # Otherwie, fall back on using `map_fn`.
906      return _matmul_3d_with_map_fn(a, b, **kwargs)
907
908
909def _matmul_2d(a, b, **kwargs):
910  """Multiplies potentially ragged 2D tensors.
911
912  Args:
913    a: A 2D Tensor or RaggedTensor with `shape=[I, J]`
914    b: A 2D Tensor or RaggedTensor with `shape=[J, K]`
915    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
916
917  Returns:
918    A 2D Tensor with `shape=[I, K]`.
919  """
920  # multiplying `a` and `b` is only well-defined if `a` and `b` are
921  # actually uniform (and just happened to be stored as ragged tensors).
922  # Check that they're uniform, convert them to tf.Tensor.
923  ragged_err = ('The matrices in `a` and `b` may not be '
924                'ragged in their innermost dimension.')
925  checks = []
926  if isinstance(a, ragged_tensor.RaggedTensor):
927    original_size = array_ops.size(a.flat_values)
928    a = a.to_tensor()
929    checks.append(
930        check_ops.assert_equal(
931            original_size, array_ops.size(a), message=ragged_err))
932  if isinstance(b, ragged_tensor.RaggedTensor):
933    original_size = array_ops.size(b.flat_values)
934    b = b.to_tensor()
935    checks.append(
936        check_ops.assert_equal(
937            original_size, array_ops.size(b), message=ragged_err))
938  with ops.control_dependencies(checks):
939    return math_ops.matmul(a, b, **kwargs)
940
941
942def _matmul_3d_with_map_fn(a, b, **kwargs):
943  """Multiplies batches of 2D matrices using map_fn.
944
945  `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`).
946
947  Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`).
948
949  Args:
950    a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I`
951      and `J` may be ragged.
952    b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J`
953      and `K` may be ragged.
954    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
955
956  Returns:
957    A 3D RaggedTensor with `shape=[B, (I), (K)]`.
958  """
959  # Determine the ragged rank of the result.  In the normal case, we have:
960  #   [B, I, J] * [B, J, K] -> [B, I, K]
961  # Or if we're using transpose_b, then we have:
962  #   [B, I, J] * [B, K, J] -> [B, I, K]
963  # In either case, output_ragged_rank=2 iff the K dimension is ragged.
964  if (isinstance(b, ragged_tensor.RaggedTensor) and
965      (b.ragged_rank == 2 or kwargs.get('transpose_b') or
966       kwargs.get('adjoint_b'))):
967    output_ragged_rank = 2
968  else:
969    output_ragged_rank = 1
970
971  def single_batch_matmul(x):
972    out = _matmul_2d(x[0], x[1], **kwargs)
973    if output_ragged_rank == 2:
974      out = ragged_tensor.RaggedTensor.from_tensor(out)
975    return out
976
977  fn_out_shape = None  # Figure out proper shape.
978  row_splits_dtype = (
979      a.row_splits.dtype
980      if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype)
981  output_type = kwargs['output_type']
982  if output_type is None:
983    output_type = a.dtype
984  spec = ragged_tensor.RaggedTensorSpec(
985      shape=fn_out_shape,
986      dtype=output_type,
987      ragged_rank=output_ragged_rank - 1,
988      row_splits_dtype=row_splits_dtype)
989  result = map_fn.map_fn(
990      single_batch_matmul, elems=(a, b), fn_output_signature=spec)
991
992  # map_fn loses shape information; restore it, where possible.
993  # pylint: disable=protected-access
994  if kwargs.get('transpose_a') or kwargs.get('adjoint_a'):
995    result._set_shape(a.shape[:-2] + a.shape[-1:] + [None])
996  else:
997    result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None])
998  if kwargs.get('transpose_b') or kwargs.get('adjoint_b'):
999    result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1])
1000  else:
1001    result._set_shape(b.shape[:-2] + [None] + b.shape[-1:])
1002
1003  return result
1004
1005
1006def _matmul_3d_with_batch_dim_folding(a, b, **kwargs):
1007  """Multiply batches of 2D matrices where only `a.shape[1]` is ragged.
1008
1009  Args:
1010    a: A RaggedTensor with `shape=[B, (I), J]`.  (ragged_rank must be 1.)
1011    b: A Tensor with `shape=[B, J, K]`
1012    **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
1013      transpose_a and adjoint_a must not be true.
1014
1015  Returns:
1016    A RaggedTensor with `shape=[B, (I), K].
1017  """
1018  # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J]
1019  reshaped_a = array_ops.expand_dims(a.values, 1)
1020  # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K]
1021  reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0)
1022  # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K]
1023  flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs)
1024  # result.shape = [B, (I), K]
1025  return a.with_values(array_ops.squeeze(flat_result, axis=1))
1026
1027
1028#===============================================================================
1029# ragged.softmax
1030#===============================================================================
1031@dispatch.dispatch_for_api(nn_ops.softmax_v2)
1032def softmax(logits: ragged_tensor.Ragged, axis=None, name=None):
1033  """Computes softmax activations.
1034
1035  Used for multi-class predictions. The sum of all outputs generated by softmax
1036  is 1.
1037
1038  This function performs the equivalent of
1039
1040      softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
1041
1042  Example usage:
1043
1044  >>> softmax = tf.nn.softmax([-1, 0., 1.])
1045  >>> softmax
1046  <tf.Tensor: shape=(3,), dtype=float32,
1047  numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
1048  >>> sum(softmax)
1049  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
1050
1051  Args:
1052    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
1053      `float32`, `float64`.
1054    axis: The dimension softmax would be performed on. The default is -1 which
1055      indicates the last dimension.
1056    name: A name for the operation (optional).
1057
1058  Returns:
1059    A `Tensor`. Has the same type and shape as `logits`.
1060
1061  Raises:
1062    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
1063      dimension of `logits`.
1064  """
1065  if axis is None:
1066    axis = -1
1067
1068  with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name:
1069    max_input = reduce_max(logits, axis=axis, keepdims=True)
1070    logits_exp = math_ops.exp(math_ops.subtract(logits, max_input))
1071    denominator = reduce_sum(logits_exp, axis=axis, keepdims=True)
1072    return math_ops.divide(logits_exp, denominator)
1073
1074
1075#===============================================================================
1076# ragged.add_n
1077#===============================================================================
1078@dispatch.dispatch_for_api(math_ops.add_n)
1079def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None):
1080  """RaggedTensor implementation for tf.math.add_n."""
1081  if len(inputs) < 0:
1082    raise ValueError('tf.add_n: expected at least one input.')
1083  with ops.name_scope(name, 'RaggedAddN', inputs):
1084    return ragged_functional_ops.map_flat_values(math_ops.add_n, inputs)
1085
1086
1087#===============================================================================
1088# Ragged version of nn_ops.dropout
1089#===============================================================================
1090@dispatch.dispatch_for_api(nn_ops.dropout)
1091def dropout_v1(x: ragged_tensor.Ragged,
1092               keep_prob=None,
1093               noise_shape=None,
1094               seed=None,
1095               name=None,
1096               rate=None):
1097  """Ragged dispatch target for tf.nn.dropout."""
1098  if noise_shape is not None:
1099    raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1100  with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
1101    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1102    return x.with_flat_values(
1103        nn_ops.dropout(
1104            x.flat_values, keep_prob=keep_prob, seed=seed, rate=rate))
1105
1106
1107@dispatch.dispatch_for_api(nn_ops.dropout_v2)
1108def dropout_v2(x: ragged_tensor.Ragged,
1109               rate,
1110               noise_shape=None,
1111               seed=None,
1112               name=None):
1113  """Ragged dispatch target for tf.nn.dropout."""
1114  if noise_shape is not None:
1115    raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1116  with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
1117    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1118    return x.with_flat_values(
1119        nn_ops.dropout_v2(x.flat_values, rate=rate, seed=seed))
1120
1121
1122@dispatch.dispatch_for_api(nn_ops.stateless_dropout)
1123def stateless_dropout(x: ragged_tensor.Ragged,
1124                      rate,
1125                      seed,
1126                      rng_alg=None,
1127                      noise_shape=None,
1128                      name=None):
1129  """Ragged dispatch target for tf.nn.experimental.stateless_dropout."""
1130  if noise_shape is not None:
1131    raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1132  with ops.name_scope(name, 'RaggedNNStatelessDropout', [x, rate]):
1133    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1134    return x.with_flat_values(
1135        nn_ops.stateless_dropout(
1136            x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg))
1137
1138
1139#===============================================================================
1140# Ragged version of Tensor.__eq__ and Tensor.__ne__
1141#===============================================================================
1142@dispatch.dispatch_for_api(math_ops.tensor_equals)
1143def tensor_equals(self: ragged_tensor.RaggedOrDense,
1144                  other: ragged_tensor.RaggedOrDense):
1145  """Ragged version of the operation invoked by `Tensor.__eq__`."""
1146  if other is None:
1147    return False
1148  elif _use_legacy_mode_for_tensor_equality(self):
1149    return self is other
1150  else:
1151    try:
1152      return math_ops.equal(self, other)
1153    except (errors.InvalidArgumentError, ValueError):
1154      return False  # values are not broadcast-compatbile.
1155
1156
1157@dispatch.dispatch_for_api(math_ops.tensor_not_equals)
1158def tensor_not_equals(self: ragged_tensor.RaggedOrDense,
1159                      other: ragged_tensor.RaggedOrDense):
1160  """Ragged version of the operation invoked by `Tensor.__ne__`."""
1161  if other is None:
1162    return False
1163  elif _use_legacy_mode_for_tensor_equality(self):
1164    return self is not other
1165  else:
1166    try:
1167      return math_ops.not_equal(self, other)
1168    except (errors.InvalidArgumentError, ValueError):
1169      return True  # values are not broadcast-compatbile.
1170
1171
1172def _use_legacy_mode_for_tensor_equality(self):
1173  g = getattr(self, 'graph', None)
1174  return not (ops.Tensor._USE_EQUALITY and  # pylint: disable=protected-access
1175              ops.executing_eagerly_outside_functions() and
1176              (g is None or g.building_function))
1177
1178
1179def _cumsum_flat_values_at_ragged_rank(last_rp, flat_values, exclusive=False,
1180                                       reverse=False):
1181  """Calculate flat_values for math_ops.cumsum when axis==ragged_rank."""
1182  if not exclusive:
1183    partial = _cumsum_flat_values_at_ragged_rank(
1184        last_rp, flat_values, exclusive=True, reverse=reverse)
1185    return partial + flat_values
1186
1187  if reverse:
1188    youngest_sibling = array_ops.gather(
1189        params=last_rp.row_splits(), indices=last_rp.value_rowids() + 1) - 1
1190    new_flat_values = math_ops.cumsum(flat_values, exclusive=True, reverse=True)
1191    initial_values = array_ops.gather(params=new_flat_values,
1192                                      indices=youngest_sibling)
1193
1194    return new_flat_values - initial_values
1195  else:
1196    eldest_sibling = array_ops.gather(
1197        params=last_rp.row_splits(), indices=last_rp.value_rowids())
1198    new_flat_values = math_ops.cumsum(flat_values, exclusive=True)
1199    initial_values = array_ops.gather(params=new_flat_values,
1200                                      indices=eldest_sibling)
1201    return new_flat_values - initial_values
1202
1203
1204@dispatch.dispatch_for_api(math_ops.cumsum)
1205def ragged_cumsum(x: ragged_tensor.Ragged,
1206                  axis: int = 0,
1207                  exclusive: bool = False,
1208                  reverse: bool = False,
1209                  name: typing.Optional[str] = None):
1210  """Calculate math_ops.cumsum for a RaggedTensor.
1211
1212  Given a ragged tensor `x`, the `result` is a ragged tensor with the same
1213  shape. One can calculate the value of `result[i_1...i_k]` as follows:
1214  ```
1215  dense_result=tf.math.cumsum(rt.to_tensor(), axis=axis, exclusive=exclusive,
1216                              reverse=reverse)
1217  result[i_1...i_k]=dense_result[i_1...i_k]
1218  ```
1219
1220  Args:
1221    x: the original ragged tensor to sum.
1222    axis: the axis along which to sum, can range -rank<=axis<rank.
1223    exclusive: is the sum exclusive or inclusive? If True, then result[0]=0.
1224        If False, then result[0]=x[0].
1225    reverse: If True, sum from back to front.
1226    name: the name of the op.
1227  Returns:
1228    the cumulative sum.
1229  """
1230  with ops.name_scope(name, 'RaggedCumSum', [x, axis, exclusive, reverse]):
1231    axis = array_ops.get_positive_axis(axis, x.shape.rank, ndims_name='rank')
1232    if axis == x.ragged_rank:
1233      last_rp = x._nested_row_partitions[-1]  # pylint: disable=protected-access
1234      return x.with_flat_values(
1235          _cumsum_flat_values_at_ragged_rank(last_rp, x.flat_values,
1236                                             exclusive=exclusive,
1237                                             reverse=reverse))
1238    elif axis > x.ragged_rank:
1239      new_axis = axis - x.ragged_rank
1240      cumsum_bound = functools.partial(
1241          math_ops.cumsum, axis=new_axis, exclusive=exclusive, reverse=reverse)
1242      return ragged_functional_ops.map_flat_values(cumsum_bound, x)
1243    else:
1244      dense_version = x.to_tensor()
1245      result = math_ops.cumsum(
1246          dense_version, axis, exclusive=exclusive, reverse=reverse, name=name)
1247      return ragged_tensor.RaggedTensor.from_tensor(
1248          result, lengths=x.nested_row_lengths())
1249