1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67 "_GrapplerSpecializedFunc";
68
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 // attribute, and special Tensorflow kernels are responsible for calling the
75 // function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79 return func_node.op() == func.signature().name();
80 }
81
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85 if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86 return false;
87 }
88
89 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90 return func_attr != nullptr && func_attr->has_func() &&
91 func_attr->func().name() == func.signature().name();
92 }
93
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95 const NodeDef& func_node) {
96 if (IsDirectFunctionCall(func, func_node)) {
97 return AttrSlice(func_node);
98
99 } else if (IsIndirectFunctionCall(func, func_node)) {
100 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101 return AttrSlice(&func_attr->func().attr());
102
103 } else {
104 LOG(WARNING) << "Can't resolve function instantiation attributes: "
105 << SummarizeNodeDef(func_node);
106 return AttrSlice();
107 }
108 }
109
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114 public:
FakeDevice(Env * env,const string & device)115 FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116 explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117 Status Sync() override { return OkStatus(); }
118
119 private:
attr(const string & device)120 static DeviceAttributes attr(const string& device) {
121 DeviceNameUtils::ParsedName parsed_name;
122 bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123 DCHECK(parsed) << "Failed to parse full device name: " << device;
124
125 DeviceAttributes attr;
126 attr.set_name(device);
127 attr.set_device_type(parsed_name.type);
128 return attr;
129 }
130 };
131
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149 const auto attr = AttrSlice(&fdef.attr());
150 bool nospecialize = false;
151 return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157 // Currently we do not support functions with tensor lists as inputs or
158 // outputs, so caller node input/output ports always match function
159 // input/output arguments.
160 using InputPort = int;
161 using OutputPort = int;
162
163 string func_name;
164 bool is_in_fetch_set;
165 absl::flat_hash_set<OutputPort> active_outputs;
166 absl::flat_hash_map<string, DataType> type_parameters;
167 absl::flat_hash_map<string, AttrValue> body_parameters;
168 absl::flat_hash_map<InputPort, string> const_inputs;
169
operator ==tensorflow::grappler::__anon7d4353b00111::FunctionSpecializationSignature170 bool operator==(const FunctionSpecializationSignature& other) const {
171 bool equals = func_name == other.func_name &&
172 is_in_fetch_set == other.is_in_fetch_set &&
173 active_outputs == other.active_outputs &&
174 type_parameters == other.type_parameters &&
175 const_inputs == other.const_inputs;
176
177 if (!equals) return false;
178
179 // Equality is not defined for AttrValue.
180 if (body_parameters.size() != other.body_parameters.size()) return false;
181
182 for (const auto& lhs : body_parameters) {
183 auto it = other.body_parameters.find(lhs.first);
184 if (it == other.body_parameters.end()) return false;
185 if (!AreAttrValuesEqual(lhs.second, (*it).second,
186 /*allow_false_negatives=*/true)) {
187 return false;
188 }
189 }
190
191 return true;
192 }
193
194 template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)195 friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
196 H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
197
198 // First pre-compute hashes for all values in collections with
199 // non-deterministic iteration order.
200 std::vector<uint64> hashes;
201 hashes.reserve(s.active_outputs.size() //
202 + s.type_parameters.size() * 2 //
203 + s.body_parameters.size() * 2 //
204 + s.const_inputs.size() * 2);
205
206 absl::c_transform(s.active_outputs, std::back_inserter(hashes),
207 hash<OutputPort>());
208
209 using TypeParam = std::pair<const string, DataType>;
210 absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
211 AttrValue attr_value;
212 attr_value.set_type(type_param.second);
213 hashes.push_back(Hash64(type_param.first));
214 hashes.push_back(AttrValueHash(attr_value));
215 });
216
217 using BodyParam = std::pair<const string, AttrValue>;
218 absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
219 hashes.push_back(Hash64(body_param.first));
220 hashes.push_back(FastAttrValueHash(body_param.second));
221 });
222
223 using ConstInput = std::pair<const InputPort, string>;
224 absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
225 hashes.push_back(hash<InputPort>()(const_input.first));
226 hashes.push_back(Hash64(const_input.second));
227 });
228
229 // Combine all pre-computed hashes in a deterministic order.
230 absl::c_sort(hashes);
231 return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
232 }
233 };
234
235 struct FunctionSpecialization {
236 string specialized_func_name;
237 // True if the function caller node is in GrapplerItem fetch set.
238 bool is_in_fetch_set;
239 // Names of the tensors that were pushed down into the function body.
240 absl::flat_hash_set<string> const_inputs;
241 // Control dependencies of pushed down const inputs have to be attached to
242 // function caller node.
243 absl::flat_hash_set<string> control_deps;
244 // Output tensors (ports) that consumed by other nodes in the graph or in a
245 // GrapplerItem fetch set.
246 absl::flat_hash_set<int> active_outputs;
247 // Mapping from original function output port to the output port of
248 // specialized function. If function specialization changes the number of
249 // function outputs it's required to update all node consumers.
250 std::vector<std::pair<int, int>> output_mapping;
251 };
252
253 // Function optimizer context initialized once for each optimization pass, and
254 // it uses the latest available graph (for the first iteration it will be the
255 // GrapplerItem.graph, for next iterations it will be the output of previous
256 // function optimizer pass).
257 class FunctionOptimizerContext {
258 public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)259 explicit FunctionOptimizerContext(const GrapplerItem& item,
260 RewriterConfig::Toggle opt_level,
261 const GraphDef& graph)
262 : item_(&item),
263 opt_level_(opt_level),
264 function_library_(OpRegistry::Global(), graph.library()),
265 truly_const_nodes_(InferTrulyConstNodes(item, graph)),
266 graph_view_(&graph) {}
267
item() const268 const GrapplerItem& item() const { return *item_; }
269
graph_version() const270 const int graph_version() const { return item_->graph.versions().producer(); }
271
opt_level() const272 RewriterConfig::Toggle opt_level() const { return opt_level_; }
273
function_library() const274 const FunctionLibraryDefinition& function_library() const {
275 return function_library_;
276 }
function_library()277 FunctionLibraryDefinition& function_library() { return function_library_; }
278
279 const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const280 tensor_mapping() const {
281 return tensor_mapping_;
282 }
283
graph_view() const284 const GraphView& graph_view() const { return graph_view_; }
285
IsFeedNode(const string & node_name) const286 bool IsFeedNode(const string& node_name) const {
287 return absl::c_any_of(
288 item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
289 return ParseTensorName(feed.first).node() == node_name;
290 });
291 }
292
IsFetchNode(const string & node_name) const293 bool IsFetchNode(const string& node_name) const {
294 return absl::c_any_of(item_->fetch, [&](const string& fetch) {
295 return ParseTensorName(fetch).node() == node_name;
296 });
297 }
298
IsTrulyConst(const string & name) const299 bool IsTrulyConst(const string& name) const {
300 return TrulyConstNode(name) != nullptr;
301 }
302
TrulyConstNode(const string & name) const303 const NodeDef* TrulyConstNode(const string& name) const {
304 return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
305 }
306
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const307 const FunctionSpecialization* FindFunctionSpecialization(
308 const FunctionSpecializationSignature& sig) const {
309 return gtl::FindOrNull(specialized_functions_, sig);
310 }
311
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)312 void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
313 const FunctionSpecialization& specialized_func) {
314 specialized_functions_.emplace(sig, specialized_func);
315 }
316
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)317 void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
318 DCHECK(from.index() != Graph::kControlSlot)
319 << "Tensor mapping must be from regular tensor";
320 DCHECK(to.index() != Graph::kControlSlot)
321 << "Tensor mapping must be to regular tensor";
322
323 auto inserted = tensor_mapping_.insert({from, to});
324 DCHECK(inserted.second)
325 << "Failed to insert duplicated tensor mapping: "
326 << "from=" << from.ToString() << " to=" << to.ToString();
327 }
328
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)329 void AddTensorMapping(const string& func_node,
330 const FunctionSpecialization& specialized_func) {
331 for (const auto& pair : specialized_func.output_mapping) {
332 int from_idx = pair.first;
333 int to_idx = pair.second;
334 if (from_idx != to_idx) {
335 SafeTensorId from_tensor(func_node, from_idx);
336 SafeTensorId to_tensor(func_node, to_idx);
337 AddTensorMapping(from_tensor, to_tensor);
338 }
339 }
340 }
341
342 private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)343 static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
344 const GrapplerItem& item, const GraphDef& graph) {
345 absl::flat_hash_set<absl::string_view> feed_nodes;
346 for (const auto& feed : item.feed) {
347 feed_nodes.insert(feed.first);
348 }
349
350 absl::flat_hash_map<string, const NodeDef*> const_nodes;
351 for (const NodeDef& node : graph.node()) {
352 if (IsConstant(node) && !feed_nodes.contains(node.name())) {
353 const_nodes[node.name()] = &node;
354 }
355 }
356
357 return const_nodes;
358 }
359
360 const GrapplerItem* item_; // must outlive this object
361 RewriterConfig::Toggle opt_level_;
362
363 // Function library constructed from current graph.
364 FunctionLibraryDefinition function_library_;
365
366 // Nodes that are Const and not in feed.
367 absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
368 // Specialized functions.
369 absl::flat_hash_map<FunctionSpecializationSignature,
370 const FunctionSpecialization>
371 specialized_functions_;
372
373 // After function specialization, the optimized graph might be in invalid
374 // state, nodes can read from output index that is no longer valid after
375 // unused outputs pruning.
376 //
377 // Tensor mapping that has to be applied to the graph after all functions
378 // optimizations (invalidated tensor id -> optimized graph tensor id).
379 absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
380 tensor_mapping_;
381
382 // Use graph view to find active outputs of the function caller nodes.
383 GraphView graph_view_;
384
385 TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
386 };
387
388 // Returns a pointer to the called function definition iff the given node is
389 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)390 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
391 const NodeDef& node) {
392 // Check if a node does indirect function call via PartitionedCallOp.
393 if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
394 const AttrValue* func_attr = AttrSlice(node).Find("f");
395 return (func_attr != nullptr && func_attr->has_func())
396 ? ctx.function_library().Find(func_attr->func().name())
397 : nullptr;
398 }
399
400 // Check if the function op itself is a function name.
401 return ctx.function_library().Find(node.op());
402 }
403
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)404 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
405 const FunctionOptimizerContext& ctx,
406 int size_hint = 0) {
407 absl::flat_hash_set<int> active_outputs;
408 active_outputs.reserve(static_cast<size_t>(size_hint));
409
410 // 1. Output can be consumed by the other graph node.
411 const auto node_fanout_edges =
412 ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
413 for (const GraphView::Edge& edge : node_fanout_edges) {
414 active_outputs.insert(edge.src.port_id);
415 }
416
417 // 2. Or it can be in a fetch set.
418 for (const string& fetch : ctx.item().fetch) {
419 TensorId fetch_tensor = ParseTensorName(fetch);
420 if (fetch_tensor.node() == node.name()) {
421 active_outputs.insert(fetch_tensor.index());
422 }
423 }
424
425 return active_outputs;
426 }
427
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)428 bool HasTrulyConstInputs(const NodeDef& node,
429 const FunctionOptimizerContext& ctx) {
430 const auto is_truly_const = [&ctx](const string& input) {
431 return ctx.IsTrulyConst(NodeName(input));
432 };
433 return absl::c_any_of(node.input(), is_truly_const);
434 }
435
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)436 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
437 const FunctionOptimizerContext& ctx) {
438 // Functions with tensor list outputs are not supported right now, so the
439 // number of output args is the same as number of possible function caller
440 // node outputs.
441 int num_outputs = func.signature().output_arg_size();
442 const absl::flat_hash_set<int> active_outputs =
443 GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
444 int active_outputs_size = active_outputs.size();
445 return active_outputs_size != num_outputs;
446 }
447
448 // Return pruned FunctionDefLibrary with functions that are reachable from
449 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)450 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
451 const GraphDef& optimized_graph) {
452 FunctionLibraryDefinition pruned_flib =
453 flib.ReachableDefinitions(optimized_graph);
454
455 int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
456 static_cast<int>(flib.num_functions());
457
458 VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
459 << " functions (" << pruned_functions << ")";
460
461 return pruned_flib.ToProto();
462 }
463
464 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)465 Status PushDownConstInputs(const NodeDef& func_node,
466 const FunctionOptimizerContext& ctx,
467 GrapplerFunctionItem* item,
468 absl::flat_hash_set<string>* const_inputs,
469 absl::flat_hash_set<string>* control_deps) {
470 // Record node control dependencies in the control_deps set.
471 const auto record_control_deps = [&](const NodeDef* const_input) {
472 for (int i = const_input->input_size() - 1; i >= 0; --i) {
473 const string& input = const_input->input(i);
474 if (IsControlInput(input))
475 control_deps->insert(input);
476 else
477 break;
478 }
479 };
480
481 for (int i = func_node.input_size() - 1; i >= 0; --i) {
482 const string& input = func_node.input(i);
483 if (IsControlInput(input)) continue;
484
485 const string node_name = NodeName(input);
486 if (ctx.IsTrulyConst(node_name)) {
487 VLOG(3) << "Push const into function body: input=" << input;
488 const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
489 const_inputs->insert(input);
490 record_control_deps(const_input);
491 TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
492 }
493 }
494
495 return OkStatus();
496 }
497
498 // Remove inputs that were pushed into the function body, and attach their
499 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)500 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
501 NodeDef* specialized_func_node) {
502 // Nothing to do if it was no const inputs to the function node.
503 if (specialization.const_inputs.empty()) return;
504
505 // Keep only non-const inputs.
506 std::vector<string> keep_inputs;
507 const auto& inputs = specialized_func_node->input();
508 absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
509 [&](const string& input) {
510 return !specialization.const_inputs.contains(input);
511 });
512
513 specialized_func_node->clear_input();
514 for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
515
516 // Attach control dependencies of pushed down const input to the caller node.
517 if (!specialization.control_deps.empty()) {
518 absl::flat_hash_set<string> existing_control_deps;
519
520 for (const string& input : keep_inputs) {
521 existing_control_deps.insert(AsControlDependency(NodeName(input)));
522 }
523
524 for (const string& ctrl : specialization.control_deps) {
525 if (!existing_control_deps.contains(ctrl)) {
526 VLOG(3) << "Forward control dependency: input=" << ctrl;
527 specialized_func_node->add_input(ctrl);
528 }
529 }
530 }
531 }
532
533 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)534 void RemovePushedDownConstInputTypes(
535 const FunctionSpecialization& specialization, const NodeDef& func_node,
536 NodeDef* specialized_func_node) {
537 // Nothing to do if it was no const inputs to the function node.
538 if (specialization.const_inputs.empty()) return;
539
540 // Make sure that original function caller has Tin attribute.
541 const AttrValue* tin = AttrSlice(func_node).Find("Tin");
542 if (tin == nullptr || !tin->has_list()) return;
543
544 // Clear input types for the specialized node.
545 auto* attr = specialized_func_node->mutable_attr();
546 (*attr)["Tin"].mutable_list()->clear_type();
547
548 // Keep types of non-const inputs.
549 for (int i = 0; i < func_node.input_size(); ++i) {
550 const string& input = func_node.input(i);
551 if (IsControlInput(input)) break;
552
553 if (!specialization.const_inputs.contains(input)) {
554 DataType dt = tin->list().type(i);
555 (*attr)["Tin"].mutable_list()->add_type(dt);
556 }
557 }
558 }
559
560 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)561 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
562 const NodeDef& func_node,
563 NodeDef* specialized_func_node) {
564 // Make sure that original function caller has Tout attribute.
565 const AttrValue* tout = AttrSlice(func_node).Find("Tout");
566 if (tout == nullptr || !tout->has_list()) return;
567
568 // Nothing to do if all outputs are active.
569 int specialization_active_outputs_size = specialization.active_outputs.size();
570 if (specialization_active_outputs_size == tout->list().type_size()) return;
571
572 // Clear input types for the specialized node.
573 auto* attr = specialized_func_node->mutable_attr();
574 (*attr)["Tout"].mutable_list()->clear_type();
575
576 // Keep output types of active outputs only.
577 for (int i = 0; i < tout->list().type_size(); ++i) {
578 if (specialization.active_outputs.contains(i)) {
579 DataType dt = tout->list().type(i);
580 (*attr)["Tout"].mutable_list()->add_type(dt);
581 }
582 }
583 }
584
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)585 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
586 const NodeDef& func_node,
587 const string& specialized_func_name,
588 NodeDef* specialized_func_node) {
589 if (IsDirectFunctionCall(func, func_node)) {
590 specialized_func_node->set_op(specialized_func_name);
591
592 } else if (IsIndirectFunctionCall(func, func_node)) {
593 auto* attr = specialized_func_node->mutable_attr();
594 (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
595
596 } else {
597 return errors::InvalidArgument("Unknown function call site");
598 }
599
600 return OkStatus();
601 }
602
603 // Update a graph node created from the original function caller node, to the
604 // function specialization. Function specialization might change the number of
605 // inputs and outputs, so we have to make sure that graph node is updated
606 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)607 Status UpdateSpecializedFunctionNode(
608 const FunctionDef& func, const NodeDef& func_node,
609 const FunctionSpecialization& specialization,
610 NodeDef* specialized_func_node) {
611 // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
612 bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
613
614 // 1. Call the specialized function instead of original one.
615 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
616 func, func_node, specialization.specialized_func_name,
617 specialized_func_node));
618
619 // 2. Remove inputs corresponding to the pushed down consts.
620 RemovePushedDownConstInputs(specialization, specialized_func_node);
621
622 // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
623 // types, that must be in sync with updated function signature.
624
625 // 3. Update input types for the indirect function calls.
626 if (is_indirect_call) {
627 RemovePushedDownConstInputTypes(specialization, func_node,
628 specialized_func_node);
629 }
630
631 // 4. Update output types for the indirect function call. It's unsafe to
632 // change the number of outputs for the fetch nodes, so we just skip them.
633 if (is_indirect_call && !specialization.is_in_fetch_set) {
634 RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
635 }
636
637 // 5. Remove custom gradient annotation.
638 specialized_func_node->mutable_attr()->erase("_gradient_op_type");
639
640 return OkStatus();
641 }
642
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)643 Status InitializeFunctionSpecializationSignature(
644 const NodeDef& func_node, const FunctionDef& func,
645 const AttrSlice& func_instantiation_attr,
646 const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
647 DCHECK(sig->const_inputs.empty());
648 DCHECK(sig->active_outputs.empty());
649
650 sig->func_name = func.signature().name();
651 sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
652 sig->active_outputs = GetActiveOutputs(func_node, ctx);
653
654 TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
655 &sig->type_parameters));
656 TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
657 &sig->body_parameters));
658
659 for (int i = 0; i < func_node.input_size(); ++i) {
660 const string& input = func_node.input(i);
661 if (IsControlInput(input)) break;
662 if (ctx.IsTrulyConst(input)) {
663 sig->const_inputs.emplace(i, input);
664 }
665 }
666
667 return OkStatus();
668 }
669
670 // Create a name for the function specialization. The name of the function, name
671 // of the node instantiating it, and a Grappler item id should generate unique
672 // function name. Meta optimizer might create multiple Grappler items for the
673 // same graph when optimizing functions, but it's guaranteed that they all will
674 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)675 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
676 const FunctionDef& func,
677 const NodeDef& func_node) {
678 return absl::Substitute(
679 "$0_specialized_for_$1_at_$2", func.signature().name(),
680 absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
681 }
682
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)683 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
684 FunctionOptimizerContext* ctx,
685 GraphDef* optimized_graph) {
686 VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
687
688 const AttrSlice func_instantiation_attr =
689 FunctionInstantiationAttributes(func, func_node);
690
691 FunctionSpecializationSignature signature;
692 TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
693 func_node, func, func_instantiation_attr, *ctx, &signature));
694
695 // Check if function was already specialized for identical context.
696 const FunctionSpecialization* already_specialized =
697 ctx->FindFunctionSpecialization(signature);
698
699 if (already_specialized) {
700 VLOG(2) << "Function was already specialized in identical context: "
701 "specialized_name="
702 << already_specialized->specialized_func_name;
703
704 // Add a function call node for the specialized function.
705 NodeDef* specialized_func_node = optimized_graph->add_node();
706 *specialized_func_node = func_node;
707
708 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
709 func, func_node, *already_specialized, specialized_func_node));
710
711 ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
712
713 return OkStatus();
714 }
715
716 // Add a new specialized function definition to the library.
717 const auto& flib = ctx->function_library();
718
719 // Make a GrapplerFunctionItem and convert it back to FunctionDef after
720 // pushing all constant inputs into the function body.
721 GrapplerFunctionItem item;
722 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
723 func, func_instantiation_attr, flib, ctx->graph_version(), &item));
724
725 // Push const inputs into the function body, and keep track of their control
726 // dependencies.
727 absl::flat_hash_set<string> const_inputs;
728 absl::flat_hash_set<string> control_deps;
729 TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
730 &control_deps));
731
732 // Remove function outputs that do not have any consumers. We can't safely
733 // update outputs for the fetch nodes, so we just skip them.
734 std::vector<std::pair<int, int>> output_mapping;
735 if (!signature.is_in_fetch_set) {
736 int num_func_outputs = item.output_size();
737
738 absl::flat_hash_set<int> remove;
739 for (int i = 0; i < num_func_outputs; ++i) {
740 if (!signature.active_outputs.count(i)) remove.insert(i);
741 }
742
743 TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
744 }
745
746 // TODO(ezhulenev): Push down known input shapes.
747 FunctionDef specialized_func;
748 TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
749
750 // Find a name for specialized function.
751 const string specialized_func_name =
752 SpecializedFunctionName(*ctx, func, func_node);
753 if (flib.Contains(specialized_func_name)) {
754 // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
755 // a serious internal error, that must be investigated.
756 return errors::Internal("Created duplicate function specialization");
757 }
758
759 specialized_func.mutable_signature()->set_name(specialized_func_name);
760 auto* specialized_attr = specialized_func.mutable_attr();
761 (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
762
763 // Add specialized function to the library.
764 TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
765
766 // Add a function call node for the specialized function.
767 NodeDef* specialized_func_node = optimized_graph->add_node();
768 *specialized_func_node = func_node;
769
770 FunctionSpecialization func_specialization = {
771 specialized_func_name, signature.is_in_fetch_set, const_inputs,
772 control_deps, signature.active_outputs, output_mapping};
773
774 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
775 func, func_node, func_specialization, specialized_func_node));
776
777 ctx->AddSpecializedFunction(signature, func_specialization);
778 ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
779
780 return OkStatus();
781 }
782
783 // -------------------------------------------------------------------------- //
784 // Inline function calls into a graph using function inlining implementation
785 // from common_runtime:
786 //
787 // 1) Convert GraphDef to Graph.
788 // 2) Inline function calls.
789 // 3) Convert Graph back to the GraphDef.
790
791 constexpr const char* const kLowerUsingSwitchMergeAttr =
792 LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
793 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
794 LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
795
796 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
797 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
798
799 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)800 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
801 bool match;
802 bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
803 return found && match;
804 }
805
806 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)807 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
808 const string& value = GetNodeAttrString(n->attrs(), attr_name);
809 return !value.empty();
810 }
811
LowerUsingSwitchMergeIsOn(const Node * n)812 bool LowerUsingSwitchMergeIsOn(const Node* n) {
813 return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
814 }
815
LowerAsMultiDeviceFunctionIsOn(const Node * n)816 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
817 return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
818 }
819
MarkedForXlaCompilation(const NodeDef & n)820 bool MarkedForXlaCompilation(const NodeDef& n) {
821 auto is_enabled = [&](std::string attr_name) -> bool {
822 auto it = n.attr().find(attr_name);
823 return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
824 };
825 return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
826 is_enabled(kXlaMustCompileAttr);
827 }
828
IsExemptFromSideEffectsExecutionValidation(const string & op)829 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
830 static const auto* exemption = new absl::flat_hash_set<string>(
831 {// LINT.IfChange
832 // Op types that should not run in program order, e.g. because they need
833 // to run asynchronously to avoid deadlock.
834 "CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
835 "CollectiveBcastRecv", "CollectiveBcastSendV2", "CollectiveBcastRecvV2",
836 "NcclAllReduce", "Send", "Recv", "CollectiveAssignGroupsV2",
837 "CollectiveInitializeCommunicator",
838
839 // Legacy random ops.
840 // See details in tensorflow/python/framework/auto_control_deps.py.
841 "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
842 "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
843 "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
844 "RandomPoissonV2",
845
846 // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
847 // but it can't generate any observable side-effect.
848 "ReadVariableOp",
849
850 // CudnnRNN ops are stateful but they can't generate any observable
851 // side-effect.
852 "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
853 "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
854
855 // TPUEmbedding EnqueueOps are stateful but this is only between ops with
856 // the same device_ordinal on the same host.
857 "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
858 "EnqueueTPUEmbeddingSparseTensorBatch",
859 "EnqueueTPUEmbeddingRaggedTensorBatch",
860 "EnqueueTPUEmbeddingArbitraryTensorBatch"
861
862 // SaveV2 and RestoreV2 should be allowed to operate in parallel on
863 // multiple hosts.
864 "SaveV2",
865 "RestoreV2"
866
867 // InfeedEnqueue are stateful but should not be serialized for the
868 // input pipeline
869 "InfeedEnqueue",
870 "InfeedEnqueueTuple"});
871 // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
872 return exemption->contains(op);
873 }
874
875 // Validates that all side effects inside function body will be executed after
876 // function inlining. We do it by looking for a path from stateful ops, to one
877 // of the output control sources.
878 //
879 // When function executed via FunctionLibraryRuntime we do not have to check
880 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)881 Status ValidateSideEffectsExecution(
882 const FunctionBody& fbody, OutputControlSource output_control_source,
883 bool has_outgoing_control_edges,
884 bool validate_outgoing_control_edge = true) {
885 // Find all nodes that can produce side effects in the function body graph. We
886 // use 'is_stateful()' bit as an approximation of "has side effects" property.
887 std::vector<const Node*> fbody_side_effects;
888 absl::c_copy_if(
889 fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
890 [](const Node* n) {
891 return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
892 !IsExemptFromSideEffectsExecutionValidation(n->type_string());
893 });
894
895 // When graph executed in TF-2.0 context with automatic control dependencies
896 // tracking, absence of outgoing control edge indicates that no one is
897 // interested in observing side effects, so it is safe to inline the function
898 // body, even if some side-effects will not be executed.
899 if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
900 const string error_message =
901 "Can't guarantee execution of function side-effects after inlining. "
902 "Function call node has no outgoing control edges.";
903 if (validate_outgoing_control_edge) {
904 return errors::Internal(error_message);
905 } else {
906 VLOG(3) << error_message;
907 }
908 }
909
910 // Find all nodes in the function body that will be used as control sources.
911 absl::flat_hash_set<const Node*> control_sources;
912 if (output_control_source == OutputControlSource::kDataOutputs) {
913 control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
914 } else if (output_control_source == OutputControlSource::kControlOutputs) {
915 control_sources = {fbody.control_ret_nodes.begin(),
916 fbody.control_ret_nodes.end()};
917 }
918
919 for (const Node* side_effect : fbody_side_effects) {
920 VLOG(4) << "Check that node " << side_effect->name()
921 << " will execute after inlining.";
922 bool will_execute = false;
923
924 const auto is_control_source = [&](const Node* n) -> void {
925 const auto it = control_sources.find(n);
926 if (it != control_sources.end()) {
927 VLOG(4) << "Found a path to control source: " << side_effect->name()
928 << " ---> " << (*it)->name();
929 will_execute = true;
930 }
931 };
932
933 DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
934 /*leave=*/{}, NodeComparatorName{});
935
936 if (!will_execute) {
937 return errors::Internal(
938 "Can't guarantee execution of a side-effectful node, that is not "
939 "reachable from function control source. Function body node: ",
940 SummarizeNode(*side_effect));
941 }
942 }
943
944 return OkStatus();
945 }
946
947 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)948 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
949 const FunctionBody& fbody) {
950 absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
951 fbody.ret_nodes.end()};
952
953 // Find all nodes that can produce dead tensors.
954 std::vector<const Node*> dead_tensor_sources;
955 for (const Node* n : fbody.graph->nodes()) {
956 if (n->IsSwitch()) {
957 VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
958 dead_tensor_sources.push_back(n);
959 continue;
960 }
961
962 // Native function call can also produce dead tensors if the function body
963 // has mergeless switches.
964 const FunctionDef* fdef = flib_def.Find(n->type_string());
965 if (fdef != nullptr) {
966 std::unique_ptr<FunctionBody> nested_fbody;
967
968 NameAttrList func;
969 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
970 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
971 &flib_def, &nested_fbody));
972
973 if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
974 VLOG(4) << "Add dead tensors source. Function call: " << func.name()
975 << " node=" << n->name();
976 dead_tensor_sources.push_back(n);
977 }
978 }
979 }
980
981 for (const Node* dead_tensor_source : dead_tensor_sources) {
982 bool has_dead_output = false;
983
984 const auto is_output_node = [&](const Node* n) -> void {
985 const auto it = output_nodes.find(n);
986 if (it != output_nodes.end()) {
987 VLOG(4) << "Found a path to output node from dead tensor source: "
988 << dead_tensor_source->name() << " ---> " << (*it)->name();
989 has_dead_output = true;
990 }
991 };
992
993 // Stop DFS traversal at a Merge node or if already found a dead output.
994 const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
995 return !edge.src()->IsMerge() || has_dead_output;
996 };
997
998 DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
999 /*leave=*/{}, NodeComparatorName{},
1000 /*edge_filter=*/stop_traversal);
1001
1002 if (has_dead_output) {
1003 return errors::Internal(
1004 "Can't inline a function with dead outputs. Dead tensor source: ",
1005 SummarizeNode(*dead_tensor_source));
1006 }
1007 }
1008
1009 return OkStatus();
1010 }
1011
1012 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1013 Status MakeFunctionBodyForInlining(const Node& node,
1014 const FunctionLibraryDefinition& flib_def,
1015 std::unique_ptr<FunctionBody>* fbody) {
1016 VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1017
1018 // Finds a FunctionDef in a library and verifies that it exists.
1019 const auto find_fdef = [&flib_def, &node](
1020 const string& name,
1021 const FunctionDef** fdef) -> Status {
1022 if ((*fdef = flib_def.Find(name)) == nullptr) {
1023 return errors::Internal(
1024 "Was not able to find a function definition (name=", name,
1025 ") for a function call: ", SummarizeNode(node));
1026 }
1027 return OkStatus();
1028 };
1029
1030 // SymbolicGradient is a special "function call" op, which has been
1031 // deprecated for a while, but we still support for compatibility reasons.
1032 if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1033 NameAttrList func;
1034 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1035
1036 const string grad = flib_def.FindGradient(func.name());
1037
1038 if (!grad.empty()) {
1039 // Function has a custom gradient registered in a library.
1040 const FunctionDef* grad_fdef;
1041 TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1042
1043 VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1044 << " (function=" << func.name() << ")";
1045 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1046 *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1047
1048 } else if (flib_def.Find(func.name()) == nullptr) {
1049 // Function is not really a function, but a primitive op.
1050 gradient::Creator creator;
1051 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1052 if (creator == nullptr) {
1053 return errors::InvalidArgument("No gradient is defined for ",
1054 func.name());
1055 }
1056 FunctionDef grad_fdef;
1057 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1058
1059 VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1060 << func.name();
1061 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1062 grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1063
1064 } else {
1065 // Build a gradient graph from the function body.
1066 const FunctionDef* fdef;
1067 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1068
1069 VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1070 << func.name();
1071 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1072 &flib_def, fbody));
1073 *fbody = SymbolicGradient(**fbody);
1074 }
1075
1076 } else {
1077 NameAttrList func;
1078 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1079 const FunctionDef* fdef;
1080 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1081
1082 VLOG(4) << "Instantiate a function call: function=" << func.name();
1083 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1084 &flib_def, fbody));
1085 }
1086
1087 return OkStatus();
1088 }
1089
1090 // Adds a control edges from each data input to the 'caller' to enforce strict
1091 // inputs semantics (all inputs are ready and alive). This is required when:
1092 //
1093 // 1) The function takes resources as inputs, and it doesn't have incoming
1094 // control edges. In Tensorflow v2 context (eager mode) this should never
1095 // happen, because automatic control dependencies tracking will add a
1096 // control edge from the last op touching the resource. However such graphs
1097 // might be produced by legacy v1 code without automatic dependency
1098 // tracking. In this case strict function call semantics is required for
1099 // enforcing side effects execution order.
1100 //
1101 // 2) One of the inputs is consuming Enter[is_constant=true] node, in which
1102 // case it will be always alive, and potentially can lead to partial
1103 // function execution after the last loop execution.
1104 //
1105 // Both of these cases would be considered illegal by construction in Tensorflow
1106 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1107 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1108 void AddStrictInputSemantics(Node* caller, Graph* g) {
1109 absl::flat_hash_set<const Node*> existing_control_sources;
1110 for (const Edge* edge : caller->in_edges()) {
1111 if (edge->IsControlEdge()) {
1112 existing_control_sources.insert(edge->src());
1113 }
1114 }
1115
1116 const bool has_incoming_control_edges = !existing_control_sources.empty();
1117
1118 const bool has_resource_input =
1119 absl::c_any_of(caller->input_types(),
1120 [](const DataType dtype) { return dtype == DT_RESOURCE; });
1121
1122 const bool has_constant_enter_input =
1123 absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1124 Node* src = edge->src();
1125 return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1126 });
1127
1128 const bool requires_strict_semantics =
1129 (!has_incoming_control_edges && has_resource_input) || // Case #1
1130 (has_constant_enter_input); // Case #2
1131 if (!requires_strict_semantics) return;
1132
1133 std::set<const Node*> data_inputs;
1134 for (const Edge* edge : caller->in_edges()) {
1135 if (!edge->IsControlEdge() &&
1136 !existing_control_sources.contains(edge->src())) {
1137 data_inputs.insert(edge->src());
1138 }
1139 }
1140
1141 VLOG(3) << "Add control edges from all data inputs to enforce strict "
1142 "semantics with regard to function inputs";
1143
1144 // Do not add control edges from placeholders, because it will prevent
1145 // pruning, and they can't produce any side effects anyway.
1146 const auto is_placeholder = [](const Node* node) -> bool {
1147 return node->type_string() == "Placeholder";
1148 };
1149
1150 for (const Node* node : data_inputs) {
1151 if (is_placeholder(node)) continue;
1152 g->AddControlEdge(g->FindNodeId(node->id()), caller,
1153 /*allow_duplicates=*/true);
1154 }
1155 }
1156
1157 // Adds a control edge from a frame node if the 'caller' is executing inside a
1158 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1159 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1160 Node* caller, Graph* g) {
1161 // All nodes added to the graph by v2 control flow lowering and function
1162 // inlining are guaranteed to have control edges to nested function calls.
1163 int info_size = info.size();
1164 if (caller->id() >= info_size) return;
1165
1166 // Check if a lowered node is executing inside a while loop.
1167 const Node* frame = info[caller->id()].frame;
1168 const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1169 if (!is_in_while_loop) return;
1170
1171 // Check if a node already has an incoming control edge. All incoming edges
1172 // must be from the same execution frame (executor.cc invariant), so if we
1173 // already have an incoming control edge, it's guaranteed that it will "carry"
1174 // the same frame as all regular inputs.
1175 const bool has_incoming_control_edges =
1176 absl::c_any_of(caller->in_edges(),
1177 [](const Edge* edge) { return edge->IsControlEdge(); });
1178 if (has_incoming_control_edges) return;
1179
1180 VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1181 << " to=" << caller->name();
1182 Node* enter = g->FindNodeId(frame->id());
1183 bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1184 if (is_constant_enter) {
1185 // Enter[is_constant=true] is always alive. So we directly add a control
1186 // edge from that.
1187 g->AddControlEdge(enter, caller);
1188 } else {
1189 // Enter[is_constant=false] activates nodes only in 0th iteration so we
1190 // add an edge from the Merge node which is activated in every iteration.
1191 // A non-constant Enter node must have an edge to a Merge node.
1192 auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1193 return !e->IsControlEdge() && e->dst()->IsMerge();
1194 });
1195 if (it != enter->out_edges().end()) {
1196 g->AddControlEdge((*it)->dst(), caller);
1197 } else {
1198 LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1199 << " does not have an outgoing edge to a Merge.";
1200 }
1201 }
1202 }
1203
1204 // Inlines all function calls that are safe for inlining into the main graph.
1205 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1206 // ops (Switch/Merge/...).
1207 //
1208 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1209 Status InlineFunctionCalls(const GrapplerItem& item,
1210 const RewriterConfig::Toggle opt_level,
1211 const bool lower_control_flow,
1212 GraphDef* output_graph) {
1213 bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1214 VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1215 << " (aggressive_mode=" << is_aggressive << ")";
1216
1217 FunctionLibraryDefinition flib_def =
1218 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1219 std::unique_ptr<Graph> graph = std::make_unique<Graph>(flib_def);
1220
1221 GraphConstructorOptions graph_constructor_options;
1222 graph_constructor_options.allow_internal_ops = true;
1223 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1224 item.graph, graph.get()));
1225
1226 using NodeNames = absl::flat_hash_set<absl::string_view>;
1227 NodeNames fetch_nodes;
1228 fetch_nodes.reserve(item.fetch.size());
1229 for (const string& fetch : item.fetch) {
1230 fetch_nodes.insert(ParseTensorName(fetch).node());
1231 }
1232 NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1233
1234 std::vector<string> inlined_function_names;
1235
1236 // Do not inline function call nodes that are part of a feed set.
1237 NodeNames feed_nodes;
1238 feed_nodes.reserve(item.feed.size());
1239 for (const std::pair<std::string, Tensor>& feed : item.feed) {
1240 feed_nodes.insert(ParseTensorName(feed.first).node());
1241 }
1242
1243 // If a function call is inside a While loop, it must have an incoming control
1244 // edge, because it will be used to pass execution frame into the function
1245 // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1246 // will be added an extra control edge from the 'input_control_node'.
1247 std::vector<ControlFlowInfo> control_flow_info;
1248 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1249
1250 // Function inlining always adds new nodes to the end of the list, so we keep
1251 // iterating until we are out of nodes.
1252 for (int i = 2; i < graph->num_node_ids(); ++i) {
1253 Node* n = graph->FindNodeId(i);
1254 if (n == nullptr) continue; // deleted node
1255
1256 // Special case for lowering functional control flow ops. We do not rely on
1257 // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1258 // about what type of function calls we are allowed to inline.
1259 if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1260 VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1261 AddStrictInputSemantics(n, graph.get());
1262 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1263
1264 if (n->IsIfNode()) {
1265 TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1266 } else if (n->IsCaseNode()) {
1267 TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1268 } else if (n->IsWhileNode()) {
1269 TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), &flib_def, false));
1270 }
1271 continue;
1272 }
1273
1274 // Skip nodes that are not function calls.
1275 if (!IsFunctionCall(flib_def, *n)) continue;
1276 // Skip function calls that we plan to compile later.
1277 if (MarkedForXlaCompilation(n->def())) continue;
1278 // Skip nodes in a feed set.
1279 if (feed_nodes.contains(n->name())) continue;
1280
1281 // Function body that we will inline into the main graph. It can be a
1282 // function instantiation, or a gradient function instantiated from
1283 // SymbolicGradient op.
1284 std::unique_ptr<FunctionBody> fbody;
1285 TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1286
1287 InlineFunctionBodyOptions inline_options;
1288 // Ignore '_noinline' flag in aggressive mode.
1289 inline_options.ignore_noinline = is_aggressive;
1290
1291 // Function calls created after inlining If/While ops are always inlined as
1292 // multi-device functions and are not required to pass additional Grappler
1293 // validations (side effects execution validation below).
1294 bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1295
1296 // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1297 // functions:
1298 // a) Function can be multi-device.
1299 // b) Automatic control dependencies tracking guarantees that all function
1300 // side-effectful nodes will have a path to one of the control outputs.
1301 // Control outputs and control edges between side-effectful (stateful)
1302 // nodes are used to explicitly mark the nodes that must execute, and to
1303 // define their execution order.
1304 if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1305 inline_options.output_control_src = OutputControlSource::kControlOutputs;
1306 inline_options.inlined_function_body_placer =
1307 InlinedFunctionBodyPlacer::MultiDevice();
1308 } else {
1309 inline_options.output_control_src = OutputControlSource::kDataOutputs;
1310 inline_options.inlined_function_body_placer =
1311 InlinedFunctionBodyPlacer::SingleDevice();
1312 }
1313
1314 if (fetch_nodes.contains(n->name())) {
1315 inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1316 } else if (keep_nodes.contains(n->name())) {
1317 inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1318 } else {
1319 inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1320 }
1321
1322 // Basic validation rules defined in common_runtime shared by all functions.
1323 Status can_inline_function_call =
1324 ValidateInlining(n, fbody.get(), inline_options);
1325
1326 // Additional validation rules defined only in Grappler.
1327 // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1328 if (can_inline_function_call.ok()) {
1329 bool has_outgoing_control_edges = absl::c_any_of(
1330 n->out_edges(),
1331 [](const Edge* edge) { return edge->IsControlEdge(); });
1332
1333 can_inline_function_call = ValidateSideEffectsExecution(
1334 *fbody, inline_options.output_control_src,
1335 has_outgoing_control_edges);
1336
1337 if (!can_inline_function_call.ok() &&
1338 (is_aggressive || force_inline_as_multi_device)) {
1339 VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1340 can_inline_function_call = OkStatus();
1341 }
1342 }
1343 if (can_inline_function_call.ok()) {
1344 can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1345 }
1346
1347 if (can_inline_function_call.ok()) {
1348 VLOG(2) << "Inline function call node: " << n->name();
1349 AddStrictInputSemantics(n, graph.get());
1350 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1351
1352 TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1353 fbody.get(), inline_options));
1354 inlined_function_names.push_back(fbody->fdef.signature().name());
1355
1356 } else {
1357 VLOG(2) << "Failed to inline function call node: "
1358 << can_inline_function_call.error_message();
1359 }
1360 }
1361
1362 VLOG(4) << "Inlined " << inlined_function_names.size()
1363 << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1364
1365 // ------------------------------------------------------------------------ //
1366 // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1367 // passes, so each node has a valid device assignment. After function inlining
1368 // and control flow V2 lowering we have to keep graph placed.
1369
1370 if (inlined_function_names.empty()) {
1371 VLOG(3) << "Not placing graph after function inlining"
1372 << " (did not inline any of the function calls).";
1373
1374 } else if (item.devices().empty()) {
1375 // If there are no devices available for placer, we do not place graph after
1376 // function inlining. This happens when Grappler is optimizing the function
1377 // library, or when a graph optimized "offline", without an active runtime
1378 // session, for example as a part of batch job for graph
1379 // analysis/optimization. GrapplerItem instantiated from a function library
1380 // doesn't have to be fully placed after all optimizations; it will be
1381 // placed by the function library runtime before execution.
1382 VLOG(3) << "Not placing graph after function inlining"
1383 << " (device set is empty)";
1384
1385 } else {
1386 // If we are running in an active runtime session, Grappler will get the
1387 // graph after initial placing is done, and we should have devices for the
1388 // placer.
1389 VLOG(3) << "Run placer for the graph after function inlining. "
1390 << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1391
1392 DeviceSet device_set; // does not own devices
1393 std::vector<std::unique_ptr<Device>> fake_devices; // owns fake devices
1394
1395 for (const string& name : item.devices()) {
1396 auto device = std::make_unique<FakeDevice>(name);
1397 device_set.AddDevice(device.get());
1398 fake_devices.push_back(std::move(device));
1399 }
1400
1401 Placer placer(graph.get(), item.id, &flib_def, &device_set);
1402 TF_RETURN_IF_ERROR(placer.Run());
1403 }
1404
1405 graph->ToGraphDef(output_graph);
1406 return OkStatus();
1407 }
1408
1409 // Restores tensor mapping after function specialization: all inputs must be
1410 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1411 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1412 GraphDef* optimized_graph) {
1413 if (ctx.tensor_mapping().empty()) return;
1414
1415 // During function specialization, we might prune unused function outputs. We
1416 // need to "close the holes" that might appear in the function outputs.
1417 //
1418 // Example: prune unused output "f:1"
1419 //
1420 // f = my_func[T=float](...) f = my_func_specialized[T=float](...)
1421 // a = Identity(f:0) -> a = Identity(f:0)
1422 // b = Identity(f:2) b = Identity(f:1)
1423 //
1424 // Tensor mapping (size=1): [f:2 -> f:1]
1425 for (NodeDef& node : *optimized_graph->mutable_node()) {
1426 for (int idx = 0; idx < node.input_size(); ++idx) {
1427 TensorId input_tensor = ParseTensorName(node.input(idx));
1428 if (input_tensor.index() == Graph::kControlSlot) break;
1429
1430 auto mapping = ctx.tensor_mapping().find(input_tensor);
1431 if (mapping != ctx.tensor_mapping().end()) {
1432 node.set_input(idx, TensorIdToString(mapping->second));
1433 }
1434 }
1435 }
1436 }
1437
1438 } // namespace
1439
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1440 Status FunctionOptimizer::RunFunctionOptimizerPass(
1441 const GrapplerItem& item, GraphDef* optimized_graph) const {
1442 VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1443
1444 // Inline all function calls into a graph using common_runtime/function
1445 // implementation (see `InlineFunctionBody` function documentation).
1446 GraphDef graph_after_inlining;
1447 TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1448 &graph_after_inlining));
1449
1450 // Specialize function calls that we could not inline.
1451 FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1452
1453 for (const NodeDef& node : graph_after_inlining.node()) {
1454 // Function specialization can modify optimized graph only by adding new
1455 // nodes, we can check node size to make sure that graph was not modified.
1456 const int num_nodes_before = optimized_graph->node_size();
1457 const auto is_graph_modified = [&]() {
1458 int num_nodes = optimized_graph->node_size();
1459 DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1460 return num_nodes > num_nodes_before;
1461 };
1462
1463 // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1464 const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1465
1466 // Find if a node is a function call (direct or indirect).
1467 const FunctionDef* func = FindFunctionCall(ctx, node);
1468 if (func == nullptr) {
1469 copy_node();
1470 continue;
1471 }
1472
1473 const string& func_name = func->signature().name();
1474
1475 // Specialize it to its instantiation context if it has something worth
1476 // specializing.
1477 const bool specialization_worthy = IsParametrized(*func) ||
1478 HasTrulyConstInputs(node, ctx) ||
1479 HasUnusedOutputs(node, *func, ctx);
1480
1481 // Do not specialize if function has custom gradient or marked nospecialize.
1482 const string grad_func = ctx.function_library().FindGradient(func_name);
1483 const bool no_specialize =
1484 !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1485 MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1486
1487 if (specialization_worthy && !no_specialize) {
1488 // TODO(ezhulenev): Specialize function call if input has a known shape.
1489 // Specialize function body for its instantiation attributes and inputs.
1490 Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1491 if (!status.ok() && is_graph_modified()) {
1492 return status;
1493 } else if (!status.ok() && !is_graph_modified()) {
1494 VLOG(3) << "Skip specialization error: " << status.error_message();
1495 copy_node();
1496 }
1497 continue;
1498 } else {
1499 VLOG(2) << "Skip function specialization: " << func->signature().name();
1500 copy_node();
1501 }
1502 }
1503
1504 RestoreTensorMapping(ctx, optimized_graph);
1505
1506 // Preserve the graph version.
1507 *optimized_graph->mutable_versions() = item.graph.versions();
1508 // Prune unreachable function from the library.
1509 *optimized_graph->mutable_library() =
1510 PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1511
1512 return OkStatus();
1513 }
1514
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1515 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1516 GraphDef* optimized_graph) {
1517 // Nothing to do here.
1518 if (item.graph.library().function_size() == 0) {
1519 return errors::Aborted("Nothing to do.");
1520 }
1521
1522 TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1523
1524 return OkStatus();
1525 }
1526
1527 } // end namespace grappler
1528 } // end namespace tensorflow
1529