xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/ragged_string_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ragged operations for working with string Tensors."""
16
17import typing
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_string_ops
27from tensorflow.python.ops import string_ops
28from tensorflow.python.ops.ragged import ragged_array_ops
29from tensorflow.python.ops.ragged import ragged_functional_ops
30from tensorflow.python.ops.ragged import ragged_math_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.util import compat as util_compat
33from tensorflow.python.util import deprecation
34from tensorflow.python.util import dispatch
35from tensorflow.python.util.lazy_loader import LazyLoader
36from tensorflow.python.util.tf_export import tf_export
37
38
39map_fn_lib = LazyLoader("map_fn_lib", globals(),
40                        "tensorflow.python.ops.map_fn")
41
42
43@tf_export("strings.bytes_split")
44@dispatch.add_dispatch_support
45def string_bytes_split(input, name=None):  # pylint: disable=redefined-builtin
46  """Split string elements of `input` into bytes.
47
48  Examples:
49
50  >>> tf.strings.bytes_split('hello').numpy()
51  array([b'h', b'e', b'l', b'l', b'o'], dtype=object)
52  >>> tf.strings.bytes_split(['hello', '123'])
53  <tf.RaggedTensor [[b'h', b'e', b'l', b'l', b'o'], [b'1', b'2', b'3']]>
54
55  Note that this op splits strings into bytes, not unicode characters.  To
56  split strings into unicode characters, use `tf.strings.unicode_split`.
57
58  See also: `tf.io.decode_raw`, `tf.strings.split`, `tf.strings.unicode_split`.
59
60  Args:
61    input: A string `Tensor` or `RaggedTensor`: the strings to split.  Must
62      have a statically known rank (`N`).
63    name: A name for the operation (optional).
64
65  Returns:
66    A `RaggedTensor` of rank `N+1`: the bytes that make up the source strings.
67  """
68  with ops.name_scope(name, "StringsByteSplit", [input]):
69    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input,
70                                                             name="input")
71    if isinstance(input, ragged_tensor.RaggedTensor):
72      return input.with_flat_values(string_bytes_split(input.flat_values))
73
74    rank = input.shape.ndims
75    if rank is None:
76      raise ValueError("input must have a statically-known rank.")
77
78    if rank == 0:
79      return string_bytes_split(array_ops.stack([input]))[0]
80    elif rank == 1:
81      indices, values, shape = gen_string_ops.string_split(
82          input, delimiter="", skip_empty=False)
83      return ragged_tensor.RaggedTensor.from_value_rowids(
84          values=values, value_rowids=indices[:, 0], nrows=shape[0],
85          validate=False)
86    else:
87      return string_bytes_split(ragged_tensor.RaggedTensor.from_tensor(input))
88
89
90# pylint: disable=redefined-builtin
91@tf_export("strings.unicode_encode")
92@dispatch.add_dispatch_support
93def unicode_encode(input,
94                   output_encoding,
95                   errors="replace",
96                   replacement_char=65533,
97                   name=None):
98  r"""Encodes each sequence of Unicode code points in `input` into a string.
99
100  `result[i1...iN]` is the string formed by concatenating the Unicode
101  codepoints `input[1...iN, :]`, encoded using `output_encoding`.
102
103  Args:
104    input: An `N+1` dimensional potentially ragged integer tensor with shape
105      `[D1...DN, num_chars]`.
106    output_encoding: Unicode encoding that should be used to encode each
107      codepoint sequence.  Can be `"UTF-8"`, `"UTF-16-BE"`, or `"UTF-32-BE"`.
108    errors: Specifies the response when an invalid codepoint is encountered
109      (optional). One of:
110            * `'replace'`: Replace invalid codepoint with the
111              `replacement_char`. (default)
112            * `'ignore'`: Skip invalid codepoints.
113            * `'strict'`: Raise an exception for any invalid codepoint.
114    replacement_char: The replacement character codepoint to be used in place of
115      any invalid input when `errors='replace'`. Any valid unicode codepoint may
116      be used. The default value is the default unicode replacement character
117      which is 0xFFFD (U+65533).
118    name: A name for the operation (optional).
119
120  Returns:
121    A `N` dimensional `string` tensor with shape `[D1...DN]`.
122
123  #### Example:
124
125  >>> input = tf.ragged.constant(
126  ...     [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
127  >>> print(unicode_encode(input, 'UTF-8'))
128  tf.Tensor([b'G\xc3\xb6\xc3\xb6dnight' b'\xf0\x9f\x98\x8a'],
129            shape=(2,), dtype=string)
130  """
131  with ops.name_scope(name, "UnicodeEncode", [input]):
132    input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
133    if input_tensor.shape.ndims is None:
134      raise ValueError("Rank of input_tensor must be statically known.")
135    if ragged_tensor.is_ragged(input_tensor):
136      if input_tensor.flat_values.shape.ndims > 1:
137        # If the flat_values of our ragged tensor is multi-dimensional, we can
138        # process it separately and our output will have the same nested splits
139        # as our input.
140        return input_tensor.with_flat_values(
141            unicode_encode(input_tensor.flat_values, output_encoding, errors,
142                           replacement_char))
143      elif input_tensor.ragged_rank > 1:
144        # Recursively process the values of the ragged tensor.
145        return input_tensor.with_values(
146            unicode_encode(input_tensor.values, output_encoding, errors,
147                           replacement_char))
148      else:
149        # Our ragged tensor is of the correct shape (rank 1 flat_values tensor
150        # with ragged_rank of 1) so we can process it as normal.
151        return gen_string_ops.unicode_encode(
152            input_values=input_tensor.values,
153            input_splits=input_tensor.row_splits,
154            output_encoding=output_encoding,
155            errors=errors,
156            replacement_char=replacement_char)
157    else:
158      if input_tensor.shape.ndims == 2:
159        # The input tensor is of the correct 2-D shape, it's just not ragged.
160        return unicode_encode(
161            ragged_tensor.RaggedTensor.from_tensor(input_tensor),
162            output_encoding, errors, replacement_char)
163      elif input_tensor.shape.ndims > 2:
164        # We need to initially flatten the input tensor to 2-D, and then can
165        # reshape the output of our processed flattened tensor.
166        flat_input_tensor = array_ops.reshape(
167            input_tensor,
168            array_ops.stack([-1, array_ops.shape(input_tensor)[-1]]))
169        flat_output_tensor = unicode_encode(flat_input_tensor, output_encoding,
170                                            errors, replacement_char)
171        return array_ops.reshape(flat_output_tensor, input_tensor.shape[:-1])
172      elif input_tensor.shape.ndims == 0:
173        raise ValueError("input_tensor's rank must be at least 1.")
174      else:
175        # Our input tensor is rank 1, so we create a ragged tensor with an added
176        # dimension to create the correct input shape & type, and then remove
177        # the additional dimension from the output and return the string scalar.
178        ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
179            input_tensor,
180            array_ops.stack(
181                [0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]),
182            validate=False)
183        output_tensor = unicode_encode(ragged_input_tensor, output_encoding,
184                                       errors, replacement_char)
185        return array_ops.reshape(output_tensor, [])
186
187
188# pylint: disable=redefined-builtin
189@tf_export("strings.unicode_decode")
190@dispatch.add_dispatch_support
191def unicode_decode(input,
192                   input_encoding,
193                   errors="replace",
194                   replacement_char=0xFFFD,
195                   replace_control_characters=False,
196                   name=None):
197  r"""Decodes each string in `input` into a sequence of Unicode code points.
198
199  `result[i1...iN, j]` is the Unicode codepoint for the `j`th character in
200  `input[i1...iN]`, when decoded using `input_encoding`.
201
202  Args:
203    input: An `N` dimensional potentially ragged `string` tensor with shape
204      `[D1...DN]`.  `N` must be statically known.
205    input_encoding: String name for the unicode encoding that should be used to
206      decode each string.
207    errors: Specifies the response when an input string can't be converted
208      using the indicated encoding. One of:
209      * `'strict'`: Raise an exception for any illegal substrings.
210      * `'replace'`: Replace illegal substrings with `replacement_char`.
211      * `'ignore'`: Skip illegal substrings.
212    replacement_char: The replacement codepoint to be used in place of invalid
213      substrings in `input` when `errors='replace'`; and in place of C0 control
214      characters in `input` when `replace_control_characters=True`.
215    replace_control_characters: Whether to replace the C0 control characters
216      `(U+0000 - U+001F)` with the `replacement_char`.
217    name: A name for the operation (optional).
218
219  Returns:
220    A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
221    The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
222    `tf.RaggedTensor` otherwise.
223
224  #### Example:
225
226  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
227  >>> tf.strings.unicode_decode(input, 'UTF-8').to_list()
228  [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
229  """
230  with ops.name_scope(name, "UnicodeDecode", [input]):
231    return _unicode_decode(input, input_encoding, errors, replacement_char,
232                           replace_control_characters, with_offsets=False)
233
234
235@tf_export("strings.unicode_decode_with_offsets")
236@dispatch.add_dispatch_support
237def unicode_decode_with_offsets(input,
238                                input_encoding,
239                                errors="replace",
240                                replacement_char=0xFFFD,
241                                replace_control_characters=False,
242                                name=None):
243  r"""Decodes each string into a sequence of code points with start offsets.
244
245  This op is similar to `tf.strings.decode(...)`, but it also returns the
246  start offset for each character in its respective string.  This information
247  can be used to align the characters with the original byte sequence.
248
249  Returns a tuple `(codepoints, start_offsets)` where:
250
251  * `codepoints[i1...iN, j]` is the Unicode codepoint for the `j`th character
252    in `input[i1...iN]`, when decoded using `input_encoding`.
253  * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
254    character in `input[i1...iN]`, when decoded using `input_encoding`.
255
256  Args:
257    input: An `N` dimensional potentially ragged `string` tensor with shape
258      `[D1...DN]`.  `N` must be statically known.
259    input_encoding: String name for the unicode encoding that should be used to
260      decode each string.
261    errors: Specifies the response when an input string can't be converted
262      using the indicated encoding. One of:
263      * `'strict'`: Raise an exception for any illegal substrings.
264      * `'replace'`: Replace illegal substrings with `replacement_char`.
265      * `'ignore'`: Skip illegal substrings.
266    replacement_char: The replacement codepoint to be used in place of invalid
267      substrings in `input` when `errors='replace'`; and in place of C0 control
268      characters in `input` when `replace_control_characters=True`.
269    replace_control_characters: Whether to replace the C0 control characters
270      `(U+0000 - U+001F)` with the `replacement_char`.
271    name: A name for the operation (optional).
272
273  Returns:
274    A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
275
276    * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
277    * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
278
279    The returned tensors are `tf.Tensor`s if `input` is a scalar, or
280    `tf.RaggedTensor`s otherwise.
281
282  #### Example:
283
284  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
285  >>> result = tf.strings.unicode_decode_with_offsets(input, 'UTF-8')
286  >>> result[0].to_list()  # codepoints
287  [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
288  >>> result[1].to_list()  # offsets
289  [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
290
291  """
292  with ops.name_scope(name, "UnicodeDecodeWithOffsets", [input]):
293    return _unicode_decode(input, input_encoding, errors, replacement_char,
294                           replace_control_characters, with_offsets=True)
295
296
297@tf_export("strings.unicode_split")
298@dispatch.add_dispatch_support
299def unicode_split(input,
300                  input_encoding,
301                  errors="replace",
302                  replacement_char=0xFFFD,
303                  name=None):
304  r"""Splits each string in `input` into a sequence of Unicode code points.
305
306  `result[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
307  `j`th character, when decoded using `input_encoding`.
308
309  Args:
310    input: An `N` dimensional potentially ragged `string` tensor with shape
311      `[D1...DN]`.  `N` must be statically known.
312    input_encoding: String name for the unicode encoding that should be used to
313      decode each string.
314    errors: Specifies the response when an input string can't be converted
315      using the indicated encoding. One of:
316      * `'strict'`: Raise an exception for any illegal substrings.
317      * `'replace'`: Replace illegal substrings with `replacement_char`.
318      * `'ignore'`: Skip illegal substrings.
319    replacement_char: The replacement codepoint to be used in place of invalid
320      substrings in `input` when `errors='replace'`.
321    name: A name for the operation (optional).
322
323  Returns:
324    A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
325    The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
326    `tf.RaggedTensor` otherwise.
327
328  #### Example:
329
330  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
331  >>> tf.strings.unicode_split(input, 'UTF-8').to_list()
332  [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
333   [b'\xf0\x9f\x98\x8a']]
334  """
335  with ops.name_scope(name, "UnicodeSplit", [input]):
336    codepoints = _unicode_decode(input, input_encoding, errors,
337                                 replacement_char, False, with_offsets=False)
338    return unicode_encode(
339        ragged_array_ops.expand_dims(codepoints, -1),
340        output_encoding=input_encoding,
341        errors=errors,
342        replacement_char=replacement_char)
343
344
345@tf_export("strings.unicode_split_with_offsets")
346@dispatch.add_dispatch_support
347def unicode_split_with_offsets(input,
348                               input_encoding,
349                               errors="replace",
350                               replacement_char=0xFFFD,
351                               name=None):
352  r"""Splits each string into a sequence of code points with start offsets.
353
354  This op is similar to `tf.strings.decode(...)`, but it also returns the
355  start offset for each character in its respective string.  This information
356  can be used to align the characters with the original byte sequence.
357
358  Returns a tuple `(chars, start_offsets)` where:
359
360  * `chars[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
361    `j`th character, when decoded using `input_encoding`.
362  * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
363    character in `input[i1...iN]`, when decoded using `input_encoding`.
364
365  Args:
366    input: An `N` dimensional potentially ragged `string` tensor with shape
367      `[D1...DN]`.  `N` must be statically known.
368    input_encoding: String name for the unicode encoding that should be used to
369      decode each string.
370    errors: Specifies the response when an input string can't be converted
371      using the indicated encoding. One of:
372      * `'strict'`: Raise an exception for any illegal substrings.
373      * `'replace'`: Replace illegal substrings with `replacement_char`.
374      * `'ignore'`: Skip illegal substrings.
375    replacement_char: The replacement codepoint to be used in place of invalid
376      substrings in `input` when `errors='replace'`.
377    name: A name for the operation (optional).
378
379  Returns:
380    A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
381
382    * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
383    * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
384
385    The returned tensors are `tf.Tensor`s if `input` is a scalar, or
386    `tf.RaggedTensor`s otherwise.
387
388  #### Example:
389
390  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
391  >>> result = tf.strings.unicode_split_with_offsets(input, 'UTF-8')
392  >>> result[0].to_list()  # character substrings
393  [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
394   [b'\xf0\x9f\x98\x8a']]
395  >>> result[1].to_list()  # offsets
396  [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
397
398  """
399  with ops.name_scope(name, "UnicodeSplitWithOffsets", [input]):
400    codepoints, offsets = _unicode_decode(input, input_encoding, errors,
401                                          replacement_char, False,
402                                          with_offsets=True)
403    chars = unicode_encode(
404        ragged_array_ops.expand_dims(codepoints, -1),
405        output_encoding=input_encoding,
406        errors=errors,
407        replacement_char=replacement_char)
408    return chars, offsets
409
410
411def _unicode_decode(input, input_encoding, errors, replacement_char,
412                    replace_control_characters, with_offsets):
413  """Decodes each string into a sequence of codepoints."""
414  input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
415  input_ndims = input.shape.ndims
416  if input_ndims is None:
417    raise ValueError("Rank of `input` must be statically known.")
418
419  if input_ndims > 1:
420    # Convert to a ragged tensor with ragged_rank = input_ndims - 1.
421    if not ragged_tensor.is_ragged(input):
422      input = ragged_tensor.RaggedTensor.from_tensor(
423          input, ragged_rank=input_ndims - 1)
424    elif input.ragged_rank < input_ndims - 1:
425      input = input.with_flat_values(
426          ragged_tensor.RaggedTensor.from_tensor(
427              input.flat_values,
428              ragged_rank=input_ndims - input.ragged_rank - 1))
429
430  # Reshape the input to a flat vector, and apply the gen_string_ops op.
431  if ragged_tensor.is_ragged(input):
432    flat_input = array_ops.reshape(input.flat_values, [-1])
433  else:
434    flat_input = array_ops.reshape(input, [-1])
435
436  if with_offsets:
437    decode_op = gen_string_ops.unicode_decode_with_offsets
438  else:
439    decode_op = gen_string_ops.unicode_decode
440  flat_result = decode_op(
441      input=flat_input,
442      input_encoding=input_encoding,
443      errors=errors,
444      replacement_char=replacement_char,
445      replace_control_characters=replace_control_characters)
446
447  if input_ndims == 0:
448    codepoints = flat_result.char_values
449    if with_offsets:
450      offsets = flat_result.char_to_byte_starts
451  else:
452    codepoints = ragged_tensor.RaggedTensor.from_row_splits(
453        flat_result.char_values, flat_result.row_splits, validate=False)
454    if input_ndims > 1:
455      codepoints = input.with_flat_values(codepoints)
456    if with_offsets:
457      offsets = ragged_tensor.RaggedTensor.from_row_splits(
458          flat_result.char_to_byte_starts, flat_result.row_splits,
459          validate=False)
460      if input_ndims > 1:
461        offsets = input.with_flat_values(offsets)
462
463  if with_offsets:
464    return codepoints, offsets
465  else:
466    return codepoints
467
468
469@tf_export("strings.split", v1=[])
470@dispatch.add_dispatch_support
471def string_split_v2(input, sep=None, maxsplit=-1, name=None):  # pylint: disable=redefined-builtin
472  """Split elements of `input` based on `sep` into a `RaggedTensor`.
473
474  Let N be the size of `input` (typically N will be the batch size). Split each
475  element of `input` based on `sep` and return a `RaggedTensor` containing the
476  split tokens. Empty tokens are ignored.
477
478  Example:
479
480  >>> tf.strings.split('hello world').numpy()
481   array([b'hello', b'world'], dtype=object)
482  >>> tf.strings.split(['hello world', 'a b c'])
483  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
484
485  If `sep` is given, consecutive delimiters are not grouped together and are
486  deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
487  `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
488  string, consecutive whitespace are regarded as a single separator, and the
489  result will contain no empty strings at the start or end if the string has
490  leading or trailing whitespace.
491
492  Note that the above mentioned behavior matches python's str.split.
493
494  Args:
495    input: A string `Tensor` of rank `N`, the strings to split.  If
496      `rank(input)` is not known statically, then it is assumed to be `1`.
497    sep: `0-D` string `Tensor`, the delimiter string.
498    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
499    name: A name for the operation (optional).
500
501  Raises:
502    ValueError: If sep is not a string.
503
504  Returns:
505    A `RaggedTensor` of rank `N+1`, the strings split according to the
506    delimiter.
507  """
508  with ops.name_scope(name, "StringSplit", [input]):
509    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
510        input, dtype=dtypes.string, name="input")
511    if isinstance(input, ragged_tensor.RaggedTensor):
512      return input.with_flat_values(
513          string_split_v2(input.flat_values, sep, maxsplit))
514
515    rank = input.shape.ndims
516    if rank == 0:
517      return string_split_v2(array_ops.stack([input]), sep, maxsplit)[0]
518    elif rank == 1 or rank is None:
519      sparse_result = string_ops.string_split_v2(
520          input, sep=sep, maxsplit=maxsplit)
521      return ragged_tensor.RaggedTensor.from_value_rowids(
522          values=sparse_result.values,
523          value_rowids=sparse_result.indices[:, 0],
524          nrows=sparse_result.dense_shape[0],
525          validate=False)
526    else:
527      return string_split_v2(
528          ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit)
529
530
531@tf_export(v1=["string_split"])
532@dispatch.add_dispatch_support
533@deprecation.deprecated_args(None,
534                             "delimiter is deprecated, please use sep instead.",
535                             "delimiter")
536def string_split(source, sep=None, skip_empty=True, delimiter=None,
537                 result_type="SparseTensor", name=None):  # pylint: disable=invalid-name
538  """Split elements of `source` based on `delimiter`.
539
540  Let N be the size of `source` (typically N will be the batch size). Split each
541  element of `source` based on `delimiter` and return a `SparseTensor`
542  or `RaggedTensor` containing the split tokens. Empty tokens are ignored.
543
544  If `sep` is an empty string, each element of the `source` is split
545  into individual strings, each containing one byte. (This includes splitting
546  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
547  treated as a set of delimiters with each considered a potential split point.
548
549  Examples:
550
551  >>> print(tf.compat.v1.string_split(['hello world', 'a b c']))
552  SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
553               values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
554               dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
555
556  >>> print(tf.compat.v1.string_split(['hello world', 'a b c'],
557  ...     result_type="RaggedTensor"))
558  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
559
560  Args:
561    source: `1-D` string `Tensor`, the strings to split.
562    sep: `0-D` string `Tensor`, the delimiter character, the string should
563      be length 0 or 1. Default is ' '.
564    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
565    delimiter: deprecated alias for `sep`.
566    result_type: The tensor type for the result: one of `"RaggedTensor"` or
567      `"SparseTensor"`.
568    name: A name for the operation (optional).
569
570  Raises:
571    ValueError: If delimiter is not a string.
572
573  Returns:
574    A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
575    to the delimiter.  The first column of the indices corresponds to the row
576    in `source` and the second column corresponds to the index of the split
577    component in this row.
578  """
579  with ops.name_scope(name, "StringSplit", [source]):
580    sparse_result = string_ops.string_split(
581        source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
582    if result_type == "SparseTensor":
583      return sparse_result
584    elif result_type == "RaggedTensor":
585      return ragged_tensor.RaggedTensor.from_value_rowids(
586          values=sparse_result.values,
587          value_rowids=sparse_result.indices[:, 0],
588          nrows=sparse_result.dense_shape[0],
589          validate=False)
590    else:
591      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
592
593
594# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
595# but we need to add the result_type argument.
596@tf_export(v1=["strings.split"])
597@dispatch.add_dispatch_support
598def strings_split_v1(input=None, sep=None, maxsplit=-1,  # pylint: disable=redefined-builtin
599                     result_type="SparseTensor", source=None, name=None):
600  """Split elements of `input` based on `sep`.
601
602  Let N be the size of `input` (typically N will be the batch size). Split each
603  element of `input` based on `sep` and return a `SparseTensor` or
604  `RaggedTensor` containing the split tokens. Empty tokens are ignored.
605
606  Examples:
607
608  >>> print(tf.compat.v1.strings.split(['hello world', 'a b c']))
609  SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
610               values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
611               dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
612
613  >>> print(tf.compat.v1.strings.split(['hello world', 'a b c'],
614  ...     result_type="RaggedTensor"))
615  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
616
617  If `sep` is given, consecutive delimiters are not grouped together and are
618  deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
619  `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
620  string, consecutive whitespace are regarded as a single separator, and the
621  result will contain no empty strings at the start or end if the string has
622  leading or trailing whitespace.
623
624  Note that the above mentioned behavior matches python's str.split.
625
626  Args:
627    input: A string `Tensor` of rank `N`, the strings to split.  If
628      `rank(input)` is not known statically, then it is assumed to be `1`.
629    sep: `0-D` string `Tensor`, the delimiter character.
630    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
631    result_type: The tensor type for the result: one of `"RaggedTensor"` or
632      `"SparseTensor"`.
633    source: alias for "input" argument.
634    name: A name for the operation (optional).
635
636  Raises:
637    ValueError: If sep is not a string.
638
639  Returns:
640    A `SparseTensor` or `RaggedTensor` of rank `N+1`, the strings split
641    according to the delimiter.
642  """
643  input = deprecation.deprecated_argument_lookup(
644      "input", input, "source", source)
645  with ops.name_scope(name, "StringSplit", [input]):
646    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
647        input, dtype=dtypes.string, name="input")
648
649    if input.shape.rank == 0:
650      input = array_ops.expand_dims(input, 0)
651
652    if result_type == "SparseTensor":
653      if input.shape.rank == 1:
654        return string_ops.string_split_v2(input, sep=sep, maxsplit=maxsplit)
655      else:
656        return string_split_v2(input, sep=sep, maxsplit=maxsplit).to_sparse()
657    elif result_type == "RaggedTensor":
658      return string_split_v2(input, sep=sep, maxsplit=maxsplit)
659    else:
660      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
661
662
663@dispatch.dispatch_for_api(string_ops.reduce_join_v2)
664def reduce_join(inputs: ragged_tensor.Ragged,
665                axis=None,
666                keepdims=None,
667                separator="",
668                name=None):
669  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
670  return ragged_math_ops.ragged_reduce_aggregate(
671      string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis,
672      keepdims, separator, name or "RaggedSegmentJoin")
673
674
675@tf_export("strings.ngrams")
676@dispatch.add_dispatch_support
677def ngrams(data,
678           ngram_width,
679           separator=" ",
680           pad_values=None,
681           padding_width=None,
682           preserve_short_sequences=False,
683           name=None):
684  """Create a tensor of n-grams based on `data`.
685
686  Creates a tensor of n-grams based on `data`. The n-grams are created by
687  joining windows of `width` adjacent strings from the inner axis of `data`
688  using `separator`.
689
690  The input data can be padded on both the start and end of the sequence, if
691  desired, using the `pad_values` argument. If set, `pad_values` should contain
692  either a tuple of strings or a single string; the 0th element of the tuple
693  will be used to pad the left side of the sequence and the 1st element of the
694  tuple will be used to pad the right side of the sequence. The `padding_width`
695  arg controls how many padding values are added to each side; it defaults to
696  `ngram_width-1`.
697
698  If this op is configured to not have padding, or if it is configured to add
699  padding with `padding_width` set to less than ngram_width-1, it is possible
700  that a sequence, or a sequence plus padding, is smaller than the ngram
701  width. In that case, no ngrams will be generated for that sequence. This can
702  be prevented by setting `preserve_short_sequences`, which will cause the op
703  to always generate at least one ngram per non-empty sequence.
704
705  Examples:
706
707  >>> tf.strings.ngrams(["A", "B", "C", "D"], 2).numpy()
708  array([b'A B', b'B C', b'C D'], dtype=object)
709  >>> tf.strings.ngrams(["TF", "and", "keras"], 1).numpy()
710  array([b'TF', b'and', b'keras'], dtype=object)
711
712  Args:
713    data: A Tensor or RaggedTensor containing the source data for the ngrams.
714    ngram_width: The width(s) of the ngrams to create. If this is a list or
715      tuple, the op will return ngrams of all specified arities in list order.
716      Values must be non-Tensor integers greater than 0.
717    separator: The separator string used between ngram elements. Must be a
718      string constant, not a Tensor.
719    pad_values: A tuple of (left_pad_value, right_pad_value), a single string,
720      or None. If None, no padding will be added; if a single string, then that
721      string will be used for both left and right padding. Values must be Python
722      strings.
723    padding_width: If set, `padding_width` pad values will be added to both
724      sides of each sequence. Defaults to `ngram_width`-1. Must be greater than
725      0. (Note that 1-grams are never padded, regardless of this value.)
726    preserve_short_sequences: If true, then ensure that at least one ngram is
727      generated for each input sequence.  In particular, if an input sequence is
728      shorter than `min(ngram_width) + 2*pad_width`, then generate a single
729      ngram containing the entire sequence.  If false, then no ngrams are
730      generated for these short input sequences.
731    name: The op name.
732
733  Returns:
734    A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then
735    `output.shape=[D1...DN, NUM_NGRAMS]`, where
736    `NUM_NGRAMS=S-ngram_width+1+2*padding_width`.
737
738  Raises:
739    TypeError: if `pad_values` is set to an invalid type.
740    ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an
741      invalid value.
742  """
743
744  with ops.name_scope(name, "StringNGrams", [data]):
745    if pad_values is None:
746      left_pad = ""
747      right_pad = ""
748    elif isinstance(pad_values, (list, tuple)):
749      if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or
750          not isinstance(pad_values[1], util_compat.bytes_or_text_types)):
751        raise TypeError(
752            "pad_values must be a string, tuple of strings, or None.")
753      left_pad = pad_values[0]
754      right_pad = pad_values[1]
755    else:
756      if not isinstance(pad_values, util_compat.bytes_or_text_types):
757        raise TypeError(
758            "pad_values must be a string, tuple of strings, or None.")
759      left_pad = pad_values
760      right_pad = pad_values
761
762    if padding_width is not None and padding_width < 1:
763      raise ValueError("padding_width must be greater than 0.")
764
765    if padding_width is not None and pad_values is None:
766      raise ValueError("pad_values must be provided if padding_width is set.")
767
768    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(
769        data, name="data", dtype=dtypes.string)
770
771    # preserve the shape of the data if it is a tensor
772    to_tensor = False
773    if isinstance(data, ops.Tensor):
774      dense_shape = array_ops.concat([array_ops.shape(data)[:-1], [-1]], axis=0)
775      to_tensor = True
776
777    if not isinstance(data, ragged_tensor.RaggedTensor):
778      if data.shape.ndims is None:
779        raise ValueError("Rank of data must be known.")
780      elif data.shape.ndims == 0:
781        raise ValueError("Data must have rank>0")
782      elif data.shape.ndims == 1:
783        rt = ragged_tensor.RaggedTensor.from_row_starts(
784            data, [0], validate=False)
785        return ngrams(rt, ngram_width, separator, pad_values, padding_width,
786                      preserve_short_sequences, name)[0]
787      else:
788        data = ragged_tensor.RaggedTensor.from_tensor(
789            data, ragged_rank=data.shape.ndims - 1)
790
791    if data.ragged_rank > 1:
792      output = data.with_values(
793          ngrams(data.values, ngram_width, separator, pad_values, padding_width,
794                 preserve_short_sequences, name))
795      return array_ops.reshape(output.flat_values,
796                               dense_shape) if to_tensor else output
797
798    if pad_values is None:
799      padding_width = 0
800
801    if pad_values is not None and padding_width is None:
802      padding_width = -1
803
804    if not isinstance(ngram_width, (list, tuple)):
805      ngram_widths = [ngram_width]
806    else:
807      ngram_widths = ngram_width
808    for width in ngram_widths:
809      if width < 1:
810        raise ValueError("All ngram_widths must be greater than 0. Got %s" %
811                         ngram_width)
812
813    output, output_splits = gen_string_ops.string_n_grams(
814        data=data.flat_values,
815        data_splits=data.row_splits,
816        separator=separator,
817        ngram_widths=ngram_widths,
818        left_pad=left_pad,
819        right_pad=right_pad,
820        pad_width=padding_width,
821        preserve_short_sequences=preserve_short_sequences)
822
823    # if the input is Dense tensor, the output should also be a dense tensor
824    output = ragged_tensor.RaggedTensor.from_row_splits(
825        values=output, row_splits=output_splits, validate=False)
826    return array_ops.reshape(output.flat_values,
827                             dense_shape) if to_tensor else output
828
829
830@dispatch.dispatch_for_api(string_ops.string_format)
831def string_format(
832    template: str,
833    inputs: typing.Union[ragged_tensor.Ragged,
834                         typing.List[ragged_tensor.RaggedOrDense]],
835    placeholder="{}",
836    summarize=3,
837    name=None):
838  """Version of tf.strings.format that handles RaggedTensors."""
839  if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs):
840    inputs = [inputs]
841
842  split_template = template.split(placeholder)
843  if len(inputs) != len(split_template) - 1:
844    raise ValueError("num placeholders in template and num inputs must match"
845                     ": {} vs {}".format(len(split_template) - 1, len(inputs)))
846
847  with ops.name_scope(name, "StringFormat", [inputs]):
848    output_pieces = [constant_op.constant(split_template[0])]
849    for i, input in enumerate(inputs):
850      if ragged_tensor.is_ragged(input):
851        output_pieces.append(ragged_tensor_to_string(input, summarize))
852      else:
853        output_pieces.append(string_ops.string_format(
854            "{}", [input], summarize=summarize))
855      output_pieces.append(constant_op.constant(split_template[i + 1]))
856    if len(output_pieces) == 1:
857      return output_pieces[0]
858    else:
859      return string_ops.reduce_join(output_pieces)
860
861
862def ragged_tensor_to_string(rt, summarize=None):
863  """Returns a scalar string tensor with the contents of a RaggedTensor.
864
865  Requires that `rt.shape.rank` is not `None`.
866
867  Note: this converts the entire `RaggedTensor` into a single string scalar.
868  If you want to convert individual elements, use `tf.strings.as_string(rt)`.
869
870  >>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]])
871  >>> ragged_tensor_to_string(rt1).numpy()
872  b'[[1, 2, 3], [4, 5]]'
873
874  >>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]])
875  >>> ragged_tensor_to_string(rt2).numpy()
876  b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]"
877
878  >>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]])
879  >>> ragged_tensor_to_string(rt3, summarize=2).numpy()
880  b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]'
881
882  Args:
883    rt: The RaggedTensor that should be converted to a string.
884    summarize: If specified, then only the first and last `summarize` elements
885      within each dimension are included in the string. If `-1` or `None`, then
886      all elements are included.
887  """
888  if (summarize is not None and summarize != -1 and
889      not (isinstance(summarize, int) and summarize > 0)):
890    raise ValueError("Expected summarize to be -1 or a positive int, got %r" %
891                     summarize)
892  with ops.name_scope(None, "AsString", [rt]):
893    rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
894    if rt.shape.rank is None:
895      raise ValueError("RaggedTensor to_string requires that rt.shape.rank "
896                       "is not None.")
897    # Convert all elements of `rt` to strings.
898    if rt.dtype == dtypes.string:
899      escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1")
900      str_t = rt.with_flat_values("'" + escaped + "'")
901    else:
902      str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values))
903
904    return _ragged_tensor_to_string(str_t, summarize)
905
906
907def _ragged_tensor_to_string(string_tensor, summarize):
908  """Returns a scalar string tensor with the contents of `string_tensor`.
909
910  Args:
911    string_tensor: A potentially ragged tensor with dtype=string.
912    summarize: Include only the first and last `summarize` elements of each
913      dimension.  If `-1` or `None`, then include all elements.
914
915  Returns:
916    A scalar string Tensor.
917  """
918  if string_tensor.shape.rank == 1:
919    pieces = string_tensor
920  else:
921    pieces = map_fn_lib.map_fn(
922        lambda s: _ragged_tensor_to_string(s, summarize),
923        string_tensor,
924        fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string))
925  if summarize not in (-1, None):
926    pieces = control_flow_ops.cond(
927        _nrows(string_tensor) <= 2 * summarize,
928        lambda: pieces,
929        lambda: array_ops.concat(  # pylint: disable=g-long-lambda
930            [pieces[:summarize], ["..."], pieces[-summarize:]],
931            axis=0))
932  return "[" + string_ops.reduce_join(pieces, separator=", ") + "]"
933
934
935def _nrows(tensor, out_type=dtypes.int32):
936  if isinstance(tensor, ragged_tensor.RaggedTensor):
937    return tensor.nrows(out_type=out_type)
938  else:
939    return array_ops.shape(tensor, out_type=out_type)[0]
940
941
942@dispatch.dispatch_for_api(string_ops.string_join)
943def string_join(inputs: typing.List[ragged_tensor.RaggedOrDense],
944                separator="",
945                name=None):
946  """RaggedTensor implementation for tf.strings.join."""
947  if len(inputs) < 0:
948    raise ValueError("tf.strings.join: expected at least one input.")
949  with ops.name_scope(name, "RaggedStringJoin", inputs):
950    return ragged_functional_ops.map_flat_values(string_ops.string_join, inputs,
951                                                 separator)
952