1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/control_flow.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32
33 // ALGORITHM OVERVIEW
34 // ==================
35 //
36 // We map every output produced by each node in the TensorFlow graph (including
37 // control dependence) into an instance of the Predicate class. Instances of
38 // Predicate denote logical formulas and mapping a node `n` to a predicate
39 // `pred` implies that `n` is live whenever `pred` is true. Then we can deduce
40 // mismatching liveness in the inputs to node by comparing the predicate those
41 // inputs are mapped to. The core logic of this pass resides in creating the
42 // map from TensorFlow nodes to predicates.
43 //
44 //
45 // MAPPING NODES TO PREDICATES, MODULO CYCLES
46 // ------------------------------------------
47 //
48 // If we ignore cycles for a moment, computing predicates is fairly
49 // straightforward. We traverse the graph in a topological order, mapping each
50 // node to a predicate based on the predicates its inputs are mapped to. For
51 // instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
52 // PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
53 // the "liveness" domain, where values in the domain represent if a tensor
54 // carries a dead signal or not.
55 //
56 //
57 // DEALING WITH CYCLES
58 // -------------------
59 //
60 // We map Merge nodes that are the target of a backedge to AndRecurrence
61 // instances. An AndRecurrence with start() = S and step() = X, printed as
62 // {S,&,X}, *roughly* represents the infinite list of predicates
63 // [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate
64 // for Merge in a graph like:
65 //
66 // Init
67 // |
68 // v
69 // Merge <-----------+
70 // | |
71 // v |
72 // Incr |
73 // | |
74 // v |
75 // Switch <- Cond |
76 // | |
77 // v (oidx: 1) |
78 // | |
79 // +---------------+
80 //
81 // Where S is the predicate for Init and X is the predicate that asserts that
82 // Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff
83 // S is true, live on the second iteration iff "S&X" is true, live on the third
84 // iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would
85 // normally be equivalent to S&X which isn't quite what we want to represent.
86 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
87 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
88 // true on iteration 0, 1, 2 respectively. This is made more precise in the
89 // comment on the AndRecurrence class.
90 //
91 // The general algorithm that deals with cycles does two topological-order
92 // iterations over the graph. On the first iteration it assigns a symbolic
93 // predicate to merge nodes with backedges. On the second iteration it tries
94 // to pattern match the predicates for the backedges of these merges and infer
95 // an AndRecurrence for the merge. In other words, we do a data flow analysis
96 // where the data-flow lattice has two elements, Symbolic and NonSymbolic with
97 // Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
98 // sufficient to converge.
99 //
100 // We first do an optimistic analysis and, if it does not converge, we then fall
101 // back to a pessimistic analysis. The optimistic analysis assigns the same
102 // symbolic predicate to all the merge nodes whose preceding enter nodes have
103 // the same frame name on the first iteration. On the second iteration, if all
104 // the merge nodes are pattern matched into the same AndRecurrence predicate
105 // instance, the optimistic assignment of the same symbolic predicate is correct
106 // and the analyzed result is taken.
107 //
108 // Otherwise, if the optimistic analysis fails to converge, we then obtain the
109 // result by falling back to the pessimistic analysis which assigns a unique
110 // symbolic predicate to each merge on the first iteration. We still use
111 // symbolic predicates for merges for which we can't pattern match on the
112 // backedge predicate. This is conservatively correct.
113
114 namespace tensorflow {
115
116 namespace {
117
118 using se::port::StatusOr;
119
120 // Represents a logical predicate, used as described in the algorithm overview
121 // above.
122 class Predicate {
123 public:
124 enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol };
125
126 virtual string ToString() const = 0;
127
128 // An ID assigned to the Predicate at construction time. Conceptually like a
129 // pointer, except that it is stable across runs.
id() const130 int64_t id() const { return id_; }
131
132 virtual absl::Span<Predicate* const> GetOperands() const = 0;
133
134 virtual Kind kind() const = 0;
~Predicate()135 virtual ~Predicate() {}
136
137 // Invokes func on p and on all of its operands recursively. Does not invoke
138 // `func` on the same Predicate instance twice. Aborts the search if `func`
139 // returns true.
140 template <typename FunctionTy>
141 static void Visit(Predicate* p, const FunctionTy& func);
142
143 protected:
Predicate(int64_t id)144 explicit Predicate(int64_t id) : id_(id) {}
145
146 private:
147 const int64_t id_;
148
149 TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
150 };
151
152 // Represents a logical conjunction of a set of predicates.
153 class AndPredicate : public Predicate {
154 public:
AndPredicate(int64_t id,std::vector<Predicate * > operands)155 explicit AndPredicate(int64_t id, std::vector<Predicate*> operands)
156 : Predicate(id), operands_(std::move(operands)) {}
157
ToString() const158 string ToString() const override {
159 if (operands().empty()) {
160 return "#true";
161 }
162
163 std::vector<string> operands_str;
164 std::transform(operands().begin(), operands().end(),
165 std::back_inserter(operands_str),
166 [](Predicate* pred) { return pred->ToString(); });
167
168 return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
169 }
170
kind() const171 Kind kind() const override { return Kind::kAnd; }
172
GetOperands() const173 absl::Span<Predicate* const> GetOperands() const override {
174 return operands_;
175 }
operands() const176 absl::Span<Predicate* const> operands() const { return operands_; }
177
178 private:
179 std::vector<Predicate*> operands_;
180 };
181
182 // Represents a logical disjunction of a set of predicates.
183 class OrPredicate : public Predicate {
184 public:
OrPredicate(int64_t id,std::vector<Predicate * > operands)185 explicit OrPredicate(int64_t id, std::vector<Predicate*> operands)
186 : Predicate(id), operands_(std::move(operands)) {}
187
ToString() const188 string ToString() const override {
189 if (operands().empty()) {
190 return "#false";
191 }
192
193 std::vector<string> operands_str;
194 std::transform(operands().begin(), operands().end(),
195 std::back_inserter(operands_str),
196 [](Predicate* pred) { return pred->ToString(); });
197
198 return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
199 }
200
kind() const201 Kind kind() const override { return Kind::kOr; }
GetOperands() const202 absl::Span<Predicate* const> GetOperands() const override {
203 return operands_;
204 }
operands() const205 absl::Span<Predicate* const> operands() const { return operands_; }
206
207 private:
208 std::vector<Predicate*> operands_;
209 };
210
211 // Represents a logical negation of a set of predicates.
212 class NotPredicate : public Predicate {
213 public:
NotPredicate(int64_t id,Predicate * operand)214 explicit NotPredicate(int64_t id, Predicate* operand)
215 : Predicate(id), operands_({operand}) {}
216
ToString() const217 string ToString() const override {
218 return absl::StrCat("~", operand()->ToString());
219 }
220
kind() const221 Kind kind() const override { return Kind::kNot; }
operand() const222 Predicate* operand() const { return operands_[0]; }
GetOperands() const223 absl::Span<Predicate* const> GetOperands() const override {
224 return operands_;
225 }
226
227 private:
228 std::array<Predicate*, 1> operands_;
229 };
230
231 // Represents the liveness of an induction variable. For users inside the loop
232 // this represents the "current" liveness of the induction variable. For users
233 // outside the loop it represents the "last" liveness of the induction variable.
234 //
235 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
236 // in the following graph:
237 //
238 // V = Merge(S', V_NextIt)
239 // V = Op(V, X')
240 // V_NextIt = NextIteration(V)
241 //
242 // where Predicate(S') = S and Predicate(X') = X.
243 //
244 // `X` may contain symbolic predicates and the operations corresponding to these
245 // symbolic predicates are either in frame `loop` or outside it. The symbols
246 // that are inside frame `loop` are loop variant (i.e. can have different
247 // liveness in each loop iteration) and the symbols that are outside frame
248 // `loop` are loop invariant (i.e. have the same liveness across all
249 // iterations).
250 class AndRecurrencePredicate : public Predicate {
251 public:
AndRecurrencePredicate(int64_t id,Predicate * start,Predicate * step,std::vector<string> frame)252 explicit AndRecurrencePredicate(int64_t id, Predicate* start, Predicate* step,
253 std::vector<string> frame)
254 : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
255
start() const256 Predicate* start() const { return operands_[0]; }
step() const257 Predicate* step() const { return operands_[1]; }
frame() const258 absl::Span<const string> frame() const { return frame_; }
259
ToString() const260 string ToString() const override {
261 return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
262 "}<", absl::StrJoin(frame(), ";"), ">");
263 }
264
kind() const265 Kind kind() const override { return Kind::kAndRecurrence; }
266
GetOperands() const267 absl::Span<Predicate* const> GetOperands() const override {
268 return operands_;
269 }
270
271 private:
272 std::array<Predicate*, 2> operands_;
273 std::vector<string> frame_;
274 };
275
276 // Represents an uninterpreted symbol in a logical predicate.
277 //
278 // Two predicates are equivalent iff they are equivalent for all assignments to
279 // the symbols contained in them, i.e. predicates are forall qualified over
280 // symbols.
281 class SymbolPredicate : public Predicate {
282 public:
SymbolPredicate(int64_t id,TensorId tensor_id,bool must_be_true)283 explicit SymbolPredicate(int64_t id, TensorId tensor_id, bool must_be_true)
284 : Predicate(id),
285 tensor_id_(std::move(tensor_id)),
286 must_be_true_(must_be_true) {}
287
ToString() const288 string ToString() const override {
289 return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
290 : tensor_id_.ToString();
291 }
292
kind() const293 Kind kind() const override { return Kind::kSymbol; }
GetOperands() const294 absl::Span<Predicate* const> GetOperands() const override { return {}; }
295
296 // If `must_be_true()` is true this SymbolPredicate represents the proposition
297 // "tensor_id() is live and evaluates to true".
298 //
299 // If `must_be_true()` is false then this SymbolPredicate represents the
300 // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const301 TensorId tensor_id() const { return tensor_id_; }
must_be_true() const302 bool must_be_true() const { return must_be_true_; }
303
304 private:
305 TensorId tensor_id_;
306 bool must_be_true_;
307 };
308
309 // Represents an uninterpreted symbol in a logical predicate.
310 //
311 // Two predicates are equivalent iff they are equivalent for all assignments to
312 // the symbols contained in them, i.e. predicates are forall qualified over
313 // symbols.
314 class IntSymbolPredicate : public Predicate {
315 public:
IntSymbolPredicate(int64_t id,TensorId tensor_id,std::optional<int> must_have_value)316 explicit IntSymbolPredicate(int64_t id, TensorId tensor_id,
317 std::optional<int> must_have_value)
318 : Predicate(id),
319 tensor_id_(std::move(tensor_id)),
320 must_have_value_(must_have_value) {}
321
ToString() const322 string ToString() const override {
323 return must_have_value().has_value()
324 ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_)
325 : tensor_id_.ToString();
326 }
327
kind() const328 Kind kind() const override { return Kind::kIntSymbol; }
GetOperands() const329 absl::Span<Predicate* const> GetOperands() const override { return {}; }
330
331 // If `must_have_value().has_value()` is true, then this IntSymbolPredicate
332 // represents the proposition "tensor_id() is live and evaluates to
333 // `*must_have_value()`".
334 //
335 // If `must_have_value().has_value()` is false, then this IntSymbolPredicate
336 // represents the proposition "tensor_id() is live (and may evaluate to any
337 // value)".
tensor_id() const338 TensorId tensor_id() const { return tensor_id_; }
must_have_value() const339 const std::optional<int>& must_have_value() const { return must_have_value_; }
340
341 private:
342 TensorId tensor_id_;
343 std::optional<int> must_have_value_;
344 };
345
346 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)347 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
348 absl::flat_hash_set<Predicate*> visited;
349 std::vector<Predicate*> stack;
350
351 stack.push_back(p);
352 visited.insert(p);
353
354 while (!stack.empty()) {
355 Predicate* current = stack.back();
356 stack.pop_back();
357 bool done = func(current);
358 if (done) {
359 return;
360 }
361 for (Predicate* op : current->GetOperands()) {
362 if (visited.insert(op).second) {
363 stack.push_back(op);
364 }
365 }
366 }
367 }
368
369 // Creates and owns Predicate instances. Simplifies predicates as it creates
370 // them.
371 class PredicateFactory {
372 public:
MakeAndPredicate(absl::Span<Predicate * const> operands)373 Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
374 return MakeAndOrImpl(operands, /*is_and=*/true);
375 }
376
MakeOrPredicate(absl::Span<Predicate * const> operands)377 Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
378 return MakeAndOrImpl(operands, /*is_and=*/false);
379 }
380
MakeNotPredicate(Predicate * pred)381 Predicate* MakeNotPredicate(Predicate* pred) {
382 auto it = make_not_predicate_cache_.find(pred);
383 if (it != make_not_predicate_cache_.end()) {
384 return it->second;
385 }
386
387 Predicate* result = MakeNotPredicateImpl(pred);
388
389 bool insert_successful =
390 make_not_predicate_cache_.insert({pred, result}).second;
391 (void)insert_successful;
392 DCHECK(insert_successful);
393
394 return result;
395 }
396
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)397 Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
398 std::vector<string> frame) {
399 SignatureForAndRec signature(start, step, std::move(frame));
400 auto it = interned_and_rec_instances_.find(signature);
401 if (it != interned_and_rec_instances_.end()) {
402 return it->second.get();
403 }
404
405 std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
406 std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
407 Predicate* new_pred_ptr = new_pred.get();
408 bool inserted =
409 interned_and_rec_instances_.emplace(signature, std::move(new_pred))
410 .second;
411 (void)inserted;
412 DCHECK(inserted);
413 return new_pred_ptr;
414 }
415
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)416 Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
417 Predicate** predicate) {
418 TensorId tensor_id(node->name(), output_idx);
419
420 bool is_boolean_tensor =
421 BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
422 TF_RET_CHECK(!must_be_true || is_boolean_tensor);
423
424 if (node->type_string() == "Const" && must_be_true) {
425 const TensorProto* proto = nullptr;
426 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
427
428 Tensor tensor(proto->dtype());
429 TF_RET_CHECK(tensor.FromProto(*proto));
430
431 *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
432 return OkStatus();
433 }
434
435 SignatureForSymbol signature = {tensor_id, must_be_true};
436 auto it = interned_symbol_instances_.find(signature);
437 if (it == interned_symbol_instances_.end()) {
438 std::unique_ptr<Predicate> new_pred =
439 Make<SymbolPredicate>(tensor_id, must_be_true);
440 Predicate* new_pred_ptr = new_pred.get();
441 interned_symbol_instances_.emplace(std::move(signature),
442 std::move(new_pred));
443 *predicate = new_pred_ptr;
444 } else {
445 *predicate = it->second.get();
446 }
447
448 return OkStatus();
449 }
450
MakeSymbolPredicate(Node * node,int output_idx,std::optional<int> must_have_value,Predicate ** predicate)451 Status MakeSymbolPredicate(Node* node, int output_idx,
452 std::optional<int> must_have_value,
453 Predicate** predicate) {
454 TensorId tensor_id(node->name(), output_idx);
455
456 TF_RET_CHECK(BaseType(node->output_type(tensor_id.index())) == DT_INT32);
457
458 if (must_have_value.has_value() && node->type_string() == "Const") {
459 const TensorProto* proto = nullptr;
460 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
461
462 Tensor tensor(proto->dtype());
463 TF_RET_CHECK(tensor.FromProto(*proto));
464
465 *predicate = tensor.scalar<int32>()() == *must_have_value ? MakeTrue()
466 : MakeFalse();
467 return OkStatus();
468 }
469 SignatureForIntSymbol signature = {tensor_id, must_have_value};
470 auto it = interned_int_symbol_instances_.find(signature);
471 if (it == interned_int_symbol_instances_.end()) {
472 std::unique_ptr<Predicate> new_pred =
473 Make<IntSymbolPredicate>(tensor_id, must_have_value);
474 Predicate* new_pred_ptr = new_pred.get();
475 interned_int_symbol_instances_.emplace(std::move(signature),
476 std::move(new_pred));
477 *predicate = new_pred_ptr;
478 } else {
479 *predicate = it->second.get();
480 }
481
482 return OkStatus();
483 }
484
MakeTrue()485 Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()486 Predicate* MakeFalse() { return MakeOrPredicate({}); }
487
~PredicateFactory()488 ~PredicateFactory() {
489 DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
490 }
491
492 private:
MakeNotPredicateImpl(Predicate * pred)493 Predicate* MakeNotPredicateImpl(Predicate* pred) {
494 IncrementStackDepth stack_frame(this);
495 if (!stack_frame.HasOverflowed()) {
496 if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
497 return simplified;
498 }
499
500 // ~~A => A
501 if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
502 return not_pred->operand();
503 }
504 }
505
506 SignatureForNot signature = pred;
507 auto it = interned_not_instances_.find(signature);
508 if (it == interned_not_instances_.end()) {
509 std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
510 Predicate* new_pred_ptr = new_pred.get();
511 interned_not_instances_.emplace(signature, std::move(new_pred));
512 return new_pred_ptr;
513 } else {
514 return it->second.get();
515 }
516 }
517
SimplifyUsingDeMorgan(Predicate * pred)518 Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
519 // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
520 // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
521 Predicate::Kind kind = pred->kind();
522
523 if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
524 std::vector<Predicate*> new_operands;
525 absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
526 [&](Predicate* p) { return MakeNotPredicate(p); });
527 return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
528 : MakeOrPredicate(new_operands);
529 }
530
531 return nullptr;
532 }
533
534 template <typename PredicateT, typename... Args>
Make(Args &&...args)535 std::unique_ptr<Predicate> Make(Args&&... args) {
536 // If we ever expose the Predicate class outside this .cc file then we may
537 // want to make this hard to misuse (by accidentally passing in an arbitrary
538 // integer to the Predicate constructor for instance).
539 return std::unique_ptr<PredicateT>(
540 new PredicateT(id_counter_++, std::forward<Args>(args)...));
541 }
542
543 Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
544 Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
545 Predicate::Kind pred_kind);
546
547 // Predicate instances are interned, meaning that there is only a single
548 // instance of a Predicate object with a given content. This makes checking
549 // for structural equality super-cheap -- we can just compare pointers.
550 //
551 // We intern predicates by maintaining a map from the content of a Predicate
552 // to the only instance of said predicate we allow to exist in the
553 // interned_and_or_instances_, interned_not_instances_ and
554 // interned_symbol_instances_ fields. These maps also double up as storage
555 // for the owning pointers to predicate instances.
556
557 using SignatureForAndOr =
558 std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
559 using SignatureForNot = Predicate*;
560 using SignatureForAndRec =
561 std::tuple<Predicate*, Predicate*, std::vector<string>>;
562 using SignatureForSymbol = std::pair<SafeTensorId, bool>;
563 using SignatureForIntSymbol = std::pair<SafeTensorId, std::optional<int32>>;
564
565 struct HashSignatureForAndOr {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForAndOr566 size_t operator()(const SignatureForAndOr& signature) const {
567 size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
568 for (Predicate* p : signature.second) {
569 hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
570 }
571 return hash;
572 }
573 };
574
575 struct HashSignatureForSymbol {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForSymbol576 size_t operator()(const SignatureForSymbol& signature) const {
577 return Hash64Combine(SafeTensorId::Hasher()(signature.first),
578 ::tensorflow::hash<bool>()(signature.second));
579 }
580 };
581
582 struct HashSignatureForIntSymbol {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForIntSymbol583 size_t operator()(const SignatureForIntSymbol& signature) const {
584 return Hash64Combine(
585 SafeTensorId::Hasher()(signature.first),
586 Hash64Combine(
587 ::tensorflow::hash<bool>()(signature.second.has_value()),
588 ::tensorflow::hash<int32>()(
589 signature.second.has_value() ? *signature.second : 0)));
590 }
591 };
592
593 // Used to limit recursion to avoid blowing up the stack and cap compile time.
594 class IncrementStackDepth {
595 public:
IncrementStackDepth(PredicateFactory * parent)596 explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
597 parent_->stack_depth_++;
598 }
599
HasOverflowed() const600 bool HasOverflowed() const {
601 const int kMaxStackDepth = 8;
602 return parent_->stack_depth_ >= kMaxStackDepth;
603 }
604
~IncrementStackDepth()605 ~IncrementStackDepth() { parent_->stack_depth_--; }
606
607 private:
608 PredicateFactory* parent_;
609 };
610
611 // A cache for the MakeNotPredicate function.
612 //
613 // NB! This is *not* the same as `interned_not_instances_`.
614 // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
615 // instances, i.e., it ensures there at most one instance of Not(predicate)
616 // for any given predicate whereas `make_not_predicate_cache_` simply caches
617 // the result of the `MakeNotPredicate` function. The values in
618 // `interned_not_instances_` are always instance of `NotPredicate` whereas the
619 // values in `make_not_predicate_cache_` may not be (for instance it will map
620 // Not(Not(A)) to A).
621 absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
622
623 absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
624 HashSignatureForAndOr>
625 interned_and_or_instances_;
626 absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
627 interned_not_instances_;
628 absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
629 interned_and_rec_instances_;
630 absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
631 HashSignatureForSymbol>
632 interned_symbol_instances_;
633 absl::flat_hash_map<SignatureForIntSymbol, std::unique_ptr<Predicate>,
634 HashSignatureForIntSymbol>
635 interned_int_symbol_instances_;
636 int64_t id_counter_ = 0;
637 int stack_depth_ = 0;
638 };
639
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)640 Predicate* PredicateFactory::MakeInternedAndOr(
641 std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
642 std::stable_sort(
643 simplified_ops.begin(), simplified_ops.end(),
644 [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
645
646 auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
647 if (it != interned_and_or_instances_.end()) {
648 return it->second.get();
649 }
650
651 simplified_ops.shrink_to_fit();
652 // NB! Because we'll use a non-owning reference to simplified_ops in the
653 // key for interned_and_or_instances_ we need to be careful to std::move()
654 // it all the way through.
655 absl::Span<Predicate* const> operands_slice = simplified_ops;
656 std::unique_ptr<Predicate> new_pred =
657 pred_kind == Predicate::Kind::kAnd
658 ? Make<AndPredicate>(std::move(simplified_ops))
659 : Make<OrPredicate>(std::move(simplified_ops));
660
661 Predicate* new_pred_ptr = new_pred.get();
662 interned_and_or_instances_.emplace(
663 SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
664 return new_pred_ptr;
665 }
666
667 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)668 Predicate* PredicateFactory::MakeAndOrImpl(
669 absl::Span<Predicate* const> operands, bool is_and) {
670 Predicate::Kind pred_kind =
671 is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
672
673 IncrementStackDepth stack_frame(this);
674 if (stack_frame.HasOverflowed()) {
675 return MakeInternedAndOr(
676 std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
677 }
678
679 Predicate::Kind other_pred_kind =
680 is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
681 absl::flat_hash_set<Predicate*> simplified_ops_set;
682 std::vector<Predicate*> simplified_ops;
683 for (Predicate* op : operands) {
684 // Simplify A&A => A and A|A => A.
685 if (!simplified_ops_set.insert(op).second) {
686 continue;
687 }
688
689 if (op->kind() == pred_kind) {
690 // "Inline" the operands of an inner And/Or into the parent And/Or.
691 for (Predicate* subop : op->GetOperands()) {
692 if (simplified_ops_set.insert(subop).second) {
693 simplified_ops.push_back(subop);
694 }
695 }
696 } else {
697 simplified_ops.push_back(op);
698 }
699 }
700
701 if (simplified_ops.size() == 1) {
702 return simplified_ops[0];
703 }
704
705 // Simplify "A&~A=>False" and "A|~A=>True".
706 absl::flat_hash_set<Predicate*> negated_ops;
707 for (Predicate* op : simplified_ops) {
708 if (negated_ops.count(op)) {
709 // Simple case:
710 //
711 // A & ~A & ... == False
712 // A | ~A | ... == True
713 return is_and ? MakeFalse() : MakeTrue();
714 }
715
716 Predicate* negated_op = MakeNotPredicate(op);
717 if (negated_op->kind() == pred_kind) {
718 // Slightly more complicated case:
719 //
720 // (~A | ~B | ~C) & A & B & C & ... ==
721 // ~(A & B & C) & (A & B & C) & ... == False
722 //
723 // (~A & ~B & ~C) | A | B | C | ... ==
724 // ~(A | B | C) | (A | B | C) | ... == True
725 if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
726 return simplified_ops_set.contains(p);
727 })) {
728 return is_and ? MakeFalse() : MakeTrue();
729 }
730 }
731 negated_ops.insert(negated_op);
732 }
733
734 // Simplify {S,&,X} & ~X & ... => S & ...
735 if (is_and) {
736 absl::flat_hash_set<Predicate*> to_remove;
737 std::vector<Predicate*> to_add;
738 for (Predicate* op : simplified_ops) {
739 if (op->kind() == Predicate::Kind::kAndRecurrence) {
740 auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
741 if (negated_ops.contains(and_rec->step())) {
742 // Remove and_rec and ~X and insert S. Note that checking the
743 // existence of ~X through negated_ops is sufficient since it makes
744 // sure the predicate is in the input operands. It does not need to
745 // be in simplified_ops if it was already cancelled out.
746 to_remove.insert(and_rec);
747 to_remove.insert(MakeNotPredicate(and_rec->step()));
748 to_add.push_back(and_rec->start());
749 }
750 }
751 }
752 auto it = simplified_ops.begin();
753 while (it != simplified_ops.end()) {
754 if (to_remove.contains(*it)) {
755 it = simplified_ops.erase(it);
756 } else {
757 ++it;
758 }
759 }
760 simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
761 }
762
763 // If all ops contain the same subop, then factor it out thanks to the
764 // distributive property. Such as:
765 // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
766 // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
767 //
768 // First find any predicates contained in all subops.
769 std::vector<Predicate*> common_inner_operands;
770 absl::flat_hash_set<Predicate*> common_inner_operands_set;
771 for (Predicate* op : simplified_ops) {
772 if (op->kind() != other_pred_kind) {
773 common_inner_operands.clear();
774 break;
775 }
776
777 if (common_inner_operands.empty()) {
778 common_inner_operands.insert(common_inner_operands.end(),
779 op->GetOperands().begin(),
780 op->GetOperands().end());
781 } else {
782 common_inner_operands.clear();
783 absl::c_copy_if(op->GetOperands(),
784 std::back_inserter(common_inner_operands),
785 [&](Predicate* sub_op) {
786 return common_inner_operands_set.count(sub_op) == 1;
787 });
788 }
789 if (common_inner_operands.empty()) break;
790 common_inner_operands_set.clear();
791 common_inner_operands_set.insert(common_inner_operands.begin(),
792 common_inner_operands.end());
793 }
794
795 if (common_inner_operands.empty()) {
796 return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
797 }
798
799 // For all predicates that can be factored out, remove them and recreate the
800 // subops.
801 std::vector<Predicate*> factored_ops;
802 for (Predicate* op : simplified_ops) {
803 std::vector<Predicate*> new_sub_op_ops;
804 absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
805 [&](Predicate* sub_op) {
806 return std::find(common_inner_operands.begin(),
807 common_inner_operands.end(),
808 sub_op) == common_inner_operands.end();
809 });
810 factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
811 }
812
813 Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
814 std::vector<Predicate*> outer_ops;
815 outer_ops.push_back(new_inner_op);
816 outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
817 common_inner_operands.end());
818 return MakeAndOrImpl(outer_ops, !is_and);
819 }
820
821 class DeadnessAnalysisImpl : public DeadnessAnalysis {
822 public:
DeadnessAnalysisImpl(const Graph * graph)823 explicit DeadnessAnalysisImpl(const Graph* graph)
824 : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
825
826 Status Populate(bool enable_optimistic);
827 Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
828 bool* success);
829 StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
830 Node* n, int oidx) const override;
831 void Print() const override;
832 absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
833 const;
834
835 private:
836 enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
837
838 Status GetInputPreds(Node* n, EdgeKind edge_kind,
839 std::vector<Predicate*>* result);
840
841 // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
842 // bit of `should_revisit` if `pred` is different from the current predicate
843 // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)844 void SetPredicate(Node* n, int output_idx, Predicate* pred,
845 std::vector<bool>* should_revisit) {
846 auto insert_result =
847 predicate_map_.insert({TensorId(n->name(), output_idx), pred});
848 if (!insert_result.second && insert_result.first->second != pred) {
849 VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
850 << insert_result.first->second->ToString() << " "
851 << insert_result.first->second << " to " << pred->ToString()
852 << " " << pred;
853 insert_result.first->second = pred;
854 if (should_revisit != nullptr) {
855 for (const Edge* e : n->out_edges()) {
856 (*should_revisit)[e->dst()->id()] = true;
857 }
858 }
859 }
860 }
861
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)862 void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
863 std::vector<bool>* should_revisit) {
864 for (int output_idx : output_idxs) {
865 SetPredicate(n, output_idx, pred, should_revisit);
866 }
867 }
868
869 Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
870 Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
871 bool use_optimistic_mode);
872 Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
873 Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
874 Status HandleNode(Node* n, std::vector<bool>* should_revisit,
875 bool use_optimistic_mode = false);
876
877 Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
878
IsRootEnter(const Node * n) const879 bool IsRootEnter(const Node* n) const {
880 return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
881 }
882
IsRootExit(const Node * n) const883 bool IsRootExit(const Node* n) const {
884 return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
885 }
886
887 const Graph& graph_;
888 absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
889 PredicateFactory predicate_factory_;
890 std::vector<ControlFlowInfo> control_flow_info_;
891 bool vlog_;
892 absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
893 };
894
InputEdgeToTensorId(const Edge * e)895 TensorId InputEdgeToTensorId(const Edge* e) {
896 return TensorId(e->src()->name(), e->src_output());
897 }
898
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)899 Status DeadnessAnalysisImpl::GetInputPreds(
900 Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
901 std::vector<Predicate*>* result) {
902 result->clear();
903 for (const Edge* in_edge : n->in_edges()) {
904 bool should_process =
905 edge_kind == EdgeKind::kDataAndControl ||
906 (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
907 (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
908
909 if (should_process) {
910 auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
911 if (it == predicate_map_.end()) {
912 GraphCycles graph_cycles;
913 TF_RETURN_IF_ERROR(
914 CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
915
916 // If we didn't return with an error above then the graph is probably
917 // fine and we have a bug in deadness analysis.
918 return errors::Internal("Could not find input ", in_edge->DebugString(),
919 " to ", n->name(),
920 " when visiting the graph in post-order. Most "
921 "likely indicates a bug in deadness analysis.");
922 }
923 result->push_back(it->second);
924 }
925 }
926 return OkStatus();
927 }
928
HandleSwitch(Node * n,std::vector<bool> * should_revisit)929 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
930 std::vector<bool>* should_revisit) {
931 std::vector<Predicate*> input_preds;
932 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
933 const Edge* pred_edge;
934 TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
935
936 if (n->type_string() != "_SwitchN") { // bool pred branch selector.
937 Predicate* true_switch;
938 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
939 pred_edge->src(), pred_edge->src_output(),
940 /*must_be_true=*/true, &true_switch));
941
942 Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
943
944 // Output 0 is alive iff all inputs are alive and the condition is false.
945 input_preds.push_back(false_switch);
946 SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
947 should_revisit);
948 input_preds.pop_back();
949
950 // Output 1 is alive iff all inputs are alive and the condition is true.
951 input_preds.push_back(true_switch);
952 SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
953 should_revisit);
954 input_preds.pop_back();
955 } else { // N-way switch case. Exactly one of N branches is alive.
956 Predicate* branch_pred;
957 for (int i = 0; i < n->num_outputs() - 1; i++) {
958 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
959 pred_edge->src(), pred_edge->src_output(),
960 /*must_have_value=*/std::optional<int32>(i), &branch_pred));
961 input_preds.push_back(branch_pred);
962 SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
963 should_revisit);
964 input_preds.pop_back();
965 input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
966 }
967 // The default (last) branch does not need its own symbol, is simply the
968 // nor of all other branches.
969 SetPredicate(n, n->num_outputs() - 1,
970 predicate_factory_.MakeAndPredicate(input_preds),
971 should_revisit);
972 }
973
974 // Control is alive iff all inputs are alive.
975 SetPredicate(n, Graph::kControlSlot,
976 predicate_factory_.MakeAndPredicate(input_preds),
977 should_revisit);
978
979 return OkStatus();
980 }
981
982 namespace {
CreateMultipleNextIterationInputsError(Node * merge)983 Status CreateMultipleNextIterationInputsError(Node* merge) {
984 std::vector<string> backedges;
985 for (const Edge* backedge : merge->in_edges()) {
986 if (backedge->src()->IsNextIteration()) {
987 backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
988 }
989 }
990 return errors::InvalidArgument(
991 "Multiple NextIteration inputs to merge node ",
992 FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
993 "\nMerge nodes can have at most one incoming NextIteration edge.");
994 }
995
FindUniqueBackedge(Node * merge,const Edge ** result)996 Status FindUniqueBackedge(Node* merge, const Edge** result) {
997 *result = nullptr;
998 CHECK(merge->IsMerge());
999 for (const Edge* e : merge->in_edges()) {
1000 if (e->src()->IsNextIteration()) {
1001 if (*result != nullptr) {
1002 return CreateMultipleNextIterationInputsError(merge);
1003 }
1004 *result = e;
1005 }
1006 }
1007 return OkStatus();
1008 }
1009
1010 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
1011 // does not contain `symbolic_predicate` as an inner (not top-level) operand
1012 // then returns `Step`. Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)1013 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
1014 Predicate* symbolic_predicate,
1015 Predicate* backedge_predicate) {
1016 CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
1017 if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
1018 return nullptr;
1019 }
1020
1021 std::vector<Predicate*> and_ops;
1022 absl::Span<Predicate* const> recurrent_pred_ops =
1023 backedge_predicate->GetOperands();
1024
1025 bool found_sym = false;
1026 for (Predicate* and_op : recurrent_pred_ops) {
1027 // We want the `symbol_predicate` to be the one of the operands of
1028 // `backedge_predicate`,
1029 if (and_op == symbolic_predicate) {
1030 found_sym = true;
1031 continue;
1032 }
1033
1034 // but we don't want it to be present anywhere else in the formula. E.g. we
1035 // don't want the recurrent predicate to be
1036 // symbol_predicate&(X|symbol_predicate).
1037 bool found_sym_as_inner_operand = false;
1038 auto has_self_as_inner_operand = [&](Predicate* p) {
1039 if (p == symbolic_predicate) {
1040 found_sym_as_inner_operand = true;
1041 return true; // Stop searching, we're done.
1042 }
1043
1044 // Continue searching.
1045 return false;
1046 };
1047
1048 Predicate::Visit(and_op, has_self_as_inner_operand);
1049 if (found_sym_as_inner_operand) {
1050 return nullptr;
1051 }
1052 and_ops.push_back(and_op);
1053 }
1054
1055 return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
1056 }
1057
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)1058 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1059 std::vector<string>* frame) {
1060 int depth = 0;
1061 for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
1062 n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
1063 frame->push_back(cfi_iter->frame_name);
1064
1065 if (depth++ > 5000) {
1066 return errors::Internal(
1067 "Frame of depth > 5000: Probably malformed graph or a bug in "
1068 "BuildControlFlowInfo");
1069 }
1070 }
1071
1072 return OkStatus();
1073 }
1074
1075 // If the node is inside some frames, get the name of the outermost non-empty
1076 // frame. Otherwise, get an empty frame name.
GetRootFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,absl::string_view * frame)1077 Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1078 absl::string_view* frame) {
1079 int depth = 0;
1080 const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
1081 while (!cfi_iter->parent_frame->IsSource()) {
1082 n = cfi_iter->parent_frame;
1083 cfi_iter = &cfi_infos[n->id()];
1084
1085 if (depth++ > 5000) {
1086 return errors::Internal(
1087 "Frame of depth > 5000: Probably malformed graph or a bug in "
1088 "BuildControlFlowInfo");
1089 }
1090 }
1091
1092 *frame = cfi_iter->frame_name;
1093 return OkStatus();
1094 }
1095 } // namespace
1096
HandleMerge(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1097 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
1098 std::vector<bool>* should_revisit,
1099 bool use_optimistic_mode) {
1100 // Merge ignores deadness of its control inputs. A merge that isn't the
1101 // target of a backedge has is alive iff any of its data inputs are. The
1102 // liveness of a merge that is the target of a backedge can sometimes be
1103 // represented using a AndRecurrencePredicate. If neither apply, we represent
1104 // the liveness of the merge symbolically.
1105
1106 bool has_unvisited_backedge = false;
1107 for (const Edge* e : n->in_edges()) {
1108 if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
1109 has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
1110 }
1111 }
1112
1113 auto it = predicate_map_.find(TensorId(n->name(), 0));
1114 if (it == predicate_map_.end()) {
1115 if (has_unvisited_backedge) {
1116 // We're visiting this merge for the first time and it has an unvisited
1117 // backedge.
1118 Predicate* input_data_pred;
1119 if (use_optimistic_mode) {
1120 // In the optimistic mode, we use the first-seen Merge node per
1121 // frame as the representative Merge node. It is just convenient and
1122 // does not affect the result after pattern-matching into the
1123 // AndRecurrence form.
1124 absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
1125 auto insert_result = frame_to_merge_node_.insert({frame_name, n});
1126 Node* representative = insert_result.first->second;
1127 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1128 representative, /*output_idx=*/0, /*must_be_true=*/false,
1129 &input_data_pred));
1130 } else {
1131 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1132 n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
1133 }
1134
1135 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1136 should_revisit);
1137 return OkStatus();
1138 }
1139
1140 std::vector<Predicate*> input_preds;
1141 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
1142
1143 // We're visiting this merge for the first time and it is an acyclic merge.
1144 Predicate* input_data_pred =
1145 predicate_factory_.MakeOrPredicate(input_preds);
1146 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1147 should_revisit);
1148 return OkStatus();
1149 }
1150
1151 if (it->second->kind() == Predicate::Kind::kSymbol) {
1152 // Last time we visited this merge we only got a symbolic predicate because
1153 // of an unvisited backedge. Try to pattern match the predicate expression
1154 // for that backedge (which should be visited now) into an and recurrence
1155 // for the merge node.
1156 const Edge* unique_backedge;
1157 TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
1158 if (unique_backedge) {
1159 if (Predicate* step = DeduceStepPredicate(
1160 &predicate_factory_, it->second,
1161 predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
1162 // If the predicate for the backedge is "Sym&X" where "Sym" is the
1163 // predicate for the merge then the merge has predicate {S,&,X} where S
1164 // is the predicate for the merge ignoring the backedge.
1165 std::vector<Predicate*> non_recurrent_inputs;
1166 for (const Edge* e : n->in_edges()) {
1167 if (e != unique_backedge) {
1168 non_recurrent_inputs.push_back(
1169 predicate_map_[InputEdgeToTensorId(e)]);
1170 }
1171 }
1172
1173 Predicate* start =
1174 predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
1175 std::vector<string> frame;
1176 TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
1177 Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
1178 start, step, std::move(frame));
1179 SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
1180 return OkStatus();
1181 }
1182 }
1183 }
1184 return OkStatus();
1185 }
1186
HandleRecv(Node * n,std::vector<bool> * should_revisit)1187 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
1188 std::vector<bool>* should_revisit) {
1189 // In addition to being alive or dead based on the inputs, a _Recv can also
1190 // acquire a dead signal from a _Send.
1191 std::vector<Predicate*> input_preds;
1192 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1193 Predicate* signal_is_alive;
1194 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1195 n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1196 input_preds.push_back(signal_is_alive);
1197 SetPredicate(n, {0, Graph::kControlSlot},
1198 predicate_factory_.MakeAndPredicate(input_preds),
1199 should_revisit);
1200 return OkStatus();
1201 }
1202
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1203 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1204 std::vector<bool>* should_revisit) {
1205 // Generally nodes are alive iff all their inputs are alive.
1206 std::vector<Predicate*> input_preds;
1207 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1208 Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1209 for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1210 SetPredicate(n, output_idx, pred, should_revisit);
1211 }
1212 SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1213 return OkStatus();
1214 }
1215
HandleNode(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1216 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1217 std::vector<bool>* should_revisit,
1218 bool use_optimistic_mode) {
1219 if (n->IsSwitch()) {
1220 TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1221 } else if (n->IsMerge()) {
1222 TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
1223 } else if (n->IsControlTrigger()) {
1224 SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1225 nullptr);
1226 } else if (n->IsRecv() || n->IsHostRecv()) {
1227 TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1228 } else if (n->IsNextIteration()) {
1229 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1230 } else {
1231 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1232 }
1233 return OkStatus();
1234 }
1235
1236 // Compute a special topological order for the Graph, where nodes having the
1237 // same root frame are placed adjacent to each other. The traversal uses a
1238 // variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
1239 // many inputs of each node are ready; a node is ready to be scheduled if all
1240 // of its inputs are ready.
1241 // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
GetFrameBasedTopologicalOrder(std::vector<Node * > * order)1242 Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
1243 std::vector<Node*>* order) {
1244 absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
1245 absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
1246 std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
1247 Node* src_node = graph_.source_node();
1248 for (const auto* node : graph_.op_nodes()) {
1249 const ControlFlowInfo& cf = control_flow_info_[node->id()];
1250 if (IsRootEnter(node)) {
1251 // Since we care only the root-level frame, full frame names are the same
1252 // as frame names.
1253 ++num_enters_for_frame[cf.frame_name];
1254 } else if (IsRootExit(node)) {
1255 ++num_exits_for_frame[cf.frame_name];
1256 }
1257 // Edge NextIteration->Merge is counted before starting the traversal to
1258 // break the backedges.
1259 if (IsMerge(node)) {
1260 for (const Edge* e : node->in_edges()) {
1261 if (IsNextIteration(e->src())) {
1262 ++num_ready_inputs[node->id()];
1263 }
1264 }
1265 }
1266 }
1267
1268 // dequeue is used to ensure that the nodes are first-in-first-out. This
1269 // order guarantees that the exits in the ready queue are visited before
1270 // nodes that will become ready in the future.
1271 std::deque<Node*> ready;
1272 ready.push_back(src_node);
1273 // ready_enters_per_frame and ready_exits serve as a staging area to buffer
1274 // the ready enters/exits before they are moved to the `ready` queue for
1275 // controlling the start and end of a processing frame.
1276 absl::flat_hash_map<absl::string_view, std::vector<Node*>>
1277 ready_enters_per_frame;
1278 // Exit nodes shall all be from the same frame, as we process a frame at a
1279 // time. So, one vector is enough.
1280 std::vector<Node*> ready_exits;
1281 while (!ready.empty()) {
1282 Node* curr_node = ready.front();
1283 ready.pop_front();
1284
1285 VLOG(4) << "Visiting " << curr_node->name();
1286 order->push_back(curr_node);
1287
1288 for (const Edge* out_edge : curr_node->out_edges()) {
1289 Node* out = out_edge->dst();
1290 int out_id = out->id();
1291 if (IsNextIteration(curr_node) && IsMerge(out)) {
1292 // Edge NextIteration->Merge has been counted.
1293 continue;
1294 }
1295 ++num_ready_inputs[out->id()];
1296 if (!out->IsOp()) continue; // Skip Sink/Source nodes.
1297 if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
1298
1299 absl::string_view frame_name = control_flow_info_[out_id].frame_name;
1300 if (IsRootEnter(out)) {
1301 ready_enters_per_frame[frame_name].push_back(out);
1302 } else if (IsRootExit(out)) {
1303 ready_exits.push_back(out);
1304 } else {
1305 ready.push_back(out);
1306 }
1307 }
1308
1309 if (ready.empty()) {
1310 // Try moving nodes from ready_enters_per_frame and ready_exits to
1311 // `ready`.
1312 if (!ready_exits.empty()) {
1313 // If there are nodes in ready_exits we must process them before
1314 // processing ready_enters_per_frame to make sure all nodes in the
1315 // currently processing frame are visited before starting processing
1316 // other frames.
1317 absl::string_view frame_name =
1318 control_flow_info_[ready_exits.front()->id()].frame_name;
1319 CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
1320 ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
1321 ready_exits.clear();
1322 } else {
1323 // Otherwise, try moving nodes from ready_enters to `ready`.
1324 for (auto iter = ready_enters_per_frame.begin();
1325 iter != ready_enters_per_frame.end(); ++iter) {
1326 absl::string_view frame_name = iter->first;
1327 const std::vector<Node*>& ready_enters = iter->second;
1328 if (ready_enters.size() == num_enters_for_frame[frame_name]) {
1329 ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
1330 ready_enters_per_frame.erase(iter);
1331 break;
1332 }
1333 }
1334 }
1335 }
1336 }
1337
1338 if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
1339 return errors::InvalidArgument(
1340 "Some enters/exits have never been visited in the traversal."
1341 " Most probably the input graph is malformed.");
1342 }
1343 return OkStatus();
1344 }
1345
1346 // We populate the nodes along a special topological order where nodes having
1347 // the same root frame are placed adjacent to each other. This grouping enables
1348 // processing the graph per root frame at a time and guarantees that when a root
1349 // frame is being processed, nodes in the downstream frames have not yet been
1350 // processed. This property is important because we need to process an entire
1351 // frame to know whether the optimistic mode converges or not. In other words,
1352 // nodes in the downstream frames shall not be populated until all of its
1353 // upstream frames are populated. In effect, this order enables processing each
1354 // (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
1355 // (root) frame. Note that we don't separate while loops belonging to the same
1356 // nested while, as there is no clean cut for separating them in the topological
1357 // order.
Populate(bool enable_optimistic)1358 Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
1359 std::vector<string> unreachable_nodes;
1360 // Compute the loop structure of the graph.
1361 TF_RETURN_IF_ERROR(
1362 BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1363
1364 // Do some opportunistic error checking:
1365 if (!unreachable_nodes.empty()) {
1366 if (unreachable_nodes.size() > 5) {
1367 unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1368 unreachable_nodes.end());
1369 }
1370
1371 return errors::InvalidArgument(
1372 "Found unreachable nodes, most likely source and sink nodes not "
1373 "connected: ",
1374 absl::StrJoin(unreachable_nodes, ", "));
1375 }
1376
1377 std::vector<Node*> topo;
1378 TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
1379
1380 size_t frame_start = 0;
1381 while (frame_start < topo.size()) {
1382 // Batching nodes who have the same root frame.
1383 absl::string_view cur_frame_name;
1384 TF_RETURN_IF_ERROR(
1385 GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
1386 size_t frame_end = frame_start;
1387 for (size_t i = frame_start + 1; i < topo.size(); ++i) {
1388 absl::string_view i_frame_name;
1389 TF_RETURN_IF_ERROR(
1390 GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
1391 if (i_frame_name == cur_frame_name) {
1392 frame_end = i;
1393 } else {
1394 break;
1395 }
1396 }
1397 absl::Span<Node*> sub_topo(topo.data() + frame_start,
1398 /*length=*/frame_end - frame_start + 1);
1399 frame_start = frame_end + 1;
1400
1401 // First, try the optimistic mode.
1402 bool success = false;
1403 if (enable_optimistic && !cur_frame_name.empty()) {
1404 TF_RETURN_IF_ERROR(
1405 PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
1406 }
1407 if (!success) {
1408 // The optimistic mode does not converge. Let's fall back to the
1409 // pessimistic mode.
1410 TF_RETURN_IF_ERROR(
1411 PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
1412 }
1413 VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
1414 << (success ? "optimistic" : "pessimistic") << " mode.";
1415 }
1416
1417 return OkStatus();
1418 }
1419
PopulateFrame(absl::Span<Node * const> topo,bool use_optimistic_mode,bool * success)1420 Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
1421 bool use_optimistic_mode,
1422 bool* success) {
1423 CHECK(use_optimistic_mode && success != nullptr ||
1424 !use_optimistic_mode && success == nullptr);
1425
1426 // This an abstract interpretation over the deadness propagation semantics of
1427 // the graph executor.
1428 //
1429 // We iterate over the graph twice, each time in a topological order. On the
1430 // first iteration merge nodes with backedges are mapped to symbolic
1431 // predicates. On the second iteration we use the predicates assigned to the
1432 // backedges in the previous iteration to infer a more precise predicate for
1433 // the backedge merge nodes and all the nodes that transitively use it.
1434 //
1435 // We don't track the output indices for should_revisit. Instead, putting a
1436 // node in `should_revisit` denotes that the deadness flowing out from any
1437 // output from said node may have changed. This is fine; only switches
1438 // propagate different deadness along different output edges, and since the
1439 // delta is solely due to the input *values* (and not input deadness), the
1440 // delta should not change in the second iteration.
1441 std::vector<bool> should_revisit;
1442 should_revisit.resize(graph_.num_node_ids());
1443 for (Node* n : topo) {
1444 VLOG(4) << "Visiting " << n->name();
1445 TF_RETURN_IF_ERROR(
1446 HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
1447 if (n->IsNextIteration()) {
1448 // If this is a backedge for a merge node then remember to reprocess the
1449 // merge the next time we run.
1450 for (const Edge* e : n->out_edges()) {
1451 if (e->dst()->IsMerge()) {
1452 should_revisit[e->dst()->id()] = true;
1453 }
1454 }
1455 }
1456 }
1457
1458 for (Node* n : topo) {
1459 // The nodes added to should_revisit in the previous loop need to be
1460 // revisited now. Reprocessing these initial nodes may add *their*
1461 // consumers to should_revisit, and these newly added nodes will also be
1462 // processed by this very same loop. Since we're traversing the graph in
1463 // topological order (producers before consumers) and HandleNode(n) can only
1464 // ever add n's consumers to should_revisit, we won't "miss" an addition to
1465 // should_revisit.
1466 if (should_revisit[n->id()]) {
1467 VLOG(4) << "Revisiting " << n->name();
1468 TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1469 }
1470 }
1471
1472 // Check if the optimistic analysis converges. Specifically, check whether
1473 // all the predicates of the merge nodes in the same frame are the same. If
1474 // yes, report success. If not, report failure and clear the assigned
1475 // predicates.
1476 if (use_optimistic_mode) {
1477 bool is_converged = true;
1478 absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
1479 for (Node* n : topo) {
1480 if (!n->IsMerge()) {
1481 continue;
1482 }
1483 const Edge* e;
1484 TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
1485 if (e == nullptr) {
1486 // Skip acyclic merge nodes.
1487 continue;
1488 }
1489 Node* merge = n;
1490 // Note that here uses frame names instead of root frame names. In the
1491 // case of a nested while loop, each level of while loops can have merges
1492 // with different predicate instances, while the merge nodes on the same
1493 // level must have the same predicate instances.
1494 absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
1495 auto it = predicate_map_.find(TensorId(merge->name(), 0));
1496 Predicate* merge_pred = it->second;
1497 if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
1498 is_converged = false;
1499 VLOG(2) << "Running the optimistic mode on frame " << frame_name
1500 << " does not converge because node " << merge->name()
1501 << " cannot be mapped into the AndRecurrence form.";
1502 break;
1503 }
1504
1505 auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
1506 if (!insert_result.second) {
1507 // If we have already seen this frame name, verify the predicate is the
1508 // same as the previously seen one's.
1509 Predicate* curr_andrec = merge_pred;
1510 Predicate* prev_andrec = insert_result.first->second;
1511 if (curr_andrec != prev_andrec) {
1512 is_converged = false;
1513 VLOG(2) << "Running the optimistic mode on frame " << frame_name
1514 << " does not converge. Seeing different Merge predicates: \n"
1515 << curr_andrec->ToString() << " and \n"
1516 << prev_andrec->ToString();
1517 break;
1518 }
1519 }
1520 }
1521
1522 // Clear the assigned predicates if the optimistic mode does not converge.
1523 if (!is_converged) {
1524 for (Node* n : topo) {
1525 for (int oid = 0; oid < n->num_outputs(); ++oid) {
1526 predicate_map_.erase(TensorId(n->name(), oid));
1527 }
1528 predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
1529 }
1530 }
1531
1532 if (success != nullptr) {
1533 *success = is_converged;
1534 }
1535 }
1536
1537 return OkStatus();
1538 }
1539
1540 StatusOr<DeadnessAnalysis::DeadnessPredicate>
GetPredicateFor(Node * n,int oidx) const1541 DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
1542 auto it = predicate_map_.find(TensorId(n->name(), oidx));
1543 TF_RET_CHECK(it != predicate_map_.end())
1544 << "could not find " << TensorId(n->name(), oidx).ToString()
1545 << " in predicate map";
1546 return MakeDeadnessPredicate(it->second);
1547 }
1548
Print() const1549 void DeadnessAnalysisImpl::Print() const {
1550 std::vector<TensorId> tensor_ids;
1551 tensor_ids.reserve(predicate_map_.size());
1552 for (const auto& kv_pair : predicate_map_) {
1553 tensor_ids.push_back(kv_pair.first);
1554 }
1555
1556 std::sort(tensor_ids.begin(), tensor_ids.end());
1557
1558 for (TensorId tensor_id : tensor_ids) {
1559 auto it = predicate_map_.find(tensor_id);
1560 CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1561 VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1562 }
1563 }
1564
1565 } // namespace
1566
~DeadnessAnalysis()1567 DeadnessAnalysis::~DeadnessAnalysis() {}
1568
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1569 /*static*/ Status DeadnessAnalysis::Run(
1570 const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1571 std::unique_ptr<DeadnessAnalysisImpl> analysis(
1572 new DeadnessAnalysisImpl(&graph));
1573 TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
1574
1575 if (VLOG_IS_ON(2)) {
1576 analysis->Print();
1577 }
1578
1579 *result = std::move(analysis);
1580 return OkStatus();
1581 }
1582
1583 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1584 DeadnessAnalysisImpl::PredicateMapAsString() const {
1585 absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1586 for (const auto& kv_pair : predicate_map_) {
1587 CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1588 }
1589 return result;
1590 }
1591
1592 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map,bool enable_optimistic)1593 Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
1594 bool enable_optimistic) {
1595 DeadnessAnalysisImpl impl(&graph);
1596 TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
1597 *out_predicate_map = impl.PredicateMapAsString();
1598 return OkStatus();
1599 }
1600
1601 } // namespace deadness_analysis_internal
1602
DebugString(DeadnessPredicate predicate) const1603 string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
1604 return static_cast<Predicate*>(predicate.pred_)->ToString();
1605 }
1606
1607 } // namespace tensorflow
1608