1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 #include <vector> 18 19 #include "absl/strings/str_split.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/shape_inference.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 30 namespace shape_inference { 31 class InferenceContext; 32 } // namespace shape_inference 33 34 using shape_inference::DimensionHandle; 35 using shape_inference::InferenceContext; 36 using shape_inference::ShapeHandle; 37 38 REGISTER_OP("RegexReplace") 39 .Input("input: string") 40 .Input("pattern: string") 41 .Input("rewrite: string") 42 .Output("output: string") 43 .Attr("replace_global: bool = true") __anonf5c73fa00102(InferenceContext* c) 44 .SetShapeFn([](InferenceContext* c) { 45 ShapeHandle unused; 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 48 c->set_output(0, c->input(0)); 49 return OkStatus(); 50 }); 51 52 REGISTER_OP("StaticRegexReplace") 53 .Input("input: string") 54 .Attr("pattern: string") 55 .Attr("rewrite: string") 56 .Output("output: string") 57 .Attr("replace_global: bool = true") 58 .SetShapeFn(shape_inference::UnchangedShape); 59 60 REGISTER_OP("RegexFullMatch") 61 .Input("input: string") 62 .Input("pattern: string") 63 .Output("output: bool") __anonf5c73fa00202(InferenceContext* c) 64 .SetShapeFn([](InferenceContext* c) { 65 ShapeHandle unused; 66 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 67 c->set_output(0, c->input(0)); 68 return OkStatus(); 69 }); 70 71 REGISTER_OP("StaticRegexFullMatch") 72 .Input("input: string") 73 .Attr("pattern: string") 74 .Output("output: bool") 75 .SetShapeFn(shape_inference::UnchangedShape); 76 77 REGISTER_OP("StringToHashBucketFast") 78 .Input("input: string") 79 .Output("output: int64") 80 .Attr("num_buckets: int >= 1") 81 .SetShapeFn(shape_inference::UnchangedShape); 82 83 REGISTER_OP("_TensorToHashBucketFast") 84 .Input("input: T") 85 .Output("output: int64") 86 .Attr("T: {int8, uint8, int16, uint16, int32, uint32, int64, uint64}") 87 .Attr("num_buckets: int >= 1") 88 .SetShapeFn(shape_inference::UnchangedShape) 89 .Doc(R"doc( 90 Internal operation which is a composition of converting the tensor to a string 91 tensor (AsString) and then calling hash functions (StringToHashBucketFast): 92 reserved for internal use. 93 94 Do not invoke this operator directly in Python. A fusion optimization is 95 expected to create these operators. 96 )doc"); 97 98 REGISTER_OP("StringToHashBucketStrong") 99 .Input("input: string") 100 .Output("output: int64") 101 .Attr("num_buckets: int >= 1") 102 .Attr("key: list(int)") 103 .SetShapeFn(shape_inference::UnchangedShape); 104 105 REGISTER_OP("StringToHashBucket") 106 .Input("string_tensor: string") 107 .Output("output: int64") 108 .Attr("num_buckets: int >= 1") 109 .SetShapeFn(shape_inference::UnchangedShape); 110 111 REGISTER_OP("ReduceJoin") 112 .Input("inputs: string") 113 .Input("reduction_indices: int32") 114 .Attr("keep_dims: bool = false") 115 .Attr("separator: string = ''") 116 .Output("output: string") 117 .SetShapeFn(shape_inference::ReductionShape); 118 119 REGISTER_OP("UnsortedSegmentJoin") 120 .Input("inputs: string") 121 .Input("segment_ids: Tindices") 122 .Input("num_segments: Tnumsegments") 123 .Attr("separator: string = ''") 124 .Attr("Tindices: {int32,int64}") 125 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 126 .Output("output: string") 127 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); 128 129 REGISTER_OP("AsString") 130 .Input("input: T") 131 .Output("output: string") 132 .Attr("T: {realnumbertype, complex64, complex128, bool, variant}") 133 .Attr("precision: int = -1") 134 .Attr("scientific: bool = false") 135 .Attr("shortest: bool = false") 136 .Attr("width: int = -1") 137 .Attr("fill: string = ''") 138 .SetShapeFn(shape_inference::UnchangedShape); 139 140 REGISTER_OP("StringFormat") 141 .Input("inputs: T") 142 .Output("output: string") 143 .Attr("T: list(type) >= 0") 144 .Attr("template: string = '%s'") 145 .Attr("placeholder: string = '%s'") 146 .Attr("summarize: int = 3") __anonf5c73fa00302(InferenceContext* c) 147 .SetShapeFn([](InferenceContext* c) { 148 string template_; 149 string placeholder; 150 TF_RETURN_IF_ERROR(c->GetAttr("template", &template_)); 151 TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder)); 152 153 std::vector<std::string> split_template; 154 split_template = absl::StrSplit(template_, placeholder); 155 int64_t num_placeholders = split_template.size() - 1; 156 if (c->num_inputs() != num_placeholders) { 157 return errors::InvalidArgument(strings::StrCat( 158 "num placeholders in template and num inputs must match: ", 159 num_placeholders, " vs. ", c->num_inputs())); 160 } 161 162 c->set_output(0, c->Scalar()); 163 return OkStatus(); 164 }); 165 166 REGISTER_OP("StringJoin") 167 .Input("inputs: N * string") 168 .Attr("N: int") 169 .Attr("separator: string = ''") 170 .Output("output: string") __anonf5c73fa00402(InferenceContext* c) 171 .SetShapeFn([](InferenceContext* c) { 172 // If all inputs are scalars, then return a scalar. 173 bool all_scalar = true; 174 for (int i = 0; i < c->num_inputs(); ++i) { 175 if (c->Rank(c->input(i)) != 0) all_scalar = false; 176 } 177 if (all_scalar) { 178 c->set_output(0, c->Scalar()); 179 return OkStatus(); 180 } 181 182 // At least one input is unknown or a scalar. 183 // Merge the non-scalars to find the output shape. 184 // Don't merge inputs with unknown rank, as they can actually be scalars 185 // or the output shape. 186 ShapeHandle out = c->UnknownShape(); 187 for (int i = 0; i < c->num_inputs(); ++i) { 188 if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) { 189 TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out)); 190 } 191 } 192 c->set_output(0, out); 193 return OkStatus(); 194 }); 195 196 REGISTER_OP("StringSplit") 197 .Input("input: string") 198 .Input("delimiter: string") 199 .Output("indices: int64") 200 .Output("values: string") 201 .Output("shape: int64") 202 .Attr("skip_empty: bool = true") __anonf5c73fa00502(InferenceContext* c) 203 .SetShapeFn([](InferenceContext* c) { 204 ShapeHandle unused; 205 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 206 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 207 208 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 209 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 210 c->set_output(2, c->Vector(2)); 211 return OkStatus(); 212 }); 213 214 REGISTER_OP("StringSplitV2") 215 .Input("input: string") 216 .Input("sep: string") 217 .Output("indices: int64") 218 .Output("values: string") 219 .Output("shape: int64") 220 .Attr("maxsplit: int = -1") __anonf5c73fa00602(InferenceContext* c) 221 .SetShapeFn([](InferenceContext* c) { 222 ShapeHandle unused; 223 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 225 226 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 227 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 228 c->set_output(2, c->Vector(2)); 229 return OkStatus(); 230 }); 231 232 REGISTER_OP("StringLower") 233 .Input("input: string") 234 .Output("output: string") 235 .Attr("encoding: string =''") 236 .SetShapeFn(shape_inference::UnchangedShape); 237 238 REGISTER_OP("StringUpper") 239 .Input("input: string") 240 .Output("output: string") 241 .Attr("encoding: string =''") 242 .SetShapeFn(shape_inference::UnchangedShape); 243 244 REGISTER_OP("StringStrip") 245 .Input("input: string") 246 .Output("output: string") 247 .SetShapeFn(shape_inference::UnchangedShape); 248 249 REGISTER_OP("StringLength") 250 .Input("input: string") 251 .Output("output: int32") 252 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") 253 .SetShapeFn(shape_inference::UnchangedShape); 254 255 REGISTER_OP("EncodeBase64") 256 .Input("input: string") 257 .Output("output: string") 258 .Attr("pad: bool = false") 259 .SetShapeFn(shape_inference::UnchangedShape); 260 261 REGISTER_OP("DecodeBase64") 262 .Input("input: string") 263 .Output("output: string") 264 .SetShapeFn(shape_inference::UnchangedShape); 265 266 REGISTER_OP("Substr") 267 .Input("input: string") 268 .Input("pos: T") 269 .Input("len: T") 270 .Output("output: string") 271 .Attr("T: {int32, int64}") 272 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") __anonf5c73fa00702(InferenceContext* c) 273 .SetShapeFn([](InferenceContext* c) { 274 ShapeHandle pos_shape = c->input(1); 275 ShapeHandle len_shape = c->input(2); 276 ShapeHandle unused; 277 // If len rank is known, check that pos and len have the same rank 278 if (c->RankKnown(len_shape)) { 279 TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused)); 280 } 281 // Check that dimensions are equal 282 for (int32_t i = 0; i < c->Rank(pos_shape); ++i) { 283 DimensionHandle pos_dim = c->Dim(pos_shape, i); 284 DimensionHandle len_dim = c->Dim(len_shape, i); 285 if (c->Value(pos_dim) != c->Value(len_dim)) { 286 return errors::InvalidArgument( 287 "pos and len shapes must match: ", c->DebugString(pos_shape), 288 " vs. ", c->DebugString(len_shape)); 289 } 290 } 291 // c->input(0) is the ShapeHandle to input strings 292 // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1). 293 return shape_inference::BroadcastBinaryOpShapeFn(c); 294 }); 295 296 REGISTER_OP("UnicodeScript") 297 .Input("input: int32") 298 .Output("output: int32") 299 .SetShapeFn(shape_inference::UnchangedShape); 300 301 REGISTER_OP("UnicodeEncode") 302 .Input("input_values: int32") 303 .Input("input_splits: Tsplits") 304 .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'") 305 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 306 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 307 .Attr("Tsplits: {int32, int64} = DT_INT64") 308 .Output("output: string") __anonf5c73fa00802(InferenceContext* c) 309 .SetShapeFn([](InferenceContext* c) { 310 // Check rank of inner values 311 ShapeHandle input_inner_values_shape = c->input(0); 312 ShapeHandle unused; 313 TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused)); 314 315 // Check rank of input_splits 316 ShapeHandle splits_shape = c->input(1); 317 TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused)); 318 319 // Output shape is a 1-D tensor with size equal to number of splits. 320 std::vector<DimensionHandle> dims(1); 321 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0])); 322 c->set_output(0, c->MakeShape(dims)); 323 324 return OkStatus(); 325 }); 326 327 REGISTER_OP("UnicodeTranscode") 328 .Input("input: string") 329 .Output("output: string") 330 .Attr("input_encoding: string") 331 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 332 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 333 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 334 .Attr("replace_control_characters: bool = false") 335 .SetShapeFn(shape_inference::UnchangedShape); 336 337 REGISTER_OP("UnicodeDecode") 338 .Input("input: string") 339 .Output("row_splits: Tsplits") 340 .Output("char_values: int32") 341 .Attr("input_encoding: string") 342 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 343 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 344 .Attr("replace_control_characters: bool = false") 345 .Attr("Tsplits: {int32, int64} = DT_INT64") __anonf5c73fa00902(InferenceContext* c) 346 .SetShapeFn([](InferenceContext* c) { 347 // row_splits.shape == [input.size() + 1] 348 DimensionHandle num_row_splits; 349 DimensionHandle input_size = c->NumElements(c->input(0)); 350 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 351 c->set_output(0, c->Vector(num_row_splits)); 352 353 // char_values.shape == [num_chars] 354 DimensionHandle num_chars = c->UnknownDim(); 355 c->set_output(1, c->Vector(num_chars)); 356 return OkStatus(); 357 }); 358 359 REGISTER_OP("UnicodeDecodeWithOffsets") 360 .Input("input: string") 361 .Output("row_splits: Tsplits") 362 .Output("char_values: int32") 363 .Output("char_to_byte_starts: int64") 364 .Attr("input_encoding: string") 365 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 366 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 367 .Attr("replace_control_characters: bool = false") 368 .Attr("Tsplits: {int32, int64} = DT_INT64") __anonf5c73fa00a02(InferenceContext* c) 369 .SetShapeFn([](InferenceContext* c) { 370 // row_splits.shape == [input.size() + 1] 371 DimensionHandle num_row_splits; 372 DimensionHandle input_size = c->NumElements(c->input(0)); 373 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 374 c->set_output(0, c->Vector(num_row_splits)); 375 376 // char_values.shape == offset_values.shape == [num_chars] 377 DimensionHandle num_chars = c->UnknownDim(); 378 c->set_output(1, c->Vector(num_chars)); 379 c->set_output(2, c->Vector(num_chars)); 380 return OkStatus(); 381 }); 382 383 REGISTER_OP("StringNGrams") 384 .Attr("separator: string") 385 .Attr("ngram_widths: list(int) >= 0") 386 .Attr("left_pad: string") 387 .Attr("right_pad: string") 388 .Attr("pad_width: int") 389 .Attr("preserve_short_sequences: bool") 390 .Attr("Tsplits: {int32, int64} = DT_INT64") 391 .Input("data: string") 392 .Input("data_splits: Tsplits") 393 .Output("ngrams: string") 394 .Output("ngrams_splits: Tsplits") __anonf5c73fa00b02(InferenceContext* c) 395 .SetShapeFn([](InferenceContext* c) { 396 c->set_output(0, c->UnknownShapeOfRank(1)); 397 ShapeHandle data = c->input(0); 398 TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data)); 399 ShapeHandle data_splits = c->input(1); 400 TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits)); 401 c->set_output(1, data_splits); 402 return OkStatus(); 403 }); 404 405 } // namespace tensorflow 406