xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/op_def_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_builder.h"
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/scanner.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/errors.h"
32 
33 using ::tensorflow::strings::Scanner;
34 
35 namespace tensorflow {
36 
37 namespace {
38 
AttrError(StringPiece orig,const string & op_name)39 string AttrError(StringPiece orig, const string& op_name) {
40   return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
41 }
42 
ConsumeAttrName(StringPiece * sp,StringPiece * out)43 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
44   return Scanner(*sp)
45       .One(Scanner::LETTER)
46       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
47       .StopCapture()
48       .AnySpace()
49       .OneLiteral(":")
50       .AnySpace()
51       .GetResult(sp, out);
52 }
53 
ConsumeListPrefix(StringPiece * sp)54 bool ConsumeListPrefix(StringPiece* sp) {
55   return Scanner(*sp)
56       .OneLiteral("list")
57       .AnySpace()
58       .OneLiteral("(")
59       .AnySpace()
60       .GetResult(sp);
61 }
62 
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)63 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
64   const string quote_str(1, quote_ch);
65   return Scanner(*sp)
66       .OneLiteral(quote_str.c_str())
67       .RestartCapture()
68       .ScanEscapedUntil(quote_ch)
69       .StopCapture()
70       .OneLiteral(quote_str.c_str())
71       .AnySpace()
72       .GetResult(sp, out);
73 }
74 
ConsumeAttrType(StringPiece * sp,StringPiece * out)75 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
76   return Scanner(*sp)
77       .Many(Scanner::LOWERLETTER_DIGIT)
78       .StopCapture()
79       .AnySpace()
80       .GetResult(sp, out);
81 }
82 
ConsumeAttrNumber(StringPiece * sp,int64_t * out)83 bool ConsumeAttrNumber(StringPiece* sp, int64_t* out) {
84   Scanner scan(*sp);
85   StringPiece match;
86   StringPiece remaining;
87 
88   scan.AnySpace().RestartCapture();
89   if (scan.Peek() == '-') {
90     scan.OneLiteral("-");
91   }
92   if (!scan.Many(Scanner::DIGIT)
93            .StopCapture()
94            .AnySpace()
95            .GetResult(&remaining, &match)) {
96     return false;
97   }
98   int64_t value = 0;
99   if (!strings::safe_strto64(match, &value)) {
100     return false;
101   }
102   *out = value;
103   *sp = remaining;
104   return true;
105 }
106 
107 #define VERIFY(expr, ...)                                                 \
108   do {                                                                    \
109     if (!(expr)) {                                                        \
110       errors->push_back(                                                  \
111           strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
112       return;                                                             \
113     }                                                                     \
114   } while (false)
115 
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)116 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
117   auto capture_data = sp->data();
118   auto capture_begin = sp->begin();
119   if (absl::ConsumePrefix(sp, "numbertype") ||
120       absl::ConsumePrefix(sp, "numerictype") ||
121       absl::ConsumePrefix(sp, "quantizedtype") ||
122       absl::ConsumePrefix(sp, "realnumbertype") ||
123       absl::ConsumePrefix(sp, "realnumberictype")) {
124     *out = StringPiece(capture_data, sp->begin() - capture_begin);
125     return true;
126   }
127   return false;
128 }
129 
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)130 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
131   if (type_string == "numbertype" || type_string == "numerictype") {
132     for (DataType dt : NumberTypes()) {
133       allowed->mutable_list()->add_type(dt);
134     }
135   } else if (type_string == "quantizedtype") {
136     for (DataType dt : QuantizedTypes()) {
137       allowed->mutable_list()->add_type(dt);
138     }
139   } else if (type_string == "realnumbertype" ||
140              type_string == "realnumerictype") {
141     for (DataType dt : RealNumberTypes()) {
142       allowed->mutable_list()->add_type(dt);
143     }
144   } else {
145     return false;
146   }
147   return true;
148 }
149 
FinalizeAttr(StringPiece spec,bool allow_attr_type_any,OpDef * op_def,std::vector<string> * errors)150 void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
151                   std::vector<string>* errors) {
152   OpDef::AttrDef* attr = op_def->add_attr();
153   StringPiece orig(spec);
154 
155   // Parse "<name>:" at the beginning.
156   StringPiece tmp_name;
157   VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
158   attr->set_name(tmp_name.data(), tmp_name.size());
159 
160   // Read "<type>" or "list(<type>)".
161   bool is_list = ConsumeListPrefix(&spec);
162   string type;
163   StringPiece type_string;  // Used if type == "type"
164   if (absl::ConsumePrefix(&spec, "string")) {
165     type = "string";
166   } else if (absl::ConsumePrefix(&spec, "int")) {
167     type = "int";
168   } else if (absl::ConsumePrefix(&spec, "float")) {
169     type = "float";
170   } else if (absl::ConsumePrefix(&spec, "bool")) {
171     type = "bool";
172   } else if (absl::ConsumePrefix(&spec, "type")) {
173     type = "type";
174   } else if (absl::ConsumePrefix(&spec, "shape")) {
175     type = "shape";
176   } else if (absl::ConsumePrefix(&spec, "tensor")) {
177     type = "tensor";
178   } else if (absl::ConsumePrefix(&spec, "func")) {
179     type = "func";
180   } else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
181     type = "any";
182   } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
183     type = "type";
184     AttrValue* allowed = attr->mutable_allowed_values();
185     VERIFY(ProcessCompoundType(type_string, allowed),
186            "Expected to see a compound type, saw: ", type_string);
187   } else if (absl::ConsumePrefix(&spec, "{")) {
188     // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
189     AttrValue* allowed = attr->mutable_allowed_values();
190     str_util::RemoveLeadingWhitespace(&spec);
191     if (absl::StartsWith(spec, "\"") || absl::StartsWith(spec, "'")) {
192       type = "string";  // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
193       while (true) {
194         StringPiece escaped_string;
195         VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
196                    ConsumeQuotedString('\'', &spec, &escaped_string),
197                "Trouble parsing allowed string at '", spec, "'");
198         string unescaped;
199         string error;
200         VERIFY(absl::CUnescape(escaped_string, &unescaped, &error),
201                "Trouble unescaping \"", escaped_string,
202                "\", got error: ", error);
203         allowed->mutable_list()->add_s(unescaped);
204         if (absl::ConsumePrefix(&spec, ",")) {
205           str_util::RemoveLeadingWhitespace(&spec);
206           if (absl::ConsumePrefix(&spec, "}"))
207             break;  // Allow ending with ", }".
208         } else {
209           VERIFY(absl::ConsumePrefix(&spec, "}"),
210                  "Expected , or } after strings in list, not: '", spec, "'");
211           break;
212         }
213       }
214     } else {  // "{ bool, numbertype, string }"
215       type = "type";
216       while (true) {
217         VERIFY(ConsumeAttrType(&spec, &type_string),
218                "Trouble parsing type string at '", spec, "'");
219         if (ProcessCompoundType(type_string, allowed)) {
220           // Processed a compound type.
221         } else {
222           DataType dt;
223           VERIFY(DataTypeFromString(type_string, &dt),
224                  "Unrecognized type string '", type_string, "'");
225           allowed->mutable_list()->add_type(dt);
226         }
227         if (absl::ConsumePrefix(&spec, ",")) {
228           str_util::RemoveLeadingWhitespace(&spec);
229           if (absl::ConsumePrefix(&spec, "}"))
230             break;  // Allow ending with ", }".
231         } else {
232           VERIFY(absl::ConsumePrefix(&spec, "}"),
233                  "Expected , or } after types in list, not: '", spec, "'");
234           break;
235         }
236       }
237     }
238   } else {  // if spec.Consume("{")
239     VERIFY(false, "Trouble parsing type string at '", spec, "'");
240   }
241   str_util::RemoveLeadingWhitespace(&spec);
242 
243   // Write the type into *attr.
244   if (is_list) {
245     VERIFY(absl::ConsumePrefix(&spec, ")"),
246            "Expected ) to close 'list(', not: '", spec, "'");
247     str_util::RemoveLeadingWhitespace(&spec);
248     attr->set_type(strings::StrCat("list(", type, ")"));
249   } else {
250     attr->set_type(type);
251   }
252 
253   // Read optional minimum constraint at the end.
254   if ((is_list || type == "int") && absl::ConsumePrefix(&spec, ">=")) {
255     int64_t min_limit = -999;
256     VERIFY(ConsumeAttrNumber(&spec, &min_limit),
257            "Could not parse integer lower limit after '>=', found '", spec,
258            "' instead");
259     attr->set_has_minimum(true);
260     attr->set_minimum(min_limit);
261   }
262 
263   // Parse default value, if present.
264   if (absl::ConsumePrefix(&spec, "=")) {
265     str_util::RemoveLeadingWhitespace(&spec);
266     VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
267            "Could not parse default value '", spec, "'");
268   } else {
269     VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
270   }
271 }
272 
273 #undef VERIFY
274 
InOutError(bool is_output,StringPiece orig,const string & op_name)275 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
276   return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
277                          "\") for Op ", op_name);
278 }
279 
ConsumeInOutName(StringPiece * sp,StringPiece * out)280 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
281   return Scanner(*sp)
282       .One(Scanner::LOWERLETTER)
283       .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
284       .StopCapture()
285       .AnySpace()
286       .OneLiteral(":")
287       .AnySpace()
288       .GetResult(sp, out);
289 }
290 
ConsumeInOutRefOpen(StringPiece * sp)291 bool ConsumeInOutRefOpen(StringPiece* sp) {
292   return Scanner(*sp)
293       .OneLiteral("Ref")
294       .AnySpace()
295       .OneLiteral("(")
296       .AnySpace()
297       .GetResult(sp);
298 }
299 
ConsumeInOutRefClose(StringPiece * sp)300 bool ConsumeInOutRefClose(StringPiece* sp) {
301   return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
302 }
303 
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)304 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
305   return Scanner(*sp)
306       .One(Scanner::LETTER)
307       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
308       .StopCapture()
309       .AnySpace()
310       .GetResult(sp, out);
311 }
312 
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)313 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
314   return Scanner(*sp)
315       .OneLiteral("*")
316       .AnySpace()
317       .RestartCapture()
318       .One(Scanner::LETTER)
319       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
320       .StopCapture()
321       .AnySpace()
322       .GetResult(sp, out);
323 }
324 
ConsumeControlOutName(StringPiece * sp,StringPiece * out)325 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
326   return Scanner(*sp)
327       .One(Scanner::LETTER)
328       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
329       .StopCapture()
330       .GetResult(sp, out);
331 }
332 
333 #define VERIFY(expr, ...)                                             \
334   do {                                                                \
335     if (!(expr)) {                                                    \
336       errors->push_back(strings::StrCat(                              \
337           __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
338       return;                                                         \
339     }                                                                 \
340   } while (false)
341 
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)342 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
343                            std::vector<string>* errors) {
344   OpDef::ArgDef* arg =
345       is_output ? op_def->add_output_arg() : op_def->add_input_arg();
346 
347   StringPiece orig(spec);
348 
349   // Parse "<name>:" at the beginning.
350   StringPiece tmp_name;
351   VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
352   arg->set_name(tmp_name.data(), tmp_name.size());
353 
354   // Detect "Ref(...)".
355   if (ConsumeInOutRefOpen(&spec)) {
356     arg->set_is_ref(true);
357   }
358 
359   {  // Parse "<name|type>" or "<name>*<name|type>".
360     StringPiece first, second, type_or_attr;
361     VERIFY(ConsumeInOutNameOrType(&spec, &first),
362            "Trouble parsing either a type or an attr name at '", spec, "'");
363     if (ConsumeInOutTimesType(&spec, &second)) {
364       arg->set_number_attr(first.data(), first.size());
365       type_or_attr = second;
366     } else {
367       type_or_attr = first;
368     }
369     DataType dt;
370     if (DataTypeFromString(type_or_attr, &dt)) {
371       arg->set_type(dt);
372     } else {
373       const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
374       VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
375       if (attr->type() == "type") {
376         arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
377       } else {
378         VERIFY(attr->type() == "list(type)", "Reference to attr '",
379                type_or_attr, "' with type ", attr->type(),
380                " that isn't type or list(type)");
381         arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
382       }
383     }
384   }
385 
386   // Closing ) for Ref(.
387   if (arg->is_ref()) {
388     VERIFY(ConsumeInOutRefClose(&spec),
389            "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
390   }
391 
392   // Should not have anything else.
393   VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
394 
395   // Int attrs that are the length of an input or output get a default
396   // minimum of 1.
397   if (!arg->number_attr().empty()) {
398     OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
399     if (attr != nullptr && !attr->has_minimum()) {
400       attr->set_has_minimum(true);
401       attr->set_minimum(1);
402     }
403   } else if (!arg->type_list_attr().empty()) {
404     // If an input or output has type specified by a list(type) attr,
405     // it gets a default minimum of 1 as well.
406     OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
407     if (attr != nullptr && attr->type() == "list(type)" &&
408         !attr->has_minimum()) {
409       attr->set_has_minimum(true);
410       attr->set_minimum(1);
411     }
412   }
413 
414   // If the arg's dtype is resource we should mark the op as stateful as it
415   // likely touches a resource manager. This deliberately doesn't cover inputs /
416   // outputs which resolve to resource via Attrs as those mostly operate on
417   // resource handles as an opaque type (as opposed to ops which explicitly take
418   // / produce resources).
419   if (arg->type() == DT_RESOURCE) {
420     op_def->set_is_stateful(true);
421   }
422 }
423 
424 #undef VERIFY
425 
ControlOutError(StringPiece orig,const string & op_name)426 string ControlOutError(StringPiece orig, const string& op_name) {
427   return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
428                          op_name);
429 }
430 
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)431 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
432                            std::vector<string>* errors) {
433   StringPiece orig(name);
434 
435   // Parse control output name.
436   StringPiece tmp_name;
437   if (!ConsumeControlOutName(&orig, &tmp_name)) {
438     errors->push_back(strings::StrCat("Trouble parsing 'name:'",
439                                       ControlOutError(orig, op_def->name())));
440   }
441 
442   *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
443 }
444 
num_leading_spaces(StringPiece s)445 int num_leading_spaces(StringPiece s) {
446   size_t i = 0;
447   while (i < s.size() && s[i] == ' ') {
448     ++i;
449   }
450   return i;
451 }
452 
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)453 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
454   return Scanner(*sp)
455       .One(Scanner::LETTER)
456       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
457       .StopCapture()
458       .AnySpace()
459       .OneLiteral(":")
460       .AnySpace()
461       .GetResult(sp, out);
462 }
463 
IsDocNameColon(StringPiece s)464 bool IsDocNameColon(StringPiece s) {
465   return ConsumeDocNameColon(&s, nullptr /* out */);
466 }
467 
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)468 void FinalizeDoc(const string& text, OpDef* op_def,
469                  std::vector<string>* errors) {
470   std::vector<string> lines = str_util::Split(text, '\n');
471 
472   // Remove trailing spaces.
473   for (string& line : lines) {
474     absl::StripTrailingAsciiWhitespace(&line);
475   }
476 
477   // First non-blank line -> summary.
478   int l = 0;
479   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
480   if (static_cast<size_t>(l) < lines.size()) {
481     op_def->set_summary(lines[l]);
482     ++l;
483   }
484   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
485 
486   // Lines until we see name: -> description.
487   int start_l = l;
488   while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
489     ++l;
490   }
491   int end_l = l;
492   // Trim trailing blank lines from the description.
493   while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
494   string desc = absl::StrJoin(
495       gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
496   if (!desc.empty()) op_def->set_description(desc);
497 
498   // name: description
499   //   possibly continued on the next line
500   //   if so, we remove the minimum indent
501   StringPiece name;
502   std::vector<StringPiece> description;
503   while (static_cast<size_t>(l) < lines.size()) {
504     description.clear();
505     description.push_back(lines[l]);
506     ConsumeDocNameColon(&description.back(), &name);
507     ++l;
508     while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
509       description.push_back(lines[l]);
510       ++l;
511     }
512     // Remove any trailing blank lines.
513     while (!description.empty() && description.back().empty()) {
514       description.pop_back();
515     }
516     // Compute the minimum indent of all lines after the first.
517     int min_indent = -1;
518     for (size_t i = 1; i < description.size(); ++i) {
519       if (!description[i].empty()) {
520         int indent = num_leading_spaces(description[i]);
521         if (min_indent < 0 || indent < min_indent) min_indent = indent;
522       }
523     }
524     // Remove min_indent spaces from all lines after the first.
525     for (size_t i = 1; i < description.size(); ++i) {
526       if (!description[i].empty()) description[i].remove_prefix(min_indent);
527     }
528     // Concatenate lines into a single string.
529     const string complete(absl::StrJoin(description, "\n"));
530 
531     // Find name.
532     bool found = false;
533     for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
534       if (op_def->input_arg(i).name() == name) {
535         op_def->mutable_input_arg(i)->set_description(complete);
536         found = true;
537       }
538     }
539     for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
540       if (op_def->output_arg(i).name() == name) {
541         op_def->mutable_output_arg(i)->set_description(complete);
542         found = true;
543       }
544     }
545     for (int i = 0; !found && i < op_def->attr_size(); ++i) {
546       if (op_def->attr(i).name() == name) {
547         op_def->mutable_attr(i)->set_description(complete);
548         found = true;
549       }
550     }
551     if (!found) {
552       errors->push_back(
553           strings::StrCat("No matching input/output/attr for name '", name,
554                           "' from Doc() for Op ", op_def->name()));
555       return;
556     }
557   }
558 }
559 
560 }  // namespace
561 
OpDefBuilder(string op_name)562 OpDefBuilder::OpDefBuilder(string op_name) {
563   op_def()->set_name(std::move(op_name));
564 }
565 
Attr(string spec)566 OpDefBuilder& OpDefBuilder::Attr(string spec) {
567   attrs_.push_back(std::move(spec));
568   return *this;
569 }
570 
Input(string spec)571 OpDefBuilder& OpDefBuilder::Input(string spec) {
572   inputs_.push_back(std::move(spec));
573   return *this;
574 }
575 
Output(string spec)576 OpDefBuilder& OpDefBuilder::Output(string spec) {
577   outputs_.push_back(std::move(spec));
578   return *this;
579 }
580 
ControlOutput(string name)581 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
582   control_outputs_.push_back(std::move(name));
583   return *this;
584 }
585 
Doc(string text)586 OpDefBuilder& OpDefBuilder::Doc(string text) {
587 #ifndef TF_LEAN_BINARY
588   if (!doc_.empty()) {
589     errors_.push_back(
590         strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
591   } else {
592     doc_ = std::move(text);
593   }
594 #endif
595   return *this;
596 }
597 
SetIsCommutative()598 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
599   op_def()->set_is_commutative(true);
600   return *this;
601 }
602 
SetIsAggregate()603 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
604   op_def()->set_is_aggregate(true);
605   return *this;
606 }
607 
SetIsStateful()608 OpDefBuilder& OpDefBuilder::SetIsStateful() {
609   op_def()->set_is_stateful(true);
610   return *this;
611 }
612 
SetAllowsUninitializedInput()613 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
614   op_def()->set_allows_uninitialized_input(true);
615   return *this;
616 }
617 
SetIsDistributedCommunication()618 OpDefBuilder& OpDefBuilder::SetIsDistributedCommunication() {
619   op_def()->set_is_distributed_communication(true);
620   return *this;
621 }
622 
Deprecated(int version,string explanation)623 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
624   if (op_def()->has_deprecation()) {
625     errors_.push_back(
626         strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
627   } else {
628     OpDeprecation* deprecation = op_def()->mutable_deprecation();
629     deprecation->set_version(version);
630     deprecation->set_explanation(std::move(explanation));
631   }
632   return *this;
633 }
634 
SetTypeConstructor(OpTypeConstructor c)635 OpDefBuilder& OpDefBuilder::SetTypeConstructor(OpTypeConstructor c) {
636   op_reg_data_.type_ctor = c;
637   return *this;
638 }
639 
SetForwardTypeFn(ForwardTypeInferenceFn f)640 OpDefBuilder& OpDefBuilder::SetForwardTypeFn(ForwardTypeInferenceFn f) {
641   op_reg_data_.fwd_type_fn = f;
642   return *this;
643 }
644 
SetReverseTypeFn(int input_number,ForwardTypeInferenceFn f)645 OpDefBuilder& OpDefBuilder::SetReverseTypeFn(int input_number,
646                                              ForwardTypeInferenceFn f) {
647   op_reg_data_.rev_type_fn = f;
648   op_reg_data_.rev_type_input = input_number;
649   return *this;
650 }
651 
SetShapeFn(OpShapeInferenceFn fn)652 OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
653   if (op_reg_data_.shape_inference_fn != nullptr) {
654     errors_.push_back(
655         strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
656   } else {
657     op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
658   }
659   return *this;
660 }
661 
AllowAttrTypeAny()662 OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
663   allow_attr_type_any_ = true;
664   return *this;
665 }
666 
Finalize(OpRegistrationData * op_reg_data) const667 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
668   std::vector<string> errors = errors_;
669   *op_reg_data = op_reg_data_;
670 
671   OpDef* op_def = &op_reg_data->op_def;
672   for (StringPiece attr : attrs_) {
673     FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
674   }
675   for (StringPiece input : inputs_) {
676     FinalizeInputOrOutput(input, false, op_def, &errors);
677   }
678   for (StringPiece output : outputs_) {
679     FinalizeInputOrOutput(output, true, op_def, &errors);
680   }
681   for (StringPiece control_output : control_outputs_) {
682     FinalizeControlOutput(control_output, op_def, &errors);
683   }
684   FinalizeDoc(doc_, op_def, &errors);
685 
686   if (op_reg_data->type_ctor != nullptr) {
687     TF_RETURN_IF_ERROR(op_reg_data->type_ctor(op_def));
688   }
689 
690   if (errors.empty()) return OkStatus();
691   return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
692 }
693 
694 }  // namespace tensorflow
695