xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/forward_type_inference.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/forward_type_inference.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <string>
21 #include <string_view>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/framework/full_type.pb.h"
25 #include "tensorflow/core/framework/full_type_util.h"
26 #include "tensorflow/core/framework/op_def_builder.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/util/dump_graph.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
34 int MAX_VISITS_PER_NODE = 3;
35 
36 typedef absl::flat_hash_map<
37     int, std::reference_wrapper<ForwardTypeInferenceFn const>>
38     ForwardInferMap;
39 typedef absl::flat_hash_map<
40     int, std::pair<int, std::reference_wrapper<ForwardTypeInferenceFn const>>>
41     ReverseInferMap;
42 
all_sources_closed(const Node & n,const absl::flat_hash_set<int> & closed,const ForwardInferMap & forward,const ReverseInferMap & reverse)43 bool all_sources_closed(const Node& n, const absl::flat_hash_set<int>& closed,
44                         const ForwardInferMap& forward,
45                         const ReverseInferMap& reverse) {
46   for (const auto& e : n.out_edges()) {
47     if (e->IsControlEdge()) {
48       continue;
49     }
50     int dst_id = e->dst()->id();
51     if (reverse.contains(dst_id) && !closed.contains(dst_id)) {
52       return false;
53     }
54   }
55   if (forward.contains(n.id())) {
56     for (const auto& e : n.in_edges()) {
57       if (e->IsControlEdge()) {
58         continue;
59       }
60       if (!closed.contains(e->src()->id())) {
61         return false;
62       }
63     }
64   }
65   return true;
66 }
67 
input_types(const Node & n)68 std::vector<std::reference_wrapper<const FullTypeDef>> input_types(
69     const Node& n) {
70   static FullTypeDef* no_type = new FullTypeDef();
71 
72   std::vector<std::reference_wrapper<const FullTypeDef>> input_types;
73   for (const auto& in_edge : n.in_edges()) {
74     if (in_edge->IsControlEdge()) {
75       continue;
76     }
77     input_types.push_back(*no_type);
78   }
79   for (const auto& in_edge : n.in_edges()) {
80     if (in_edge->IsControlEdge()) {
81       continue;
82     }
83     VLOG(5) << "  in edge: " << in_edge->DebugString();
84     NodeDef* ndef = in_edge->src()->mutable_def();
85     if (ndef->has_experimental_type()) {
86       const auto& t = ndef->experimental_type();
87       if (t.type_id() != TFT_UNSET) {
88         DCHECK(t.type_id() == TFT_PRODUCT) << ndef->DebugString();
89         DCHECK(t.args_size() > in_edge->src_output()) << ndef->DebugString();
90         input_types.at(in_edge->dst_input()) = t.args(in_edge->src_output());
91       }
92     }
93   }
94   return input_types;
95 }
96 
updated_inferred_type(Node * target,const FullTypeDef & t,bool & updated)97 Status updated_inferred_type(Node* target, const FullTypeDef& t,
98                              bool& updated) {
99   if (t.type_id() == TFT_UNSET) {
100     VLOG(3) << "  " << target->name() << " no inferred type";
101     return OkStatus();
102   }
103 
104   if (target->def().has_experimental_type()) {
105     const auto existing = target->def().experimental_type();
106     if (full_type::IsSubtype(existing, t)) {
107       VLOG(3) << "  " << target->name() << " no new type info";
108       return OkStatus();
109     } else if (!full_type::IsSubtype(t, existing)) {
110       // The only allowable type mismatches are those which would further
111       // specialize the existing type.
112       return Status(
113           error::INVALID_ARGUMENT,
114           absl::StrCat("type mismatch for node '", target->name(),
115                        "': expected a subtype of:\n", existing.DebugString(),
116                        "\n  got:\n", t.DebugString(), "\n  "));
117     }
118   }
119 
120   *(target->mutable_def()->mutable_experimental_type()) = t;
121   updated = true;
122   VLOG(3) << "  " << target->name() << " updated";
123   return OkStatus();
124 }
125 
126 }  // namespace
127 
Run(const GraphOptimizationPassOptions & options)128 Status ForwardTypeInferencePass::Run(
129     const GraphOptimizationPassOptions& options) {
130   VLOG(1) << "ForwardTypeInferencePass::Run";
131 
132   DCHECK(options.graph != nullptr);
133   Graph* g = options.graph->get();
134   DCHECK(g != nullptr);
135   FunctionLibraryDefinition* flib_def = options.flib_def;
136   DCHECK(flib_def != nullptr);
137 
138   if (VLOG_IS_ON(1)) {
139     DumpGraphToFile("forward_type_inference_before", *g, flib_def);
140   }
141 
142   for (Node* n : g->nodes()) {
143     // TODO(mdan): Needed?
144     n->UpdateProperties();
145   }
146 
147   // Cache type inference functions, to avoid repeated flib_def lookups.
148   ForwardInferMap forward;
149   ReverseInferMap reverse;
150   for (Node* n : g->nodes()) {
151     VLOG(4) << "\n  node: " << n->def().DebugString()
152             << "\n  op def: " << n->op_def().DebugString();
153     const OpRegistrationData* reg;
154     TF_RETURN_IF_ERROR(flib_def->LookUp(n->op_def().name(), &reg));
155     if (reg->fwd_type_fn != nullptr) {
156       forward.emplace(n->id(), reg->fwd_type_fn);
157     }
158     if (reg->rev_type_fn != nullptr) {
159       reverse.emplace(n->id(), std::make_pair(reg->rev_type_input,
160                                               std::cref(reg->rev_type_fn)));
161     }
162   }
163 
164   auto infer_forward = [&forward](Node* n, bool& updated) {
165     if (!forward.contains(n->id())) {
166       return OkStatus();
167     }
168     VLOG(4) << "  " << n->name() << " has forward function";
169 
170     // TODO(b/224775462): Populate with types from function references.
171     TypeRefMap type_vars;
172     auto in_types = input_types(*n);
173     const auto& infer_ret = forward.at(n->id())(in_types, type_vars);
174 
175     TF_RETURN_WITH_CONTEXT_IF_ERROR(
176         infer_ret.status(),
177         absl::StrCat("while inferring type of node '", n->name(), "'"));
178 
179     TF_RETURN_WITH_CONTEXT_IF_ERROR(
180         updated_inferred_type(n, *infer_ret, updated),
181         "while updating its output type.");
182 
183     return OkStatus();
184   };
185 
186   auto infer_reverse = [&reverse](Node* n, bool& updated) {
187     if (!reverse.contains(n->id())) {
188       return OkStatus();
189     }
190     VLOG(4) << "  " << n->name() << " has reverse function";
191 
192     // TODO(b/224775462): Populate with types from function references.
193     TypeRefMap type_vars;
194     auto in_types = input_types(*n);
195     auto rev_idx_and_fn = reverse.at(n->id());
196     const auto& infer_ret = rev_idx_and_fn.second(in_types, type_vars);
197 
198     const Edge* e;
199     TF_RETURN_WITH_CONTEXT_IF_ERROR(
200         n->input_edge(rev_idx_and_fn.first, &e),
201         absl::StrCat("while querying input ", rev_idx_and_fn.first, " of '",
202                      n->name(), "'"));
203 
204     TF_RETURN_WITH_CONTEXT_IF_ERROR(
205         infer_ret.status(),
206         absl::StrCat("while inferring type of node '", e->src()->name(),
207                      "' via '", n->name(), "'"));
208 
209     TF_RETURN_WITH_CONTEXT_IF_ERROR(
210         updated_inferred_type(e->src(), *infer_ret, updated),
211         absl::StrCat("while updating its output type inferred from '",
212                      n->name(), ","));
213 
214     return OkStatus();
215   };
216 
217   std::list<int> queue;
218   absl::flat_hash_set<int> in_queue;
219   absl::flat_hash_map<int, int> visit_count;
220   // Open nodes. A node is open if it has never been visited.
221   absl::flat_hash_set<int> open;
222   // Closed nodes. A closed node will never be visited again.
223   absl::flat_hash_set<int> closed;
224 
225   // Upper bound. Worst-case is a cycle in which no nodes have type info,
226   // case in which there will be max_passes iterations, each visiting one node.
227   int max_passes = g->num_nodes();
228 
229   int visits = 0;
230 
231   // Start with niladic nodes. If none exist, a random one will be selected at
232   // the end of first iteration.
233   for (Node* n : g->nodes()) {
234     const int nid = n->id();
235     bool niladic = true;
236     for (const auto& e : n->in_edges()) {
237       if (!e->IsControlEdge()) {
238         niladic = false;
239         break;
240       }
241     }
242     if (niladic) {
243       queue.emplace_back(nid);
244       in_queue.emplace(nid);
245     }
246     open.emplace(nid);
247     visit_count.emplace(nid, 0);
248   }
249 
250   for (int i = 0; i < max_passes; i++) {
251     VLOG(2) << "Iteration " << i << ", " << queue.size() << " nodes in queue";
252 
253     while (!queue.empty()) {
254       int nid = queue.front();
255       Node* n = g->FindNodeId(nid);
256       VLOG(3) << "  visiting " << n->name();
257       visits++;
258       visit_count[nid]++;
259       DCHECK(!closed.contains(nid));
260 
261       bool updated = false;
262       TF_RETURN_IF_ERROR(infer_forward(n, updated));
263       TF_RETURN_IF_ERROR(infer_reverse(n, updated));
264 
265       VLOG(4) << "  done " << n->def().DebugString();
266 
267       queue.pop_front();
268       in_queue.erase(nid);
269       open.erase(nid);
270 
271       // Update the graph to fixed point, with iterations limited
272       // by MAX_VISITS_PER_NODE.
273       if (visit_count.at(nid) >= MAX_VISITS_PER_NODE) {
274         VLOG(3) << "  closing " << n->name() << " - visit limit reached";
275         closed.emplace(nid);
276       } else if (all_sources_closed(*n, closed, forward, reverse)) {
277         VLOG(3) << "  closing " << n->name() << " - all sources closed";
278         closed.emplace(nid);
279       }
280 
281       for (const auto& out_edge : n->out_edges()) {
282         if (out_edge->IsControlEdge()) {
283           continue;
284         }
285         Node* c = out_edge->dst();
286         int cid = c->id();
287         if (closed.contains(cid) || in_queue.contains(cid)) {
288           continue;
289         }
290         if (updated || all_sources_closed(*c, closed, forward, reverse)) {
291           queue.emplace_back(cid);
292           in_queue.emplace(cid);
293         }
294       }
295       if (updated && reverse.contains(nid)) {
296         const Edge* e;
297         TF_RETURN_IF_ERROR(n->input_edge(reverse.at(nid).first, &e));
298         Node* c = e->src();
299         int cid = c->id();
300         if (!closed.contains(cid) && !in_queue.contains(cid)) {
301           queue.emplace_back(cid);
302           in_queue.emplace(cid);
303         }
304       }
305     }
306 
307     VLOG(2) << "Done iteration " << i << ", " << closed.size()
308             << " nodes closed";
309 
310     if (open.empty()) {
311       VLOG(1) << "Finished after " << i + 1 << " iterations; done "
312               << closed.size() << " of " << g->num_nodes() << " nodes in "
313               << visits << " visits";
314       break;
315     } else {
316       queue.emplace_back(*(open.begin()));
317     }
318   }
319 
320   if (VLOG_IS_ON(1)) {
321     DumpGraphToFile("forward_type_inference_after", *g, flib_def);
322   }
323 
324   return OkStatus();
325 }
326 
Run(const GraphOptimizationPassOptions & options)327 Status WeakForwardTypeInferencePass::Run(
328     const GraphOptimizationPassOptions& options) {
329   ForwardTypeInferencePass pass;
330   const auto& pass_status = pass.Run(options);
331   if (!pass_status.ok()) {
332     LOG_FIRST_N(WARNING, 1)
333         << "Type inference failed. This indicates an "
334            "invalid graph that escaped type checking. Error message: "
335         << pass_status.ToString();
336   }
337   return OkStatus();
338 }
339 
340 // Note: This needs to run last because Placer needs it.
341 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 99999,
342                       WeakForwardTypeInferencePass);
343 
344 }  // namespace tensorflow
345