1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17import functools 18import typing 19 20import numpy as np 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import gen_ragged_math_ops 29from tensorflow.python.ops import map_fn 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.ops.ragged import ragged_functional_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.ops.ragged import segment_id_ops 35from tensorflow.python.util import dispatch 36from tensorflow.python.util.tf_export import tf_export 37 38 39#=============================================================================== 40# ragged.range 41#=============================================================================== 42# pylint: disable=redefined-builtin 43@tf_export('ragged.range') 44@dispatch.add_dispatch_support 45def range(starts, 46 limits=None, 47 deltas=1, 48 dtype=None, 49 name=None, 50 row_splits_dtype=dtypes.int64): 51 """Returns a `RaggedTensor` containing the specified sequences of numbers. 52 53 Each row of the returned `RaggedTensor` contains a single sequence: 54 55 ```python 56 ragged.range(starts, limits, deltas)[i] == 57 tf.range(starts[i], limits[i], deltas[i]) 58 ``` 59 60 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an 61 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then 62 `output[i]` will be an empty list. This behavior is consistent with the 63 Python `range` function, but differs from the `tf.range` op, which returns 64 an error for these cases. 65 66 Examples: 67 68 >>> tf.ragged.range([3, 5, 2]).to_list() 69 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]] 70 >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list() 71 [[0, 1, 2], [], [8, 9, 10, 11]] 72 >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list() 73 [[0, 2], [], [8, 10]] 74 75 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. 76 The vector inputs must all have the same size. Scalar inputs are broadcast 77 to match the size of the vector inputs. 78 79 Args: 80 starts: Vector or scalar `Tensor`. Specifies the first entry for each range 81 if `limits` is not `None`; otherwise, specifies the range limits, and the 82 first entries default to `0`. 83 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for 84 each range. 85 deltas: Vector or scalar `Tensor`. Specifies the increment for each range. 86 Defaults to `1`. 87 dtype: The type of the elements of the resulting tensor. If not specified, 88 then a value is chosen based on the other args. 89 name: A name for the operation. 90 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 91 tensor. One of `tf.int32` or `tf.int64`. 92 93 Returns: 94 A `RaggedTensor` of type `dtype` with `ragged_rank=1`. 95 """ 96 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 97 if limits is None: 98 starts, limits = 0, starts 99 100 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name: 101 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts') 102 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits') 103 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas') 104 105 # infer dtype if not explicitly provided 106 if dtype is None: 107 starts, limits, deltas = _infer_matching_dtype( 108 [starts, limits, deltas], 109 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) 110 111 result = gen_ragged_math_ops.ragged_range( 112 starts, limits, deltas, Tsplits=row_splits_dtype, name=name) 113 return ragged_tensor.RaggedTensor.from_row_splits( 114 result.rt_dense_values, result.rt_nested_splits, validate=False) 115 116 117def _infer_matching_dtype(tensors, dtype_hierarchy): 118 """Infers a matching dtype for tensors, and casts them to that dtype.""" 119 assert all(t.dtype in dtype_hierarchy for t in tensors) 120 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index) 121 return [math_ops.cast(t, inferred_dtype) for t in tensors] 122 123 124ops.no_gradient('RaggedRange') 125 126#=============================================================================== 127# ragged_segment_<AGGREGATE> 128#=============================================================================== 129 130# Docstring template used for the raggged_segment_<AGGREGATE> ops. 131_RAGGED_SEGMENT_DOCSTRING = """\ 132Computes the %(combination)s along segments of a RaggedTensor. 133 134 Returns a RaggedTensor `output` with `num_segments` rows, where the row 135 `output[i]` is formed by taking the %(combination)s of all rows of `data` 136 whose corresponding `segment_id` is `i`. 137 138 The length of the row `output[i]` will be the maximum of the lengths of 139 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 140 rows correspond to a given segment ID, then the output row for that segment 141 ID will be empty. 142 143 Args: 144 data: A `RaggedTensor` containing the values to combine. 145 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 146 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 147 Must be greater than or equal to zero, and less than `num_segments`. 148 `segment_ids` is not required to be sorted. 149 num_segments: An `int32` or `int64` scalar specifying the number of 150 distinct segment ids. 151 name: A name prefix for the returned tensor (optional). 152 Returns: 153 A `RaggedTensor` containing the %(combined)s values. The returned tensor 154 has the same dtype as `data`, and its shape is 155 `[num_segments] + data.shape[segment_ids.rank:]`. 156 Raises: 157 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`. 158""" 159 160 161def _ragged_segment_aggregate(unsorted_segment_op, 162 data, 163 segment_ids, 164 num_segments, 165 separator=None, 166 name=None): 167 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. 168 169 Returns a RaggedTensor `output` with `num_segments` rows, where the row 170 `output[i]` is formed by combining all rows of `data` whose corresponding 171 `segment_id` is `i`. The values in each row are combined using 172 `unsorted_segment_op`. 173 174 The length of the row `output[i]` will be the maximum of the lengths of 175 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 176 rows correspond to a given segment ID, then the output row for that segment 177 ID will be empty. 178 179 Args: 180 unsorted_segment_op: The tensorflow `op` that should be used to combine 181 values in each row. Must have the same signature and basic behavior as 182 `unsorted_segment_sum`, `unsorted_segment_max`, etc. 183 data: A `RaggedTensor` containing the values to be combined. 184 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 185 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 186 `segment_ids` is not required to be sorted. 187 num_segments: An `int32` or `int64` scalar. 188 separator: An optional string. Defaults to None. The separator to use when 189 joining. Only used for string types. 190 name: A name prefix for the returned tensor (optional). 191 192 Returns: 193 A `RaggedTensor` containing the aggregated values. The returned tensor 194 has the same dtype as `data`, and its shape is 195 `[num_segments] + data.shape[segment_ids.rank:]`. 196 Raises: 197 ValueError: If segment_ids.shape is not a prefix of data.shape. 198 """ 199 if not (ragged_tensor.is_ragged(data) or 200 ragged_tensor.is_ragged(segment_ids)): 201 if separator is not None: 202 # It uses unsorted_segment_join. 203 return unsorted_segment_op(data, segment_ids, num_segments, separator, 204 name) 205 else: 206 return unsorted_segment_op(data, segment_ids, num_segments, name) 207 208 with ops.name_scope(name, 'RaggedSegment', 209 [data, segment_ids, num_segments]) as name: 210 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 211 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( 212 segment_ids, name='segment_ids') 213 data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) 214 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 215 raise ValueError('segment_ids must have dtype int32 or int64.') 216 217 if ragged_tensor.is_ragged(segment_ids): 218 if not ragged_tensor.is_ragged(data): 219 raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 220 'but segment_ids is ragged and data is not.') 221 check_splits = check_ops.assert_equal( 222 segment_ids.row_splits, 223 data.row_splits, 224 message='segment_ids.shape must be a prefix of data.shape') 225 with ops.control_dependencies([check_splits]): 226 return _ragged_segment_aggregate(unsorted_segment_op, data.values, 227 segment_ids.values, num_segments, 228 separator) 229 230 # Find the length of each row in data. (shape=[data_nrows]) 231 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] 232 233 # Find the length that each output row will have. The length of the row 234 # corresponding to segment `id` is `max(data_row_lengths[i])` where 235 # `segment_ids[i]=id`. (shape=[output_nrows]) 236 output_row_lengths = math_ops.maximum( 237 math_ops.unsorted_segment_max(data_row_lengths, segment_ids, 238 num_segments), 0) 239 240 # Build the splits tensor for the output RaggedTensor. 241 output_splits = array_ops.concat([ 242 array_ops.zeros([1], output_row_lengths.dtype), 243 math_ops.cumsum(output_row_lengths) 244 ], 245 axis=0) 246 247 # For each row in `data`, find the start & limit position where that row's 248 # values will be aggregated in output.values. 249 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) 250 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths 251 252 # For each value in `data.values`, find the position where it will 253 # aggregated in `output.values`. 254 # Get the target output values index for each data values index. 255 data_val_to_out_val_index = range(data_row_to_out_row_start, 256 data_row_to_out_row_limit).values 257 258 # Recursively aggregate the values. 259 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, 260 data_val_to_out_val_index, 261 output_splits[-1], separator) 262 return ragged_tensor.RaggedTensor.from_row_splits( 263 output_values, output_splits, validate=False) 264 265 266@dispatch.dispatch_for_api(math_ops.unsorted_segment_sum) 267def segment_sum(data: ragged_tensor.RaggedOrDense, 268 segment_ids: ragged_tensor.RaggedOrDense, 269 num_segments, 270 name=None): 271 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 272 return _ragged_segment_aggregate( 273 math_ops.unsorted_segment_sum, 274 data=data, 275 segment_ids=segment_ids, 276 num_segments=num_segments, 277 name=(name or 'RaggedSegmentSum')) 278 279 280@dispatch.dispatch_for_api(math_ops.unsorted_segment_prod) 281def segment_prod(data: ragged_tensor.RaggedOrDense, 282 segment_ids: ragged_tensor.RaggedOrDense, 283 num_segments, 284 name=None): 285 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 286 return _ragged_segment_aggregate( 287 math_ops.unsorted_segment_prod, 288 data=data, 289 segment_ids=segment_ids, 290 num_segments=num_segments, 291 name=(name or 'RaggedSegmentProd')) 292 293 294@dispatch.dispatch_for_api(math_ops.unsorted_segment_min) 295def segment_min(data: ragged_tensor.RaggedOrDense, 296 segment_ids: ragged_tensor.RaggedOrDense, 297 num_segments, 298 name=None): 299 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 300 return _ragged_segment_aggregate( 301 math_ops.unsorted_segment_min, 302 data=data, 303 segment_ids=segment_ids, 304 num_segments=num_segments, 305 name=(name or 'RaggedSegmentMin')) 306 307 308@dispatch.dispatch_for_api(math_ops.unsorted_segment_max) 309def segment_max(data: ragged_tensor.RaggedOrDense, 310 segment_ids: ragged_tensor.RaggedOrDense, 311 num_segments, 312 name=None): 313 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 314 return _ragged_segment_aggregate( 315 math_ops.unsorted_segment_max, 316 data=data, 317 segment_ids=segment_ids, 318 num_segments=num_segments, 319 name=(name or 'RaggedSegmentMax')) 320 321 322@dispatch.dispatch_for_api(math_ops.unsorted_segment_mean) 323def segment_mean(data: ragged_tensor.RaggedOrDense, 324 segment_ids: ragged_tensor.RaggedOrDense, 325 num_segments, 326 name=None): 327 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 328 with ops.name_scope(name, 'RaggedSegmentMean', 329 [data, segment_ids, num_segments]): 330 total = segment_sum(data, segment_ids, num_segments) 331 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 332 array_ops.ones_like(data.flat_values), 333 data.nested_row_splits, 334 validate=False) 335 count = segment_sum(ones, segment_ids, num_segments) 336 if ragged_tensor.is_ragged(total): 337 return total.with_flat_values(total.flat_values / count.flat_values) 338 else: 339 return total / count 340 341 342@dispatch.dispatch_for_api(math_ops.unsorted_segment_sqrt_n) 343def segment_sqrt_n(data: ragged_tensor.RaggedOrDense, 344 segment_ids: ragged_tensor.RaggedOrDense, 345 num_segments, 346 name=None): 347 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 348 with ops.name_scope(name, 'RaggedSegmentSqrtN', 349 [data, segment_ids, num_segments]): 350 total = segment_sum(data, segment_ids, num_segments) 351 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 352 array_ops.ones_like(data.flat_values), 353 data.nested_row_splits, 354 validate=False) 355 count = segment_sum(ones, segment_ids, num_segments) 356 if ragged_tensor.is_ragged(total): 357 return total.with_flat_values(total.flat_values / 358 math_ops.sqrt(count.flat_values)) 359 else: 360 return total / math_ops.sqrt(count) 361 362 363def _set_ragged_segment_docstring(func, combination, combined): 364 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict( 365 combination=combination, combined=combined) 366 367 368_set_ragged_segment_docstring(segment_sum, 'sum', 'summed') 369_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied') 370_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized') 371_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized') 372_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged') 373_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 374 'summed') 375 376#=============================================================================== 377# ragged_reduce_<AGGREGATE> 378#=============================================================================== 379 380# Docstring template used for ragged_reduce_<AGGREGATE> ops. 381_RAGGED_REDUCE_DOCSTRING = """\ 382Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. 383 384 Reduces `input_tensor` along the dimensions given in `axis` by taking the 385 %(combination)s of values. If a reduced dimension has no elements for 386 some index, then the value for that index will be %(default)s. 387 388 The rank of the tensor is reduced by `1` for each entry in `axis`. If 389 `axis` is not specified, then all dimensions are reduced, and a scalar 390 value is returned. 391 Args: 392 input_tensor: A `RaggedTensor` containing the values to be %(combined)s. 393 axis: The dimensions to reduce. May be `None` (to reduce all axes), an 394 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce 395 a given set of axes), or a `Tensor` with a constant value. Must be in 396 the range `[0, input_tensor.rank]`. 397 name: A name prefix for the returned tensor (optional). 398 Returns: 399 A `RaggedTensor` containing the %(combined)s values. The returned tensor 400 has the same dtype as `data`, and its shape is given by removing the 401 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` 402 of the returned tensor is given by substracting any ragged dimensions 403 specified in `axis` from `input_tensor.ragged_rank`. 404 Raises: 405 ValueError: If `axis` contains a `Tensor` whose value is not constant. 406 ####Example: 407 %(example)s 408""" 409_RAGGED_REDUCE_SUM_EXAMPLE = """ 410 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 411 >>> tf.reduce_sum(rt, axis=0).numpy() # = [3+1+9+2, 1+5+6, 4] 412 array([15, 12, 4], dtype=int32) 413 >>> tf.reduce_sum(rt, axis=1).numpy() # = [3+1+4, 1+5, 9, 2+6] 414 array([8, 6, 9, 8], dtype=int32) 415""" 416_RAGGED_REDUCE_PROD_EXAMPLE = """ 417 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 418 >>> tf.reduce_prod(rt, axis=0).numpy() # = [3*1*9*2, 1*5*6, 4] 419 array([54, 30, 4], dtype=int32) 420 >>> tf.reduce_prod(rt, axis=1).numpy() # = [3*1*4, 1*5, 9, 2*6] 421 array([12, 5, 9, 12], dtype=int32) 422""" 423_RAGGED_REDUCE_MIN_EXAMPLE = """ 424 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 425 >>> tf.reduce_min(rt, axis=0).numpy() 426 array([1, 1, 4], dtype=int32) 427 >>> tf.reduce_min(rt, axis=1).numpy() 428 array([1, 1, 9, 2], dtype=int32) 429""" 430_RAGGED_REDUCE_MAX_EXAMPLE = """ 431 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 432 >>> tf.reduce_max(rt, axis=0).numpy() 433 array([9, 6, 4], dtype=int32) 434 >>> tf.reduce_max(rt, axis=1).numpy() 435 array([4, 5, 9, 6], dtype=int32) 436""" 437_RAGGED_REDUCE_MEAN_EXAMPLE = """ 438 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 439 >>> tf.reduce_mean(rt, axis=0).numpy() 440 array([3.75, 4. , 4. ]) 441 >>> tf.reduce_mean(rt, axis=1).numpy() 442 array([2.66666667, 3. , 9. , 4. ]) 443""" 444_RAGGED_REDUCE_VARIANCE_EXAMPLE = """ 445 >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]], 446 ... dtype=tf.float64) 447 >>> tf.math.reduce_variance(rt, axis=0).numpy() 448 array([1.25, 0., 0.]) 449 >>> tf.math.reduce_variance(rt, axis=1).numpy() 450 array([2., 0.25, 0., 2.25]) 451""" 452_RAGGED_REDUCE_STD_EXAMPLE = """ 453 >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]], 454 ... dtype=tf.float64) 455 >>> tf.math.reduce_std(rt, axis=0).numpy() 456 array([1.11803399, 0.47140452]) 457 >>> tf.math.reduce_std(rt, axis=1).numpy() 458 array([0.5, 0.5, 0., 1.5]) 459""" 460_RAGGED_REDUCE_ALL_EXAMPLE = """ 461 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 462 >>> tf.reduce_all(rt, axis=0).numpy() 463 array([False, True, False, True]) 464 >>> tf.reduce_all(rt, axis=1).numpy() 465 array([ True, False, False]) 466""" 467_RAGGED_REDUCE_ANY_EXAMPLE = """ 468 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 469 >>> tf.reduce_any(rt, axis=0).numpy() 470 array([ True, True, False, True]) 471 >>> tf.reduce_any(rt, axis=1).numpy() 472 array([ True, True, True]) 473""" 474 475 476def ragged_reduce_aggregate(reduce_op, 477 unsorted_segment_op, 478 rt_input, 479 axis, 480 keepdims, 481 separator=None, 482 name=None): 483 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. 484 485 Reduces `rt_input` along the dimensions given in `axis`. The rank of the 486 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified, 487 then all dimensions are reduced, and a scalar value is returned. 488 489 This op assumes that `reduce_op` and `unsorted_segment_op` are associative; 490 if not, then reducing multiple axes will return incorrect results. (In 491 particular, reducing multiple axes is currently implemented by reducing the 492 axes one at a time.) 493 494 Args: 495 reduce_op: The tensorflow `op` that should be used to reduce values in 496 uniform dimensions. Must have the same signature and basic behavior as 497 `reduce_sum`, `reduce_max`, etc. 498 unsorted_segment_op: The tensorflow `op` that should be used to combine 499 values in ragged dimensions. Must have the same signature and basic 500 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. 501 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced. 502 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an 503 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a 504 given set of axes), or a `Tensor` with a constant value. Must be in the 505 range `[0, rt_input.rank)`. 506 keepdims: If true, retains reduced dimensions with length 1. 507 separator: An optional string. Defaults to None. The separator to use when 508 joining. The separator must not be set for non-string data types. (i.e. if 509 separator is not None then it uses string ops) 510 name: A name prefix for the returned tensor (optional). 511 512 Returns: 513 A `RaggedTensor` containing the reduced values. The returned tensor 514 has the same dtype as `data`, and its shape is given by removing the 515 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` 516 of the returned tensor is given by substracting any ragged dimensions 517 specified in `axis` from `rt_input.ragged_rank`. 518 Raises: 519 ValueError: If `axis` contains a `Tensor` whose value is not constant. 520 """ 521 if not ragged_tensor.is_ragged(rt_input): 522 if separator is None: 523 return reduce_op(rt_input, axis, keepdims=keepdims, name=name) 524 else: 525 # When separator is not None, We infer that dtype is string and 526 # reduce_join will be called. 527 return reduce_op( 528 rt_input, axis, keepdims=keepdims, name=name, separator=separator) 529 530 if isinstance(axis, ops.Tensor): 531 axis = tensor_util.constant_value(axis) 532 if axis is None: 533 raise ValueError('axis must be known at graph construction time.') 534 if isinstance(axis, np.ndarray): 535 axis = axis.tolist() 536 537 # When reducing all axes, just ignore splits & reduce the inner values. 538 if axis is None: 539 result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name) 540 if keepdims: 541 # Expand the result to the input number of dimensions. 542 for _ in rt_input.shape[1:]: 543 result = array_ops.expand_dims(result, axis=0) 544 return result 545 546 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): 547 if isinstance(axis, (tuple, list)): 548 if not axis: 549 return rt_input 550 elif len(axis) == 1: 551 axis = axis[0] 552 else: 553 # When reducing multiple axes, as we reduce one at a time (see below), 554 # the negative axis has to be converted to positive at the first run 555 # as the sort with negative axis will have different orders. 556 # See GitHub issue 27497. 557 axis = [ 558 array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i, 559 'rank(input_tensor)') 560 for i, a in enumerate(axis) 561 ] 562 # When reducing multiple axes, just reduce one at a time. This is less 563 # efficient, and only works for associative ops. (In particular, it 564 # does not work for reduce_mean.) However, reducing multiple axes at 565 # once will probably require a nontrivial c++ op. 566 axis = sorted(axis) 567 inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 568 rt_input, axis[-1], keepdims, 569 separator) 570 return ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 571 inner_reduced, axis[:-1], keepdims, 572 separator) 573 574 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 575 rt_input, name='rt_input') 576 577 axis = array_ops.get_positive_axis( 578 axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)') 579 580 if axis == 0: 581 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] 582 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] 583 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) 584 segment_ids = range(row_lengths).values 585 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 586 segment_ids, num_segments, separator) 587 if keepdims: 588 result = array_ops.expand_dims(result, axis=0) 589 return result 590 elif axis == 1: 591 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] 592 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 593 segment_ids = segment_id_ops.row_splits_to_segment_ids( 594 rt_input.row_splits) 595 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 596 segment_ids, num_segments, separator) 597 if keepdims: 598 result = array_ops.expand_dims(result, axis=1) 599 return result 600 else: 601 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = 602 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] 603 return rt_input.with_values( 604 ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 605 rt_input.values, axis - 1, keepdims, 606 separator)) 607 608 609@dispatch.dispatch_for_api(math_ops.reduce_sum) 610def reduce_sum(input_tensor: ragged_tensor.Ragged, 611 axis=None, 612 keepdims=None, 613 name=None): 614 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 615 616 return ragged_reduce_aggregate( 617 reduce_op=math_ops.reduce_sum, 618 unsorted_segment_op=math_ops.unsorted_segment_sum, 619 rt_input=input_tensor, 620 axis=axis, 621 keepdims=keepdims, 622 name=(name or 'RaggedReduceSum')) 623 624 625@dispatch.dispatch_for_api(math_ops.reduce_prod) 626def reduce_prod(input_tensor: ragged_tensor.Ragged, 627 axis=None, 628 keepdims=None, 629 name=None): 630 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 631 return ragged_reduce_aggregate( 632 reduce_op=math_ops.reduce_prod, 633 unsorted_segment_op=math_ops.unsorted_segment_prod, 634 rt_input=input_tensor, 635 axis=axis, 636 keepdims=keepdims, 637 name=(name or 'RaggedReduceProd')) 638 639 640@dispatch.dispatch_for_api(math_ops.reduce_min) 641def reduce_min(input_tensor: ragged_tensor.Ragged, 642 axis=None, 643 keepdims=None, 644 name=None): 645 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 646 return ragged_reduce_aggregate( 647 reduce_op=math_ops.reduce_min, 648 unsorted_segment_op=math_ops.unsorted_segment_min, 649 rt_input=input_tensor, 650 axis=axis, 651 keepdims=keepdims, 652 name=(name or 'RaggedReduceMin')) 653 654 655@dispatch.dispatch_for_api(math_ops.reduce_max) 656def reduce_max(input_tensor: ragged_tensor.Ragged, 657 axis=None, 658 keepdims=None, 659 name=None): 660 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 661 return ragged_reduce_aggregate( 662 reduce_op=math_ops.reduce_max, 663 unsorted_segment_op=math_ops.unsorted_segment_max, 664 rt_input=input_tensor, 665 axis=axis, 666 keepdims=keepdims, 667 name=(name or 'RaggedReduceMax')) 668 669 670@dispatch.dispatch_for_api(math_ops.reduce_mean) 671def reduce_mean(input_tensor: ragged_tensor.Ragged, 672 axis=None, 673 keepdims=None, 674 name=None): 675 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 676 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): 677 total = reduce_sum(input_tensor, axis, keepdims) 678 if ragged_tensor.is_ragged(input_tensor): 679 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 680 array_ops.ones_like(input_tensor.flat_values), 681 input_tensor.nested_row_splits, 682 validate=False) 683 else: 684 ones = array_ops.ones_like(input_tensor) 685 count = reduce_sum(ones, axis, keepdims) 686 if ragged_tensor.is_ragged(total): 687 return ragged_tensor.RaggedTensor.from_nested_row_splits( 688 total.flat_values / count.flat_values, 689 total.nested_row_splits, 690 validate=False) 691 else: 692 return total / count 693 694 695@dispatch.dispatch_for_api(math_ops.reduce_variance) 696def reduce_variance(input_tensor: ragged_tensor.Ragged, 697 axis=None, 698 keepdims=False, 699 name=None): 700 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 701 with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]): 702 input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 703 input_tensor, name='input_tensor') 704 if input_tensor.dtype.is_complex: 705 raise ValueError( 706 'reduce_variance is not supported for RaggedTensors with complex dtypes.' 707 ) 708 square_of_input = math_ops.square(input_tensor) 709 mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims) 710 mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims) 711 square_of_mean = math_ops.square(mean) 712 # Note: the above method of computing variance is not numerically stable, 713 # and can result in negative variances. Here we clip to >= 0. 714 return math_ops.maximum(mean_of_square - square_of_mean, 0) 715 716 717@dispatch.dispatch_for_api(math_ops.reduce_std) 718def reduce_std(input_tensor: ragged_tensor.Ragged, 719 axis=None, 720 keepdims=False, 721 name=None): 722 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 723 with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]): 724 variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims) 725 return math_ops.sqrt(variance) 726 727 728def _cast(input_tensor, dtype): 729 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor, 730 dtype) 731 732 733@dispatch.dispatch_for_api(math_ops.reduce_all) 734def reduce_all(input_tensor: ragged_tensor.Ragged, 735 axis=None, 736 keepdims=None, 737 name=None): 738 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 739 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): 740 return _cast( 741 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), 742 dtypes.bool) 743 744 745@dispatch.dispatch_for_api(math_ops.reduce_any) 746def reduce_any(input_tensor: ragged_tensor.Ragged, 747 axis=None, 748 keepdims=None, 749 name=None): 750 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 751 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): 752 return _cast( 753 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), 754 dtypes.bool) 755 756 757def _set_ragged_reduce_docstring(func, combination, combined, default, example): 758 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict( 759 combination=combination, 760 combined=combined, 761 default=default, 762 example=example) 763 764 765_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', 766 _RAGGED_REDUCE_SUM_EXAMPLE) 767_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', 768 _RAGGED_REDUCE_PROD_EXAMPLE) 769_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', 770 '`input_tensor.dtype.min`', 771 _RAGGED_REDUCE_MIN_EXAMPLE) 772_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', 773 '`input_tensor.dtype.max`', 774 _RAGGED_REDUCE_MAX_EXAMPLE) 775_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', 776 _RAGGED_REDUCE_MEAN_EXAMPLE) 777_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN', 778 _RAGGED_REDUCE_VARIANCE_EXAMPLE) 779_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN', 780 _RAGGED_REDUCE_STD_EXAMPLE) 781_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True', 782 _RAGGED_REDUCE_ALL_EXAMPLE) 783_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False', 784 _RAGGED_REDUCE_ANY_EXAMPLE) 785 786 787#=============================================================================== 788# ragged.matmul 789#=============================================================================== 790@dispatch.dispatch_for_api(math_ops.matmul) 791def matmul(a: ragged_tensor.RaggedOrDense, 792 b: ragged_tensor.RaggedOrDense, 793 transpose_a=False, 794 transpose_b=False, 795 adjoint_a=False, 796 adjoint_b=False, 797 a_is_sparse=False, 798 b_is_sparse=False, 799 output_type=None, 800 name=None): 801 """Multiplies matrix `a` by matrix `b`. 802 803 If all transpose or adjoint attributes are `False` then: 804 805 ``` 806 output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j. 807 ``` 808 809 The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2` 810 dimensions are batch dimensions. The inputs must have the same dtype. See 811 `tf.matmul` for more information. 812 813 Args: 814 a: `tf.Tensor` or `RaggedTensor` with `rank > 1`. 815 b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`. 816 transpose_a: If `True`, `a` is transposed before multiplication. 817 transpose_b: If `True`, `b` is transposed before multiplication. 818 adjoint_a: If `True`, `a` is conjugated & transposed before multiplication. 819 adjoint_b: If `True`, `b` is conjugated & transposed before multiplication. 820 a_is_sparse: If `True`, optimize assuming `a` is mostly zero. 821 b_is_sparse: If `True`, optimize assuming `b` is mostly zero. 822 output_type: The output datatype (optional). 823 name: Name for the operation (optional). 824 825 Returns: 826 A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where 827 each inner-most matrix is the product of the corresponding matrices in `a` 828 and `b`. 829 """ 830 if transpose_a and adjoint_a: 831 raise ValueError('Only one of transpose_a and adjoint_a can be True.') 832 if transpose_b and adjoint_b: 833 raise ValueError('Only one of transpose_b and adjoint_b can be True.') 834 835 kwargs = dict( 836 transpose_a=transpose_a, 837 transpose_b=transpose_b, 838 adjoint_a=adjoint_a, 839 adjoint_b=adjoint_b, 840 a_is_sparse=a_is_sparse, 841 b_is_sparse=b_is_sparse, 842 output_type=output_type) 843 844 with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name: 845 a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a') 846 b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b') 847 848 a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor) 849 b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor) 850 if not (a_is_ragged or b_is_ragged): 851 return math_ops.matmul(a, b, **kwargs) 852 853 if a.dtype != b.dtype: 854 raise ValueError('`a` and `b` must have the same dtype.') 855 856 # TODO(edloper): Support broadcasting inputs. (Broadcast support is not 857 # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul, 858 # but it is supported by the op.) 859 860 # Find the rank of the input tensors. 861 if a.shape.rank is None: 862 if b.shape.rank is None: 863 raise ValueError('matmul requires at least one input to have known ' 864 'rank if either input is ragged.') 865 rank = b.shape.rank 866 else: 867 if b.shape.rank is not None and a.shape.rank != b.shape.rank: 868 raise ValueError('`a` and `b` must have the same rank.') 869 rank = a.shape.rank 870 871 # At least one of `a` and `b` is ragged; and ragged tensors always have 872 # rank>=2. 873 if rank < 2: 874 # This can happen if e.g. `a` is a 1D dense tensor and `b` is a 875 # ragged tensor with unknown rank. Since ragged tensors always have 876 # `rank>=2`, this implies that `a` and `b` have different ranks. 877 raise ValueError('`a` and `b` must have the same rank.') 878 879 # Rank>3: We have multiple batch dimensions. Merge them into a single 880 # batch dimension, recursively call `matmul`, and then restore the original 881 # batch dimension (using a.row_splits). 882 if rank > 3: 883 shape_err = 'Batch dimensions of `a` and `b` do not have the same size.' 884 if not a_is_ragged: 885 a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1) 886 if not b_is_ragged: 887 b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1) 888 with ops.control_dependencies([ 889 check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err) 890 ]): 891 flat_result = matmul(a.values, b.values, **kwargs) 892 return a.with_values(flat_result) 893 894 if rank == 2: 895 return _matmul_2d(a, b, **kwargs) 896 897 assert rank == 3 # I.e., we have a single batch dimension. 898 899 a_ragged_rank = a.ragged_rank if a_is_ragged else 0 900 if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a): 901 # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute 902 # the result with a single dense `matmul`. 903 return _matmul_3d_with_batch_dim_folding(a, b, **kwargs) 904 else: 905 # Otherwie, fall back on using `map_fn`. 906 return _matmul_3d_with_map_fn(a, b, **kwargs) 907 908 909def _matmul_2d(a, b, **kwargs): 910 """Multiplies potentially ragged 2D tensors. 911 912 Args: 913 a: A 2D Tensor or RaggedTensor with `shape=[I, J]` 914 b: A 2D Tensor or RaggedTensor with `shape=[J, K]` 915 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 916 917 Returns: 918 A 2D Tensor with `shape=[I, K]`. 919 """ 920 # multiplying `a` and `b` is only well-defined if `a` and `b` are 921 # actually uniform (and just happened to be stored as ragged tensors). 922 # Check that they're uniform, convert them to tf.Tensor. 923 ragged_err = ('The matrices in `a` and `b` may not be ' 924 'ragged in their innermost dimension.') 925 checks = [] 926 if isinstance(a, ragged_tensor.RaggedTensor): 927 original_size = array_ops.size(a.flat_values) 928 a = a.to_tensor() 929 checks.append( 930 check_ops.assert_equal( 931 original_size, array_ops.size(a), message=ragged_err)) 932 if isinstance(b, ragged_tensor.RaggedTensor): 933 original_size = array_ops.size(b.flat_values) 934 b = b.to_tensor() 935 checks.append( 936 check_ops.assert_equal( 937 original_size, array_ops.size(b), message=ragged_err)) 938 with ops.control_dependencies(checks): 939 return math_ops.matmul(a, b, **kwargs) 940 941 942def _matmul_3d_with_map_fn(a, b, **kwargs): 943 """Multiplies batches of 2D matrices using map_fn. 944 945 `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`). 946 947 Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`). 948 949 Args: 950 a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I` 951 and `J` may be ragged. 952 b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J` 953 and `K` may be ragged. 954 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 955 956 Returns: 957 A 3D RaggedTensor with `shape=[B, (I), (K)]`. 958 """ 959 # Determine the ragged rank of the result. In the normal case, we have: 960 # [B, I, J] * [B, J, K] -> [B, I, K] 961 # Or if we're using transpose_b, then we have: 962 # [B, I, J] * [B, K, J] -> [B, I, K] 963 # In either case, output_ragged_rank=2 iff the K dimension is ragged. 964 if (isinstance(b, ragged_tensor.RaggedTensor) and 965 (b.ragged_rank == 2 or kwargs.get('transpose_b') or 966 kwargs.get('adjoint_b'))): 967 output_ragged_rank = 2 968 else: 969 output_ragged_rank = 1 970 971 def single_batch_matmul(x): 972 out = _matmul_2d(x[0], x[1], **kwargs) 973 if output_ragged_rank == 2: 974 out = ragged_tensor.RaggedTensor.from_tensor(out) 975 return out 976 977 fn_out_shape = None # Figure out proper shape. 978 row_splits_dtype = ( 979 a.row_splits.dtype 980 if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype) 981 output_type = kwargs['output_type'] 982 if output_type is None: 983 output_type = a.dtype 984 spec = ragged_tensor.RaggedTensorSpec( 985 shape=fn_out_shape, 986 dtype=output_type, 987 ragged_rank=output_ragged_rank - 1, 988 row_splits_dtype=row_splits_dtype) 989 result = map_fn.map_fn( 990 single_batch_matmul, elems=(a, b), fn_output_signature=spec) 991 992 # map_fn loses shape information; restore it, where possible. 993 # pylint: disable=protected-access 994 if kwargs.get('transpose_a') or kwargs.get('adjoint_a'): 995 result._set_shape(a.shape[:-2] + a.shape[-1:] + [None]) 996 else: 997 result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None]) 998 if kwargs.get('transpose_b') or kwargs.get('adjoint_b'): 999 result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1]) 1000 else: 1001 result._set_shape(b.shape[:-2] + [None] + b.shape[-1:]) 1002 1003 return result 1004 1005 1006def _matmul_3d_with_batch_dim_folding(a, b, **kwargs): 1007 """Multiply batches of 2D matrices where only `a.shape[1]` is ragged. 1008 1009 Args: 1010 a: A RaggedTensor with `shape=[B, (I), J]`. (ragged_rank must be 1.) 1011 b: A Tensor with `shape=[B, J, K]` 1012 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 1013 transpose_a and adjoint_a must not be true. 1014 1015 Returns: 1016 A RaggedTensor with `shape=[B, (I), K]. 1017 """ 1018 # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J] 1019 reshaped_a = array_ops.expand_dims(a.values, 1) 1020 # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K] 1021 reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0) 1022 # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K] 1023 flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs) 1024 # result.shape = [B, (I), K] 1025 return a.with_values(array_ops.squeeze(flat_result, axis=1)) 1026 1027 1028#=============================================================================== 1029# ragged.softmax 1030#=============================================================================== 1031@dispatch.dispatch_for_api(nn_ops.softmax_v2) 1032def softmax(logits: ragged_tensor.Ragged, axis=None, name=None): 1033 """Computes softmax activations. 1034 1035 Used for multi-class predictions. The sum of all outputs generated by softmax 1036 is 1. 1037 1038 This function performs the equivalent of 1039 1040 softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) 1041 1042 Example usage: 1043 1044 >>> softmax = tf.nn.softmax([-1, 0., 1.]) 1045 >>> softmax 1046 <tf.Tensor: shape=(3,), dtype=float32, 1047 numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)> 1048 >>> sum(softmax) 1049 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 1050 1051 Args: 1052 logits: A non-empty `Tensor`. Must be one of the following types: `half`, 1053 `float32`, `float64`. 1054 axis: The dimension softmax would be performed on. The default is -1 which 1055 indicates the last dimension. 1056 name: A name for the operation (optional). 1057 1058 Returns: 1059 A `Tensor`. Has the same type and shape as `logits`. 1060 1061 Raises: 1062 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last 1063 dimension of `logits`. 1064 """ 1065 if axis is None: 1066 axis = -1 1067 1068 with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name: 1069 max_input = reduce_max(logits, axis=axis, keepdims=True) 1070 logits_exp = math_ops.exp(math_ops.subtract(logits, max_input)) 1071 denominator = reduce_sum(logits_exp, axis=axis, keepdims=True) 1072 return math_ops.divide(logits_exp, denominator) 1073 1074 1075#=============================================================================== 1076# ragged.add_n 1077#=============================================================================== 1078@dispatch.dispatch_for_api(math_ops.add_n) 1079def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None): 1080 """RaggedTensor implementation for tf.math.add_n.""" 1081 if len(inputs) < 0: 1082 raise ValueError('tf.add_n: expected at least one input.') 1083 with ops.name_scope(name, 'RaggedAddN', inputs): 1084 return ragged_functional_ops.map_flat_values(math_ops.add_n, inputs) 1085 1086 1087#=============================================================================== 1088# Ragged version of nn_ops.dropout 1089#=============================================================================== 1090@dispatch.dispatch_for_api(nn_ops.dropout) 1091def dropout_v1(x: ragged_tensor.Ragged, 1092 keep_prob=None, 1093 noise_shape=None, 1094 seed=None, 1095 name=None, 1096 rate=None): 1097 """Ragged dispatch target for tf.nn.dropout.""" 1098 if noise_shape is not None: 1099 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 1100 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]): 1101 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 1102 return x.with_flat_values( 1103 nn_ops.dropout( 1104 x.flat_values, keep_prob=keep_prob, seed=seed, rate=rate)) 1105 1106 1107@dispatch.dispatch_for_api(nn_ops.dropout_v2) 1108def dropout_v2(x: ragged_tensor.Ragged, 1109 rate, 1110 noise_shape=None, 1111 seed=None, 1112 name=None): 1113 """Ragged dispatch target for tf.nn.dropout.""" 1114 if noise_shape is not None: 1115 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 1116 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]): 1117 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 1118 return x.with_flat_values( 1119 nn_ops.dropout_v2(x.flat_values, rate=rate, seed=seed)) 1120 1121 1122@dispatch.dispatch_for_api(nn_ops.stateless_dropout) 1123def stateless_dropout(x: ragged_tensor.Ragged, 1124 rate, 1125 seed, 1126 rng_alg=None, 1127 noise_shape=None, 1128 name=None): 1129 """Ragged dispatch target for tf.nn.experimental.stateless_dropout.""" 1130 if noise_shape is not None: 1131 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 1132 with ops.name_scope(name, 'RaggedNNStatelessDropout', [x, rate]): 1133 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 1134 return x.with_flat_values( 1135 nn_ops.stateless_dropout( 1136 x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg)) 1137 1138 1139#=============================================================================== 1140# Ragged version of Tensor.__eq__ and Tensor.__ne__ 1141#=============================================================================== 1142@dispatch.dispatch_for_api(math_ops.tensor_equals) 1143def tensor_equals(self: ragged_tensor.RaggedOrDense, 1144 other: ragged_tensor.RaggedOrDense): 1145 """Ragged version of the operation invoked by `Tensor.__eq__`.""" 1146 if other is None: 1147 return False 1148 elif _use_legacy_mode_for_tensor_equality(self): 1149 return self is other 1150 else: 1151 try: 1152 return math_ops.equal(self, other) 1153 except (errors.InvalidArgumentError, ValueError): 1154 return False # values are not broadcast-compatbile. 1155 1156 1157@dispatch.dispatch_for_api(math_ops.tensor_not_equals) 1158def tensor_not_equals(self: ragged_tensor.RaggedOrDense, 1159 other: ragged_tensor.RaggedOrDense): 1160 """Ragged version of the operation invoked by `Tensor.__ne__`.""" 1161 if other is None: 1162 return False 1163 elif _use_legacy_mode_for_tensor_equality(self): 1164 return self is not other 1165 else: 1166 try: 1167 return math_ops.not_equal(self, other) 1168 except (errors.InvalidArgumentError, ValueError): 1169 return True # values are not broadcast-compatbile. 1170 1171 1172def _use_legacy_mode_for_tensor_equality(self): 1173 g = getattr(self, 'graph', None) 1174 return not (ops.Tensor._USE_EQUALITY and # pylint: disable=protected-access 1175 ops.executing_eagerly_outside_functions() and 1176 (g is None or g.building_function)) 1177 1178 1179def _cumsum_flat_values_at_ragged_rank(last_rp, flat_values, exclusive=False, 1180 reverse=False): 1181 """Calculate flat_values for math_ops.cumsum when axis==ragged_rank.""" 1182 if not exclusive: 1183 partial = _cumsum_flat_values_at_ragged_rank( 1184 last_rp, flat_values, exclusive=True, reverse=reverse) 1185 return partial + flat_values 1186 1187 if reverse: 1188 youngest_sibling = array_ops.gather( 1189 params=last_rp.row_splits(), indices=last_rp.value_rowids() + 1) - 1 1190 new_flat_values = math_ops.cumsum(flat_values, exclusive=True, reverse=True) 1191 initial_values = array_ops.gather(params=new_flat_values, 1192 indices=youngest_sibling) 1193 1194 return new_flat_values - initial_values 1195 else: 1196 eldest_sibling = array_ops.gather( 1197 params=last_rp.row_splits(), indices=last_rp.value_rowids()) 1198 new_flat_values = math_ops.cumsum(flat_values, exclusive=True) 1199 initial_values = array_ops.gather(params=new_flat_values, 1200 indices=eldest_sibling) 1201 return new_flat_values - initial_values 1202 1203 1204@dispatch.dispatch_for_api(math_ops.cumsum) 1205def ragged_cumsum(x: ragged_tensor.Ragged, 1206 axis: int = 0, 1207 exclusive: bool = False, 1208 reverse: bool = False, 1209 name: typing.Optional[str] = None): 1210 """Calculate math_ops.cumsum for a RaggedTensor. 1211 1212 Given a ragged tensor `x`, the `result` is a ragged tensor with the same 1213 shape. One can calculate the value of `result[i_1...i_k]` as follows: 1214 ``` 1215 dense_result=tf.math.cumsum(rt.to_tensor(), axis=axis, exclusive=exclusive, 1216 reverse=reverse) 1217 result[i_1...i_k]=dense_result[i_1...i_k] 1218 ``` 1219 1220 Args: 1221 x: the original ragged tensor to sum. 1222 axis: the axis along which to sum, can range -rank<=axis<rank. 1223 exclusive: is the sum exclusive or inclusive? If True, then result[0]=0. 1224 If False, then result[0]=x[0]. 1225 reverse: If True, sum from back to front. 1226 name: the name of the op. 1227 Returns: 1228 the cumulative sum. 1229 """ 1230 with ops.name_scope(name, 'RaggedCumSum', [x, axis, exclusive, reverse]): 1231 axis = array_ops.get_positive_axis(axis, x.shape.rank, ndims_name='rank') 1232 if axis == x.ragged_rank: 1233 last_rp = x._nested_row_partitions[-1] # pylint: disable=protected-access 1234 return x.with_flat_values( 1235 _cumsum_flat_values_at_ragged_rank(last_rp, x.flat_values, 1236 exclusive=exclusive, 1237 reverse=reverse)) 1238 elif axis > x.ragged_rank: 1239 new_axis = axis - x.ragged_rank 1240 cumsum_bound = functools.partial( 1241 math_ops.cumsum, axis=new_axis, exclusive=exclusive, reverse=reverse) 1242 return ragged_functional_ops.map_flat_values(cumsum_bound, x) 1243 else: 1244 dense_version = x.to_tensor() 1245 result = math_ops.cumsum( 1246 dense_version, axis, exclusive=exclusive, reverse=reverse, name=name) 1247 return ragged_tensor.RaggedTensor.from_tensor( 1248 result, lengths=x.nested_row_lengths()) 1249