xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/resource_operation_safety_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ALGORITHM OVERVIEW
17 // ==================
18 //
19 // An XLA cluster hoists all resource reads to be beginning of the cluster
20 // execution and all the resource writes to the end.  This means it cannot
21 // enforce arbitrary ordering dependencies (via control or data edges) between
22 // resource operations.  Since all resource reads happen before all resource
23 // writes, edges constraining resource reads to happen before resource writes
24 // are fine, but all other kinds of edges are problematic.  This analysis
25 // computes the set of pairs of resource operations that cannot be put in the
26 // same cluster because XLA cannot respect the dependencies between them in the
27 // TensorFlow program.
28 //
29 // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
30 // dependencies.
31 //
32 // Specifically the result computed by this analysis contains the edge {W, R}
33 // iff all of these hold true:
34 //
35 //   - In the graph (g - {edges from NextIteration to Merge}) there is a path
36 //     from W to R.
37 //   - IsEdgeSafe(W, R) == False [defined below]
38 //   - W != R (note: some resource operations both read from and write to
39 //     resource variables).
40 //
41 // The result is incorrect around loops because we ignore edges from
42 // NextIteration to Merge.  For instance, in:
43 //
44 // Init -----> Merge <-------+
45 //               |           |
46 //               v           |
47 //             Read          |
48 //               |           |
49 //               v           |
50 //             Write         |
51 //               |           |
52 //               v           |
53 //           NextIteration --+
54 //
55 // we won't put (Read, Write) in the returned set.  This is fine if
56 // auto-clustering can only cluster the Read->Write edge, but it is a problem if
57 // it clusters the Write->NextIteration->Merge->Read edges instead.  So we rely
58 // on auto-clustering to not cluster NextIteration->Merge edges.  The same
59 // problem is present for the functional version of the loop above and we also
60 // rely on auto-clustering not clustering functional while loops containing
61 // resource operations.
62 //
63 // One way to think about this is that we only care about cases where two nodes,
64 // A and B, would normally have been put in the same cluster but cannot legally
65 // be in the same cluster because of resourcevar-dependencies.  If A and B would
66 // normally have been put in the same cluster then all paths between A and B
67 // would have to be clusterable (otherwise we'd have introduced a cycle).  Ergo
68 // there could not have been a NextIteration->Merge edge between A and B since
69 // we don't cluster these edges.
70 //
71 // IMPLEMENTATION
72 // --------------
73 //
74 // We traverse the graph minus backedges in reverse post order, mapping each
75 // node to the set of resource operation reaching that node.  Since we visit
76 // producers before consumers, we can construct the set of reaching operations
77 // by taking the union of the operations reaching the input nodes.  These
78 // "reaching resource operations" can then be used to create the pairs of
79 // incompatible nodes using `IsEdgeSafe`.
80 
81 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
82 
83 #include "absl/container/flat_hash_set.h"
84 #include "absl/memory/memory.h"
85 #include "absl/strings/str_join.h"
86 #include "absl/types/optional.h"
87 #include "tensorflow/compiler/jit/xla_cluster_util.h"
88 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
89 #include "tensorflow/core/framework/node_def.pb.h"
90 #include "tensorflow/core/graph/algorithm.h"
91 #include "tensorflow/core/graph/tensor_id.h"
92 #include "tensorflow/core/lib/hash/hash.h"
93 #include "tensorflow/core/util/ptr_util.h"
94 
95 namespace tensorflow {
96 namespace {
97 // Maps `n` to the XlaResourceOpKind corresponding to its operation.  If `n` is
98 // not a resource operation recognized by XLA then sets `out_resource_op_kind`
99 // to nullopt.
XlaResourceOpKindForNode(const Node & n,const FunctionLibraryDefinition * flib_def,const std::function<Status (const Node &,bool *)> & resource_ops_to_ignore,std::optional<XlaResourceOpKind> * out_resource_op_kind)100 Status XlaResourceOpKindForNode(
101     const Node& n, const FunctionLibraryDefinition* flib_def,
102     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
103     std::optional<XlaResourceOpKind>* out_resource_op_kind) {
104   bool should_ignore = false;
105   if (resource_ops_to_ignore) {
106     TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
107   }
108   if (should_ignore) {
109     *out_resource_op_kind = std::nullopt;
110     return OkStatus();
111   }
112 
113   const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
114   if (op_info) {
115     *out_resource_op_kind = op_info->kind();
116     return OkStatus();
117   }
118 
119   // We conservatively assume that functions will both read and write resource
120   // variables.  In the future we may consider doing some form of
121   // inter-procedural analysis.
122   if (MayCallFunction(n, flib_def)) {
123     *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
124   } else {
125     *out_resource_op_kind = std::nullopt;
126   }
127 
128   return OkStatus();
129 }
130 
131 // Returns true if a control or data dependence from a TensorFlow operation of
132 // resource op kind `from` to a TensorFlow operation of resource op kind `to`
133 // can be represented by an XLA cluster and needs no special handling around
134 // auto-jit.
IsEdgeSafe(XlaResourceOpKind from,XlaResourceOpKind to)135 bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
136   // XLA clusters force all reads to happen before all writes.  Moreover the set
137   // of reads are executed as one atomic operation, and the set of writes are as
138   // another atomic operation.  This means we can faithfully represent the
139   // following edges: Read->*, *->Write.
140 
141   return from == XlaResourceOpKind::kRead || to == XlaResourceOpKind::kWrite;
142 }
143 
144 using ResourceOp = std::pair<int, XlaResourceOpKind>;
145 
ResourceOpToString(const ResourceOp & resource_op)146 string ResourceOpToString(const ResourceOp& resource_op) {
147   return absl::StrCat(
148       resource_op.first, ": ",
149       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
150 }
151 
152 // A copy-on-write set used to store the set of ResourceOps reaching a node in a
153 // TensorFlow graph.
154 //
155 // TODO(sanjoy): It may be useful to pull this out into its own header at some
156 // point.
157 class ResourceOpSet {
158  private:
159   using Impl = absl::flat_hash_set<ResourceOp>;
160 
161  public:
162   ResourceOpSet() = default;
163 
164   // Adds all ResourceOp s in `other` to this set.
Add(const ResourceOpSet & other)165   void Add(const ResourceOpSet& other) {
166     CHECK(!frozen_);
167     if (other.impl_ == impl_) {
168       other.frozen_ = true;
169       return;
170     }
171 
172     if (!impl_) {
173       other.frozen_ = true;
174       impl_ = other.impl_;
175       return;
176     }
177 
178     for (ResourceOp resource_op : other) {
179       Add(resource_op);
180     }
181   }
182 
Add(const ResourceOp & resource_op)183   void Add(const ResourceOp& resource_op) {
184     CHECK(!frozen_);
185     if (!IsCopy() && Contains(resource_op)) {
186       // We can avoid the copy if the item we want to insert already exists.
187       return;
188     }
189 
190     EnsureIsCopied();
191     impl_->insert(resource_op);
192   }
193 
begin() const194   Impl::const_iterator begin() const {
195     return impl_ ? impl_->begin() : GetEmptyImpl()->begin();
196   }
197 
end() const198   Impl::const_iterator end() const {
199     return impl_ ? impl_->end() : GetEmptyImpl()->end();
200   }
201 
Contains(const ResourceOp & resource_op) const202   bool Contains(const ResourceOp& resource_op) const {
203     return impl_ != nullptr && impl_->count(resource_op);
204   }
205 
206  private:
IsCopy() const207   bool IsCopy() const { return storage_ != nullptr; }
208 
EnsureIsCopied()209   void EnsureIsCopied() {
210     if (storage_ == nullptr) {
211       storage_ = std::make_unique<Impl>();
212       for (ResourceOp op : *this) {
213         storage_->insert(op);
214       }
215       impl_ = storage_.get();
216     }
217   }
218 
GetEmptyImpl()219   static Impl* GetEmptyImpl() {
220     static Impl* empty_impl = new Impl;
221     return empty_impl;
222   }
223 
224   Impl* impl_ = nullptr;
225   std::unique_ptr<Impl> storage_;
226 
227   // frozen_ is true if there is another set pointing to this set's impl_.  We
228   // can no longer add elements to this set in that case since the sets pointing
229   // to this set expect the contents of this set to be stable.
230   mutable bool frozen_ = false;
231 
232   TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet);
233 };
234 
ResourceOpSetToString(const ResourceOpSet & resource_op_set)235 string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
236   std::vector<string> elements_debug_string;
237   std::transform(resource_op_set.begin(), resource_op_set.end(),
238                  std::back_inserter(elements_debug_string), ResourceOpToString);
239   return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
240 }
241 
NodeToString(const Node & n,XlaResourceOpKind resource_op_kind)242 string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
243   return absl::StrCat(
244       "[", n.name(), ": ", n.type_string(), "(",
245       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
246 }
247 }  // namespace
248 
ComputeIncompatibleResourceOperationPairs(const Graph & g,const FunctionLibraryDefinition * flib_def,const std::function<Status (const Node &,bool *)> & resource_ops_to_ignore,std::vector<std::pair<int,int>> * result)249 Status ComputeIncompatibleResourceOperationPairs(
250     const Graph& g, const FunctionLibraryDefinition* flib_def,
251     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
252     std::vector<std::pair<int, int>>* result) {
253   CHECK(result->empty());
254 
255   std::vector<Node*> rpo;
256   GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(),
257                       /*edge_filter=*/[](const Edge& edge) {
258                         return !edge.src()->IsNextIteration();
259                       });
260 
261   auto resource_op_set_for_node =
262       std::make_unique<ResourceOpSet[]>(g.num_node_ids());
263 
264   const bool vlog = VLOG_IS_ON(2);
265 
266   for (Node* n : rpo) {
267     std::optional<XlaResourceOpKind> op_kind;
268     TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
269         *n, flib_def, resource_ops_to_ignore, &op_kind));
270 
271     ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
272 
273     // Merge the reaching resource operations for all the incoming edges to
274     // create the set of all possible resource ops reaching `n`.
275     for (const Edge* e : n->in_edges()) {
276       if (n->IsMerge() && e->src()->IsNextIteration()) {
277         // Ignore back-edges (see file comment).
278         continue;
279       }
280 
281       const ResourceOpSet& incoming_op_set =
282           resource_op_set_for_node[e->src()->id()];
283       resource_op_set->Add(incoming_op_set);
284     }
285 
286     // Add to the "incompatible resource ops" set if necessary.
287     if (op_kind) {
288       for (ResourceOp incoming_op : *resource_op_set) {
289         if (IsEdgeSafe(incoming_op.second, *op_kind)) {
290           continue;
291         }
292 
293         if (vlog) {
294           VLOG(2) << "Unsafe edge: "
295                   << NodeToString(*g.FindNodeId(incoming_op.first),
296                                   incoming_op.second)
297                   << " -> " << NodeToString(*n, *op_kind);
298         }
299         result->push_back({incoming_op.first, n->id()});
300       }
301 
302       // Some graphs might have a lot of 'kRead' kinds, but they are always safe
303       // for incoming ops, so not storing them might save a lot of memory.
304       if (op_kind != XlaResourceOpKind::kRead) {
305         resource_op_set->Add({n->id(), *op_kind});
306       }
307     }
308 
309     if (vlog) {
310       VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set);
311     }
312   }
313 
314   std::sort(result->begin(), result->end());
315   CHECK(std::unique(result->begin(), result->end()) == result->end());
316 
317   return OkStatus();
318 }
319 }  // namespace tensorflow
320