xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/function_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55 
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59 
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61 
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64 
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67     "_GrapplerSpecializedFunc";
68 
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 //    attribute, and special Tensorflow kernels are responsible for calling the
75 //    function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76 
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79   return func_node.op() == func.signature().name();
80 }
81 
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85   if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86     return false;
87   }
88 
89   auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90   return func_attr != nullptr && func_attr->has_func() &&
91          func_attr->func().name() == func.signature().name();
92 }
93 
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95                                           const NodeDef& func_node) {
96   if (IsDirectFunctionCall(func, func_node)) {
97     return AttrSlice(func_node);
98 
99   } else if (IsIndirectFunctionCall(func, func_node)) {
100     auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101     return AttrSlice(&func_attr->func().attr());
102 
103   } else {
104     LOG(WARNING) << "Can't resolve function instantiation attributes: "
105                  << SummarizeNodeDef(func_node);
106     return AttrSlice();
107   }
108 }
109 
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114  public:
FakeDevice(Env * env,const string & device)115   FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116   explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117   Status Sync() override { return OkStatus(); }
118 
119  private:
attr(const string & device)120   static DeviceAttributes attr(const string& device) {
121     DeviceNameUtils::ParsedName parsed_name;
122     bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123     DCHECK(parsed) << "Failed to parse full device name: " << device;
124 
125     DeviceAttributes attr;
126     attr.set_name(device);
127     attr.set_device_type(parsed_name.type);
128     return attr;
129   }
130 };
131 
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147 
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149   const auto attr = AttrSlice(&fdef.attr());
150   bool nospecialize = false;
151   return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153 
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157   // Currently we do not support functions with tensor lists as inputs or
158   // outputs, so caller node input/output ports always match function
159   // input/output arguments.
160   using InputPort = int;
161   using OutputPort = int;
162 
163   string func_name;
164   bool is_in_fetch_set;
165   absl::flat_hash_set<OutputPort> active_outputs;
166   absl::flat_hash_map<string, DataType> type_parameters;
167   absl::flat_hash_map<string, AttrValue> body_parameters;
168   absl::flat_hash_map<InputPort, string> const_inputs;
169 
operator ==tensorflow::grappler::__anon7d4353b00111::FunctionSpecializationSignature170   bool operator==(const FunctionSpecializationSignature& other) const {
171     bool equals = func_name == other.func_name &&
172                   is_in_fetch_set == other.is_in_fetch_set &&
173                   active_outputs == other.active_outputs &&
174                   type_parameters == other.type_parameters &&
175                   const_inputs == other.const_inputs;
176 
177     if (!equals) return false;
178 
179     // Equality is not defined for AttrValue.
180     if (body_parameters.size() != other.body_parameters.size()) return false;
181 
182     for (const auto& lhs : body_parameters) {
183       auto it = other.body_parameters.find(lhs.first);
184       if (it == other.body_parameters.end()) return false;
185       if (!AreAttrValuesEqual(lhs.second, (*it).second,
186                               /*allow_false_negatives=*/true)) {
187         return false;
188       }
189     }
190 
191     return true;
192   }
193 
194   template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)195   friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
196     H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
197 
198     // First pre-compute hashes for all values in collections with
199     // non-deterministic iteration order.
200     std::vector<uint64> hashes;
201     hashes.reserve(s.active_outputs.size()         //
202                    + s.type_parameters.size() * 2  //
203                    + s.body_parameters.size() * 2  //
204                    + s.const_inputs.size() * 2);
205 
206     absl::c_transform(s.active_outputs, std::back_inserter(hashes),
207                       hash<OutputPort>());
208 
209     using TypeParam = std::pair<const string, DataType>;
210     absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
211       AttrValue attr_value;
212       attr_value.set_type(type_param.second);
213       hashes.push_back(Hash64(type_param.first));
214       hashes.push_back(AttrValueHash(attr_value));
215     });
216 
217     using BodyParam = std::pair<const string, AttrValue>;
218     absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
219       hashes.push_back(Hash64(body_param.first));
220       hashes.push_back(FastAttrValueHash(body_param.second));
221     });
222 
223     using ConstInput = std::pair<const InputPort, string>;
224     absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
225       hashes.push_back(hash<InputPort>()(const_input.first));
226       hashes.push_back(Hash64(const_input.second));
227     });
228 
229     // Combine all pre-computed hashes in a deterministic order.
230     absl::c_sort(hashes);
231     return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
232   }
233 };
234 
235 struct FunctionSpecialization {
236   string specialized_func_name;
237   // True if the function caller node is in GrapplerItem fetch set.
238   bool is_in_fetch_set;
239   // Names of the tensors that were pushed down into the function body.
240   absl::flat_hash_set<string> const_inputs;
241   // Control dependencies of pushed down const inputs have to be attached to
242   // function caller node.
243   absl::flat_hash_set<string> control_deps;
244   // Output tensors (ports) that consumed by other nodes in the graph or in a
245   // GrapplerItem fetch set.
246   absl::flat_hash_set<int> active_outputs;
247   // Mapping from original function output port to the output port of
248   // specialized function. If function specialization changes the number of
249   // function outputs it's required to update all node consumers.
250   std::vector<std::pair<int, int>> output_mapping;
251 };
252 
253 // Function optimizer context initialized once for each optimization pass, and
254 // it uses the latest available graph (for the first iteration it will be the
255 // GrapplerItem.graph, for next iterations it will be the output of previous
256 // function optimizer pass).
257 class FunctionOptimizerContext {
258  public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)259   explicit FunctionOptimizerContext(const GrapplerItem& item,
260                                     RewriterConfig::Toggle opt_level,
261                                     const GraphDef& graph)
262       : item_(&item),
263         opt_level_(opt_level),
264         function_library_(OpRegistry::Global(), graph.library()),
265         truly_const_nodes_(InferTrulyConstNodes(item, graph)),
266         graph_view_(&graph) {}
267 
item() const268   const GrapplerItem& item() const { return *item_; }
269 
graph_version() const270   const int graph_version() const { return item_->graph.versions().producer(); }
271 
opt_level() const272   RewriterConfig::Toggle opt_level() const { return opt_level_; }
273 
function_library() const274   const FunctionLibraryDefinition& function_library() const {
275     return function_library_;
276   }
function_library()277   FunctionLibraryDefinition& function_library() { return function_library_; }
278 
279   const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const280   tensor_mapping() const {
281     return tensor_mapping_;
282   }
283 
graph_view() const284   const GraphView& graph_view() const { return graph_view_; }
285 
IsFeedNode(const string & node_name) const286   bool IsFeedNode(const string& node_name) const {
287     return absl::c_any_of(
288         item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
289           return ParseTensorName(feed.first).node() == node_name;
290         });
291   }
292 
IsFetchNode(const string & node_name) const293   bool IsFetchNode(const string& node_name) const {
294     return absl::c_any_of(item_->fetch, [&](const string& fetch) {
295       return ParseTensorName(fetch).node() == node_name;
296     });
297   }
298 
IsTrulyConst(const string & name) const299   bool IsTrulyConst(const string& name) const {
300     return TrulyConstNode(name) != nullptr;
301   }
302 
TrulyConstNode(const string & name) const303   const NodeDef* TrulyConstNode(const string& name) const {
304     return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
305   }
306 
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const307   const FunctionSpecialization* FindFunctionSpecialization(
308       const FunctionSpecializationSignature& sig) const {
309     return gtl::FindOrNull(specialized_functions_, sig);
310   }
311 
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)312   void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
313                               const FunctionSpecialization& specialized_func) {
314     specialized_functions_.emplace(sig, specialized_func);
315   }
316 
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)317   void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
318     DCHECK(from.index() != Graph::kControlSlot)
319         << "Tensor mapping must be from regular tensor";
320     DCHECK(to.index() != Graph::kControlSlot)
321         << "Tensor mapping must be to regular tensor";
322 
323     auto inserted = tensor_mapping_.insert({from, to});
324     DCHECK(inserted.second)
325         << "Failed to insert duplicated tensor mapping: "
326         << "from=" << from.ToString() << " to=" << to.ToString();
327   }
328 
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)329   void AddTensorMapping(const string& func_node,
330                         const FunctionSpecialization& specialized_func) {
331     for (const auto& pair : specialized_func.output_mapping) {
332       int from_idx = pair.first;
333       int to_idx = pair.second;
334       if (from_idx != to_idx) {
335         SafeTensorId from_tensor(func_node, from_idx);
336         SafeTensorId to_tensor(func_node, to_idx);
337         AddTensorMapping(from_tensor, to_tensor);
338       }
339     }
340   }
341 
342  private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)343   static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
344       const GrapplerItem& item, const GraphDef& graph) {
345     absl::flat_hash_set<absl::string_view> feed_nodes;
346     for (const auto& feed : item.feed) {
347       feed_nodes.insert(feed.first);
348     }
349 
350     absl::flat_hash_map<string, const NodeDef*> const_nodes;
351     for (const NodeDef& node : graph.node()) {
352       if (IsConstant(node) && !feed_nodes.contains(node.name())) {
353         const_nodes[node.name()] = &node;
354       }
355     }
356 
357     return const_nodes;
358   }
359 
360   const GrapplerItem* item_;  // must outlive this object
361   RewriterConfig::Toggle opt_level_;
362 
363   // Function library constructed from current graph.
364   FunctionLibraryDefinition function_library_;
365 
366   // Nodes that are Const and not in feed.
367   absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
368   // Specialized functions.
369   absl::flat_hash_map<FunctionSpecializationSignature,
370                       const FunctionSpecialization>
371       specialized_functions_;
372 
373   // After function specialization, the optimized graph might be in invalid
374   // state, nodes can read from output index that is no longer valid after
375   // unused outputs pruning.
376   //
377   // Tensor mapping that has to be applied to the graph after all functions
378   // optimizations (invalidated tensor id -> optimized graph tensor id).
379   absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
380       tensor_mapping_;
381 
382   // Use graph view to find active outputs of the function caller nodes.
383   GraphView graph_view_;
384 
385   TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
386 };
387 
388 // Returns a pointer to the called function definition iff the given node is
389 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)390 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
391                                     const NodeDef& node) {
392   // Check if a node does indirect function call via PartitionedCallOp.
393   if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
394     const AttrValue* func_attr = AttrSlice(node).Find("f");
395     return (func_attr != nullptr && func_attr->has_func())
396                ? ctx.function_library().Find(func_attr->func().name())
397                : nullptr;
398   }
399 
400   // Check if the function op itself is a function name.
401   return ctx.function_library().Find(node.op());
402 }
403 
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)404 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
405                                           const FunctionOptimizerContext& ctx,
406                                           int size_hint = 0) {
407   absl::flat_hash_set<int> active_outputs;
408   active_outputs.reserve(static_cast<size_t>(size_hint));
409 
410   // 1. Output can be consumed by the other graph node.
411   const auto node_fanout_edges =
412       ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
413   for (const GraphView::Edge& edge : node_fanout_edges) {
414     active_outputs.insert(edge.src.port_id);
415   }
416 
417   // 2. Or it can be in a fetch set.
418   for (const string& fetch : ctx.item().fetch) {
419     TensorId fetch_tensor = ParseTensorName(fetch);
420     if (fetch_tensor.node() == node.name()) {
421       active_outputs.insert(fetch_tensor.index());
422     }
423   }
424 
425   return active_outputs;
426 }
427 
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)428 bool HasTrulyConstInputs(const NodeDef& node,
429                          const FunctionOptimizerContext& ctx) {
430   const auto is_truly_const = [&ctx](const string& input) {
431     return ctx.IsTrulyConst(NodeName(input));
432   };
433   return absl::c_any_of(node.input(), is_truly_const);
434 }
435 
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)436 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
437                       const FunctionOptimizerContext& ctx) {
438   // Functions with tensor list outputs are not supported right now, so the
439   // number of output args is the same as number of possible function caller
440   // node outputs.
441   int num_outputs = func.signature().output_arg_size();
442   const absl::flat_hash_set<int> active_outputs =
443       GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
444   int active_outputs_size = active_outputs.size();
445   return active_outputs_size != num_outputs;
446 }
447 
448 // Return pruned FunctionDefLibrary with functions that are reachable from
449 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)450 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
451                                         const GraphDef& optimized_graph) {
452   FunctionLibraryDefinition pruned_flib =
453       flib.ReachableDefinitions(optimized_graph);
454 
455   int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
456                          static_cast<int>(flib.num_functions());
457 
458   VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
459           << " functions (" << pruned_functions << ")";
460 
461   return pruned_flib.ToProto();
462 }
463 
464 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)465 Status PushDownConstInputs(const NodeDef& func_node,
466                            const FunctionOptimizerContext& ctx,
467                            GrapplerFunctionItem* item,
468                            absl::flat_hash_set<string>* const_inputs,
469                            absl::flat_hash_set<string>* control_deps) {
470   // Record node control dependencies in the control_deps set.
471   const auto record_control_deps = [&](const NodeDef* const_input) {
472     for (int i = const_input->input_size() - 1; i >= 0; --i) {
473       const string& input = const_input->input(i);
474       if (IsControlInput(input))
475         control_deps->insert(input);
476       else
477         break;
478     }
479   };
480 
481   for (int i = func_node.input_size() - 1; i >= 0; --i) {
482     const string& input = func_node.input(i);
483     if (IsControlInput(input)) continue;
484 
485     const string node_name = NodeName(input);
486     if (ctx.IsTrulyConst(node_name)) {
487       VLOG(3) << "Push const into function body: input=" << input;
488       const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
489       const_inputs->insert(input);
490       record_control_deps(const_input);
491       TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
492     }
493   }
494 
495   return OkStatus();
496 }
497 
498 // Remove inputs that were pushed into the function body, and attach their
499 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)500 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
501                                  NodeDef* specialized_func_node) {
502   // Nothing to do if it was no const inputs to the function node.
503   if (specialization.const_inputs.empty()) return;
504 
505   // Keep only non-const inputs.
506   std::vector<string> keep_inputs;
507   const auto& inputs = specialized_func_node->input();
508   absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
509                   [&](const string& input) {
510                     return !specialization.const_inputs.contains(input);
511                   });
512 
513   specialized_func_node->clear_input();
514   for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
515 
516   // Attach control dependencies of pushed down const input to the caller node.
517   if (!specialization.control_deps.empty()) {
518     absl::flat_hash_set<string> existing_control_deps;
519 
520     for (const string& input : keep_inputs) {
521       existing_control_deps.insert(AsControlDependency(NodeName(input)));
522     }
523 
524     for (const string& ctrl : specialization.control_deps) {
525       if (!existing_control_deps.contains(ctrl)) {
526         VLOG(3) << "Forward control dependency: input=" << ctrl;
527         specialized_func_node->add_input(ctrl);
528       }
529     }
530   }
531 }
532 
533 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)534 void RemovePushedDownConstInputTypes(
535     const FunctionSpecialization& specialization, const NodeDef& func_node,
536     NodeDef* specialized_func_node) {
537   // Nothing to do if it was no const inputs to the function node.
538   if (specialization.const_inputs.empty()) return;
539 
540   // Make sure that original function caller has Tin attribute.
541   const AttrValue* tin = AttrSlice(func_node).Find("Tin");
542   if (tin == nullptr || !tin->has_list()) return;
543 
544   // Clear input types for the specialized node.
545   auto* attr = specialized_func_node->mutable_attr();
546   (*attr)["Tin"].mutable_list()->clear_type();
547 
548   // Keep types of non-const inputs.
549   for (int i = 0; i < func_node.input_size(); ++i) {
550     const string& input = func_node.input(i);
551     if (IsControlInput(input)) break;
552 
553     if (!specialization.const_inputs.contains(input)) {
554       DataType dt = tin->list().type(i);
555       (*attr)["Tin"].mutable_list()->add_type(dt);
556     }
557   }
558 }
559 
560 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)561 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
562                               const NodeDef& func_node,
563                               NodeDef* specialized_func_node) {
564   // Make sure that original function caller has Tout attribute.
565   const AttrValue* tout = AttrSlice(func_node).Find("Tout");
566   if (tout == nullptr || !tout->has_list()) return;
567 
568   // Nothing to do if all outputs are active.
569   int specialization_active_outputs_size = specialization.active_outputs.size();
570   if (specialization_active_outputs_size == tout->list().type_size()) return;
571 
572   // Clear input types for the specialized node.
573   auto* attr = specialized_func_node->mutable_attr();
574   (*attr)["Tout"].mutable_list()->clear_type();
575 
576   // Keep output types of active outputs only.
577   for (int i = 0; i < tout->list().type_size(); ++i) {
578     if (specialization.active_outputs.contains(i)) {
579       DataType dt = tout->list().type(i);
580       (*attr)["Tout"].mutable_list()->add_type(dt);
581     }
582   }
583 }
584 
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)585 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
586                                          const NodeDef& func_node,
587                                          const string& specialized_func_name,
588                                          NodeDef* specialized_func_node) {
589   if (IsDirectFunctionCall(func, func_node)) {
590     specialized_func_node->set_op(specialized_func_name);
591 
592   } else if (IsIndirectFunctionCall(func, func_node)) {
593     auto* attr = specialized_func_node->mutable_attr();
594     (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
595 
596   } else {
597     return errors::InvalidArgument("Unknown function call site");
598   }
599 
600   return OkStatus();
601 }
602 
603 // Update a graph node created from the original function caller node, to the
604 // function specialization. Function specialization might change the number of
605 // inputs and outputs, so we have to make sure that graph node is updated
606 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)607 Status UpdateSpecializedFunctionNode(
608     const FunctionDef& func, const NodeDef& func_node,
609     const FunctionSpecialization& specialization,
610     NodeDef* specialized_func_node) {
611   // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
612   bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
613 
614   // 1. Call the specialized function instead of original one.
615   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
616       func, func_node, specialization.specialized_func_name,
617       specialized_func_node));
618 
619   // 2. Remove inputs corresponding to the pushed down consts.
620   RemovePushedDownConstInputs(specialization, specialized_func_node);
621 
622   // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
623   // types, that must be in sync with updated function signature.
624 
625   // 3. Update input types for the indirect function calls.
626   if (is_indirect_call) {
627     RemovePushedDownConstInputTypes(specialization, func_node,
628                                     specialized_func_node);
629   }
630 
631   // 4. Update output types for the indirect function call. It's unsafe to
632   // change the number of outputs for the fetch nodes, so we just skip them.
633   if (is_indirect_call && !specialization.is_in_fetch_set) {
634     RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
635   }
636 
637   // 5. Remove custom gradient annotation.
638   specialized_func_node->mutable_attr()->erase("_gradient_op_type");
639 
640   return OkStatus();
641 }
642 
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)643 Status InitializeFunctionSpecializationSignature(
644     const NodeDef& func_node, const FunctionDef& func,
645     const AttrSlice& func_instantiation_attr,
646     const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
647   DCHECK(sig->const_inputs.empty());
648   DCHECK(sig->active_outputs.empty());
649 
650   sig->func_name = func.signature().name();
651   sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
652   sig->active_outputs = GetActiveOutputs(func_node, ctx);
653 
654   TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
655                                                  &sig->type_parameters));
656   TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
657                                                  &sig->body_parameters));
658 
659   for (int i = 0; i < func_node.input_size(); ++i) {
660     const string& input = func_node.input(i);
661     if (IsControlInput(input)) break;
662     if (ctx.IsTrulyConst(input)) {
663       sig->const_inputs.emplace(i, input);
664     }
665   }
666 
667   return OkStatus();
668 }
669 
670 // Create a name for the function specialization. The name of the function, name
671 // of the node instantiating it, and a Grappler item id should generate unique
672 // function name. Meta optimizer might create multiple Grappler items for the
673 // same graph when optimizing functions, but it's guaranteed that they all will
674 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)675 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
676                                const FunctionDef& func,
677                                const NodeDef& func_node) {
678   return absl::Substitute(
679       "$0_specialized_for_$1_at_$2", func.signature().name(),
680       absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
681 }
682 
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)683 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
684                           FunctionOptimizerContext* ctx,
685                           GraphDef* optimized_graph) {
686   VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
687 
688   const AttrSlice func_instantiation_attr =
689       FunctionInstantiationAttributes(func, func_node);
690 
691   FunctionSpecializationSignature signature;
692   TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
693       func_node, func, func_instantiation_attr, *ctx, &signature));
694 
695   // Check if function was already specialized for identical context.
696   const FunctionSpecialization* already_specialized =
697       ctx->FindFunctionSpecialization(signature);
698 
699   if (already_specialized) {
700     VLOG(2) << "Function was already specialized in identical context: "
701                "specialized_name="
702             << already_specialized->specialized_func_name;
703 
704     // Add a function call node for the specialized function.
705     NodeDef* specialized_func_node = optimized_graph->add_node();
706     *specialized_func_node = func_node;
707 
708     TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
709         func, func_node, *already_specialized, specialized_func_node));
710 
711     ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
712 
713     return OkStatus();
714   }
715 
716   // Add a new specialized function definition to the library.
717   const auto& flib = ctx->function_library();
718 
719   // Make a GrapplerFunctionItem and convert it back to FunctionDef after
720   // pushing all constant inputs into the function body.
721   GrapplerFunctionItem item;
722   TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
723       func, func_instantiation_attr, flib, ctx->graph_version(), &item));
724 
725   // Push const inputs into the function body, and keep track of their control
726   // dependencies.
727   absl::flat_hash_set<string> const_inputs;
728   absl::flat_hash_set<string> control_deps;
729   TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
730                                          &control_deps));
731 
732   // Remove function outputs that do not have any consumers. We can't safely
733   // update outputs for the fetch nodes, so we just skip them.
734   std::vector<std::pair<int, int>> output_mapping;
735   if (!signature.is_in_fetch_set) {
736     int num_func_outputs = item.output_size();
737 
738     absl::flat_hash_set<int> remove;
739     for (int i = 0; i < num_func_outputs; ++i) {
740       if (!signature.active_outputs.count(i)) remove.insert(i);
741     }
742 
743     TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
744   }
745 
746   // TODO(ezhulenev): Push down known input shapes.
747   FunctionDef specialized_func;
748   TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
749 
750   // Find a name for specialized function.
751   const string specialized_func_name =
752       SpecializedFunctionName(*ctx, func, func_node);
753   if (flib.Contains(specialized_func_name)) {
754     // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
755     // a serious internal error, that must be investigated.
756     return errors::Internal("Created duplicate function specialization");
757   }
758 
759   specialized_func.mutable_signature()->set_name(specialized_func_name);
760   auto* specialized_attr = specialized_func.mutable_attr();
761   (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
762 
763   // Add specialized function to the library.
764   TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
765 
766   // Add a function call node for the specialized function.
767   NodeDef* specialized_func_node = optimized_graph->add_node();
768   *specialized_func_node = func_node;
769 
770   FunctionSpecialization func_specialization = {
771       specialized_func_name, signature.is_in_fetch_set, const_inputs,
772       control_deps,          signature.active_outputs,  output_mapping};
773 
774   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
775       func, func_node, func_specialization, specialized_func_node));
776 
777   ctx->AddSpecializedFunction(signature, func_specialization);
778   ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
779 
780   return OkStatus();
781 }
782 
783 // -------------------------------------------------------------------------- //
784 // Inline function calls into a graph using function inlining implementation
785 // from common_runtime:
786 //
787 // 1) Convert GraphDef to Graph.
788 // 2) Inline function calls.
789 // 3) Convert Graph back to the GraphDef.
790 
791 constexpr const char* const kLowerUsingSwitchMergeAttr =
792     LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
793 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
794     LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
795 
796 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
797 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
798 
799 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)800 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
801   bool match;
802   bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
803   return found && match;
804 }
805 
806 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)807 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
808   const string& value = GetNodeAttrString(n->attrs(), attr_name);
809   return !value.empty();
810 }
811 
LowerUsingSwitchMergeIsOn(const Node * n)812 bool LowerUsingSwitchMergeIsOn(const Node* n) {
813   return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
814 }
815 
LowerAsMultiDeviceFunctionIsOn(const Node * n)816 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
817   return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
818 }
819 
MarkedForXlaCompilation(const NodeDef & n)820 bool MarkedForXlaCompilation(const NodeDef& n) {
821   auto is_enabled = [&](std::string attr_name) -> bool {
822     auto it = n.attr().find(attr_name);
823     return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
824   };
825   return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
826          is_enabled(kXlaMustCompileAttr);
827 }
828 
IsExemptFromSideEffectsExecutionValidation(const string & op)829 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
830   static const auto* exemption = new absl::flat_hash_set<string>(
831       {// LINT.IfChange
832        // Op types that should not run in program order, e.g. because they need
833        // to run asynchronously to avoid deadlock.
834        "CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
835        "CollectiveBcastRecv", "CollectiveBcastSendV2", "CollectiveBcastRecvV2",
836        "NcclAllReduce", "Send", "Recv", "CollectiveAssignGroupsV2",
837        "CollectiveInitializeCommunicator",
838 
839        // Legacy random ops.
840        // See details in tensorflow/python/framework/auto_control_deps.py.
841        "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
842        "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
843        "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
844        "RandomPoissonV2",
845 
846        // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
847        // but it can't generate any observable side-effect.
848        "ReadVariableOp",
849 
850        // CudnnRNN ops are stateful but they can't generate any observable
851        // side-effect.
852        "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
853        "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
854 
855        // TPUEmbedding EnqueueOps are stateful but this is only between ops with
856        // the same device_ordinal on the same host.
857        "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
858        "EnqueueTPUEmbeddingSparseTensorBatch",
859        "EnqueueTPUEmbeddingRaggedTensorBatch",
860        "EnqueueTPUEmbeddingArbitraryTensorBatch"
861 
862        // SaveV2 and RestoreV2 should be allowed to operate in parallel on
863        // multiple hosts.
864        "SaveV2",
865        "RestoreV2"
866 
867        // InfeedEnqueue are stateful but should not be serialized for the
868        // input pipeline
869        "InfeedEnqueue",
870        "InfeedEnqueueTuple"});
871   // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
872   return exemption->contains(op);
873 }
874 
875 // Validates that all side effects inside function body will be executed after
876 // function inlining. We do it by looking for a path from stateful ops, to one
877 // of the output control sources.
878 //
879 // When function executed via FunctionLibraryRuntime we do not have to check
880 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)881 Status ValidateSideEffectsExecution(
882     const FunctionBody& fbody, OutputControlSource output_control_source,
883     bool has_outgoing_control_edges,
884     bool validate_outgoing_control_edge = true) {
885   // Find all nodes that can produce side effects in the function body graph. We
886   // use 'is_stateful()' bit as an approximation of "has side effects" property.
887   std::vector<const Node*> fbody_side_effects;
888   absl::c_copy_if(
889       fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
890       [](const Node* n) {
891         return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
892                !IsExemptFromSideEffectsExecutionValidation(n->type_string());
893       });
894 
895   // When graph executed in TF-2.0 context with automatic control dependencies
896   // tracking, absence of outgoing control edge indicates that no one is
897   // interested in observing side effects, so it is safe to inline the function
898   // body, even if some side-effects will not be executed.
899   if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
900     const string error_message =
901         "Can't guarantee execution of function side-effects after inlining. "
902         "Function call node has no outgoing control edges.";
903     if (validate_outgoing_control_edge) {
904       return errors::Internal(error_message);
905     } else {
906       VLOG(3) << error_message;
907     }
908   }
909 
910   // Find all nodes in the function body that will be used as control sources.
911   absl::flat_hash_set<const Node*> control_sources;
912   if (output_control_source == OutputControlSource::kDataOutputs) {
913     control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
914   } else if (output_control_source == OutputControlSource::kControlOutputs) {
915     control_sources = {fbody.control_ret_nodes.begin(),
916                        fbody.control_ret_nodes.end()};
917   }
918 
919   for (const Node* side_effect : fbody_side_effects) {
920     VLOG(4) << "Check that node " << side_effect->name()
921             << " will execute after inlining.";
922     bool will_execute = false;
923 
924     const auto is_control_source = [&](const Node* n) -> void {
925       const auto it = control_sources.find(n);
926       if (it != control_sources.end()) {
927         VLOG(4) << "Found a path to control source: " << side_effect->name()
928                 << " ---> " << (*it)->name();
929         will_execute = true;
930       }
931     };
932 
933     DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
934             /*leave=*/{}, NodeComparatorName{});
935 
936     if (!will_execute) {
937       return errors::Internal(
938           "Can't guarantee execution of a side-effectful node, that is not "
939           "reachable from function control source. Function body node: ",
940           SummarizeNode(*side_effect));
941     }
942   }
943 
944   return OkStatus();
945 }
946 
947 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)948 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
949                              const FunctionBody& fbody) {
950   absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
951                                                    fbody.ret_nodes.end()};
952 
953   // Find all nodes that can produce dead tensors.
954   std::vector<const Node*> dead_tensor_sources;
955   for (const Node* n : fbody.graph->nodes()) {
956     if (n->IsSwitch()) {
957       VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
958       dead_tensor_sources.push_back(n);
959       continue;
960     }
961 
962     // Native function call can also produce dead tensors if the function body
963     // has mergeless switches.
964     const FunctionDef* fdef = flib_def.Find(n->type_string());
965     if (fdef != nullptr) {
966       std::unique_ptr<FunctionBody> nested_fbody;
967 
968       NameAttrList func;
969       TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
970       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
971                                                  &flib_def, &nested_fbody));
972 
973       if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
974         VLOG(4) << "Add dead tensors source. Function call: " << func.name()
975                 << " node=" << n->name();
976         dead_tensor_sources.push_back(n);
977       }
978     }
979   }
980 
981   for (const Node* dead_tensor_source : dead_tensor_sources) {
982     bool has_dead_output = false;
983 
984     const auto is_output_node = [&](const Node* n) -> void {
985       const auto it = output_nodes.find(n);
986       if (it != output_nodes.end()) {
987         VLOG(4) << "Found a path to output node from dead tensor source: "
988                 << dead_tensor_source->name() << " ---> " << (*it)->name();
989         has_dead_output = true;
990       }
991     };
992 
993     // Stop DFS traversal at a Merge node or if already found a dead output.
994     const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
995       return !edge.src()->IsMerge() || has_dead_output;
996     };
997 
998     DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
999             /*leave=*/{}, NodeComparatorName{},
1000             /*edge_filter=*/stop_traversal);
1001 
1002     if (has_dead_output) {
1003       return errors::Internal(
1004           "Can't inline a function with dead outputs. Dead tensor source: ",
1005           SummarizeNode(*dead_tensor_source));
1006     }
1007   }
1008 
1009   return OkStatus();
1010 }
1011 
1012 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1013 Status MakeFunctionBodyForInlining(const Node& node,
1014                                    const FunctionLibraryDefinition& flib_def,
1015                                    std::unique_ptr<FunctionBody>* fbody) {
1016   VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1017 
1018   // Finds a FunctionDef in a library and verifies that it exists.
1019   const auto find_fdef = [&flib_def, &node](
1020                              const string& name,
1021                              const FunctionDef** fdef) -> Status {
1022     if ((*fdef = flib_def.Find(name)) == nullptr) {
1023       return errors::Internal(
1024           "Was not able to find a function definition (name=", name,
1025           ") for a function call: ", SummarizeNode(node));
1026     }
1027     return OkStatus();
1028   };
1029 
1030   // SymbolicGradient is a special "function call" op, which has been
1031   // deprecated for a while, but we still support for compatibility reasons.
1032   if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1033     NameAttrList func;
1034     TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1035 
1036     const string grad = flib_def.FindGradient(func.name());
1037 
1038     if (!grad.empty()) {
1039       // Function has a custom gradient registered in a library.
1040       const FunctionDef* grad_fdef;
1041       TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1042 
1043       VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1044               << " (function=" << func.name() << ")";
1045       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1046           *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1047 
1048     } else if (flib_def.Find(func.name()) == nullptr) {
1049       // Function is not really a function, but a primitive op.
1050       gradient::Creator creator;
1051       TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1052       if (creator == nullptr) {
1053         return errors::InvalidArgument("No gradient is defined for ",
1054                                        func.name());
1055       }
1056       FunctionDef grad_fdef;
1057       TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1058 
1059       VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1060               << func.name();
1061       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1062           grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1063 
1064     } else {
1065       // Build a gradient graph from the function body.
1066       const FunctionDef* fdef;
1067       TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1068 
1069       VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1070               << func.name();
1071       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1072                                                  &flib_def, fbody));
1073       *fbody = SymbolicGradient(**fbody);
1074     }
1075 
1076   } else {
1077     NameAttrList func;
1078     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1079     const FunctionDef* fdef;
1080     TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1081 
1082     VLOG(4) << "Instantiate a function call: function=" << func.name();
1083     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1084                                                &flib_def, fbody));
1085   }
1086 
1087   return OkStatus();
1088 }
1089 
1090 // Adds a control edges from each data input to the 'caller' to enforce strict
1091 // inputs semantics (all inputs are ready and alive). This is required when:
1092 //
1093 //  1) The function takes resources as inputs, and it doesn't have incoming
1094 //     control edges. In Tensorflow v2 context (eager mode) this should never
1095 //     happen, because automatic control dependencies tracking will add a
1096 //     control edge from the last op touching the resource. However such graphs
1097 //     might be produced by legacy v1 code without automatic dependency
1098 //     tracking. In this case strict function call semantics is required for
1099 //     enforcing side effects execution order.
1100 //
1101 //  2) One of the inputs is consuming Enter[is_constant=true] node, in which
1102 //     case it will be always alive, and potentially can lead to partial
1103 //     function execution after the last loop execution.
1104 //
1105 // Both of these cases would be considered illegal by construction in Tensorflow
1106 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1107 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1108 void AddStrictInputSemantics(Node* caller, Graph* g) {
1109   absl::flat_hash_set<const Node*> existing_control_sources;
1110   for (const Edge* edge : caller->in_edges()) {
1111     if (edge->IsControlEdge()) {
1112       existing_control_sources.insert(edge->src());
1113     }
1114   }
1115 
1116   const bool has_incoming_control_edges = !existing_control_sources.empty();
1117 
1118   const bool has_resource_input =
1119       absl::c_any_of(caller->input_types(),
1120                      [](const DataType dtype) { return dtype == DT_RESOURCE; });
1121 
1122   const bool has_constant_enter_input =
1123       absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1124         Node* src = edge->src();
1125         return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1126       });
1127 
1128   const bool requires_strict_semantics =
1129       (!has_incoming_control_edges && has_resource_input) ||  // Case #1
1130       (has_constant_enter_input);                             // Case #2
1131   if (!requires_strict_semantics) return;
1132 
1133   std::set<const Node*> data_inputs;
1134   for (const Edge* edge : caller->in_edges()) {
1135     if (!edge->IsControlEdge() &&
1136         !existing_control_sources.contains(edge->src())) {
1137       data_inputs.insert(edge->src());
1138     }
1139   }
1140 
1141   VLOG(3) << "Add control edges from all data inputs to enforce strict "
1142              "semantics with regard to function inputs";
1143 
1144   // Do not add control edges from placeholders, because it will prevent
1145   // pruning, and they can't produce any side effects anyway.
1146   const auto is_placeholder = [](const Node* node) -> bool {
1147     return node->type_string() == "Placeholder";
1148   };
1149 
1150   for (const Node* node : data_inputs) {
1151     if (is_placeholder(node)) continue;
1152     g->AddControlEdge(g->FindNodeId(node->id()), caller,
1153                       /*allow_duplicates=*/true);
1154   }
1155 }
1156 
1157 // Adds a control edge from a frame node if the 'caller' is executing inside a
1158 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1159 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1160                                    Node* caller, Graph* g) {
1161   // All nodes added to the graph by v2 control flow lowering and function
1162   // inlining are guaranteed to have control edges to nested function calls.
1163   int info_size = info.size();
1164   if (caller->id() >= info_size) return;
1165 
1166   // Check if a lowered node is executing inside a while loop.
1167   const Node* frame = info[caller->id()].frame;
1168   const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1169   if (!is_in_while_loop) return;
1170 
1171   // Check if a node already has an incoming control edge. All incoming edges
1172   // must be from the same execution frame (executor.cc invariant), so if we
1173   // already have an incoming control edge, it's guaranteed that it will "carry"
1174   // the same frame as all regular inputs.
1175   const bool has_incoming_control_edges =
1176       absl::c_any_of(caller->in_edges(),
1177                      [](const Edge* edge) { return edge->IsControlEdge(); });
1178   if (has_incoming_control_edges) return;
1179 
1180   VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1181           << " to=" << caller->name();
1182   Node* enter = g->FindNodeId(frame->id());
1183   bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1184   if (is_constant_enter) {
1185     // Enter[is_constant=true] is always alive. So we directly add a control
1186     // edge from that.
1187     g->AddControlEdge(enter, caller);
1188   } else {
1189     // Enter[is_constant=false] activates nodes only in 0th iteration so we
1190     // add an edge from the Merge node which is activated in every iteration.
1191     // A non-constant Enter node must have an edge to a Merge node.
1192     auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1193       return !e->IsControlEdge() && e->dst()->IsMerge();
1194     });
1195     if (it != enter->out_edges().end()) {
1196       g->AddControlEdge((*it)->dst(), caller);
1197     } else {
1198       LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1199                    << " does not have an outgoing edge to a Merge.";
1200     }
1201   }
1202 }
1203 
1204 // Inlines all function calls that are safe for inlining into the main graph.
1205 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1206 // ops (Switch/Merge/...).
1207 //
1208 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1209 Status InlineFunctionCalls(const GrapplerItem& item,
1210                            const RewriterConfig::Toggle opt_level,
1211                            const bool lower_control_flow,
1212                            GraphDef* output_graph) {
1213   bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1214   VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1215           << " (aggressive_mode=" << is_aggressive << ")";
1216 
1217   FunctionLibraryDefinition flib_def =
1218       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1219   std::unique_ptr<Graph> graph = std::make_unique<Graph>(flib_def);
1220 
1221   GraphConstructorOptions graph_constructor_options;
1222   graph_constructor_options.allow_internal_ops = true;
1223   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1224                                             item.graph, graph.get()));
1225 
1226   using NodeNames = absl::flat_hash_set<absl::string_view>;
1227   NodeNames fetch_nodes;
1228   fetch_nodes.reserve(item.fetch.size());
1229   for (const string& fetch : item.fetch) {
1230     fetch_nodes.insert(ParseTensorName(fetch).node());
1231   }
1232   NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1233 
1234   std::vector<string> inlined_function_names;
1235 
1236   // Do not inline function call nodes that are part of a feed set.
1237   NodeNames feed_nodes;
1238   feed_nodes.reserve(item.feed.size());
1239   for (const std::pair<std::string, Tensor>& feed : item.feed) {
1240     feed_nodes.insert(ParseTensorName(feed.first).node());
1241   }
1242 
1243   // If a function call is inside a While loop, it must have an incoming control
1244   // edge, because it will be used to pass execution frame into the function
1245   // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1246   // will be added an extra control edge from the 'input_control_node'.
1247   std::vector<ControlFlowInfo> control_flow_info;
1248   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1249 
1250   // Function inlining always adds new nodes to the end of the list, so we keep
1251   // iterating until we are out of nodes.
1252   for (int i = 2; i < graph->num_node_ids(); ++i) {
1253     Node* n = graph->FindNodeId(i);
1254     if (n == nullptr) continue;  // deleted node
1255 
1256     // Special case for lowering functional control flow ops. We do not rely on
1257     // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1258     // about what type of function calls we are allowed to inline.
1259     if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1260       VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1261       AddStrictInputSemantics(n, graph.get());
1262       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1263 
1264       if (n->IsIfNode()) {
1265         TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1266       } else if (n->IsCaseNode()) {
1267         TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1268       } else if (n->IsWhileNode()) {
1269         TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), &flib_def, false));
1270       }
1271       continue;
1272     }
1273 
1274     // Skip nodes that are not function calls.
1275     if (!IsFunctionCall(flib_def, *n)) continue;
1276     // Skip function calls that we plan to compile later.
1277     if (MarkedForXlaCompilation(n->def())) continue;
1278     // Skip nodes in a feed set.
1279     if (feed_nodes.contains(n->name())) continue;
1280 
1281     // Function body that we will inline into the main graph. It can be a
1282     // function instantiation, or a gradient function instantiated from
1283     // SymbolicGradient op.
1284     std::unique_ptr<FunctionBody> fbody;
1285     TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1286 
1287     InlineFunctionBodyOptions inline_options;
1288     // Ignore '_noinline' flag in aggressive mode.
1289     inline_options.ignore_noinline = is_aggressive;
1290 
1291     // Function calls created after inlining If/While ops are always inlined as
1292     // multi-device functions and are not required to pass additional Grappler
1293     // validations (side effects execution validation below).
1294     bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1295 
1296     // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1297     // functions:
1298     // a) Function can be multi-device.
1299     // b) Automatic control dependencies tracking guarantees that all function
1300     //    side-effectful nodes will have a path to one of the control outputs.
1301     //    Control outputs and control edges between side-effectful (stateful)
1302     //    nodes are used to explicitly mark the nodes that must execute, and to
1303     //    define their execution order.
1304     if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1305       inline_options.output_control_src = OutputControlSource::kControlOutputs;
1306       inline_options.inlined_function_body_placer =
1307           InlinedFunctionBodyPlacer::MultiDevice();
1308     } else {
1309       inline_options.output_control_src = OutputControlSource::kDataOutputs;
1310       inline_options.inlined_function_body_placer =
1311           InlinedFunctionBodyPlacer::SingleDevice();
1312     }
1313 
1314     if (fetch_nodes.contains(n->name())) {
1315       inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1316     } else if (keep_nodes.contains(n->name())) {
1317       inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1318     } else {
1319       inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1320     }
1321 
1322     // Basic validation rules defined in common_runtime shared by all functions.
1323     Status can_inline_function_call =
1324         ValidateInlining(n, fbody.get(), inline_options);
1325 
1326     // Additional validation rules defined only in Grappler.
1327     // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1328     if (can_inline_function_call.ok()) {
1329       bool has_outgoing_control_edges = absl::c_any_of(
1330           n->out_edges(),
1331           [](const Edge* edge) { return edge->IsControlEdge(); });
1332 
1333       can_inline_function_call = ValidateSideEffectsExecution(
1334           *fbody, inline_options.output_control_src,
1335           has_outgoing_control_edges);
1336 
1337       if (!can_inline_function_call.ok() &&
1338           (is_aggressive || force_inline_as_multi_device)) {
1339         VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1340         can_inline_function_call = OkStatus();
1341       }
1342     }
1343     if (can_inline_function_call.ok()) {
1344       can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1345     }
1346 
1347     if (can_inline_function_call.ok()) {
1348       VLOG(2) << "Inline function call node: " << n->name();
1349       AddStrictInputSemantics(n, graph.get());
1350       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1351 
1352       TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1353                                             fbody.get(), inline_options));
1354       inlined_function_names.push_back(fbody->fdef.signature().name());
1355 
1356     } else {
1357       VLOG(2) << "Failed to inline function call node: "
1358               << can_inline_function_call.error_message();
1359     }
1360   }
1361 
1362   VLOG(4) << "Inlined " << inlined_function_names.size()
1363           << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1364 
1365   // ------------------------------------------------------------------------ //
1366   // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1367   // passes, so each node has a valid device assignment. After function inlining
1368   // and control flow V2 lowering we have to keep graph placed.
1369 
1370   if (inlined_function_names.empty()) {
1371     VLOG(3) << "Not placing graph after function inlining"
1372             << " (did not inline any of the function calls).";
1373 
1374   } else if (item.devices().empty()) {
1375     // If there are no devices available for placer, we do not place graph after
1376     // function inlining. This happens when Grappler is optimizing the function
1377     // library, or when a graph optimized "offline", without an active runtime
1378     // session, for example as a part of batch job for graph
1379     // analysis/optimization. GrapplerItem instantiated from a function library
1380     // doesn't have to be fully placed after all optimizations; it will be
1381     // placed by the function library runtime before execution.
1382     VLOG(3) << "Not placing graph after function inlining"
1383             << " (device set is empty)";
1384 
1385   } else {
1386     // If we are running in an active runtime session, Grappler will get the
1387     // graph after initial placing is done, and we should have devices for the
1388     // placer.
1389     VLOG(3) << "Run placer for the graph after function inlining. "
1390             << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1391 
1392     DeviceSet device_set;                               // does not own devices
1393     std::vector<std::unique_ptr<Device>> fake_devices;  // owns fake devices
1394 
1395     for (const string& name : item.devices()) {
1396       auto device = std::make_unique<FakeDevice>(name);
1397       device_set.AddDevice(device.get());
1398       fake_devices.push_back(std::move(device));
1399     }
1400 
1401     Placer placer(graph.get(), item.id, &flib_def, &device_set);
1402     TF_RETURN_IF_ERROR(placer.Run());
1403   }
1404 
1405   graph->ToGraphDef(output_graph);
1406   return OkStatus();
1407 }
1408 
1409 // Restores tensor mapping after function specialization: all inputs must be
1410 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1411 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1412                           GraphDef* optimized_graph) {
1413   if (ctx.tensor_mapping().empty()) return;
1414 
1415   // During function specialization, we might prune unused function outputs. We
1416   // need to "close the holes" that might appear in the function outputs.
1417   //
1418   // Example: prune unused output "f:1"
1419   //
1420   //   f = my_func[T=float](...)          f = my_func_specialized[T=float](...)
1421   //   a = Identity(f:0)             ->   a = Identity(f:0)
1422   //   b = Identity(f:2)                  b = Identity(f:1)
1423   //
1424   // Tensor mapping (size=1): [f:2 -> f:1]
1425   for (NodeDef& node : *optimized_graph->mutable_node()) {
1426     for (int idx = 0; idx < node.input_size(); ++idx) {
1427       TensorId input_tensor = ParseTensorName(node.input(idx));
1428       if (input_tensor.index() == Graph::kControlSlot) break;
1429 
1430       auto mapping = ctx.tensor_mapping().find(input_tensor);
1431       if (mapping != ctx.tensor_mapping().end()) {
1432         node.set_input(idx, TensorIdToString(mapping->second));
1433       }
1434     }
1435   }
1436 }
1437 
1438 }  // namespace
1439 
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1440 Status FunctionOptimizer::RunFunctionOptimizerPass(
1441     const GrapplerItem& item, GraphDef* optimized_graph) const {
1442   VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1443 
1444   // Inline all function calls into a graph using common_runtime/function
1445   // implementation (see `InlineFunctionBody` function documentation).
1446   GraphDef graph_after_inlining;
1447   TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1448                                          &graph_after_inlining));
1449 
1450   // Specialize function calls that we could not inline.
1451   FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1452 
1453   for (const NodeDef& node : graph_after_inlining.node()) {
1454     // Function specialization can modify optimized graph only by adding new
1455     // nodes, we can check node size to make sure that graph was not modified.
1456     const int num_nodes_before = optimized_graph->node_size();
1457     const auto is_graph_modified = [&]() {
1458       int num_nodes = optimized_graph->node_size();
1459       DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1460       return num_nodes > num_nodes_before;
1461     };
1462 
1463     // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1464     const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1465 
1466     // Find if a node is a function call (direct or indirect).
1467     const FunctionDef* func = FindFunctionCall(ctx, node);
1468     if (func == nullptr) {
1469       copy_node();
1470       continue;
1471     }
1472 
1473     const string& func_name = func->signature().name();
1474 
1475     // Specialize it to its instantiation context if it has something worth
1476     // specializing.
1477     const bool specialization_worthy = IsParametrized(*func) ||
1478                                        HasTrulyConstInputs(node, ctx) ||
1479                                        HasUnusedOutputs(node, *func, ctx);
1480 
1481     // Do not specialize if function has custom gradient or marked nospecialize.
1482     const string grad_func = ctx.function_library().FindGradient(func_name);
1483     const bool no_specialize =
1484         !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1485         MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1486 
1487     if (specialization_worthy && !no_specialize) {
1488       // TODO(ezhulenev): Specialize function call if input has a known shape.
1489       // Specialize function body for its instantiation attributes and inputs.
1490       Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1491       if (!status.ok() && is_graph_modified()) {
1492         return status;
1493       } else if (!status.ok() && !is_graph_modified()) {
1494         VLOG(3) << "Skip specialization error: " << status.error_message();
1495         copy_node();
1496       }
1497       continue;
1498     } else {
1499       VLOG(2) << "Skip function specialization: " << func->signature().name();
1500       copy_node();
1501     }
1502   }
1503 
1504   RestoreTensorMapping(ctx, optimized_graph);
1505 
1506   // Preserve the graph version.
1507   *optimized_graph->mutable_versions() = item.graph.versions();
1508   // Prune unreachable function from the library.
1509   *optimized_graph->mutable_library() =
1510       PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1511 
1512   return OkStatus();
1513 }
1514 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1515 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1516                                    GraphDef* optimized_graph) {
1517   // Nothing to do here.
1518   if (item.graph.library().function_size() == 0) {
1519     return errors::Aborted("Nothing to do.");
1520   }
1521 
1522   TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1523 
1524   return OkStatus();
1525 }
1526 
1527 }  // end namespace grappler
1528 }  // end namespace tensorflow
1529