xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/string_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "absl/strings/str_split.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 class InferenceContext;
32 }  // namespace shape_inference
33 
34 using shape_inference::DimensionHandle;
35 using shape_inference::InferenceContext;
36 using shape_inference::ShapeHandle;
37 
38 REGISTER_OP("RegexReplace")
39     .Input("input: string")
40     .Input("pattern: string")
41     .Input("rewrite: string")
42     .Output("output: string")
43     .Attr("replace_global: bool = true")
__anonf5c73fa00102(InferenceContext* c) 44     .SetShapeFn([](InferenceContext* c) {
45       ShapeHandle unused;
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
48       c->set_output(0, c->input(0));
49       return OkStatus();
50     });
51 
52 REGISTER_OP("StaticRegexReplace")
53     .Input("input: string")
54     .Attr("pattern: string")
55     .Attr("rewrite: string")
56     .Output("output: string")
57     .Attr("replace_global: bool = true")
58     .SetShapeFn(shape_inference::UnchangedShape);
59 
60 REGISTER_OP("RegexFullMatch")
61     .Input("input: string")
62     .Input("pattern: string")
63     .Output("output: bool")
__anonf5c73fa00202(InferenceContext* c) 64     .SetShapeFn([](InferenceContext* c) {
65       ShapeHandle unused;
66       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
67       c->set_output(0, c->input(0));
68       return OkStatus();
69     });
70 
71 REGISTER_OP("StaticRegexFullMatch")
72     .Input("input: string")
73     .Attr("pattern: string")
74     .Output("output: bool")
75     .SetShapeFn(shape_inference::UnchangedShape);
76 
77 REGISTER_OP("StringToHashBucketFast")
78     .Input("input: string")
79     .Output("output: int64")
80     .Attr("num_buckets: int >= 1")
81     .SetShapeFn(shape_inference::UnchangedShape);
82 
83 REGISTER_OP("_TensorToHashBucketFast")
84     .Input("input: T")
85     .Output("output: int64")
86     .Attr("T: {int8, uint8, int16, uint16, int32, uint32, int64, uint64}")
87     .Attr("num_buckets: int >= 1")
88     .SetShapeFn(shape_inference::UnchangedShape)
89     .Doc(R"doc(
90 Internal operation which is a composition of converting the tensor to a string
91 tensor (AsString) and then calling hash functions (StringToHashBucketFast):
92 reserved for internal use.
93 
94 Do not invoke this operator directly in Python. A fusion optimization is
95 expected to create these operators.
96 )doc");
97 
98 REGISTER_OP("StringToHashBucketStrong")
99     .Input("input: string")
100     .Output("output: int64")
101     .Attr("num_buckets: int >= 1")
102     .Attr("key: list(int)")
103     .SetShapeFn(shape_inference::UnchangedShape);
104 
105 REGISTER_OP("StringToHashBucket")
106     .Input("string_tensor: string")
107     .Output("output: int64")
108     .Attr("num_buckets: int >= 1")
109     .SetShapeFn(shape_inference::UnchangedShape);
110 
111 REGISTER_OP("ReduceJoin")
112     .Input("inputs: string")
113     .Input("reduction_indices: int32")
114     .Attr("keep_dims: bool = false")
115     .Attr("separator: string = ''")
116     .Output("output: string")
117     .SetShapeFn(shape_inference::ReductionShape);
118 
119 REGISTER_OP("UnsortedSegmentJoin")
120     .Input("inputs: string")
121     .Input("segment_ids: Tindices")
122     .Input("num_segments: Tnumsegments")
123     .Attr("separator: string = ''")
124     .Attr("Tindices: {int32,int64}")
125     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
126     .Output("output: string")
127     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
128 
129 REGISTER_OP("AsString")
130     .Input("input: T")
131     .Output("output: string")
132     .Attr("T: {realnumbertype, complex64, complex128, bool, variant}")
133     .Attr("precision: int = -1")
134     .Attr("scientific: bool = false")
135     .Attr("shortest: bool = false")
136     .Attr("width: int = -1")
137     .Attr("fill: string = ''")
138     .SetShapeFn(shape_inference::UnchangedShape);
139 
140 REGISTER_OP("StringFormat")
141     .Input("inputs: T")
142     .Output("output: string")
143     .Attr("T: list(type) >= 0")
144     .Attr("template: string = '%s'")
145     .Attr("placeholder: string = '%s'")
146     .Attr("summarize: int = 3")
__anonf5c73fa00302(InferenceContext* c) 147     .SetShapeFn([](InferenceContext* c) {
148       string template_;
149       string placeholder;
150       TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
151       TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
152 
153       std::vector<std::string> split_template;
154       split_template = absl::StrSplit(template_, placeholder);
155       int64_t num_placeholders = split_template.size() - 1;
156       if (c->num_inputs() != num_placeholders) {
157         return errors::InvalidArgument(strings::StrCat(
158             "num placeholders in template and num inputs must match: ",
159             num_placeholders, " vs. ", c->num_inputs()));
160       }
161 
162       c->set_output(0, c->Scalar());
163       return OkStatus();
164     });
165 
166 REGISTER_OP("StringJoin")
167     .Input("inputs: N * string")
168     .Attr("N: int")
169     .Attr("separator: string = ''")
170     .Output("output: string")
__anonf5c73fa00402(InferenceContext* c) 171     .SetShapeFn([](InferenceContext* c) {
172       // If all inputs are scalars, then return a scalar.
173       bool all_scalar = true;
174       for (int i = 0; i < c->num_inputs(); ++i) {
175         if (c->Rank(c->input(i)) != 0) all_scalar = false;
176       }
177       if (all_scalar) {
178         c->set_output(0, c->Scalar());
179         return OkStatus();
180       }
181 
182       // At least one input is unknown or a scalar.
183       // Merge the non-scalars to find the output shape.
184       // Don't merge inputs with unknown rank, as they can actually be scalars
185       // or the output shape.
186       ShapeHandle out = c->UnknownShape();
187       for (int i = 0; i < c->num_inputs(); ++i) {
188         if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) {
189           TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out));
190         }
191       }
192       c->set_output(0, out);
193       return OkStatus();
194     });
195 
196 REGISTER_OP("StringSplit")
197     .Input("input: string")
198     .Input("delimiter: string")
199     .Output("indices: int64")
200     .Output("values: string")
201     .Output("shape: int64")
202     .Attr("skip_empty: bool = true")
__anonf5c73fa00502(InferenceContext* c) 203     .SetShapeFn([](InferenceContext* c) {
204       ShapeHandle unused;
205       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
206       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
207 
208       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
209       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
210       c->set_output(2, c->Vector(2));
211       return OkStatus();
212     });
213 
214 REGISTER_OP("StringSplitV2")
215     .Input("input: string")
216     .Input("sep: string")
217     .Output("indices: int64")
218     .Output("values: string")
219     .Output("shape: int64")
220     .Attr("maxsplit: int = -1")
__anonf5c73fa00602(InferenceContext* c) 221     .SetShapeFn([](InferenceContext* c) {
222       ShapeHandle unused;
223       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
224       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
225 
226       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
227       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
228       c->set_output(2, c->Vector(2));
229       return OkStatus();
230     });
231 
232 REGISTER_OP("StringLower")
233     .Input("input: string")
234     .Output("output: string")
235     .Attr("encoding: string =''")
236     .SetShapeFn(shape_inference::UnchangedShape);
237 
238 REGISTER_OP("StringUpper")
239     .Input("input: string")
240     .Output("output: string")
241     .Attr("encoding: string =''")
242     .SetShapeFn(shape_inference::UnchangedShape);
243 
244 REGISTER_OP("StringStrip")
245     .Input("input: string")
246     .Output("output: string")
247     .SetShapeFn(shape_inference::UnchangedShape);
248 
249 REGISTER_OP("StringLength")
250     .Input("input: string")
251     .Output("output: int32")
252     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
253     .SetShapeFn(shape_inference::UnchangedShape);
254 
255 REGISTER_OP("EncodeBase64")
256     .Input("input: string")
257     .Output("output: string")
258     .Attr("pad: bool = false")
259     .SetShapeFn(shape_inference::UnchangedShape);
260 
261 REGISTER_OP("DecodeBase64")
262     .Input("input: string")
263     .Output("output: string")
264     .SetShapeFn(shape_inference::UnchangedShape);
265 
266 REGISTER_OP("Substr")
267     .Input("input: string")
268     .Input("pos: T")
269     .Input("len: T")
270     .Output("output: string")
271     .Attr("T: {int32, int64}")
272     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
__anonf5c73fa00702(InferenceContext* c) 273     .SetShapeFn([](InferenceContext* c) {
274       ShapeHandle pos_shape = c->input(1);
275       ShapeHandle len_shape = c->input(2);
276       ShapeHandle unused;
277       // If len rank is known, check that pos and len have the same rank
278       if (c->RankKnown(len_shape)) {
279         TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
280       }
281       // Check that dimensions are equal
282       for (int32_t i = 0; i < c->Rank(pos_shape); ++i) {
283         DimensionHandle pos_dim = c->Dim(pos_shape, i);
284         DimensionHandle len_dim = c->Dim(len_shape, i);
285         if (c->Value(pos_dim) != c->Value(len_dim)) {
286           return errors::InvalidArgument(
287               "pos and len shapes must match: ", c->DebugString(pos_shape),
288               " vs. ", c->DebugString(len_shape));
289         }
290       }
291       // c->input(0) is the ShapeHandle to input strings
292       // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
293       return shape_inference::BroadcastBinaryOpShapeFn(c);
294     });
295 
296 REGISTER_OP("UnicodeScript")
297     .Input("input: int32")
298     .Output("output: int32")
299     .SetShapeFn(shape_inference::UnchangedShape);
300 
301 REGISTER_OP("UnicodeEncode")
302     .Input("input_values: int32")
303     .Input("input_splits: Tsplits")
304     .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
305     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
306     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
307     .Attr("Tsplits: {int32, int64} = DT_INT64")
308     .Output("output: string")
__anonf5c73fa00802(InferenceContext* c) 309     .SetShapeFn([](InferenceContext* c) {
310       // Check rank of inner values
311       ShapeHandle input_inner_values_shape = c->input(0);
312       ShapeHandle unused;
313       TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused));
314 
315       // Check rank of input_splits
316       ShapeHandle splits_shape = c->input(1);
317       TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused));
318 
319       // Output shape is a 1-D tensor with size equal to number of splits.
320       std::vector<DimensionHandle> dims(1);
321       TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0]));
322       c->set_output(0, c->MakeShape(dims));
323 
324       return OkStatus();
325     });
326 
327 REGISTER_OP("UnicodeTranscode")
328     .Input("input: string")
329     .Output("output: string")
330     .Attr("input_encoding: string")
331     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
332     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
333     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
334     .Attr("replace_control_characters: bool = false")
335     .SetShapeFn(shape_inference::UnchangedShape);
336 
337 REGISTER_OP("UnicodeDecode")
338     .Input("input: string")
339     .Output("row_splits: Tsplits")
340     .Output("char_values: int32")
341     .Attr("input_encoding: string")
342     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
343     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
344     .Attr("replace_control_characters: bool = false")
345     .Attr("Tsplits: {int32, int64} = DT_INT64")
__anonf5c73fa00902(InferenceContext* c) 346     .SetShapeFn([](InferenceContext* c) {
347       // row_splits.shape == [input.size() + 1]
348       DimensionHandle num_row_splits;
349       DimensionHandle input_size = c->NumElements(c->input(0));
350       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
351       c->set_output(0, c->Vector(num_row_splits));
352 
353       // char_values.shape == [num_chars]
354       DimensionHandle num_chars = c->UnknownDim();
355       c->set_output(1, c->Vector(num_chars));
356       return OkStatus();
357     });
358 
359 REGISTER_OP("UnicodeDecodeWithOffsets")
360     .Input("input: string")
361     .Output("row_splits: Tsplits")
362     .Output("char_values: int32")
363     .Output("char_to_byte_starts: int64")
364     .Attr("input_encoding: string")
365     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
366     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
367     .Attr("replace_control_characters: bool = false")
368     .Attr("Tsplits: {int32, int64} = DT_INT64")
__anonf5c73fa00a02(InferenceContext* c) 369     .SetShapeFn([](InferenceContext* c) {
370       // row_splits.shape == [input.size() + 1]
371       DimensionHandle num_row_splits;
372       DimensionHandle input_size = c->NumElements(c->input(0));
373       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
374       c->set_output(0, c->Vector(num_row_splits));
375 
376       // char_values.shape == offset_values.shape == [num_chars]
377       DimensionHandle num_chars = c->UnknownDim();
378       c->set_output(1, c->Vector(num_chars));
379       c->set_output(2, c->Vector(num_chars));
380       return OkStatus();
381     });
382 
383 REGISTER_OP("StringNGrams")
384     .Attr("separator: string")
385     .Attr("ngram_widths: list(int) >= 0")
386     .Attr("left_pad: string")
387     .Attr("right_pad: string")
388     .Attr("pad_width: int")
389     .Attr("preserve_short_sequences: bool")
390     .Attr("Tsplits: {int32, int64} = DT_INT64")
391     .Input("data: string")
392     .Input("data_splits: Tsplits")
393     .Output("ngrams: string")
394     .Output("ngrams_splits: Tsplits")
__anonf5c73fa00b02(InferenceContext* c) 395     .SetShapeFn([](InferenceContext* c) {
396       c->set_output(0, c->UnknownShapeOfRank(1));
397       ShapeHandle data = c->input(0);
398       TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data));
399       ShapeHandle data_splits = c->input(1);
400       TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits));
401       c->set_output(1, data_splits);
402       return OkStatus();
403     });
404 
405 }  // namespace tensorflow
406