1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for converting between row_splits and segment_ids.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.framework import tensor_util 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops.ragged import ragged_util 24from tensorflow.python.util import dispatch 25from tensorflow.python.util.tf_export import tf_export 26 27 28# For background on "segments" and "segment ids", see: 29# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 30@tf_export("ragged.row_splits_to_segment_ids") 31@dispatch.add_dispatch_support 32def row_splits_to_segment_ids(splits, name=None, out_type=None): 33 """Generates the segmentation corresponding to a RaggedTensor `row_splits`. 34 35 Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if 36 `splits[j] <= i < splits[j+1]`. Example: 37 38 >>> print(tf.ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9])) 39 tf.Tensor([0 0 0 2 2 3 4 4 4], shape=(9,), dtype=int64) 40 41 Args: 42 splits: A sorted 1-D integer Tensor. `splits[0]` must be zero. 43 name: A name prefix for the returned tensor (optional). 44 out_type: The dtype for the return value. Defaults to `splits.dtype`, 45 or `tf.int64` if `splits` does not have a dtype. 46 47 Returns: 48 A sorted 1-D integer Tensor, with `shape=[splits[-1]]` 49 50 Raises: 51 ValueError: If `splits` is invalid. 52 """ 53 with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name: 54 splits = ops.convert_to_tensor( 55 splits, name="splits", 56 preferred_dtype=dtypes.int64) 57 if splits.dtype not in (dtypes.int32, dtypes.int64): 58 raise ValueError("splits must have dtype int32 or int64") 59 splits.shape.assert_has_rank(1) 60 if tensor_shape.dimension_value(splits.shape[0]) == 0: 61 raise ValueError("Invalid row_splits: []") 62 if out_type is None: 63 out_type = splits.dtype 64 else: 65 out_type = dtypes.as_dtype(out_type) 66 row_lengths = splits[1:] - splits[:-1] 67 nrows = array_ops.shape(splits, out_type=out_type)[-1] - 1 68 indices = math_ops.range(nrows) 69 return ragged_util.repeat(indices, repeats=row_lengths, axis=0) 70 71 72# For background on "segments" and "segment ids", see: 73# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 74@tf_export("ragged.segment_ids_to_row_splits") 75@dispatch.add_dispatch_support 76def segment_ids_to_row_splits(segment_ids, num_segments=None, 77 out_type=None, name=None): 78 """Generates the RaggedTensor `row_splits` corresponding to a segmentation. 79 80 Returns an integer vector `splits`, where `splits[0] = 0` and 81 `splits[i] = splits[i-1] + count(segment_ids==i)`. Example: 82 83 >>> print(tf.ragged.segment_ids_to_row_splits([0, 0, 0, 2, 2, 3, 4, 4, 4])) 84 tf.Tensor([0 3 3 5 6 9], shape=(6,), dtype=int64) 85 86 Args: 87 segment_ids: A 1-D integer Tensor. 88 num_segments: A scalar integer indicating the number of segments. Defaults 89 to `max(segment_ids) + 1` (or zero if `segment_ids` is empty). 90 out_type: The dtype for the return value. Defaults to `segment_ids.dtype`, 91 or `tf.int64` if `segment_ids` does not have a dtype. 92 name: A name prefix for the returned tensor (optional). 93 94 Returns: 95 A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`. 96 """ 97 # Local import bincount_ops to avoid import-cycle. 98 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top 99 if out_type is None: 100 if isinstance(segment_ids, ops.Tensor): 101 out_type = segment_ids.dtype 102 elif isinstance(num_segments, ops.Tensor): 103 out_type = num_segments.dtype 104 else: 105 out_type = dtypes.int64 106 else: 107 out_type = dtypes.as_dtype(out_type) 108 with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name: 109 # Note: we cast int64 tensors to int32, since bincount currently only 110 # supports int32 inputs. 111 segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids", 112 dtype=dtypes.int32) 113 segment_ids.shape.assert_has_rank(1) 114 if num_segments is not None: 115 num_segments = ragged_util.convert_to_int_tensor(num_segments, 116 "num_segments", 117 dtype=dtypes.int32) 118 num_segments.shape.assert_has_rank(0) 119 120 row_lengths = bincount_ops.bincount( 121 segment_ids, 122 minlength=num_segments, 123 maxlength=num_segments, 124 dtype=out_type) 125 splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 126 127 # Update shape information, if possible. 128 if num_segments is not None: 129 const_num_segments = tensor_util.constant_value(num_segments) 130 if const_num_segments is not None: 131 splits.set_shape(tensor_shape.TensorShape([const_num_segments + 1])) 132 133 return splits 134