1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_builder.h"
17
18 #include <limits>
19 #include <vector>
20
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/scanner.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/errors.h"
32
33 using ::tensorflow::strings::Scanner;
34
35 namespace tensorflow {
36
37 namespace {
38
AttrError(StringPiece orig,const string & op_name)39 string AttrError(StringPiece orig, const string& op_name) {
40 return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
41 }
42
ConsumeAttrName(StringPiece * sp,StringPiece * out)43 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
44 return Scanner(*sp)
45 .One(Scanner::LETTER)
46 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
47 .StopCapture()
48 .AnySpace()
49 .OneLiteral(":")
50 .AnySpace()
51 .GetResult(sp, out);
52 }
53
ConsumeListPrefix(StringPiece * sp)54 bool ConsumeListPrefix(StringPiece* sp) {
55 return Scanner(*sp)
56 .OneLiteral("list")
57 .AnySpace()
58 .OneLiteral("(")
59 .AnySpace()
60 .GetResult(sp);
61 }
62
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)63 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
64 const string quote_str(1, quote_ch);
65 return Scanner(*sp)
66 .OneLiteral(quote_str.c_str())
67 .RestartCapture()
68 .ScanEscapedUntil(quote_ch)
69 .StopCapture()
70 .OneLiteral(quote_str.c_str())
71 .AnySpace()
72 .GetResult(sp, out);
73 }
74
ConsumeAttrType(StringPiece * sp,StringPiece * out)75 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
76 return Scanner(*sp)
77 .Many(Scanner::LOWERLETTER_DIGIT)
78 .StopCapture()
79 .AnySpace()
80 .GetResult(sp, out);
81 }
82
ConsumeAttrNumber(StringPiece * sp,int64_t * out)83 bool ConsumeAttrNumber(StringPiece* sp, int64_t* out) {
84 Scanner scan(*sp);
85 StringPiece match;
86 StringPiece remaining;
87
88 scan.AnySpace().RestartCapture();
89 if (scan.Peek() == '-') {
90 scan.OneLiteral("-");
91 }
92 if (!scan.Many(Scanner::DIGIT)
93 .StopCapture()
94 .AnySpace()
95 .GetResult(&remaining, &match)) {
96 return false;
97 }
98 int64_t value = 0;
99 if (!strings::safe_strto64(match, &value)) {
100 return false;
101 }
102 *out = value;
103 *sp = remaining;
104 return true;
105 }
106
107 #define VERIFY(expr, ...) \
108 do { \
109 if (!(expr)) { \
110 errors->push_back( \
111 strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
112 return; \
113 } \
114 } while (false)
115
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)116 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
117 auto capture_data = sp->data();
118 auto capture_begin = sp->begin();
119 if (absl::ConsumePrefix(sp, "numbertype") ||
120 absl::ConsumePrefix(sp, "numerictype") ||
121 absl::ConsumePrefix(sp, "quantizedtype") ||
122 absl::ConsumePrefix(sp, "realnumbertype") ||
123 absl::ConsumePrefix(sp, "realnumberictype")) {
124 *out = StringPiece(capture_data, sp->begin() - capture_begin);
125 return true;
126 }
127 return false;
128 }
129
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)130 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
131 if (type_string == "numbertype" || type_string == "numerictype") {
132 for (DataType dt : NumberTypes()) {
133 allowed->mutable_list()->add_type(dt);
134 }
135 } else if (type_string == "quantizedtype") {
136 for (DataType dt : QuantizedTypes()) {
137 allowed->mutable_list()->add_type(dt);
138 }
139 } else if (type_string == "realnumbertype" ||
140 type_string == "realnumerictype") {
141 for (DataType dt : RealNumberTypes()) {
142 allowed->mutable_list()->add_type(dt);
143 }
144 } else {
145 return false;
146 }
147 return true;
148 }
149
FinalizeAttr(StringPiece spec,bool allow_attr_type_any,OpDef * op_def,std::vector<string> * errors)150 void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
151 std::vector<string>* errors) {
152 OpDef::AttrDef* attr = op_def->add_attr();
153 StringPiece orig(spec);
154
155 // Parse "<name>:" at the beginning.
156 StringPiece tmp_name;
157 VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
158 attr->set_name(tmp_name.data(), tmp_name.size());
159
160 // Read "<type>" or "list(<type>)".
161 bool is_list = ConsumeListPrefix(&spec);
162 string type;
163 StringPiece type_string; // Used if type == "type"
164 if (absl::ConsumePrefix(&spec, "string")) {
165 type = "string";
166 } else if (absl::ConsumePrefix(&spec, "int")) {
167 type = "int";
168 } else if (absl::ConsumePrefix(&spec, "float")) {
169 type = "float";
170 } else if (absl::ConsumePrefix(&spec, "bool")) {
171 type = "bool";
172 } else if (absl::ConsumePrefix(&spec, "type")) {
173 type = "type";
174 } else if (absl::ConsumePrefix(&spec, "shape")) {
175 type = "shape";
176 } else if (absl::ConsumePrefix(&spec, "tensor")) {
177 type = "tensor";
178 } else if (absl::ConsumePrefix(&spec, "func")) {
179 type = "func";
180 } else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
181 type = "any";
182 } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
183 type = "type";
184 AttrValue* allowed = attr->mutable_allowed_values();
185 VERIFY(ProcessCompoundType(type_string, allowed),
186 "Expected to see a compound type, saw: ", type_string);
187 } else if (absl::ConsumePrefix(&spec, "{")) {
188 // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
189 AttrValue* allowed = attr->mutable_allowed_values();
190 str_util::RemoveLeadingWhitespace(&spec);
191 if (absl::StartsWith(spec, "\"") || absl::StartsWith(spec, "'")) {
192 type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
193 while (true) {
194 StringPiece escaped_string;
195 VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
196 ConsumeQuotedString('\'', &spec, &escaped_string),
197 "Trouble parsing allowed string at '", spec, "'");
198 string unescaped;
199 string error;
200 VERIFY(absl::CUnescape(escaped_string, &unescaped, &error),
201 "Trouble unescaping \"", escaped_string,
202 "\", got error: ", error);
203 allowed->mutable_list()->add_s(unescaped);
204 if (absl::ConsumePrefix(&spec, ",")) {
205 str_util::RemoveLeadingWhitespace(&spec);
206 if (absl::ConsumePrefix(&spec, "}"))
207 break; // Allow ending with ", }".
208 } else {
209 VERIFY(absl::ConsumePrefix(&spec, "}"),
210 "Expected , or } after strings in list, not: '", spec, "'");
211 break;
212 }
213 }
214 } else { // "{ bool, numbertype, string }"
215 type = "type";
216 while (true) {
217 VERIFY(ConsumeAttrType(&spec, &type_string),
218 "Trouble parsing type string at '", spec, "'");
219 if (ProcessCompoundType(type_string, allowed)) {
220 // Processed a compound type.
221 } else {
222 DataType dt;
223 VERIFY(DataTypeFromString(type_string, &dt),
224 "Unrecognized type string '", type_string, "'");
225 allowed->mutable_list()->add_type(dt);
226 }
227 if (absl::ConsumePrefix(&spec, ",")) {
228 str_util::RemoveLeadingWhitespace(&spec);
229 if (absl::ConsumePrefix(&spec, "}"))
230 break; // Allow ending with ", }".
231 } else {
232 VERIFY(absl::ConsumePrefix(&spec, "}"),
233 "Expected , or } after types in list, not: '", spec, "'");
234 break;
235 }
236 }
237 }
238 } else { // if spec.Consume("{")
239 VERIFY(false, "Trouble parsing type string at '", spec, "'");
240 }
241 str_util::RemoveLeadingWhitespace(&spec);
242
243 // Write the type into *attr.
244 if (is_list) {
245 VERIFY(absl::ConsumePrefix(&spec, ")"),
246 "Expected ) to close 'list(', not: '", spec, "'");
247 str_util::RemoveLeadingWhitespace(&spec);
248 attr->set_type(strings::StrCat("list(", type, ")"));
249 } else {
250 attr->set_type(type);
251 }
252
253 // Read optional minimum constraint at the end.
254 if ((is_list || type == "int") && absl::ConsumePrefix(&spec, ">=")) {
255 int64_t min_limit = -999;
256 VERIFY(ConsumeAttrNumber(&spec, &min_limit),
257 "Could not parse integer lower limit after '>=', found '", spec,
258 "' instead");
259 attr->set_has_minimum(true);
260 attr->set_minimum(min_limit);
261 }
262
263 // Parse default value, if present.
264 if (absl::ConsumePrefix(&spec, "=")) {
265 str_util::RemoveLeadingWhitespace(&spec);
266 VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
267 "Could not parse default value '", spec, "'");
268 } else {
269 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
270 }
271 }
272
273 #undef VERIFY
274
InOutError(bool is_output,StringPiece orig,const string & op_name)275 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
276 return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
277 "\") for Op ", op_name);
278 }
279
ConsumeInOutName(StringPiece * sp,StringPiece * out)280 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
281 return Scanner(*sp)
282 .One(Scanner::LOWERLETTER)
283 .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
284 .StopCapture()
285 .AnySpace()
286 .OneLiteral(":")
287 .AnySpace()
288 .GetResult(sp, out);
289 }
290
ConsumeInOutRefOpen(StringPiece * sp)291 bool ConsumeInOutRefOpen(StringPiece* sp) {
292 return Scanner(*sp)
293 .OneLiteral("Ref")
294 .AnySpace()
295 .OneLiteral("(")
296 .AnySpace()
297 .GetResult(sp);
298 }
299
ConsumeInOutRefClose(StringPiece * sp)300 bool ConsumeInOutRefClose(StringPiece* sp) {
301 return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
302 }
303
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)304 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
305 return Scanner(*sp)
306 .One(Scanner::LETTER)
307 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
308 .StopCapture()
309 .AnySpace()
310 .GetResult(sp, out);
311 }
312
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)313 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
314 return Scanner(*sp)
315 .OneLiteral("*")
316 .AnySpace()
317 .RestartCapture()
318 .One(Scanner::LETTER)
319 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
320 .StopCapture()
321 .AnySpace()
322 .GetResult(sp, out);
323 }
324
ConsumeControlOutName(StringPiece * sp,StringPiece * out)325 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
326 return Scanner(*sp)
327 .One(Scanner::LETTER)
328 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
329 .StopCapture()
330 .GetResult(sp, out);
331 }
332
333 #define VERIFY(expr, ...) \
334 do { \
335 if (!(expr)) { \
336 errors->push_back(strings::StrCat( \
337 __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
338 return; \
339 } \
340 } while (false)
341
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)342 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
343 std::vector<string>* errors) {
344 OpDef::ArgDef* arg =
345 is_output ? op_def->add_output_arg() : op_def->add_input_arg();
346
347 StringPiece orig(spec);
348
349 // Parse "<name>:" at the beginning.
350 StringPiece tmp_name;
351 VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
352 arg->set_name(tmp_name.data(), tmp_name.size());
353
354 // Detect "Ref(...)".
355 if (ConsumeInOutRefOpen(&spec)) {
356 arg->set_is_ref(true);
357 }
358
359 { // Parse "<name|type>" or "<name>*<name|type>".
360 StringPiece first, second, type_or_attr;
361 VERIFY(ConsumeInOutNameOrType(&spec, &first),
362 "Trouble parsing either a type or an attr name at '", spec, "'");
363 if (ConsumeInOutTimesType(&spec, &second)) {
364 arg->set_number_attr(first.data(), first.size());
365 type_or_attr = second;
366 } else {
367 type_or_attr = first;
368 }
369 DataType dt;
370 if (DataTypeFromString(type_or_attr, &dt)) {
371 arg->set_type(dt);
372 } else {
373 const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
374 VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
375 if (attr->type() == "type") {
376 arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
377 } else {
378 VERIFY(attr->type() == "list(type)", "Reference to attr '",
379 type_or_attr, "' with type ", attr->type(),
380 " that isn't type or list(type)");
381 arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
382 }
383 }
384 }
385
386 // Closing ) for Ref(.
387 if (arg->is_ref()) {
388 VERIFY(ConsumeInOutRefClose(&spec),
389 "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
390 }
391
392 // Should not have anything else.
393 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
394
395 // Int attrs that are the length of an input or output get a default
396 // minimum of 1.
397 if (!arg->number_attr().empty()) {
398 OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
399 if (attr != nullptr && !attr->has_minimum()) {
400 attr->set_has_minimum(true);
401 attr->set_minimum(1);
402 }
403 } else if (!arg->type_list_attr().empty()) {
404 // If an input or output has type specified by a list(type) attr,
405 // it gets a default minimum of 1 as well.
406 OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
407 if (attr != nullptr && attr->type() == "list(type)" &&
408 !attr->has_minimum()) {
409 attr->set_has_minimum(true);
410 attr->set_minimum(1);
411 }
412 }
413
414 // If the arg's dtype is resource we should mark the op as stateful as it
415 // likely touches a resource manager. This deliberately doesn't cover inputs /
416 // outputs which resolve to resource via Attrs as those mostly operate on
417 // resource handles as an opaque type (as opposed to ops which explicitly take
418 // / produce resources).
419 if (arg->type() == DT_RESOURCE) {
420 op_def->set_is_stateful(true);
421 }
422 }
423
424 #undef VERIFY
425
ControlOutError(StringPiece orig,const string & op_name)426 string ControlOutError(StringPiece orig, const string& op_name) {
427 return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
428 op_name);
429 }
430
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)431 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
432 std::vector<string>* errors) {
433 StringPiece orig(name);
434
435 // Parse control output name.
436 StringPiece tmp_name;
437 if (!ConsumeControlOutName(&orig, &tmp_name)) {
438 errors->push_back(strings::StrCat("Trouble parsing 'name:'",
439 ControlOutError(orig, op_def->name())));
440 }
441
442 *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
443 }
444
num_leading_spaces(StringPiece s)445 int num_leading_spaces(StringPiece s) {
446 size_t i = 0;
447 while (i < s.size() && s[i] == ' ') {
448 ++i;
449 }
450 return i;
451 }
452
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)453 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
454 return Scanner(*sp)
455 .One(Scanner::LETTER)
456 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
457 .StopCapture()
458 .AnySpace()
459 .OneLiteral(":")
460 .AnySpace()
461 .GetResult(sp, out);
462 }
463
IsDocNameColon(StringPiece s)464 bool IsDocNameColon(StringPiece s) {
465 return ConsumeDocNameColon(&s, nullptr /* out */);
466 }
467
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)468 void FinalizeDoc(const string& text, OpDef* op_def,
469 std::vector<string>* errors) {
470 std::vector<string> lines = str_util::Split(text, '\n');
471
472 // Remove trailing spaces.
473 for (string& line : lines) {
474 absl::StripTrailingAsciiWhitespace(&line);
475 }
476
477 // First non-blank line -> summary.
478 int l = 0;
479 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
480 if (static_cast<size_t>(l) < lines.size()) {
481 op_def->set_summary(lines[l]);
482 ++l;
483 }
484 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
485
486 // Lines until we see name: -> description.
487 int start_l = l;
488 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
489 ++l;
490 }
491 int end_l = l;
492 // Trim trailing blank lines from the description.
493 while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
494 string desc = absl::StrJoin(
495 gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
496 if (!desc.empty()) op_def->set_description(desc);
497
498 // name: description
499 // possibly continued on the next line
500 // if so, we remove the minimum indent
501 StringPiece name;
502 std::vector<StringPiece> description;
503 while (static_cast<size_t>(l) < lines.size()) {
504 description.clear();
505 description.push_back(lines[l]);
506 ConsumeDocNameColon(&description.back(), &name);
507 ++l;
508 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
509 description.push_back(lines[l]);
510 ++l;
511 }
512 // Remove any trailing blank lines.
513 while (!description.empty() && description.back().empty()) {
514 description.pop_back();
515 }
516 // Compute the minimum indent of all lines after the first.
517 int min_indent = -1;
518 for (size_t i = 1; i < description.size(); ++i) {
519 if (!description[i].empty()) {
520 int indent = num_leading_spaces(description[i]);
521 if (min_indent < 0 || indent < min_indent) min_indent = indent;
522 }
523 }
524 // Remove min_indent spaces from all lines after the first.
525 for (size_t i = 1; i < description.size(); ++i) {
526 if (!description[i].empty()) description[i].remove_prefix(min_indent);
527 }
528 // Concatenate lines into a single string.
529 const string complete(absl::StrJoin(description, "\n"));
530
531 // Find name.
532 bool found = false;
533 for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
534 if (op_def->input_arg(i).name() == name) {
535 op_def->mutable_input_arg(i)->set_description(complete);
536 found = true;
537 }
538 }
539 for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
540 if (op_def->output_arg(i).name() == name) {
541 op_def->mutable_output_arg(i)->set_description(complete);
542 found = true;
543 }
544 }
545 for (int i = 0; !found && i < op_def->attr_size(); ++i) {
546 if (op_def->attr(i).name() == name) {
547 op_def->mutable_attr(i)->set_description(complete);
548 found = true;
549 }
550 }
551 if (!found) {
552 errors->push_back(
553 strings::StrCat("No matching input/output/attr for name '", name,
554 "' from Doc() for Op ", op_def->name()));
555 return;
556 }
557 }
558 }
559
560 } // namespace
561
OpDefBuilder(string op_name)562 OpDefBuilder::OpDefBuilder(string op_name) {
563 op_def()->set_name(std::move(op_name));
564 }
565
Attr(string spec)566 OpDefBuilder& OpDefBuilder::Attr(string spec) {
567 attrs_.push_back(std::move(spec));
568 return *this;
569 }
570
Input(string spec)571 OpDefBuilder& OpDefBuilder::Input(string spec) {
572 inputs_.push_back(std::move(spec));
573 return *this;
574 }
575
Output(string spec)576 OpDefBuilder& OpDefBuilder::Output(string spec) {
577 outputs_.push_back(std::move(spec));
578 return *this;
579 }
580
ControlOutput(string name)581 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
582 control_outputs_.push_back(std::move(name));
583 return *this;
584 }
585
Doc(string text)586 OpDefBuilder& OpDefBuilder::Doc(string text) {
587 #ifndef TF_LEAN_BINARY
588 if (!doc_.empty()) {
589 errors_.push_back(
590 strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
591 } else {
592 doc_ = std::move(text);
593 }
594 #endif
595 return *this;
596 }
597
SetIsCommutative()598 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
599 op_def()->set_is_commutative(true);
600 return *this;
601 }
602
SetIsAggregate()603 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
604 op_def()->set_is_aggregate(true);
605 return *this;
606 }
607
SetIsStateful()608 OpDefBuilder& OpDefBuilder::SetIsStateful() {
609 op_def()->set_is_stateful(true);
610 return *this;
611 }
612
SetAllowsUninitializedInput()613 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
614 op_def()->set_allows_uninitialized_input(true);
615 return *this;
616 }
617
SetIsDistributedCommunication()618 OpDefBuilder& OpDefBuilder::SetIsDistributedCommunication() {
619 op_def()->set_is_distributed_communication(true);
620 return *this;
621 }
622
Deprecated(int version,string explanation)623 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
624 if (op_def()->has_deprecation()) {
625 errors_.push_back(
626 strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
627 } else {
628 OpDeprecation* deprecation = op_def()->mutable_deprecation();
629 deprecation->set_version(version);
630 deprecation->set_explanation(std::move(explanation));
631 }
632 return *this;
633 }
634
SetTypeConstructor(OpTypeConstructor c)635 OpDefBuilder& OpDefBuilder::SetTypeConstructor(OpTypeConstructor c) {
636 op_reg_data_.type_ctor = c;
637 return *this;
638 }
639
SetForwardTypeFn(ForwardTypeInferenceFn f)640 OpDefBuilder& OpDefBuilder::SetForwardTypeFn(ForwardTypeInferenceFn f) {
641 op_reg_data_.fwd_type_fn = f;
642 return *this;
643 }
644
SetReverseTypeFn(int input_number,ForwardTypeInferenceFn f)645 OpDefBuilder& OpDefBuilder::SetReverseTypeFn(int input_number,
646 ForwardTypeInferenceFn f) {
647 op_reg_data_.rev_type_fn = f;
648 op_reg_data_.rev_type_input = input_number;
649 return *this;
650 }
651
SetShapeFn(OpShapeInferenceFn fn)652 OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
653 if (op_reg_data_.shape_inference_fn != nullptr) {
654 errors_.push_back(
655 strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
656 } else {
657 op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
658 }
659 return *this;
660 }
661
AllowAttrTypeAny()662 OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
663 allow_attr_type_any_ = true;
664 return *this;
665 }
666
Finalize(OpRegistrationData * op_reg_data) const667 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
668 std::vector<string> errors = errors_;
669 *op_reg_data = op_reg_data_;
670
671 OpDef* op_def = &op_reg_data->op_def;
672 for (StringPiece attr : attrs_) {
673 FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
674 }
675 for (StringPiece input : inputs_) {
676 FinalizeInputOrOutput(input, false, op_def, &errors);
677 }
678 for (StringPiece output : outputs_) {
679 FinalizeInputOrOutput(output, true, op_def, &errors);
680 }
681 for (StringPiece control_output : control_outputs_) {
682 FinalizeControlOutput(control_output, op_def, &errors);
683 }
684 FinalizeDoc(doc_, op_def, &errors);
685
686 if (op_reg_data->type_ctor != nullptr) {
687 TF_RETURN_IF_ERROR(op_reg_data->type_ctor(op_def));
688 }
689
690 if (errors.empty()) return OkStatus();
691 return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
692 }
693
694 } // namespace tensorflow
695