1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Private convenience functions for RaggedTensors. 16 17None of these methods are exposed in the main "ragged" package. 18""" 19 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import check_ops 22from tensorflow.python.ops import gen_ragged_math_ops 23from tensorflow.python.ops import math_ops 24 25 26 27def assert_splits_match(nested_splits_lists): 28 """Checks that the given splits lists are identical. 29 30 Performs static tests to ensure that the given splits lists are identical, 31 and returns a list of control dependency op tensors that check that they are 32 fully identical. 33 34 Args: 35 nested_splits_lists: A list of nested_splits_lists, where each split_list is 36 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 37 ragged dimension to innermost ragged dimension. 38 39 Returns: 40 A list of control dependency op tensors. 41 Raises: 42 ValueError: If the splits are not identical. 43 """ 44 error_msg = "Inputs must have identical ragged splits" 45 for splits_list in nested_splits_lists: 46 if len(splits_list) != len(nested_splits_lists[0]): 47 raise ValueError(error_msg) 48 return [ 49 check_ops.assert_equal(s1, s2, message=error_msg) 50 for splits_list in nested_splits_lists[1:] 51 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 52 ] 53 54 55# Note: imported here to avoid circular dependency of array_ops. 56get_positive_axis = array_ops.get_positive_axis 57convert_to_int_tensor = array_ops.convert_to_int_tensor 58repeat = array_ops.repeat_with_axis 59 60 61def lengths_to_splits(lengths): 62 """Returns splits corresponding to the given lengths.""" 63 return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) 64 65 66def repeat_ranges(params, splits, repeats): 67 """Repeats each range of `params` (as specified by `splits`) `repeats` times. 68 69 Let the `i`th range of `params` be defined as 70 `params[splits[i]:splits[i + 1]]`. Then this function returns a tensor 71 containing range 0 repeated `repeats[0]` times, followed by range 1 repeated 72 `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times. 73 74 Args: 75 params: The `Tensor` whose values should be repeated. 76 splits: A splits tensor indicating the ranges of `params` that should be 77 repeated. 78 repeats: The number of times each range should be repeated. Supports 79 broadcasting from a scalar value. 80 81 Returns: 82 A `Tensor` with the same rank and type as `params`. 83 84 #### Example: 85 86 >>> print(repeat_ranges( 87 ... params=tf.constant(['a', 'b', 'c']), 88 ... splits=tf.constant([0, 2, 3]), 89 ... repeats=tf.constant(3))) 90 tf.Tensor([b'a' b'b' b'a' b'b' b'a' b'b' b'c' b'c' b'c'], 91 shape=(9,), dtype=string) 92 """ 93 # Divide `splits` into starts and limits, and repeat them `repeats` times. 94 if repeats.shape.ndims != 0: 95 repeated_starts = repeat(splits[:-1], repeats, axis=0) 96 repeated_limits = repeat(splits[1:], repeats, axis=0) 97 else: 98 # Optimization: we can just call repeat once, and then slice the result. 99 repeated_splits = repeat(splits, repeats, axis=0) 100 n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0] 101 repeated_starts = repeated_splits[:n_splits - repeats] 102 repeated_limits = repeated_splits[repeats:] 103 104 # Get indices for each range from starts to limits, and use those to gather 105 # the values in the desired repetition pattern. 106 one = array_ops.ones((), repeated_starts.dtype) 107 offsets = gen_ragged_math_ops.ragged_range( 108 repeated_starts, repeated_limits, one) 109 return array_ops.gather(params, offsets.rt_dense_values) 110