xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/deadness_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/control_flow.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 
33 // ALGORITHM OVERVIEW
34 // ==================
35 //
36 // We map every output produced by each node in the TensorFlow graph (including
37 // control dependence) into an instance of the Predicate class.  Instances of
38 // Predicate denote logical formulas and mapping a node `n` to a predicate
39 // `pred` implies that `n` is live whenever `pred` is true.  Then we can deduce
40 // mismatching liveness in the inputs to node by comparing the predicate those
41 // inputs are mapped to.  The core logic of this pass resides in creating the
42 // map from TensorFlow nodes to predicates.
43 //
44 //
45 // MAPPING NODES TO PREDICATES, MODULO CYCLES
46 // ------------------------------------------
47 //
48 // If we ignore cycles for a moment, computing predicates is fairly
49 // straightforward.  We traverse the graph in a topological order, mapping each
50 // node to a predicate based on the predicates its inputs are mapped to.  For
51 // instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
52 // PredicateFor(Y)).  Roughtly speaking, we abstractly interpret each node on
53 // the "liveness" domain, where values in the domain represent if a tensor
54 // carries a dead signal or not.
55 //
56 //
57 // DEALING WITH CYCLES
58 // -------------------
59 //
60 // We map Merge nodes that are the target of a backedge to AndRecurrence
61 // instances.  An AndRecurrence with start() = S and step() = X, printed as
62 // {S,&,X}, *roughly* represents the infinite list of predicates
63 // [S,S&X,S&X&X,S&X&X, ...].  So {S,&,X} can be used to represent the predicate
64 // for Merge in a graph like:
65 //
66 //     Init
67 //       |
68 //       v
69 //     Merge <-----------+
70 //       |               |
71 //       v               |
72 //      Incr             |
73 //       |               |
74 //       v               |
75 //      Switch <- Cond   |
76 //       |               |
77 //       v (oidx: 1)     |
78 //       |               |
79 //       +---------------+
80 //
81 // Where S is the predicate for Init and X is the predicate that asserts that
82 // Cond is true.  {S,&,X} states that Merge is live on the first "iteration" iff
83 // S is true, live on the second iteration iff "S&X" is true, live on the third
84 // iteration iff "S&X&X" is true etc.  There is a subtlety here, S&X&X would
85 // normally be equivalent to S&X which isn't quite what we want to represent.
86 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
87 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
88 // true on iteration 0, 1, 2 respectively.  This is made more precise in the
89 // comment on the AndRecurrence class.
90 //
91 // The general algorithm that deals with cycles does two topological-order
92 // iterations over the graph.  On the first iteration it assigns a symbolic
93 // predicate to merge nodes with backedges.  On the second iteration it tries
94 // to pattern match the predicates for the backedges of these merges and infer
95 // an AndRecurrence for the merge.  In other words, we do a data flow analysis
96 // where the data-flow lattice has two elements, Symbolic and NonSymbolic with
97 // Symbolic > NonSymbolic.  The lattice has height = 2 so two iterations are
98 // sufficient to converge.
99 //
100 // We first do an optimistic analysis and, if it does not converge, we then fall
101 // back to a pessimistic analysis.  The optimistic analysis assigns the same
102 // symbolic predicate to all the merge nodes whose preceding enter nodes have
103 // the same frame name on the first iteration.  On the second iteration, if all
104 // the merge nodes are pattern matched into the same AndRecurrence predicate
105 // instance, the optimistic assignment of the same symbolic predicate is correct
106 // and the analyzed result is taken.
107 //
108 // Otherwise, if the optimistic analysis fails to converge, we then obtain the
109 // result by falling back to the pessimistic analysis which assigns a unique
110 // symbolic predicate to each merge on the first iteration.  We still use
111 // symbolic predicates for merges for which we can't pattern match on the
112 // backedge predicate.  This is conservatively correct.
113 
114 namespace tensorflow {
115 
116 namespace {
117 
118 using se::port::StatusOr;
119 
120 // Represents a logical predicate, used as described in the algorithm overview
121 // above.
122 class Predicate {
123  public:
124   enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol };
125 
126   virtual string ToString() const = 0;
127 
128   // An ID assigned to the Predicate at construction time.  Conceptually like a
129   // pointer, except that it is stable across runs.
id() const130   int64_t id() const { return id_; }
131 
132   virtual absl::Span<Predicate* const> GetOperands() const = 0;
133 
134   virtual Kind kind() const = 0;
~Predicate()135   virtual ~Predicate() {}
136 
137   // Invokes func on p and on all of its operands recursively.  Does not invoke
138   // `func` on the same Predicate instance twice.  Aborts the search if `func`
139   // returns true.
140   template <typename FunctionTy>
141   static void Visit(Predicate* p, const FunctionTy& func);
142 
143  protected:
Predicate(int64_t id)144   explicit Predicate(int64_t id) : id_(id) {}
145 
146  private:
147   const int64_t id_;
148 
149   TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
150 };
151 
152 // Represents a logical conjunction of a set of predicates.
153 class AndPredicate : public Predicate {
154  public:
AndPredicate(int64_t id,std::vector<Predicate * > operands)155   explicit AndPredicate(int64_t id, std::vector<Predicate*> operands)
156       : Predicate(id), operands_(std::move(operands)) {}
157 
ToString() const158   string ToString() const override {
159     if (operands().empty()) {
160       return "#true";
161     }
162 
163     std::vector<string> operands_str;
164     std::transform(operands().begin(), operands().end(),
165                    std::back_inserter(operands_str),
166                    [](Predicate* pred) { return pred->ToString(); });
167 
168     return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
169   }
170 
kind() const171   Kind kind() const override { return Kind::kAnd; }
172 
GetOperands() const173   absl::Span<Predicate* const> GetOperands() const override {
174     return operands_;
175   }
operands() const176   absl::Span<Predicate* const> operands() const { return operands_; }
177 
178  private:
179   std::vector<Predicate*> operands_;
180 };
181 
182 // Represents a logical disjunction of a set of predicates.
183 class OrPredicate : public Predicate {
184  public:
OrPredicate(int64_t id,std::vector<Predicate * > operands)185   explicit OrPredicate(int64_t id, std::vector<Predicate*> operands)
186       : Predicate(id), operands_(std::move(operands)) {}
187 
ToString() const188   string ToString() const override {
189     if (operands().empty()) {
190       return "#false";
191     }
192 
193     std::vector<string> operands_str;
194     std::transform(operands().begin(), operands().end(),
195                    std::back_inserter(operands_str),
196                    [](Predicate* pred) { return pred->ToString(); });
197 
198     return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
199   }
200 
kind() const201   Kind kind() const override { return Kind::kOr; }
GetOperands() const202   absl::Span<Predicate* const> GetOperands() const override {
203     return operands_;
204   }
operands() const205   absl::Span<Predicate* const> operands() const { return operands_; }
206 
207  private:
208   std::vector<Predicate*> operands_;
209 };
210 
211 // Represents a logical negation of a set of predicates.
212 class NotPredicate : public Predicate {
213  public:
NotPredicate(int64_t id,Predicate * operand)214   explicit NotPredicate(int64_t id, Predicate* operand)
215       : Predicate(id), operands_({operand}) {}
216 
ToString() const217   string ToString() const override {
218     return absl::StrCat("~", operand()->ToString());
219   }
220 
kind() const221   Kind kind() const override { return Kind::kNot; }
operand() const222   Predicate* operand() const { return operands_[0]; }
GetOperands() const223   absl::Span<Predicate* const> GetOperands() const override {
224     return operands_;
225   }
226 
227  private:
228   std::array<Predicate*, 1> operands_;
229 };
230 
231 // Represents the liveness of an induction variable.  For users inside the loop
232 // this represents the "current" liveness of the induction variable.  For users
233 // outside the loop it represents the "last" liveness of the induction variable.
234 //
235 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
236 // in the following graph:
237 //
238 //   V = Merge(S', V_NextIt)
239 //   V = Op(V, X')
240 //   V_NextIt = NextIteration(V)
241 //
242 // where Predicate(S') = S and Predicate(X') = X.
243 //
244 // `X` may contain symbolic predicates and the operations corresponding to these
245 // symbolic predicates are either in frame `loop` or outside it.  The symbols
246 // that are inside frame `loop` are loop variant (i.e. can have different
247 // liveness in each loop iteration) and the symbols that are outside frame
248 // `loop` are loop invariant (i.e. have the same liveness across all
249 // iterations).
250 class AndRecurrencePredicate : public Predicate {
251  public:
AndRecurrencePredicate(int64_t id,Predicate * start,Predicate * step,std::vector<string> frame)252   explicit AndRecurrencePredicate(int64_t id, Predicate* start, Predicate* step,
253                                   std::vector<string> frame)
254       : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
255 
start() const256   Predicate* start() const { return operands_[0]; }
step() const257   Predicate* step() const { return operands_[1]; }
frame() const258   absl::Span<const string> frame() const { return frame_; }
259 
ToString() const260   string ToString() const override {
261     return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
262                         "}<", absl::StrJoin(frame(), ";"), ">");
263   }
264 
kind() const265   Kind kind() const override { return Kind::kAndRecurrence; }
266 
GetOperands() const267   absl::Span<Predicate* const> GetOperands() const override {
268     return operands_;
269   }
270 
271  private:
272   std::array<Predicate*, 2> operands_;
273   std::vector<string> frame_;
274 };
275 
276 // Represents an uninterpreted symbol in a logical predicate.
277 //
278 // Two predicates are equivalent iff they are equivalent for all assignments to
279 // the symbols contained in them, i.e. predicates are forall qualified over
280 // symbols.
281 class SymbolPredicate : public Predicate {
282  public:
SymbolPredicate(int64_t id,TensorId tensor_id,bool must_be_true)283   explicit SymbolPredicate(int64_t id, TensorId tensor_id, bool must_be_true)
284       : Predicate(id),
285         tensor_id_(std::move(tensor_id)),
286         must_be_true_(must_be_true) {}
287 
ToString() const288   string ToString() const override {
289     return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
290                           : tensor_id_.ToString();
291   }
292 
kind() const293   Kind kind() const override { return Kind::kSymbol; }
GetOperands() const294   absl::Span<Predicate* const> GetOperands() const override { return {}; }
295 
296   // If `must_be_true()` is true this SymbolPredicate represents the proposition
297   // "tensor_id() is live and evaluates to true".
298   //
299   // If `must_be_true()` is false then this SymbolPredicate represents the
300   // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const301   TensorId tensor_id() const { return tensor_id_; }
must_be_true() const302   bool must_be_true() const { return must_be_true_; }
303 
304  private:
305   TensorId tensor_id_;
306   bool must_be_true_;
307 };
308 
309 // Represents an uninterpreted symbol in a logical predicate.
310 //
311 // Two predicates are equivalent iff they are equivalent for all assignments to
312 // the symbols contained in them, i.e. predicates are forall qualified over
313 // symbols.
314 class IntSymbolPredicate : public Predicate {
315  public:
IntSymbolPredicate(int64_t id,TensorId tensor_id,std::optional<int> must_have_value)316   explicit IntSymbolPredicate(int64_t id, TensorId tensor_id,
317                               std::optional<int> must_have_value)
318       : Predicate(id),
319         tensor_id_(std::move(tensor_id)),
320         must_have_value_(must_have_value) {}
321 
ToString() const322   string ToString() const override {
323     return must_have_value().has_value()
324                ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_)
325                : tensor_id_.ToString();
326   }
327 
kind() const328   Kind kind() const override { return Kind::kIntSymbol; }
GetOperands() const329   absl::Span<Predicate* const> GetOperands() const override { return {}; }
330 
331   // If `must_have_value().has_value()` is true, then this IntSymbolPredicate
332   // represents the proposition "tensor_id() is live and evaluates to
333   // `*must_have_value()`".
334   //
335   // If `must_have_value().has_value()` is false, then this IntSymbolPredicate
336   // represents the proposition "tensor_id() is live (and may evaluate to any
337   // value)".
tensor_id() const338   TensorId tensor_id() const { return tensor_id_; }
must_have_value() const339   const std::optional<int>& must_have_value() const { return must_have_value_; }
340 
341  private:
342   TensorId tensor_id_;
343   std::optional<int> must_have_value_;
344 };
345 
346 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)347 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
348   absl::flat_hash_set<Predicate*> visited;
349   std::vector<Predicate*> stack;
350 
351   stack.push_back(p);
352   visited.insert(p);
353 
354   while (!stack.empty()) {
355     Predicate* current = stack.back();
356     stack.pop_back();
357     bool done = func(current);
358     if (done) {
359       return;
360     }
361     for (Predicate* op : current->GetOperands()) {
362       if (visited.insert(op).second) {
363         stack.push_back(op);
364       }
365     }
366   }
367 }
368 
369 // Creates and owns Predicate instances.  Simplifies predicates as it creates
370 // them.
371 class PredicateFactory {
372  public:
MakeAndPredicate(absl::Span<Predicate * const> operands)373   Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
374     return MakeAndOrImpl(operands, /*is_and=*/true);
375   }
376 
MakeOrPredicate(absl::Span<Predicate * const> operands)377   Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
378     return MakeAndOrImpl(operands, /*is_and=*/false);
379   }
380 
MakeNotPredicate(Predicate * pred)381   Predicate* MakeNotPredicate(Predicate* pred) {
382     auto it = make_not_predicate_cache_.find(pred);
383     if (it != make_not_predicate_cache_.end()) {
384       return it->second;
385     }
386 
387     Predicate* result = MakeNotPredicateImpl(pred);
388 
389     bool insert_successful =
390         make_not_predicate_cache_.insert({pred, result}).second;
391     (void)insert_successful;
392     DCHECK(insert_successful);
393 
394     return result;
395   }
396 
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)397   Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
398                                         std::vector<string> frame) {
399     SignatureForAndRec signature(start, step, std::move(frame));
400     auto it = interned_and_rec_instances_.find(signature);
401     if (it != interned_and_rec_instances_.end()) {
402       return it->second.get();
403     }
404 
405     std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
406         std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
407     Predicate* new_pred_ptr = new_pred.get();
408     bool inserted =
409         interned_and_rec_instances_.emplace(signature, std::move(new_pred))
410             .second;
411     (void)inserted;
412     DCHECK(inserted);
413     return new_pred_ptr;
414   }
415 
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)416   Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
417                              Predicate** predicate) {
418     TensorId tensor_id(node->name(), output_idx);
419 
420     bool is_boolean_tensor =
421         BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
422     TF_RET_CHECK(!must_be_true || is_boolean_tensor);
423 
424     if (node->type_string() == "Const" && must_be_true) {
425       const TensorProto* proto = nullptr;
426       TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
427 
428       Tensor tensor(proto->dtype());
429       TF_RET_CHECK(tensor.FromProto(*proto));
430 
431       *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
432       return OkStatus();
433     }
434 
435     SignatureForSymbol signature = {tensor_id, must_be_true};
436     auto it = interned_symbol_instances_.find(signature);
437     if (it == interned_symbol_instances_.end()) {
438       std::unique_ptr<Predicate> new_pred =
439           Make<SymbolPredicate>(tensor_id, must_be_true);
440       Predicate* new_pred_ptr = new_pred.get();
441       interned_symbol_instances_.emplace(std::move(signature),
442                                          std::move(new_pred));
443       *predicate = new_pred_ptr;
444     } else {
445       *predicate = it->second.get();
446     }
447 
448     return OkStatus();
449   }
450 
MakeSymbolPredicate(Node * node,int output_idx,std::optional<int> must_have_value,Predicate ** predicate)451   Status MakeSymbolPredicate(Node* node, int output_idx,
452                              std::optional<int> must_have_value,
453                              Predicate** predicate) {
454     TensorId tensor_id(node->name(), output_idx);
455 
456     TF_RET_CHECK(BaseType(node->output_type(tensor_id.index())) == DT_INT32);
457 
458     if (must_have_value.has_value() && node->type_string() == "Const") {
459       const TensorProto* proto = nullptr;
460       TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
461 
462       Tensor tensor(proto->dtype());
463       TF_RET_CHECK(tensor.FromProto(*proto));
464 
465       *predicate = tensor.scalar<int32>()() == *must_have_value ? MakeTrue()
466                                                                 : MakeFalse();
467       return OkStatus();
468     }
469     SignatureForIntSymbol signature = {tensor_id, must_have_value};
470     auto it = interned_int_symbol_instances_.find(signature);
471     if (it == interned_int_symbol_instances_.end()) {
472       std::unique_ptr<Predicate> new_pred =
473           Make<IntSymbolPredicate>(tensor_id, must_have_value);
474       Predicate* new_pred_ptr = new_pred.get();
475       interned_int_symbol_instances_.emplace(std::move(signature),
476                                              std::move(new_pred));
477       *predicate = new_pred_ptr;
478     } else {
479       *predicate = it->second.get();
480     }
481 
482     return OkStatus();
483   }
484 
MakeTrue()485   Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()486   Predicate* MakeFalse() { return MakeOrPredicate({}); }
487 
~PredicateFactory()488   ~PredicateFactory() {
489     DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
490   }
491 
492  private:
MakeNotPredicateImpl(Predicate * pred)493   Predicate* MakeNotPredicateImpl(Predicate* pred) {
494     IncrementStackDepth stack_frame(this);
495     if (!stack_frame.HasOverflowed()) {
496       if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
497         return simplified;
498       }
499 
500       // ~~A => A
501       if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
502         return not_pred->operand();
503       }
504     }
505 
506     SignatureForNot signature = pred;
507     auto it = interned_not_instances_.find(signature);
508     if (it == interned_not_instances_.end()) {
509       std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
510       Predicate* new_pred_ptr = new_pred.get();
511       interned_not_instances_.emplace(signature, std::move(new_pred));
512       return new_pred_ptr;
513     } else {
514       return it->second.get();
515     }
516   }
517 
SimplifyUsingDeMorgan(Predicate * pred)518   Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
519     // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
520     // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
521     Predicate::Kind kind = pred->kind();
522 
523     if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
524       std::vector<Predicate*> new_operands;
525       absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
526                         [&](Predicate* p) { return MakeNotPredicate(p); });
527       return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
528                                           : MakeOrPredicate(new_operands);
529     }
530 
531     return nullptr;
532   }
533 
534   template <typename PredicateT, typename... Args>
Make(Args &&...args)535   std::unique_ptr<Predicate> Make(Args&&... args) {
536     // If we ever expose the Predicate class outside this .cc file then we may
537     // want to make this hard to misuse (by accidentally passing in an arbitrary
538     // integer to the Predicate constructor for instance).
539     return std::unique_ptr<PredicateT>(
540         new PredicateT(id_counter_++, std::forward<Args>(args)...));
541   }
542 
543   Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
544   Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
545                                Predicate::Kind pred_kind);
546 
547   // Predicate instances are interned, meaning that there is only a single
548   // instance of a Predicate object with a given content.  This makes checking
549   // for structural equality super-cheap -- we can just compare pointers.
550   //
551   // We intern predicates by maintaining a map from the content of a Predicate
552   // to the only instance of said predicate we allow to exist in the
553   // interned_and_or_instances_, interned_not_instances_ and
554   // interned_symbol_instances_ fields.  These maps also double up as storage
555   // for the owning pointers to predicate instances.
556 
557   using SignatureForAndOr =
558       std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
559   using SignatureForNot = Predicate*;
560   using SignatureForAndRec =
561       std::tuple<Predicate*, Predicate*, std::vector<string>>;
562   using SignatureForSymbol = std::pair<SafeTensorId, bool>;
563   using SignatureForIntSymbol = std::pair<SafeTensorId, std::optional<int32>>;
564 
565   struct HashSignatureForAndOr {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForAndOr566     size_t operator()(const SignatureForAndOr& signature) const {
567       size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
568       for (Predicate* p : signature.second) {
569         hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
570       }
571       return hash;
572     }
573   };
574 
575   struct HashSignatureForSymbol {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForSymbol576     size_t operator()(const SignatureForSymbol& signature) const {
577       return Hash64Combine(SafeTensorId::Hasher()(signature.first),
578                            ::tensorflow::hash<bool>()(signature.second));
579     }
580   };
581 
582   struct HashSignatureForIntSymbol {
operator ()tensorflow::__anonfdc77a490111::PredicateFactory::HashSignatureForIntSymbol583     size_t operator()(const SignatureForIntSymbol& signature) const {
584       return Hash64Combine(
585           SafeTensorId::Hasher()(signature.first),
586           Hash64Combine(
587               ::tensorflow::hash<bool>()(signature.second.has_value()),
588               ::tensorflow::hash<int32>()(
589                   signature.second.has_value() ? *signature.second : 0)));
590     }
591   };
592 
593   // Used to limit recursion to avoid blowing up the stack and cap compile time.
594   class IncrementStackDepth {
595    public:
IncrementStackDepth(PredicateFactory * parent)596     explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
597       parent_->stack_depth_++;
598     }
599 
HasOverflowed() const600     bool HasOverflowed() const {
601       const int kMaxStackDepth = 8;
602       return parent_->stack_depth_ >= kMaxStackDepth;
603     }
604 
~IncrementStackDepth()605     ~IncrementStackDepth() { parent_->stack_depth_--; }
606 
607    private:
608     PredicateFactory* parent_;
609   };
610 
611   // A cache for the MakeNotPredicate function.
612   //
613   // NB! This is *not* the same as `interned_not_instances_`.
614   // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
615   // instances, i.e., it ensures there at most one instance of Not(predicate)
616   // for any given predicate whereas `make_not_predicate_cache_` simply caches
617   // the result of the `MakeNotPredicate` function.  The values in
618   // `interned_not_instances_` are always instance of `NotPredicate` whereas the
619   // values in `make_not_predicate_cache_` may not be (for instance it will map
620   // Not(Not(A)) to A).
621   absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
622 
623   absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
624                       HashSignatureForAndOr>
625       interned_and_or_instances_;
626   absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
627       interned_not_instances_;
628   absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
629       interned_and_rec_instances_;
630   absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
631                       HashSignatureForSymbol>
632       interned_symbol_instances_;
633   absl::flat_hash_map<SignatureForIntSymbol, std::unique_ptr<Predicate>,
634                       HashSignatureForIntSymbol>
635       interned_int_symbol_instances_;
636   int64_t id_counter_ = 0;
637   int stack_depth_ = 0;
638 };
639 
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)640 Predicate* PredicateFactory::MakeInternedAndOr(
641     std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
642   std::stable_sort(
643       simplified_ops.begin(), simplified_ops.end(),
644       [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
645 
646   auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
647   if (it != interned_and_or_instances_.end()) {
648     return it->second.get();
649   }
650 
651   simplified_ops.shrink_to_fit();
652   // NB!  Because we'll use a non-owning reference to simplified_ops in the
653   // key for interned_and_or_instances_ we need to be careful to std::move()
654   // it all the way through.
655   absl::Span<Predicate* const> operands_slice = simplified_ops;
656   std::unique_ptr<Predicate> new_pred =
657       pred_kind == Predicate::Kind::kAnd
658           ? Make<AndPredicate>(std::move(simplified_ops))
659           : Make<OrPredicate>(std::move(simplified_ops));
660 
661   Predicate* new_pred_ptr = new_pred.get();
662   interned_and_or_instances_.emplace(
663       SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
664   return new_pred_ptr;
665 }
666 
667 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)668 Predicate* PredicateFactory::MakeAndOrImpl(
669     absl::Span<Predicate* const> operands, bool is_and) {
670   Predicate::Kind pred_kind =
671       is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
672 
673   IncrementStackDepth stack_frame(this);
674   if (stack_frame.HasOverflowed()) {
675     return MakeInternedAndOr(
676         std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
677   }
678 
679   Predicate::Kind other_pred_kind =
680       is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
681   absl::flat_hash_set<Predicate*> simplified_ops_set;
682   std::vector<Predicate*> simplified_ops;
683   for (Predicate* op : operands) {
684     // Simplify A&A => A and  A|A => A.
685     if (!simplified_ops_set.insert(op).second) {
686       continue;
687     }
688 
689     if (op->kind() == pred_kind) {
690       // "Inline" the operands of an inner And/Or into the parent And/Or.
691       for (Predicate* subop : op->GetOperands()) {
692         if (simplified_ops_set.insert(subop).second) {
693           simplified_ops.push_back(subop);
694         }
695       }
696     } else {
697       simplified_ops.push_back(op);
698     }
699   }
700 
701   if (simplified_ops.size() == 1) {
702     return simplified_ops[0];
703   }
704 
705   // Simplify "A&~A=>False" and "A|~A=>True".
706   absl::flat_hash_set<Predicate*> negated_ops;
707   for (Predicate* op : simplified_ops) {
708     if (negated_ops.count(op)) {
709       // Simple case:
710       //
711       //   A & ~A & ... == False
712       //   A | ~A | ... == True
713       return is_and ? MakeFalse() : MakeTrue();
714     }
715 
716     Predicate* negated_op = MakeNotPredicate(op);
717     if (negated_op->kind() == pred_kind) {
718       // Slightly more complicated case:
719       //
720       //   (~A | ~B | ~C) & A & B & C & ... ==
721       //   ~(A & B & C) & (A & B & C) & ... == False
722       //
723       //   (~A & ~B & ~C) | A | B | C | ... ==
724       //   ~(A | B | C) | (A | B | C) | ... == True
725       if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
726             return simplified_ops_set.contains(p);
727           })) {
728         return is_and ? MakeFalse() : MakeTrue();
729       }
730     }
731     negated_ops.insert(negated_op);
732   }
733 
734   // Simplify {S,&,X} & ~X & ... => S & ...
735   if (is_and) {
736     absl::flat_hash_set<Predicate*> to_remove;
737     std::vector<Predicate*> to_add;
738     for (Predicate* op : simplified_ops) {
739       if (op->kind() == Predicate::Kind::kAndRecurrence) {
740         auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
741         if (negated_ops.contains(and_rec->step())) {
742           // Remove and_rec and ~X and insert S.  Note that checking the
743           // existence of ~X through negated_ops is sufficient since it makes
744           // sure the predicate is in the input operands.  It does not need to
745           // be in simplified_ops if it was already cancelled out.
746           to_remove.insert(and_rec);
747           to_remove.insert(MakeNotPredicate(and_rec->step()));
748           to_add.push_back(and_rec->start());
749         }
750       }
751     }
752     auto it = simplified_ops.begin();
753     while (it != simplified_ops.end()) {
754       if (to_remove.contains(*it)) {
755         it = simplified_ops.erase(it);
756       } else {
757         ++it;
758       }
759     }
760     simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
761   }
762 
763   // If all ops contain the same subop, then factor it out thanks to the
764   // distributive property. Such as:
765   // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
766   // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
767   //
768   // First find any predicates contained in all subops.
769   std::vector<Predicate*> common_inner_operands;
770   absl::flat_hash_set<Predicate*> common_inner_operands_set;
771   for (Predicate* op : simplified_ops) {
772     if (op->kind() != other_pred_kind) {
773       common_inner_operands.clear();
774       break;
775     }
776 
777     if (common_inner_operands.empty()) {
778       common_inner_operands.insert(common_inner_operands.end(),
779                                    op->GetOperands().begin(),
780                                    op->GetOperands().end());
781     } else {
782       common_inner_operands.clear();
783       absl::c_copy_if(op->GetOperands(),
784                       std::back_inserter(common_inner_operands),
785                       [&](Predicate* sub_op) {
786                         return common_inner_operands_set.count(sub_op) == 1;
787                       });
788     }
789     if (common_inner_operands.empty()) break;
790     common_inner_operands_set.clear();
791     common_inner_operands_set.insert(common_inner_operands.begin(),
792                                      common_inner_operands.end());
793   }
794 
795   if (common_inner_operands.empty()) {
796     return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
797   }
798 
799   // For all predicates that can be factored out, remove them and recreate the
800   // subops.
801   std::vector<Predicate*> factored_ops;
802   for (Predicate* op : simplified_ops) {
803     std::vector<Predicate*> new_sub_op_ops;
804     absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
805                     [&](Predicate* sub_op) {
806                       return std::find(common_inner_operands.begin(),
807                                        common_inner_operands.end(),
808                                        sub_op) == common_inner_operands.end();
809                     });
810     factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
811   }
812 
813   Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
814   std::vector<Predicate*> outer_ops;
815   outer_ops.push_back(new_inner_op);
816   outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
817                    common_inner_operands.end());
818   return MakeAndOrImpl(outer_ops, !is_and);
819 }
820 
821 class DeadnessAnalysisImpl : public DeadnessAnalysis {
822  public:
DeadnessAnalysisImpl(const Graph * graph)823   explicit DeadnessAnalysisImpl(const Graph* graph)
824       : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
825 
826   Status Populate(bool enable_optimistic);
827   Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
828                        bool* success);
829   StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
830       Node* n, int oidx) const override;
831   void Print() const override;
832   absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
833       const;
834 
835  private:
836   enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
837 
838   Status GetInputPreds(Node* n, EdgeKind edge_kind,
839                        std::vector<Predicate*>* result);
840 
841   // Sets the predicate for output `output_idx` of `n` to `pred`.  Sets the i'th
842   // bit of `should_revisit` if `pred` is different from the current predicate
843   // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)844   void SetPredicate(Node* n, int output_idx, Predicate* pred,
845                     std::vector<bool>* should_revisit) {
846     auto insert_result =
847         predicate_map_.insert({TensorId(n->name(), output_idx), pred});
848     if (!insert_result.second && insert_result.first->second != pred) {
849       VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
850               << insert_result.first->second->ToString() << " "
851               << insert_result.first->second << " to " << pred->ToString()
852               << " " << pred;
853       insert_result.first->second = pred;
854       if (should_revisit != nullptr) {
855         for (const Edge* e : n->out_edges()) {
856           (*should_revisit)[e->dst()->id()] = true;
857         }
858       }
859     }
860   }
861 
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)862   void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
863                     std::vector<bool>* should_revisit) {
864     for (int output_idx : output_idxs) {
865       SetPredicate(n, output_idx, pred, should_revisit);
866     }
867   }
868 
869   Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
870   Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
871                      bool use_optimistic_mode);
872   Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
873   Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
874   Status HandleNode(Node* n, std::vector<bool>* should_revisit,
875                     bool use_optimistic_mode = false);
876 
877   Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
878 
IsRootEnter(const Node * n) const879   bool IsRootEnter(const Node* n) const {
880     return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
881   }
882 
IsRootExit(const Node * n) const883   bool IsRootExit(const Node* n) const {
884     return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
885   }
886 
887   const Graph& graph_;
888   absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
889   PredicateFactory predicate_factory_;
890   std::vector<ControlFlowInfo> control_flow_info_;
891   bool vlog_;
892   absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
893 };
894 
InputEdgeToTensorId(const Edge * e)895 TensorId InputEdgeToTensorId(const Edge* e) {
896   return TensorId(e->src()->name(), e->src_output());
897 }
898 
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)899 Status DeadnessAnalysisImpl::GetInputPreds(
900     Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
901     std::vector<Predicate*>* result) {
902   result->clear();
903   for (const Edge* in_edge : n->in_edges()) {
904     bool should_process =
905         edge_kind == EdgeKind::kDataAndControl ||
906         (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
907         (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
908 
909     if (should_process) {
910       auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
911       if (it == predicate_map_.end()) {
912         GraphCycles graph_cycles;
913         TF_RETURN_IF_ERROR(
914             CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
915 
916         // If we didn't return with an error above then the graph is probably
917         // fine and we have a bug in deadness analysis.
918         return errors::Internal("Could not find input ", in_edge->DebugString(),
919                                 " to ", n->name(),
920                                 " when visiting the graph in post-order.  Most "
921                                 "likely indicates a bug in deadness analysis.");
922       }
923       result->push_back(it->second);
924     }
925   }
926   return OkStatus();
927 }
928 
HandleSwitch(Node * n,std::vector<bool> * should_revisit)929 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
930                                           std::vector<bool>* should_revisit) {
931   std::vector<Predicate*> input_preds;
932   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
933   const Edge* pred_edge;
934   TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
935 
936   if (n->type_string() != "_SwitchN") {  // bool pred branch selector.
937     Predicate* true_switch;
938     TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
939         pred_edge->src(), pred_edge->src_output(),
940         /*must_be_true=*/true, &true_switch));
941 
942     Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
943 
944     // Output 0 is alive iff all inputs are alive and the condition is false.
945     input_preds.push_back(false_switch);
946     SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
947                  should_revisit);
948     input_preds.pop_back();
949 
950     // Output 1 is alive iff all inputs are alive and the condition is true.
951     input_preds.push_back(true_switch);
952     SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
953                  should_revisit);
954     input_preds.pop_back();
955   } else {  // N-way switch case. Exactly one of N branches is alive.
956     Predicate* branch_pred;
957     for (int i = 0; i < n->num_outputs() - 1; i++) {
958       TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
959           pred_edge->src(), pred_edge->src_output(),
960           /*must_have_value=*/std::optional<int32>(i), &branch_pred));
961       input_preds.push_back(branch_pred);
962       SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
963                    should_revisit);
964       input_preds.pop_back();
965       input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
966     }
967     // The default (last) branch does not need its own symbol, is simply the
968     // nor of all other branches.
969     SetPredicate(n, n->num_outputs() - 1,
970                  predicate_factory_.MakeAndPredicate(input_preds),
971                  should_revisit);
972   }
973 
974   // Control is alive iff all inputs are alive.
975   SetPredicate(n, Graph::kControlSlot,
976                predicate_factory_.MakeAndPredicate(input_preds),
977                should_revisit);
978 
979   return OkStatus();
980 }
981 
982 namespace {
CreateMultipleNextIterationInputsError(Node * merge)983 Status CreateMultipleNextIterationInputsError(Node* merge) {
984   std::vector<string> backedges;
985   for (const Edge* backedge : merge->in_edges()) {
986     if (backedge->src()->IsNextIteration()) {
987       backedges.push_back(absl::StrCat("  ", SummarizeNode(*backedge->src())));
988     }
989   }
990   return errors::InvalidArgument(
991       "Multiple NextIteration inputs to merge node ",
992       FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
993       "\nMerge nodes can have at most one incoming NextIteration edge.");
994 }
995 
FindUniqueBackedge(Node * merge,const Edge ** result)996 Status FindUniqueBackedge(Node* merge, const Edge** result) {
997   *result = nullptr;
998   CHECK(merge->IsMerge());
999   for (const Edge* e : merge->in_edges()) {
1000     if (e->src()->IsNextIteration()) {
1001       if (*result != nullptr) {
1002         return CreateMultipleNextIterationInputsError(merge);
1003       }
1004       *result = e;
1005     }
1006   }
1007   return OkStatus();
1008 }
1009 
1010 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
1011 // does not contain `symbolic_predicate` as an inner (not top-level) operand
1012 // then returns `Step`.  Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)1013 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
1014                                Predicate* symbolic_predicate,
1015                                Predicate* backedge_predicate) {
1016   CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
1017   if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
1018     return nullptr;
1019   }
1020 
1021   std::vector<Predicate*> and_ops;
1022   absl::Span<Predicate* const> recurrent_pred_ops =
1023       backedge_predicate->GetOperands();
1024 
1025   bool found_sym = false;
1026   for (Predicate* and_op : recurrent_pred_ops) {
1027     // We want the `symbol_predicate` to be the one of the operands of
1028     // `backedge_predicate`,
1029     if (and_op == symbolic_predicate) {
1030       found_sym = true;
1031       continue;
1032     }
1033 
1034     // but we don't want it to be present anywhere else in the formula.  E.g. we
1035     // don't want the recurrent predicate to be
1036     // symbol_predicate&(X|symbol_predicate).
1037     bool found_sym_as_inner_operand = false;
1038     auto has_self_as_inner_operand = [&](Predicate* p) {
1039       if (p == symbolic_predicate) {
1040         found_sym_as_inner_operand = true;
1041         return true;  // Stop searching, we're done.
1042       }
1043 
1044       // Continue searching.
1045       return false;
1046     };
1047 
1048     Predicate::Visit(and_op, has_self_as_inner_operand);
1049     if (found_sym_as_inner_operand) {
1050       return nullptr;
1051     }
1052     and_ops.push_back(and_op);
1053   }
1054 
1055   return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
1056 }
1057 
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)1058 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1059                     std::vector<string>* frame) {
1060   int depth = 0;
1061   for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
1062        n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
1063     frame->push_back(cfi_iter->frame_name);
1064 
1065     if (depth++ > 5000) {
1066       return errors::Internal(
1067           "Frame of depth > 5000:  Probably malformed graph or a bug in "
1068           "BuildControlFlowInfo");
1069     }
1070   }
1071 
1072   return OkStatus();
1073 }
1074 
1075 // If the node is inside some frames, get the name of the outermost non-empty
1076 // frame.  Otherwise, get an empty frame name.
GetRootFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,absl::string_view * frame)1077 Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1078                     absl::string_view* frame) {
1079   int depth = 0;
1080   const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
1081   while (!cfi_iter->parent_frame->IsSource()) {
1082     n = cfi_iter->parent_frame;
1083     cfi_iter = &cfi_infos[n->id()];
1084 
1085     if (depth++ > 5000) {
1086       return errors::Internal(
1087           "Frame of depth > 5000:  Probably malformed graph or a bug in "
1088           "BuildControlFlowInfo");
1089     }
1090   }
1091 
1092   *frame = cfi_iter->frame_name;
1093   return OkStatus();
1094 }
1095 }  // namespace
1096 
HandleMerge(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1097 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
1098                                          std::vector<bool>* should_revisit,
1099                                          bool use_optimistic_mode) {
1100   // Merge ignores deadness of its control inputs.  A merge that isn't the
1101   // target of a backedge has is alive iff any of its data inputs are.  The
1102   // liveness of a merge that is the target of a backedge can sometimes be
1103   // represented using a AndRecurrencePredicate.  If neither apply, we represent
1104   // the liveness of the merge symbolically.
1105 
1106   bool has_unvisited_backedge = false;
1107   for (const Edge* e : n->in_edges()) {
1108     if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
1109       has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
1110     }
1111   }
1112 
1113   auto it = predicate_map_.find(TensorId(n->name(), 0));
1114   if (it == predicate_map_.end()) {
1115     if (has_unvisited_backedge) {
1116       // We're visiting this merge for the first time and it has an unvisited
1117       // backedge.
1118       Predicate* input_data_pred;
1119       if (use_optimistic_mode) {
1120         // In the optimistic mode, we use the first-seen Merge node per
1121         // frame as the representative Merge node.  It is just convenient and
1122         // does not affect the result after pattern-matching into the
1123         // AndRecurrence form.
1124         absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
1125         auto insert_result = frame_to_merge_node_.insert({frame_name, n});
1126         Node* representative = insert_result.first->second;
1127         TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1128             representative, /*output_idx=*/0, /*must_be_true=*/false,
1129             &input_data_pred));
1130       } else {
1131         TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1132             n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
1133       }
1134 
1135       SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1136                    should_revisit);
1137       return OkStatus();
1138     }
1139 
1140     std::vector<Predicate*> input_preds;
1141     TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
1142 
1143     // We're visiting this merge for the first time and it is an acyclic merge.
1144     Predicate* input_data_pred =
1145         predicate_factory_.MakeOrPredicate(input_preds);
1146     SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1147                  should_revisit);
1148     return OkStatus();
1149   }
1150 
1151   if (it->second->kind() == Predicate::Kind::kSymbol) {
1152     // Last time we visited this merge we only got a symbolic predicate because
1153     // of an unvisited backedge.  Try to pattern match the predicate expression
1154     // for that backedge (which should be visited now) into an and recurrence
1155     // for the merge node.
1156     const Edge* unique_backedge;
1157     TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
1158     if (unique_backedge) {
1159       if (Predicate* step = DeduceStepPredicate(
1160               &predicate_factory_, it->second,
1161               predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
1162         // If the predicate for the backedge is "Sym&X" where "Sym" is the
1163         // predicate for the merge then the merge has predicate {S,&,X} where S
1164         // is the predicate for the merge ignoring the backedge.
1165         std::vector<Predicate*> non_recurrent_inputs;
1166         for (const Edge* e : n->in_edges()) {
1167           if (e != unique_backedge) {
1168             non_recurrent_inputs.push_back(
1169                 predicate_map_[InputEdgeToTensorId(e)]);
1170           }
1171         }
1172 
1173         Predicate* start =
1174             predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
1175         std::vector<string> frame;
1176         TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
1177         Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
1178             start, step, std::move(frame));
1179         SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
1180         return OkStatus();
1181       }
1182     }
1183   }
1184   return OkStatus();
1185 }
1186 
HandleRecv(Node * n,std::vector<bool> * should_revisit)1187 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
1188                                         std::vector<bool>* should_revisit) {
1189   // In addition to being alive or dead based on the inputs, a _Recv can also
1190   // acquire a dead signal from a _Send.
1191   std::vector<Predicate*> input_preds;
1192   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1193   Predicate* signal_is_alive;
1194   TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1195       n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1196   input_preds.push_back(signal_is_alive);
1197   SetPredicate(n, {0, Graph::kControlSlot},
1198                predicate_factory_.MakeAndPredicate(input_preds),
1199                should_revisit);
1200   return OkStatus();
1201 }
1202 
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1203 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1204                                            std::vector<bool>* should_revisit) {
1205   // Generally nodes are alive iff all their inputs are alive.
1206   std::vector<Predicate*> input_preds;
1207   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1208   Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1209   for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1210     SetPredicate(n, output_idx, pred, should_revisit);
1211   }
1212   SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1213   return OkStatus();
1214 }
1215 
HandleNode(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1216 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1217                                         std::vector<bool>* should_revisit,
1218                                         bool use_optimistic_mode) {
1219   if (n->IsSwitch()) {
1220     TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1221   } else if (n->IsMerge()) {
1222     TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
1223   } else if (n->IsControlTrigger()) {
1224     SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1225                  nullptr);
1226   } else if (n->IsRecv() || n->IsHostRecv()) {
1227     TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1228   } else if (n->IsNextIteration()) {
1229     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1230   } else {
1231     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1232   }
1233   return OkStatus();
1234 }
1235 
1236 // Compute a special topological order for the Graph, where nodes having the
1237 // same root frame are placed adjacent to each other.  The traversal uses a
1238 // variant of Kahn's algorithm.  num_ready_inputs is used to keep track of how
1239 // many inputs of each node are ready; a node is ready to be scheduled if all
1240 // of its inputs are ready.
1241 // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
GetFrameBasedTopologicalOrder(std::vector<Node * > * order)1242 Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
1243     std::vector<Node*>* order) {
1244   absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
1245   absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
1246   std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
1247   Node* src_node = graph_.source_node();
1248   for (const auto* node : graph_.op_nodes()) {
1249     const ControlFlowInfo& cf = control_flow_info_[node->id()];
1250     if (IsRootEnter(node)) {
1251       // Since we care only the root-level frame, full frame names are the same
1252       // as frame names.
1253       ++num_enters_for_frame[cf.frame_name];
1254     } else if (IsRootExit(node)) {
1255       ++num_exits_for_frame[cf.frame_name];
1256     }
1257     // Edge NextIteration->Merge is counted before starting the traversal to
1258     // break the backedges.
1259     if (IsMerge(node)) {
1260       for (const Edge* e : node->in_edges()) {
1261         if (IsNextIteration(e->src())) {
1262           ++num_ready_inputs[node->id()];
1263         }
1264       }
1265     }
1266   }
1267 
1268   // dequeue is used to ensure that the nodes are first-in-first-out.  This
1269   // order guarantees that the exits in the ready queue are visited before
1270   // nodes that will become ready in the future.
1271   std::deque<Node*> ready;
1272   ready.push_back(src_node);
1273   // ready_enters_per_frame and ready_exits serve as a staging area to buffer
1274   // the ready enters/exits before they are moved to the `ready` queue for
1275   // controlling the start and end of a processing frame.
1276   absl::flat_hash_map<absl::string_view, std::vector<Node*>>
1277       ready_enters_per_frame;
1278   // Exit nodes shall all be from the same frame, as we process a frame at a
1279   // time. So, one vector is enough.
1280   std::vector<Node*> ready_exits;
1281   while (!ready.empty()) {
1282     Node* curr_node = ready.front();
1283     ready.pop_front();
1284 
1285     VLOG(4) << "Visiting " << curr_node->name();
1286     order->push_back(curr_node);
1287 
1288     for (const Edge* out_edge : curr_node->out_edges()) {
1289       Node* out = out_edge->dst();
1290       int out_id = out->id();
1291       if (IsNextIteration(curr_node) && IsMerge(out)) {
1292         // Edge NextIteration->Merge has been counted.
1293         continue;
1294       }
1295       ++num_ready_inputs[out->id()];
1296       if (!out->IsOp()) continue;  // Skip Sink/Source nodes.
1297       if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
1298 
1299       absl::string_view frame_name = control_flow_info_[out_id].frame_name;
1300       if (IsRootEnter(out)) {
1301         ready_enters_per_frame[frame_name].push_back(out);
1302       } else if (IsRootExit(out)) {
1303         ready_exits.push_back(out);
1304       } else {
1305         ready.push_back(out);
1306       }
1307     }
1308 
1309     if (ready.empty()) {
1310       // Try moving nodes from ready_enters_per_frame and ready_exits to
1311       // `ready`.
1312       if (!ready_exits.empty()) {
1313         // If there are nodes in ready_exits we must process them before
1314         // processing ready_enters_per_frame to make sure all nodes in the
1315         // currently processing frame are visited before starting processing
1316         // other frames.
1317         absl::string_view frame_name =
1318             control_flow_info_[ready_exits.front()->id()].frame_name;
1319         CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
1320         ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
1321         ready_exits.clear();
1322       } else {
1323         // Otherwise, try moving nodes from ready_enters to `ready`.
1324         for (auto iter = ready_enters_per_frame.begin();
1325              iter != ready_enters_per_frame.end(); ++iter) {
1326           absl::string_view frame_name = iter->first;
1327           const std::vector<Node*>& ready_enters = iter->second;
1328           if (ready_enters.size() == num_enters_for_frame[frame_name]) {
1329             ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
1330             ready_enters_per_frame.erase(iter);
1331             break;
1332           }
1333         }
1334       }
1335     }
1336   }
1337 
1338   if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
1339     return errors::InvalidArgument(
1340         "Some enters/exits have never been visited in the traversal."
1341         " Most probably the input graph is malformed.");
1342   }
1343   return OkStatus();
1344 }
1345 
1346 // We populate the nodes along a special topological order where nodes having
1347 // the same root frame are placed adjacent to each other.  This grouping enables
1348 // processing the graph per root frame at a time and guarantees that when a root
1349 // frame is being processed, nodes in the downstream frames have not yet been
1350 // processed.  This property is important because we need to process an entire
1351 // frame to know whether the optimistic mode converges or not.  In other words,
1352 // nodes in the downstream frames shall not be populated until all of its
1353 // upstream frames are populated.  In effect, this order enables processing each
1354 // (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
1355 // (root) frame.  Note that we don't separate while loops belonging to the same
1356 // nested while, as there is no clean cut for separating them in the topological
1357 // order.
Populate(bool enable_optimistic)1358 Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
1359   std::vector<string> unreachable_nodes;
1360   // Compute the loop structure of the graph.
1361   TF_RETURN_IF_ERROR(
1362       BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1363 
1364   // Do some opportunistic error checking:
1365   if (!unreachable_nodes.empty()) {
1366     if (unreachable_nodes.size() > 5) {
1367       unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1368                               unreachable_nodes.end());
1369     }
1370 
1371     return errors::InvalidArgument(
1372         "Found unreachable nodes, most likely source and sink nodes not "
1373         "connected: ",
1374         absl::StrJoin(unreachable_nodes, ", "));
1375   }
1376 
1377   std::vector<Node*> topo;
1378   TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
1379 
1380   size_t frame_start = 0;
1381   while (frame_start < topo.size()) {
1382     // Batching nodes who have the same root frame.
1383     absl::string_view cur_frame_name;
1384     TF_RETURN_IF_ERROR(
1385         GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
1386     size_t frame_end = frame_start;
1387     for (size_t i = frame_start + 1; i < topo.size(); ++i) {
1388       absl::string_view i_frame_name;
1389       TF_RETURN_IF_ERROR(
1390           GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
1391       if (i_frame_name == cur_frame_name) {
1392         frame_end = i;
1393       } else {
1394         break;
1395       }
1396     }
1397     absl::Span<Node*> sub_topo(topo.data() + frame_start,
1398                                /*length=*/frame_end - frame_start + 1);
1399     frame_start = frame_end + 1;
1400 
1401     // First, try the optimistic mode.
1402     bool success = false;
1403     if (enable_optimistic && !cur_frame_name.empty()) {
1404       TF_RETURN_IF_ERROR(
1405           PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
1406     }
1407     if (!success) {
1408       // The optimistic mode does not converge.  Let's fall back to the
1409       // pessimistic mode.
1410       TF_RETURN_IF_ERROR(
1411           PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
1412     }
1413     VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
1414             << (success ? "optimistic" : "pessimistic") << " mode.";
1415   }
1416 
1417   return OkStatus();
1418 }
1419 
PopulateFrame(absl::Span<Node * const> topo,bool use_optimistic_mode,bool * success)1420 Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
1421                                            bool use_optimistic_mode,
1422                                            bool* success) {
1423   CHECK(use_optimistic_mode && success != nullptr ||
1424         !use_optimistic_mode && success == nullptr);
1425 
1426   // This an abstract interpretation over the deadness propagation semantics of
1427   // the graph executor.
1428   //
1429   // We iterate over the graph twice, each time in a topological order.  On the
1430   // first iteration merge nodes with backedges are mapped to symbolic
1431   // predicates.  On the second iteration we use the predicates assigned to the
1432   // backedges in the previous iteration to infer a more precise predicate for
1433   // the backedge merge nodes and all the nodes that transitively use it.
1434   //
1435   // We don't track the output indices for should_revisit.  Instead, putting a
1436   // node in `should_revisit` denotes that the deadness flowing out from any
1437   // output from said node may have changed.  This is fine; only switches
1438   // propagate different deadness along different output edges, and since the
1439   // delta is solely due to the input *values* (and not input deadness), the
1440   // delta should not change in the second iteration.
1441   std::vector<bool> should_revisit;
1442   should_revisit.resize(graph_.num_node_ids());
1443   for (Node* n : topo) {
1444     VLOG(4) << "Visiting " << n->name();
1445     TF_RETURN_IF_ERROR(
1446         HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
1447     if (n->IsNextIteration()) {
1448       // If this is a backedge for a merge node then remember to reprocess the
1449       // merge the next time we run.
1450       for (const Edge* e : n->out_edges()) {
1451         if (e->dst()->IsMerge()) {
1452           should_revisit[e->dst()->id()] = true;
1453         }
1454       }
1455     }
1456   }
1457 
1458   for (Node* n : topo) {
1459     // The nodes added to should_revisit in the previous loop need to be
1460     // revisited now.  Reprocessing these initial nodes may add *their*
1461     // consumers to should_revisit, and these newly added nodes will also be
1462     // processed by this very same loop.  Since we're traversing the graph in
1463     // topological order (producers before consumers) and HandleNode(n) can only
1464     // ever add n's consumers to should_revisit, we won't "miss" an addition to
1465     // should_revisit.
1466     if (should_revisit[n->id()]) {
1467       VLOG(4) << "Revisiting " << n->name();
1468       TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1469     }
1470   }
1471 
1472   // Check if the optimistic analysis converges.  Specifically, check whether
1473   // all the predicates of the merge nodes in the same frame are the same.  If
1474   // yes, report success.  If not, report failure and clear the assigned
1475   // predicates.
1476   if (use_optimistic_mode) {
1477     bool is_converged = true;
1478     absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
1479     for (Node* n : topo) {
1480       if (!n->IsMerge()) {
1481         continue;
1482       }
1483       const Edge* e;
1484       TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
1485       if (e == nullptr) {
1486         // Skip acyclic merge nodes.
1487         continue;
1488       }
1489       Node* merge = n;
1490       // Note that here uses frame names instead of root frame names.  In the
1491       // case of a nested while loop, each level of while loops can have merges
1492       // with different predicate instances, while the merge nodes on the same
1493       // level must have the same predicate instances.
1494       absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
1495       auto it = predicate_map_.find(TensorId(merge->name(), 0));
1496       Predicate* merge_pred = it->second;
1497       if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
1498         is_converged = false;
1499         VLOG(2) << "Running the optimistic mode on frame " << frame_name
1500                 << " does not converge because node " << merge->name()
1501                 << " cannot be mapped into the AndRecurrence form.";
1502         break;
1503       }
1504 
1505       auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
1506       if (!insert_result.second) {
1507         // If we have already seen this frame name, verify the predicate is the
1508         // same as the previously seen one's.
1509         Predicate* curr_andrec = merge_pred;
1510         Predicate* prev_andrec = insert_result.first->second;
1511         if (curr_andrec != prev_andrec) {
1512           is_converged = false;
1513           VLOG(2) << "Running the optimistic mode on frame " << frame_name
1514                   << " does not converge. Seeing different Merge predicates: \n"
1515                   << curr_andrec->ToString() << " and \n"
1516                   << prev_andrec->ToString();
1517           break;
1518         }
1519       }
1520     }
1521 
1522     // Clear the assigned predicates if the optimistic mode does not converge.
1523     if (!is_converged) {
1524       for (Node* n : topo) {
1525         for (int oid = 0; oid < n->num_outputs(); ++oid) {
1526           predicate_map_.erase(TensorId(n->name(), oid));
1527         }
1528         predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
1529       }
1530     }
1531 
1532     if (success != nullptr) {
1533       *success = is_converged;
1534     }
1535   }
1536 
1537   return OkStatus();
1538 }
1539 
1540 StatusOr<DeadnessAnalysis::DeadnessPredicate>
GetPredicateFor(Node * n,int oidx) const1541 DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
1542   auto it = predicate_map_.find(TensorId(n->name(), oidx));
1543   TF_RET_CHECK(it != predicate_map_.end())
1544       << "could not find " << TensorId(n->name(), oidx).ToString()
1545       << " in predicate map";
1546   return MakeDeadnessPredicate(it->second);
1547 }
1548 
Print() const1549 void DeadnessAnalysisImpl::Print() const {
1550   std::vector<TensorId> tensor_ids;
1551   tensor_ids.reserve(predicate_map_.size());
1552   for (const auto& kv_pair : predicate_map_) {
1553     tensor_ids.push_back(kv_pair.first);
1554   }
1555 
1556   std::sort(tensor_ids.begin(), tensor_ids.end());
1557 
1558   for (TensorId tensor_id : tensor_ids) {
1559     auto it = predicate_map_.find(tensor_id);
1560     CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1561     VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1562   }
1563 }
1564 
1565 }  // namespace
1566 
~DeadnessAnalysis()1567 DeadnessAnalysis::~DeadnessAnalysis() {}
1568 
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1569 /*static*/ Status DeadnessAnalysis::Run(
1570     const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1571   std::unique_ptr<DeadnessAnalysisImpl> analysis(
1572       new DeadnessAnalysisImpl(&graph));
1573   TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
1574 
1575   if (VLOG_IS_ON(2)) {
1576     analysis->Print();
1577   }
1578 
1579   *result = std::move(analysis);
1580   return OkStatus();
1581 }
1582 
1583 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1584 DeadnessAnalysisImpl::PredicateMapAsString() const {
1585   absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1586   for (const auto& kv_pair : predicate_map_) {
1587     CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1588   }
1589   return result;
1590 }
1591 
1592 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map,bool enable_optimistic)1593 Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
1594                          bool enable_optimistic) {
1595   DeadnessAnalysisImpl impl(&graph);
1596   TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
1597   *out_predicate_map = impl.PredicateMapAsString();
1598   return OkStatus();
1599 }
1600 
1601 }  // namespace deadness_analysis_internal
1602 
DebugString(DeadnessPredicate predicate) const1603 string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
1604   return static_cast<Predicate*>(predicate.pred_)->ToString();
1605 }
1606 
1607 }  // namespace tensorflow
1608