xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/cc/framework/ops.h"
27 #include "tensorflow/cc/ops/array_ops.h"
28 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
29 #include "tensorflow/cc/ops/function_ops.h"
30 #include "tensorflow/cc/ops/functional_ops.h"
31 #include "tensorflow/cc/ops/list_ops.h"
32 #include "tensorflow/cc/ops/resource_variable_ops.h"
33 #include "tensorflow/cc/ops/sendrecv_ops.h"
34 #include "tensorflow/cc/ops/standard_ops.h"
35 #include "tensorflow/compiler/jit/defs.h"
36 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
37 #include "tensorflow/compiler/jit/node_matchers.h"
38 #include "tensorflow/compiler/jit/xla_cluster_util.h"
39 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
40 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
41 #include "tensorflow/core/common_runtime/graph_constructor.h"
42 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/graph_def_builder.h"
47 #include "tensorflow/core/lib/core/status_test_util.h"
48 #include "tensorflow/core/platform/errors.h"
49 #include "tensorflow/core/platform/test.h"
50 
51 using ::tensorflow::testing::FindNodeByName;
52 
53 namespace tensorflow {
54 namespace {
55 
__anon5a56cdc30202null56 static bool Initialized = [] {
57   tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
58   return true;
59 }();
60 
61 REGISTER_OP("UncompilableNullary").Output("o: float");
62 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
63 
GetClusters(const Graph & graph)64 std::unordered_map<string, string> GetClusters(const Graph& graph) {
65   std::unordered_map<string, string> ids;
66   for (Node* node : graph.nodes()) {
67     string cluster;
68     if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
69       CHECK(!cluster.empty());
70       ids[node->name()] = cluster;
71     }
72   }
73 
74   if (VLOG_IS_ON(2)) {
75     VLOG(2) << "Clusters:";
76     for (const auto& p : ids) {
77       VLOG(2) << " " << p.first << " -> " << p.second;
78     }
79   }
80   return ids;
81 }
82 
GetClusterNames(const Graph & graph)83 std::set<string> GetClusterNames(const Graph& graph) {
84   std::set<string> names;
85   for (Node* node : graph.nodes()) {
86     string cluster;
87     if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
88       CHECK(!cluster.empty());
89       names.insert(cluster);
90     }
91   }
92   return names;
93 }
94 
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)95 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
96     const Graph& g, std::vector<string>* cluster_names = nullptr) {
97   CHECK(cluster_names == nullptr || cluster_names->empty());
98   absl::flat_hash_map<string, std::vector<string>> cluster_sets;
99   for (const auto& p : GetClusters(g)) {
100     cluster_sets[p.second].push_back(p.first);
101   }
102   for (auto& p : cluster_sets) {
103     if (cluster_names != nullptr) {
104       cluster_names->push_back(p.first);
105     }
106     std::sort(p.second.begin(), p.second.end());
107   }
108   if (cluster_names != nullptr) {
109     std::sort(cluster_names->begin(), cluster_names->end());
110   }
111   return cluster_sets;
112 }
113 
TEST(XlaCompilationTest,Chains)114 TEST(XlaCompilationTest, Chains) {
115   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
116   {
117     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
118     Node* a =
119         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
120     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
121     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
122     Node* d =
123         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
124     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
125     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
126     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
127   }
128 
129   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
130   auto clusters = GetClusters(*graph);
131   EXPECT_EQ(4, clusters.size());
132   EXPECT_EQ(clusters["B"], clusters["C"]);
133   EXPECT_EQ(clusters["E"], clusters["F"]);
134   EXPECT_NE(clusters["B"], clusters["E"]);
135   EXPECT_TRUE(clusters.find("A") == clusters.cend());
136   EXPECT_TRUE(clusters.find("D") == clusters.cend());
137 }
138 
TEST(XlaCompilationTest,UncompilableCycles)139 TEST(XlaCompilationTest, UncompilableCycles) {
140   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
141   {
142     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
143     Node* a = ops::SourceOp("Const", builder.opts()
144                                          .WithName("A")
145                                          .WithAttr("dtype", DT_FLOAT)
146                                          .WithAttr("value", Tensor()));
147     Node* b =
148         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
149     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
150     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
151   }
152 
153   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
154   auto clusters = GetClusters(*graph);
155 
156   EXPECT_TRUE(clusters.empty());
157 }
158 
TEST(XlaCompilationTest,CompilableCycles)159 TEST(XlaCompilationTest, CompilableCycles) {
160   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
161   {
162     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
163     Node* a = ops::SourceOp("Const", builder.opts()
164                                          .WithName("A")
165                                          .WithAttr("dtype", DT_FLOAT)
166                                          .WithAttr("value", Tensor()));
167     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
168     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
169     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
170   }
171 
172   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
173   auto clusters = GetClusters(*graph);
174 
175   EXPECT_EQ(3, clusters.size());
176   EXPECT_EQ(clusters["A"], clusters["B"]);
177   EXPECT_EQ(clusters["A"], clusters["C"]);
178 }
179 
TEST(XlaCompilationTest,StringUnsupported)180 TEST(XlaCompilationTest, StringUnsupported) {
181   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
182   {
183     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
184     Node* a = ops::SourceOp(
185         "Const", builder.opts()
186                      .WithName("A")
187                      .WithAttr("dtype", DT_STRING)
188                      .WithAttr("value", Tensor(DT_STRING, TensorShape())));
189     Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
190     ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
191     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
192   }
193 
194   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
195   auto clusters = GetClusters(*graph);
196   EXPECT_TRUE(clusters.empty());
197 }
198 
TEST(XlaCompilationTest,WhereUnsupported)199 TEST(XlaCompilationTest, WhereUnsupported) {
200   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
201   {
202     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
203     Node* a = ops::SourceOp("Const", builder.opts()
204                                          .WithName("A")
205                                          .WithAttr("dtype", DT_INT32)
206                                          .WithAttr("value", Tensor()));
207     Node* b = ops::UnaryOp("Where", a, builder.opts().WithName("B"));
208     ops::BinaryOp("Gather", b, a, builder.opts().WithName("C"));
209     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
210   }
211 
212   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
213   auto clusters = GetClusters(*graph);
214   EXPECT_TRUE(!clusters.empty());
215 }
216 
TEST(XlaCompilationTest,HalfSupported)217 TEST(XlaCompilationTest, HalfSupported) {
218   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
219   {
220     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
221     Tensor t(DT_HALF, TensorShape());
222     t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
223     Node* a = ops::SourceOp("Const", builder.opts()
224                                          .WithName("A")
225                                          .WithAttr("dtype", DT_HALF)
226                                          .WithAttr("value", t));
227     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
228     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
229     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
230   }
231 
232   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
233   auto clusters = GetClusters(*graph);
234   EXPECT_FALSE(clusters.empty());
235 }
236 
237 // Tests that PartitionedCalls are only marked for compilation if every node
238 // inside the function can be compiled.
TEST(XlaCompilationTest,PartitionedCallUnsupported)239 TEST(XlaCompilationTest, PartitionedCallUnsupported) {
240   FunctionDef compilable = FunctionDefHelper::Define(
241       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
242       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
243   FunctionDef uncompilable =
244       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
245                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
246 
247   FunctionDefLibrary flib;
248   *flib.add_function() = compilable;
249   *flib.add_function() = uncompilable;
250   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
251 
252   std::unique_ptr<Graph> graph(new Graph(&flib_def));
253   Scope root = Scope::NewRootScope().ExitOnError();
254   Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
255 
256   NameAttrList b_name_attr;
257   b_name_attr.set_name("CompilableFn");
258   ops::PartitionedCall b(root.WithOpName("B"), {a, a}, {DT_FLOAT}, b_name_attr);
259   NameAttrList c_name_attr;
260   c_name_attr.set_name("UncompilableFn");
261 
262   ops::PartitionedCall c(root.WithOpName("C"), {a}, {DT_FLOAT}, c_name_attr);
263   Output d = ops::Add(root.WithOpName("D"), b.output.front(), c.output.front());
264 
265   TF_ASSERT_OK(root.ToGraph(graph.get()));
266   TF_ASSERT_OK(
267       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
268   auto clusters = GetClusters(*graph);
269 
270   EXPECT_EQ(2, clusters.size());
271   EXPECT_FALSE(clusters["B"].empty());
272   EXPECT_TRUE(clusters["C"].empty());
273   EXPECT_EQ(clusters["B"], clusters["D"]);
274 }
275 
TEST(XlaCompilationTest,FunctionCalls)276 TEST(XlaCompilationTest, FunctionCalls) {
277   FunctionDef compilable = FunctionDefHelper::Define(
278       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
279       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
280   FunctionDef uncompilable =
281       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
282                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
283   FunctionDef noinline = compilable;
284   noinline.mutable_signature()->set_name("NoInlineFn");
285   AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
286 
287   FunctionDefLibrary flib;
288   *flib.add_function() = compilable;
289   *flib.add_function() = uncompilable;
290   *flib.add_function() = noinline;
291   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
292 
293   std::unique_ptr<Graph> graph(new Graph(&flib_def));
294   {
295     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
296     Node* a =
297         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
298     Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
299     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
300     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
301     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
302     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
303   }
304 
305   TF_ASSERT_OK(
306       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
307   auto clusters = GetClusters(*graph);
308 
309   EXPECT_EQ(2, clusters.size());
310   EXPECT_FALSE(clusters["C"].empty());
311   EXPECT_EQ(clusters["C"], clusters["E"]);
312   EXPECT_TRUE(clusters.find("A") == clusters.cend());
313   EXPECT_TRUE(clusters.find("B") == clusters.cend());
314   EXPECT_TRUE(clusters.find("D") == clusters.cend());
315 }
316 
TEST(XlaCompilationTest,CallXlaDeviceFuncWithResourceOp)317 TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
318   FunctionDef compilable = FunctionDefHelper::Define(
319       "FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {},
320       {{{"assign_op"},
321         "AssignVariableOp",
322         {"var", "val"},
323         {{"dtype", DT_FLOAT}}},
324        {{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}});
325 
326   FunctionDefLibrary flib;
327   *flib.add_function() = compilable;
328   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
329 
330   std::unique_ptr<Graph> graph(new Graph(&flib_def));
331   {
332     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
333     Node* resource =
334         ops::SourceOp("VarHandleOp", builder.opts()
335                                          .WithName("varhandle")
336                                          .WithAttr("dtype", DT_FLOAT)
337                                          .WithAttr("shape", TensorShape({})));
338 
339     Tensor const_tensor(DT_FLOAT, TensorShape({}));
340     const_tensor.scalar<float>()() = 42.0f;
341     Node* value = ops::SourceOp("Const", builder.opts()
342                                              .WithName("const")
343                                              .WithAttr("value", const_tensor)
344                                              .WithAttr("dtype", DT_FLOAT));
345 
346     Node* call = ops::BinaryOp("FnWithResourceOp", resource, value,
347                                builder.opts().WithName("A"));
348     Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0"));
349     Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1"));
350     ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2"));
351     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
352   }
353 
354   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
355   testing::FindNodeByName(graph.get(), "A")
356       ->set_assigned_device_name(xla_cpu_device);
357   testing::FindNodeByName(graph.get(), "tanh0")
358       ->set_assigned_device_name(xla_cpu_device);
359   testing::FindNodeByName(graph.get(), "tanh1")
360       ->set_assigned_device_name(xla_cpu_device);
361   testing::FindNodeByName(graph.get(), "tanh2")
362       ->set_assigned_device_name(xla_cpu_device);
363 
364   TF_ASSERT_OK(
365       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
366   auto clusters = GetClusters(*graph);
367 
368   EXPECT_NE(clusters["A"], "");
369 }
370 
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)371 static Status GradForUnaryCwise(FunctionDef* g,
372                                 std::vector<FunctionDefHelper::Node> nodes) {
373   for (auto& n : nodes) {
374     if (n.attr.empty()) {
375       n.attr = {{"T", DT_FLOAT}};
376     }
377   }
378   *g = FunctionDefHelper::Define(
379       // Arg defs
380       {"x: float", "dy: float"},
381       // Ret val defs
382       {"dx: float"},
383       // Attr defs
384       {},
385       // Nodes
386       nodes);
387   return OkStatus();
388 }
389 
390 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)391 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
392   // clang-format off
393   return GradForUnaryCwise(g, {
394       {{"y"}, "Tanh", {"x"}},
395       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
396       FunctionDefHelper::Const("one", 1.0f),
397       {{"a"}, "Sub", {"one", "y2"}},
398       {{"dx"}, "Mul", {"dy", "a"}},
399   });
400   // clang-format on
401 }
402 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
403 
404 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)405 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
406   // clang-format off
407   return GradForUnaryCwise(g, {
408       {{"y"}, "Tanh", {"x"}},
409       {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
410       FunctionDefHelper::Const("one", 1.0f),
411       {{"a"}, "Sub", {"one", "y2"}},
412       {{"dx"}, "Mul", {"dy", "a"}},
413   });
414   // clang-format on
415 }
416 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
417 
TEST(XlaCompilationTest,SymbolicGradients)418 TEST(XlaCompilationTest, SymbolicGradients) {
419   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
420   {
421     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
422     Node* a =
423         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
424 
425     // Builds a Symbolic gradient for Supported
426     NodeBuilder b_builder("B", "SymbolicGradient",
427                           builder.opts().op_registry());
428     NameAttrList b_name_attr;
429     b_name_attr.set_name("Supported");
430     b_builder.Attr("f", b_name_attr);
431     b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
432     b_builder.Attr("Tout", {DT_FLOAT});
433     b_builder.Input({a, a});
434     Node* b = builder.opts().FinalizeBuilder(&b_builder);
435 
436     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
437 
438     // Builds a Symbolic gradient for Unsupported
439     NodeBuilder d_builder("D", "SymbolicGradient",
440                           builder.opts().op_registry());
441     NameAttrList d_name_attr;
442     d_name_attr.set_name("Unsupported");
443     d_builder.Attr("f", d_name_attr);
444     d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
445     d_builder.Attr("Tout", {DT_FLOAT});
446     d_builder.Input({c, c});
447     builder.opts().FinalizeBuilder(&d_builder);
448 
449     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
450   }
451 
452   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
453   auto clusters = GetClusters(*graph);
454 
455   EXPECT_EQ(2, clusters.size());
456   EXPECT_FALSE(clusters["B"].empty());
457   EXPECT_EQ(clusters["B"], clusters["C"]);
458   EXPECT_TRUE(clusters.find("A") == clusters.cend());
459   EXPECT_TRUE(clusters.find("D") == clusters.cend());
460 }
461 
TEST(XlaCompilationTest,Loops)462 TEST(XlaCompilationTest, Loops) {
463   // Regression test for b/32350199, where the autoclustering code introduced a
464   // deadlock in a graph containing a while loop.
465   Scope root = Scope::NewRootScope().ExitOnError();
466   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
467   auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
468   auto c = ops::Add(root.WithOpName("C"), a, b);
469   auto enter = ops::internal::Enter(root, c, "aframe");
470   auto next_iter = ops::NextIteration(root, enter);
471   auto exit = ops::internal::Exit(root, next_iter);
472   auto d = ops::Add(root.WithOpName("D"), c, exit);
473 
474   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
475   TF_EXPECT_OK(root.ToGraph(graph.get()));
476 
477   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
478   auto clusters = GetClusters(*graph);
479 
480   // Nothing should be compiled. In particular, 'd' and 'c' must not be
481   // compiled.
482   EXPECT_EQ(0, clusters.size());
483 }
484 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)485 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
486   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
487   {
488     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
489     Node* a = ops::SourceOp("Const", builder.opts()
490                                          .WithName("A")
491                                          .WithAttr("dtype", DT_FLOAT)
492                                          .WithAttr("value", Tensor())
493                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
494     Node* b = ops::UnaryOp(
495         "Relu", a,
496         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
497     ops::BinaryOp(
498         "MatMul", a, b,
499         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
500     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
501   }
502 
503   FunctionDefLibrary flib;
504   FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
505   TF_ASSERT_OK(
506       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
507   auto clusters = GetClusters(*graph);
508 
509   // The computation is: C = A + relu(A)
510   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
511   // In this case, the GlobalJitLevel overrides the scopes to cluster while
512   // ignoring scopes.
513   EXPECT_EQ(3, clusters.size());
514   EXPECT_EQ(clusters["A"], clusters["B"]);
515   EXPECT_EQ(clusters["A"], clusters["C"]);
516 }
517 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)518 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
519   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
520   {
521     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
522     Node* a = ops::SourceOp("Const", builder.opts()
523                                          .WithName("A")
524                                          .WithAttr("dtype", DT_FLOAT)
525                                          .WithAttr("value", Tensor())
526                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
527     Node* b = ops::UnaryOp(
528         "Relu", a,
529         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
530     ops::BinaryOp(
531         "MatMul", a, b,
532         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
533     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
534   }
535 
536   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
537       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
538   auto clusters = GetClusters(*graph);
539 
540   // The computation is: C = A + relu(A)
541   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
542   // In this case, we cannot fuse anything, and there are no clusters.
543   EXPECT_EQ(0, clusters.size());
544 }
545 
TEST(XlaCompilationTest,CyclesWithSplittingScopes)546 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
547   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
548   {
549     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
550     Node* a = ops::SourceOp("Const", builder.opts()
551                                          .WithName("A")
552                                          .WithAttr("dtype", DT_FLOAT)
553                                          .WithAttr("value", Tensor())
554                                          .WithAttr(kXlaCompileAttr, true)
555                                          .WithAttr(kXlaScopeAttr, "Scope1"));
556     Node* b = ops::UnaryOp("Relu", a,
557                            builder.opts()
558                                .WithName("B")
559                                .WithAttr(kXlaCompileAttr, true)
560                                .WithAttr(kXlaScopeAttr, "Scope1"));
561     Node* c = ops::BinaryOp("MatMul", a, b,
562                             builder.opts()
563                                 .WithName("C")
564                                 .WithAttr(kXlaCompileAttr, true)
565                                 .WithAttr(kXlaScopeAttr, "Scope2"));
566     ops::BinaryOp("Add", b, c,
567                   builder.opts()
568                       .WithName("D")
569                       .WithAttr(kXlaCompileAttr, true)
570                       .WithAttr(kXlaScopeAttr, "Scope2"));
571     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
572   }
573 
574   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
575       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
576   auto clusters = GetClusters(*graph);
577 
578   // The computation is: D = relu(A) + (A @ relu(A))
579   // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
580   // In this case, we can fuse the A and relu(A), and we can fuse the
581   // second half of the operations; there are two clusters.
582   EXPECT_EQ(4, clusters.size());
583   EXPECT_EQ(clusters["A"], clusters["B"]);
584   EXPECT_NE(clusters["A"], clusters["C"]);
585   EXPECT_EQ(clusters["C"], clusters["D"]);
586 }
587 
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)588 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
589   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
590   {
591     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
592     Node* a = ops::SourceOp("Const", builder.opts()
593                                          .WithName("A")
594                                          .WithAttr("dtype", DT_FLOAT)
595                                          .WithAttr("value", Tensor())
596                                          .WithAttr(kXlaCompileAttr, true)
597                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
598     Node* b = ops::UnaryOp("Relu", a,
599                            builder.opts()
600                                .WithName("B")
601                                .WithAttr(kXlaCompileAttr, true)
602                                .WithAttr(kXlaScopeAttr, "ScopeB"));
603     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
604     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
605   }
606 
607   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
608       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
609   auto clusters = GetClusters(*graph);
610 
611   // The computation is: C = A @ relu(A)
612   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
613   // In this case, we cannot fuse anything.
614   EXPECT_EQ(3, clusters.size());
615   EXPECT_NE(clusters["A"], clusters["B"]);
616   EXPECT_EQ(clusters["B"], clusters["C"]);
617 }
618 
TEST(XlaCompilationTest,DontClusterNodesWithMismatchingDeadness)619 TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
620   Scope root = Scope::NewRootScope().ExitOnError();
621 
622   Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
623   Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
624 
625   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
626 
627   ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
628   ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
629 
630   Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
631   Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
632 
633   Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
634   Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
635 
636   Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
637 
638   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
639   TF_EXPECT_OK(root.ToGraph(graph.get()));
640 
641   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
642       &graph,
643       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
644   auto clusters = GetClusters(*graph);
645 
646   EXPECT_NE(clusters["tan_a0"], "");
647   EXPECT_NE(clusters["tan_a1"], "");
648   EXPECT_NE(clusters["tan_b0"], "");
649   EXPECT_NE(clusters["tan_b1"], "");
650 
651   EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
652   EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
653 
654   EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
655 }
656 
TEST(XlaCompilationTest,ClusterNodesWithMismatchingInputDeadness)657 TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
658   Scope root = Scope::NewRootScope().ExitOnError();
659 
660   Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
661   Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
662 
663   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
664 
665   ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
666   ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
667 
668   Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
669                           switch_b.output_true);
670   Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
671                           switch_b.output_true);
672   Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
673 
674   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
675   TF_EXPECT_OK(root.ToGraph(graph.get()));
676 
677   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
678       &graph,
679       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
680   auto clusters = GetClusters(*graph);
681 
682   EXPECT_NE(clusters["add_a"], "");
683   EXPECT_NE(clusters["add_b"], "");
684   EXPECT_NE(clusters["add_c"], "");
685 
686   EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
687   EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
688 }
689 
690 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)691 Node* MakeRead(const Scope& scope, const string& id,
692                Node** var_handle_op = nullptr) {
693   Output var_handle =
694       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
695   Output read =
696       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
697   if (var_handle_op) {
698     *var_handle_op = var_handle.node();
699   }
700   return read.node();
701 }
702 
MakeWrite(const Scope & scope,const string & id)703 Node* MakeWrite(const Scope& scope, const string& id) {
704   Output var_handle =
705       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
706   Output value_to_write =
707       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
708   ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
709                                   var_handle, value_to_write);
710   return assign_op.operation.node();
711 }
712 
MakeNeutral(const Scope & scope,const string & id)713 Node* MakeNeutral(const Scope& scope, const string& id) {
714   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
715 }
716 }  // namespace
717 
TEST(XlaCompilationTest,ResourcesClusteringAllowed)718 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
719   Scope root = Scope::NewRootScope().ExitOnError();
720 
721   Node* read = MakeRead(root, "R");
722   Node* write = MakeWrite(root, "W");
723 
724   root.graph()->AddControlEdge(read, write);
725 
726   FixupSourceAndSinkEdges(root.graph());
727   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
728   TF_EXPECT_OK(root.ToGraph(graph.get()));
729   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
730   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
731       GetClusterSets(*graph);
732   ASSERT_EQ(cluster_sets.size(), 1);
733   std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
734                                                   "ValueToAssignW"};
735   ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
736 }
737 
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)738 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
739   Scope root = Scope::NewRootScope().ExitOnError();
740 
741   Node* read = MakeRead(root, "R");
742   Node* write = MakeWrite(root, "W");
743 
744   root.graph()->AddControlEdge(write, read);
745 
746   FixupSourceAndSinkEdges(root.graph());
747   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
748   TF_EXPECT_OK(root.ToGraph(graph.get()));
749   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
750   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
751       GetClusterSets(*graph);
752   ASSERT_EQ(cluster_sets.size(), 0);
753 }
754 
TEST(XlaCompilationTest,ChainOfOps)755 TEST(XlaCompilationTest, ChainOfOps) {
756   Scope root = Scope::NewRootScope().ExitOnError();
757 
758   Node* write_0 = MakeWrite(root, "W0");
759   Node* neutral_0 = MakeNeutral(root, "N0");
760   Node* read_0 = MakeRead(root, "R0");
761   Node* write_1 = MakeWrite(root, "W1");
762   Node* neutral_1 = MakeNeutral(root, "N1");
763   Node* read_1 = MakeRead(root, "R1");
764 
765   root.graph()->AddControlEdge(write_0, neutral_0);
766   root.graph()->AddControlEdge(neutral_0, read_0);
767   root.graph()->AddControlEdge(read_0, write_1);
768   root.graph()->AddControlEdge(write_1, neutral_1);
769   root.graph()->AddControlEdge(neutral_1, read_1);
770 
771   FixupSourceAndSinkEdges(root.graph());
772   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
773   TF_EXPECT_OK(root.ToGraph(graph.get()));
774   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
775 
776   std::vector<string> cluster_names;
777   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
778       GetClusterSets(*graph, &cluster_names);
779 
780   ASSERT_EQ(cluster_sets.size(), 1);
781 
782   std::vector<string> expected_clustered_nodes_a = {
783       "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
784   ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
785 }
786 
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)787 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
788   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
789   Scope root = Scope::NewRootScope().ExitOnError();
790   {
791     auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
792       NodeDefBuilder builder(name, "NoOp");
793       NodeDef def;
794       TF_CHECK_OK(builder.Finalize(&def));
795 
796       Status status;
797       Node* node = graph->AddNode(def, &status);
798       TF_CHECK_OK(status);
799       return node;
800     };
801 
802     Node* a = BuildNoopNode("a", graph.get());
803     Node* b = BuildNoopNode("b", graph.get());
804     Node* c = BuildNoopNode("c", graph.get());
805     graph->AddControlEdge(a, b);
806     graph->AddControlEdge(b, c);
807     graph->AddControlEdge(c, a);
808   }
809 
810   TF_EXPECT_OK(root.ToGraph(graph.get()));
811 
812   Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
813   EXPECT_FALSE(status.ok());
814   EXPECT_TRUE(absl::StrContains(status.ToString(),
815                                 "Edge from c to a would create a cycle.\n"
816                                 "+-> a\n"
817                                 "|   b\n"
818                                 "+-- c\n"));
819 }
820 
TEST(XlaCompilationTest,Retval)821 TEST(XlaCompilationTest, Retval) {
822   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
823   {
824     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
825     Node* a = ops::SourceOp("Const", builder.opts()
826                                          .WithName("A")
827                                          .WithAttr("dtype", DT_FLOAT)
828                                          .WithAttr("value", Tensor()));
829     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
830     ops::UnaryOp("_Retval", b,
831                  builder.opts()
832                      .WithName("R")
833                      .WithAttr("T", DT_FLOAT)
834                      .WithAttr("index", 0));
835 
836     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
837   }
838 
839   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
840   auto clusters = GetClusters(*graph);
841 
842   EXPECT_TRUE(clusters.empty());
843 }
844 
TEST(XlaCompilationTest,DontCountIdentityOps)845 TEST(XlaCompilationTest, DontCountIdentityOps) {
846   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
847   Scope root = Scope::NewRootScope().ExitOnError();
848   {
849     auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
850     auto b = ops::Identity(root.WithOpName("B"), a);
851     auto c = ops::Identity(root.WithOpName("C"), b);
852     auto r = ops::_Retval(root.WithOpName("R"), c, 0);
853   }
854   TF_ASSERT_OK(root.ToGraph(graph.get()));
855   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
856   auto clusters = GetClusters(*graph);
857 
858   EXPECT_TRUE(clusters.empty());
859 }
860 
TEST(XlaCompilationTest,ConstOp)861 TEST(XlaCompilationTest, ConstOp) {
862   // valid data type
863   {
864     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
865     Scope root = Scope::NewRootScope().ExitOnError();
866     auto c = ops::Const(root.WithOpName("const"), 0.5f);
867     c.node()->AddAttr(kXlaCompileAttr, true);
868     TF_ASSERT_OK(root.ToGraph(graph.get()));
869     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
870     EXPECT_EQ(1, GetClusters(*graph).size());
871   }
872 
873   // invalid data type
874   {
875     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
876     Scope root = Scope::NewRootScope().ExitOnError();
877     auto c = ops::Const(root.WithOpName("const"), string("string"));
878     c.node()->AddAttr(kXlaCompileAttr, true);
879     TF_ASSERT_OK(root.ToGraph(graph.get()));
880     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
881     EXPECT_TRUE(GetClusters(*graph).empty());
882   }
883 }
884 
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)885 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
886   Scope root = Scope::NewRootScope().ExitOnError();
887   Output variable = ops::Variable(root.WithOpName("variable"),
888                                   PartialTensorShape{}, DT_FLOAT);
889   Output read = ops::Identity(root.WithOpName("read"), variable);
890   Output neg = ops::Negate(root.WithOpName("negate"), read);
891   Output add = ops::Add(root.WithOpName("add"), neg, neg);
892   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
893 
894   TF_ASSERT_OK(root.ToGraph(graph.get()));
895   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
896 
897   std::unordered_map<string, string> clusters = GetClusters(*graph);
898 
899   ASSERT_FALSE(clusters.empty());
900   string cluster_name = clusters.begin()->second;
901 
902   std::unordered_map<string, string> expected_clusters(
903       {{"negate", cluster_name}, {"add", cluster_name}});
904   EXPECT_EQ(clusters, expected_clusters);
905 }
906 
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)907 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
908   Scope root = Scope::NewRootScope().ExitOnError();
909   Output variable = ops::Variable(root.WithOpName("variable"),
910                                   PartialTensorShape{}, DT_FLOAT);
911   Output read = ops::Identity(root.WithOpName("read"), variable);
912   Output neg = ops::Negate(root.WithOpName("negate"), read);
913   Output identity = ops::Negate(root.WithOpName("identity"), neg);
914   Output add = ops::Add(root.WithOpName("add"), identity, neg);
915   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
916 
917   TF_ASSERT_OK(root.ToGraph(graph.get()));
918   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
919 
920   std::unordered_map<string, string> clusters = GetClusters(*graph);
921 
922   ASSERT_FALSE(clusters.empty());
923   string cluster_name = clusters.begin()->second;
924 
925   std::unordered_map<string, string> expected_clusters(
926       {{"negate", cluster_name},
927        {"identity", cluster_name},
928        {"add", cluster_name}});
929   EXPECT_EQ(clusters, expected_clusters);
930 }
931 
TEST(XlaCompilationTest,ClusterControlTrigger)932 TEST(XlaCompilationTest, ClusterControlTrigger) {
933   Scope root = Scope::NewRootScope().ExitOnError();
934 
935   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
936                              "sender", 0, "receiver");
937   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
938                              "sender", 0, "receiver");
939   Output const_a = ops::Const(root.WithOpName("const_a"), 42);
940 
941   ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
942   ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
943   root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
944   root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
945   root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
946 
947   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
948 
949   TF_ASSERT_OK(root.ToGraph(graph.get()));
950   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
951 
952   std::unordered_map<string, string> clusters = GetClusters(*graph);
953 
954   // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
955   // it won't be clustered.  ctrl_trigger_b is okay to cluster but we don't
956   // cluster it because of b/118970344.
957   EXPECT_TRUE(clusters.empty());
958 }
959 
TEST(XlaCompilationTest,RandomShape)960 TEST(XlaCompilationTest, RandomShape) {
961   Scope root = Scope::NewRootScope().ExitOnError();
962   Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
963   Output shape =
964       ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
965                             ops::Const(root.WithOpName("minval"), 1),
966                             ops::Const(root.WithOpName("maxval"), 20));
967   Output reshape_input =
968       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
969                        ops::Placeholder::Shape(TensorShape({500, 500})));
970   Output reshape =
971       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
972 
973   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
974 
975   TF_ASSERT_OK(root.ToGraph(graph.get()));
976   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
977 
978   std::unordered_map<string, string> clusters = GetClusters(*graph);
979   EXPECT_EQ(clusters["shape"], "");
980 }
981 
TEST(XlaCompilationTest,RandomShapeWithFunc)982 TEST(XlaCompilationTest, RandomShapeWithFunc) {
983   Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
984 
985   FunctionDefLibrary flib_def;
986   FunctionDef func = FunctionDefHelper::Create(
987       /*function_name=*/"Stateful_func", /*in_def=*/{},
988       /*out_def=*/{"out: int32"},
989       /*attr_def*/
990       {}, /*node_def=*/
991       {FunctionDefHelper::Const("shape_shape", 2),
992        FunctionDefHelper::Const("minval", 1),
993        FunctionDefHelper::Const("maxval", 20),
994        {{"shape"},
995         "RandomUniformInt",
996         {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
997         {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
998       /*ret_def=*/{{"out", "shape:output:0"}});
999 
1000   func.mutable_signature()->set_is_stateful(true);
1001   *flib_def.add_function() = std::move(func);
1002   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
1003   NodeDef call_node;
1004   call_node.set_name("fn_call");
1005   call_node.set_op("Stateful_func");
1006   Status status;
1007   Node* call = root.graph()->AddNode(call_node, &status);
1008   TF_ASSERT_OK(status);
1009 
1010   Output shape = Output(call, 0);
1011   Output reshape_input =
1012       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
1013                        ops::Placeholder::Shape(TensorShape({500, 500})));
1014   Output reshape =
1015       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
1016 
1017   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1018   TF_ASSERT_OK(root.ToGraph(graph.get()));
1019   auto fld = std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
1020                                                           flib_def);
1021   TF_ASSERT_OK(
1022       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
1023 
1024   std::unordered_map<string, string> clusters = GetClusters(*graph);
1025   EXPECT_EQ(clusters["fn_call"], "");
1026 }
1027 
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)1028 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
1029   absl::string_view xla_gpu_device =
1030       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1031 
1032   Scope root = Scope::NewRootScope().ExitOnError();
1033   Output shape_shape =
1034       ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
1035   Output shape =
1036       ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
1037                             ops::Const(root.WithOpName("test/minval"), 1),
1038                             ops::Const(root.WithOpName("test/maxval"), 20));
1039   Output reshape_input =
1040       ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
1041                        ops::Placeholder::Shape(TensorShape({500, 500})));
1042   Output reshape =
1043       ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
1044 
1045   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1046   TF_ASSERT_OK(root.ToGraph(graph.get()));
1047 
1048   for (Node* n : graph->nodes()) {
1049     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1050       n->set_assigned_device_name(string(xla_gpu_device));
1051     }
1052   }
1053   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1054 
1055   std::unordered_map<string, string> clusters = GetClusters(*graph);
1056   EXPECT_EQ(clusters["test/shape_rng"], "");
1057   EXPECT_EQ(clusters["test/reshape"], "");
1058 }
1059 
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)1060 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
1061   absl::string_view xla_gpu_device =
1062       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1063   Scope root = Scope::NewRootScope().ExitOnError();
1064   ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
1065                                 DT_INT32);
1066   Output zero = ops::Const(root.WithOpName("test/zero"), 0);
1067   ops::TensorArrayWrite tensor_array_write(
1068       root.WithOpName("test/write"), tensor_array.handle, zero,
1069       ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
1070   Output tensor_array_read =
1071       ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
1072                            zero, tensor_array_write.flow_out, DT_INT32);
1073   Output reshape =
1074       ops::Reshape(root.WithOpName("test/reshape"),
1075                    ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
1076                    tensor_array_read);
1077 
1078   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1079   TF_ASSERT_OK(root.ToGraph(graph.get()));
1080 
1081   for (Node* n : graph->nodes()) {
1082     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1083       n->set_assigned_device_name(string(xla_gpu_device));
1084     }
1085   }
1086   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1087 
1088   std::unordered_map<string, string> clusters = GetClusters(*graph);
1089   EXPECT_NE(clusters["test/read"], "");
1090   EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
1091 }
1092 
TEST(XlaCompilationTest,DontClusterMergingNodes)1093 TEST(XlaCompilationTest, DontClusterMergingNodes) {
1094   // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
1095   // on GPU1. However, it should not be clustered with the previous node on
1096   // GPU1, because that will serialize production of its inputs that should be
1097   // done in parallel.
1098   //
1099   // This graph is:
1100   // (Const0, Const0) -> MatMul0
1101   // (Const1, Const1) -> MatMul1
1102   // (MatMul0, MatMul1) -> MatMulCombined
1103   //
1104   // Device0: [Const0, Const0, MatMul0]
1105   // Device1: [Const1, Const1, MatMul1, MatMulCombined]
1106   //
1107   // Cluster0: [Const0, Const0, MatMul0]
1108   // Cluster1: [Const1, Const1, MatMul1]
1109   // Cluster2: [MatMulCombined]
1110   Scope root = Scope::NewRootScope().ExitOnError();
1111   absl::string_view xla_gpu_dev0 =
1112       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1113   absl::string_view xla_gpu_dev1 =
1114       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1115   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1116   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1117                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1118   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1119                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1120   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1121   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1122 
1123   Output combined =
1124       ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
1125   TF_ASSERT_OK(root.ToGraph(graph.get()));
1126 
1127   for (Node* n : graph->nodes()) {
1128     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1129       n->set_assigned_device_name(string(xla_gpu_dev0));
1130     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1131       n->set_assigned_device_name(string(xla_gpu_dev1));
1132     }
1133   }
1134   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1135 
1136   // Each of the MatMuls should be in a separate cluster.
1137   std::unordered_map<string, string> clusters = GetClusters(*graph);
1138   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1139   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
1140   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
1141   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1142   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1143 }
1144 
TEST(XlaCompilationTest,DontClusterMergingNodesOnCPU)1145 TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
1146   // This is similar to the 'DontClusterMergingNodes' above, except
1147   // MatMulCombined is placed on the CPU.
1148   Scope root = Scope::NewRootScope().ExitOnError();
1149   absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
1150   absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
1151   absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
1152   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1153   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1154                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1155   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1156                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1157   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1158   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1159 
1160   Output combined =
1161       ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
1162   TF_ASSERT_OK(root.ToGraph(graph.get()));
1163 
1164   for (Node* n : graph->nodes()) {
1165     if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
1166       n->set_assigned_device_name(string(xla_cpu_dev0));
1167     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1168       n->set_assigned_device_name(string(xla_gpu_dev0));
1169     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1170       n->set_assigned_device_name(string(xla_gpu_dev1));
1171     }
1172   }
1173   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1174 
1175   // Each of the MatMuls should be in a separate cluster.
1176   std::unordered_map<string, string> clusters = GetClusters(*graph);
1177   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1178   EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
1179   EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
1180   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1181   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1182 }
1183 
1184 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)1185 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
1186   // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
1187   // on GPU0. However, it should not be clustered with the next node on
1188   // GPU0, because that will prevent the node on GPU1 from beginning its work as
1189   // soon as the data has been produced.
1190   //
1191   // This graph is:
1192   // (Const0, Const0) -> MatMulSource
1193   // MatMulSource -> (MatMul0, MatMul1)
1194   //
1195   // Device0: [Const0, Const1, MatMulSource, MatMul0]
1196   // Device1: [MatMul1]
1197   //
1198   // Cluster0: [Const0, Const1, MatMulSource]
1199   // Cluster1: [MatMul0]
1200   // Cluster2: [MatMul1]
1201   Scope root = Scope::NewRootScope().ExitOnError();
1202   absl::string_view xla_gpu_dev0 =
1203       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1204   absl::string_view xla_gpu_dev1 =
1205       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1206   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1207   Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
1208   Output matmul_source =
1209       ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1210 
1211   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1212                                matmul_source);
1213   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1214                                matmul_source);
1215 
1216   TF_ASSERT_OK(root.ToGraph(graph.get()));
1217   for (Node* n : graph->nodes()) {
1218     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1219       n->set_assigned_device_name(string(xla_gpu_dev0));
1220     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1221       n->set_assigned_device_name(string(xla_gpu_dev1));
1222     }
1223   }
1224   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1225 
1226   std::unordered_map<string, string> clusters = GetClusters(*graph);
1227   EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1228   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1229   EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1230 
1231   // Improved Heuristics should prevent this probably.
1232   EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1233 }
1234 
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1235 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1236   absl::string_view xla_cpu_device =
1237       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1238 
1239   Scope root = Scope::NewRootScope().ExitOnError();
1240   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1241   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1242   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1243   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1244 
1245   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1246   TF_ASSERT_OK(root.ToGraph(graph.get()));
1247 
1248   for (Node* n : graph->nodes()) {
1249     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1250       n->set_assigned_device_name(string(xla_cpu_device));
1251     }
1252   }
1253   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1254 
1255   std::unordered_map<string, string> clusters = GetClusters(*graph);
1256   EXPECT_NE(clusters["test/a"], "");
1257   EXPECT_NE(clusters["test/b"], "");
1258   EXPECT_NE(clusters["test/c"], "");
1259 }
1260 
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1261 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1262   Scope root = Scope::NewRootScope().ExitOnError();
1263   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1264   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1265   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1266   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1267 
1268   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1269   TF_ASSERT_OK(root.ToGraph(graph.get()));
1270 
1271   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1272 
1273   std::unordered_map<string, string> clusters = GetClusters(*graph);
1274   EXPECT_EQ(clusters["test/a"], "");
1275   EXPECT_EQ(clusters["test/b"], "");
1276 }
1277 
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1278 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1279   absl::string_view xla_cpu_device =
1280       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1281 
1282   Scope root = Scope::NewRootScope().ExitOnError();
1283   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1284   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1285   Output check =
1286       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1287   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1288   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1289 
1290   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1291   TF_ASSERT_OK(root.ToGraph(graph.get()));
1292 
1293   for (Node* n : graph->nodes()) {
1294     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1295       n->set_assigned_device_name(string(xla_cpu_device));
1296     }
1297   }
1298   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1299 
1300   std::unordered_map<string, string> clusters = GetClusters(*graph);
1301   EXPECT_NE(clusters["test/check"], "");
1302   EXPECT_NE(clusters["test/greaterequal"], "");
1303   EXPECT_NE(clusters["test/assert"], "");
1304 }
1305 
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1306 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1307   Scope root = Scope::NewRootScope().ExitOnError();
1308   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1309   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1310   Output check =
1311       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1312   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1313   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1314 
1315   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1316   TF_ASSERT_OK(root.ToGraph(graph.get()));
1317 
1318   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1319 
1320   std::unordered_map<string, string> clusters = GetClusters(*graph);
1321   EXPECT_EQ(clusters["test/assert"], "");
1322   EXPECT_EQ(clusters["test/check"], "");
1323 }
1324 
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1325 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1326   Scope root = Scope::NewRootScope().ExitOnError();
1327   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1328   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1329 
1330   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1331   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1332 
1333   Output tensor_list_reserve = ops::TensorListReserve(
1334       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1335 
1336   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1337   TF_ASSERT_OK(root.ToGraph(graph.get()));
1338 
1339   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1340 
1341   std::unordered_map<string, string> clusters = GetClusters(*graph);
1342   EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1343 }
1344 
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1345 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1346   Scope root = Scope::NewRootScope().ExitOnError();
1347   Output dummy_input =
1348       ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1349   Output variant_input =
1350       ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1351 
1352   // Create one more node so that we don't avoid creating a cluster solely
1353   // because it would be trivial.
1354   Output dummy_cast =
1355       ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1356 
1357   Output tensor_list_element_shape = ops::TensorListElementShape(
1358       root.WithOpName("test/tensor_list_element_shape"), variant_input,
1359       DT_INT32);
1360 
1361   root.graph()->AddControlEdge(dummy_cast.node(),
1362                                tensor_list_element_shape.node());
1363 
1364   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1365   TF_ASSERT_OK(root.ToGraph(graph.get()));
1366 
1367   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1368 
1369   std::unordered_map<string, string> clusters = GetClusters(*graph);
1370   EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1371 }
1372 
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1373 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1374   Scope root = Scope::NewRootScope().ExitOnError();
1375   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1376   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1377 
1378   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1379   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1380 
1381   Output tensor_list_reserve = ops::TensorListReserve(
1382       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1383 
1384   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1385   TF_ASSERT_OK(root.ToGraph(graph.get()));
1386 
1387   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1388   for (Node* n : graph->nodes()) {
1389     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1390       n->set_assigned_device_name(xla_cpu_device);
1391     }
1392   }
1393 
1394   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1395 
1396   std::unordered_map<string, string> clusters = GetClusters(*graph);
1397   EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1398 }
1399 
1400 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1401 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1402 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1403 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1404 
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1405 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1406   Scope root = Scope::NewRootScope().ExitOnError();
1407   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1408   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1409 
1410   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1411   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1412   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1413 
1414   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1415   TF_ASSERT_OK(root.ToGraph(graph.get()));
1416 
1417   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1418   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1419   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1420 
1421   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1422 
1423   std::unordered_map<string, string> clusters = GetClusters(*graph);
1424 
1425   EXPECT_NE(clusters["test/x"], "");
1426 
1427   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1428   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1429 }
1430 
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1431 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1432   Scope root = Scope::NewRootScope().ExitOnError();
1433   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1434   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1435 
1436   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1437   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1438 
1439   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1440   TF_ASSERT_OK(root.ToGraph(graph.get()));
1441 
1442   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1443   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1444 
1445   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1446 
1447   std::unordered_map<string, string> clusters = GetClusters(*graph);
1448 
1449   EXPECT_EQ(clusters["test/x"], "");
1450   EXPECT_EQ(clusters["test/y"], "");
1451 }
1452 
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1453 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1454   Scope root = Scope::NewRootScope().ExitOnError();
1455   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1456   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1457 
1458   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1459   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1460 
1461   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1462   TF_ASSERT_OK(root.ToGraph(graph.get()));
1463 
1464   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1465   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1466 
1467   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1468 
1469   std::unordered_map<string, string> clusters = GetClusters(*graph);
1470 
1471   EXPECT_EQ(clusters["test/x"], "");
1472   EXPECT_EQ(clusters["test/y"], "");
1473 }
1474 
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1475 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1476   Scope root = Scope::NewRootScope().ExitOnError();
1477   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1478   Node* var_handle;
1479   Node* resource_read = MakeRead(root, "read", &var_handle);
1480   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1481 
1482   string resource_read_name = resource_read->name();
1483   string var_handle_name = var_handle->name();
1484 
1485   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1486   TF_ASSERT_OK(root.ToGraph(graph.get()));
1487 
1488   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1489   FindNodeByName(graph.get(), resource_read_name)
1490       ->set_assigned_device_name(kGPU0);
1491   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1492 
1493   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1494 
1495   std::unordered_map<string, string> clusters = GetClusters(*graph);
1496 
1497   EXPECT_NE(clusters["test/b"], "");
1498   EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1499 }
1500 
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1501 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1502   Scope root = Scope::NewRootScope().ExitOnError();
1503   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1504   Node* var_handle;
1505   Node* resource_read = MakeRead(root, "read", &var_handle);
1506   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1507 
1508   string resource_read_name = resource_read->name();
1509   string var_handle_name = var_handle->name();
1510 
1511   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1512   TF_ASSERT_OK(root.ToGraph(graph.get()));
1513 
1514   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1515   FindNodeByName(graph.get(), resource_read_name)
1516       ->set_assigned_device_name(kCPU0);
1517   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1518 
1519   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1520 
1521   std::unordered_map<string, string> clusters = GetClusters(*graph);
1522 
1523   EXPECT_EQ(clusters["test/b"], "");
1524   EXPECT_EQ(clusters[resource_read_name], "");
1525 }
1526 
TEST(XlaCompilationTest,DontClusterNodesWithScopedAllocatorAttr)1527 TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
1528   Scope root = Scope::NewRootScope().ExitOnError();
1529   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1530   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1531 
1532   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1533   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1534   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1535 
1536   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1537   TF_ASSERT_OK(root.ToGraph(graph.get()));
1538 
1539   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1540   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1541   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1542 
1543   std::vector<int> scoped_allocator_value;
1544   scoped_allocator_value.push_back(0);
1545   scoped_allocator_value.push_back(155);
1546   FindNodeByName(graph.get(), "test/z")
1547       ->AddAttr("_scoped_allocator", scoped_allocator_value);
1548 
1549   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1550 
1551   std::unordered_map<string, string> clusters = GetClusters(*graph);
1552 
1553   EXPECT_EQ(clusters["test/z"], "");
1554 }
1555 
TEST(XlaCompilationTest,DontClusterNodesWithForwardFromAttr)1556 TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
1557   Scope root = Scope::NewRootScope().ExitOnError();
1558   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1559   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1560 
1561   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1562   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1563   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1564 
1565   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1566   TF_ASSERT_OK(root.ToGraph(graph.get()));
1567 
1568   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1569   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1570   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1571 
1572   FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
1573 
1574   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1575 
1576   std::unordered_map<string, string> clusters = GetClusters(*graph);
1577 
1578   EXPECT_EQ(clusters["test/z"], "");
1579 }
1580 
1581 // Note, this relies on other implementation details to test the
1582 // specific heuristic we care about here, so other changes might be at fault if
1583 // this CL breaks. What we care about is that if a ShapeConsumingOp can be
1584 // connected with a producer or consumer and cannot be clustered with both, it
1585 // should be clustered with the producer.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducer)1586 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
1587   Scope root = Scope::NewRootScope().ExitOnError();
1588   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1589   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1590 
1591   Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1592   Output y = ops::Size(root.WithOpName("test/y"), x);
1593   Output z = ops::Add(root.WithOpName("test/z"), y, y);
1594 
1595   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1596   TF_ASSERT_OK(root.ToGraph(graph.get()));
1597 
1598   // Ensure that the "Size" op can only be clustered with either the producer or
1599   // consumer by putting them on different devices.
1600   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1601   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1602   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
1603 
1604   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1605 
1606   std::unordered_map<string, string> clusters = GetClusters(*graph);
1607 
1608   EXPECT_NE(clusters["test/y"], "");
1609   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1610   EXPECT_NE(clusters["test/z"], clusters["test/y"]);
1611 }
1612 
1613 // Test that ShapeConsuming ops are still fully clustered whenever possible.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducerAndConsumer)1614 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
1615   Scope root = Scope::NewRootScope().ExitOnError();
1616   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1617   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1618 
1619   Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1620   Output y = ops::Size(root.WithOpName("test/y"), x);
1621   Output z = ops::Add(root.WithOpName("test/z"), y, y);
1622 
1623   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1624   TF_ASSERT_OK(root.ToGraph(graph.get()));
1625 
1626   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1627 
1628   std::unordered_map<string, string> clusters = GetClusters(*graph);
1629 
1630   EXPECT_NE(clusters["test/y"], "");
1631   EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
1632   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1633 }
1634 
AddCtrlEdge(const Scope & scope,Operation a,Operation b)1635 void AddCtrlEdge(const Scope& scope, Operation a, Operation b) {
1636   scope.graph()->AddControlEdge(a.node(), b.node());
1637 }
1638 
AddCtrlEdge(const Scope & scope,Output a,Operation b)1639 void AddCtrlEdge(const Scope& scope, Output a, Operation b) {
1640   AddCtrlEdge(scope, a.op(), b);
1641 }
1642 
AddCtrlEdge(const Scope & scope,Operation a,Output b)1643 void AddCtrlEdge(const Scope& scope, Operation a, Output b) {
1644   AddCtrlEdge(scope, a, b.op());
1645 }
1646 
1647 // Tests that we pick a good clustering for graphs that have an integer
1648 // increment operation control dependent on gradient update operations.
TEST(XlaCompilationTest,IterationIncrementAndGroupDeps)1649 TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
1650   Scope scope = Scope::NewRootScope().ExitOnError();
1651 
1652   Output iter =
1653       ops::VarHandleOp(scope.WithOpName("iter"), DT_INT64, TensorShape({}));
1654   Output weights_0 = ops::VarHandleOp(scope.WithOpName("weights_0"), DT_FLOAT,
1655                                       TensorShape({1000}));
1656   Output weights_1 = ops::VarHandleOp(scope.WithOpName("weights_1"), DT_FLOAT,
1657                                       TensorShape({1000}));
1658 
1659   // We update the weights by adding delta to them (to "simulate" a
1660   // ResourceApplyGradientDescent and similar things).
1661   Output delta = ops::Placeholder(scope.WithOpName("delta"), DT_FLOAT);
1662 
1663   ops::AssignAddVariableOp increment_op(
1664       scope.WithOpName("IncrementIteration"), iter,
1665       ops::Const(scope.WithOpName("one"), static_cast<int64_t>(1)));
1666 
1667   ops::AssignAddVariableOp weights_0_update_op(
1668       scope.WithOpName("weights_0_update"), weights_0, delta);
1669   ops::AssignAddVariableOp weights_1_update_op(
1670       scope.WithOpName("weights_1_update"), weights_1, delta);
1671 
1672   ops::NoOp group_deps(scope.WithOpName("group_deps"));
1673 
1674   ops::NoOp some_ctrl_input(scope.WithOpName("some_ctrl_input"));
1675 
1676   Output matmul_input =
1677       ops::Placeholder(scope.WithOpName("matmul_input"), DT_FLOAT);
1678   Output matmul_0 =
1679       ops::MatMul(scope.WithOpName("matmul_0"), matmul_input, matmul_input);
1680   Output matmul_1 =
1681       ops::MatMul(scope.WithOpName("matmul_1"), matmul_input, matmul_input);
1682 
1683   AddCtrlEdge(scope, increment_op, group_deps);
1684   AddCtrlEdge(scope, weights_0_update_op, increment_op);
1685   AddCtrlEdge(scope, weights_1_update_op, increment_op);
1686 
1687   AddCtrlEdge(scope, some_ctrl_input, weights_0_update_op);
1688   AddCtrlEdge(scope, some_ctrl_input, weights_1_update_op);
1689 
1690   AddCtrlEdge(scope, matmul_0, group_deps);
1691   AddCtrlEdge(scope, matmul_1, group_deps);
1692 
1693   AddCtrlEdge(scope, weights_0_update_op, matmul_0);
1694   AddCtrlEdge(scope, weights_1_update_op, matmul_1);
1695 
1696   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1697   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1698 
1699   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1700 
1701   std::unordered_map<string, string> clusters = GetClusters(*graph);
1702 
1703   EXPECT_NE(clusters["some_ctrl_input"], "");
1704   EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]);
1705   EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_1_update"]);
1706   EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1707   EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1708 }
1709 
1710 // Test a pattern where a special Identity node is driving consts in a loop.
1711 // Expect that the Identity node will not go into any clusters.  Note that we
1712 // create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
1713 // etc.) just enough to test the pattern, as a complete graph may be too
1714 // cumbersome and unnecessary.
TEST(XlaCompilationTest,DontClusterTheSpecialIdentityDrivingConstsInLoop)1715 TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
1716   Scope root = Scope::NewRootScope().ExitOnError();
1717 
1718   Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
1719   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1720   Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
1721   ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
1722 
1723   Output identity =
1724       ops::Identity(root.WithOpName("identity"), switch_node.output_true);
1725   Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
1726   root.graph()->AddControlEdge(identity.node(), const_node.node());
1727   Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
1728   Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
1729   Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
1730 
1731   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1732   TF_EXPECT_OK(root.ToGraph(graph.get()));
1733 
1734   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1735       &graph,
1736       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
1737   auto clusters = GetClusters(*graph);
1738 
1739   EXPECT_EQ(clusters["identity"], "");
1740 }
1741 
TEST(XlaCompilationTest,UnsupportedEnterExitPattern)1742 TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
1743   // Regression test for b/32350199, where the autoclustering code introduced a
1744   // deadlock in a graph containing a while loop.
1745   Scope root = Scope::NewRootScope().ExitOnError();
1746   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
1747   auto enter_0 = ops::internal::Enter(root.WithOpName("enter_a"), a, "frame");
1748   auto exit_0 = ops::internal::Exit(root.WithOpName("exit_a"), enter_0);
1749   auto tanh = ops::Tanh(root.WithOpName("tanh"), exit_0);
1750   auto enter_1 =
1751       ops::internal::Enter(root.WithOpName("enter_1"), tanh, "frame");
1752   auto exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
1753 
1754   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1755   TF_EXPECT_OK(root.ToGraph(graph.get()));
1756 
1757   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1758   auto clusters = GetClusters(*graph);
1759 
1760   // Nothing should be compiled.
1761   EXPECT_EQ(0, clusters.size());
1762 }
1763 
TEST(XlaCompilationTest,DeterministicClusterNames)1764 TEST(XlaCompilationTest, DeterministicClusterNames) {
1765   auto create_graph =
1766       [](absl::string_view output_name) -> std::unique_ptr<Graph> {
1767     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1768     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1769     Tensor t(DT_FLOAT, TensorShape());
1770     t.scalar<float>()() = 0.0f;
1771     Node* a = ops::SourceOp("Const", builder.opts()
1772                                          .WithName("A")
1773                                          .WithAttr("dtype", DT_FLOAT)
1774                                          .WithAttr("value", t));
1775     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
1776     ops::BinaryOp("MatMul", a, b, builder.opts().WithName(output_name));
1777     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
1778     return graph;
1779   };
1780 
1781   // Checks if two cluster names match for all parts except their sequence
1782   // number. Names are expected as: cluster_fp_seq#
1783   auto cluster_names_match = [](absl::string_view lhs_cluster_name,
1784                                 absl::string_view rhs_cluster_name) {
1785     std::vector<absl::string_view> lhs_cluster_name_parts =
1786         absl::StrSplit(lhs_cluster_name, '_');
1787     std::vector<absl::string_view> rhs_cluster_name_parts =
1788         absl::StrSplit(rhs_cluster_name, '_');
1789 
1790     if (lhs_cluster_name_parts.size() != 3) {
1791       return errors::FailedPrecondition("unexpected lhs cluster name: ",
1792                                         lhs_cluster_name);
1793     }
1794 
1795     if (rhs_cluster_name_parts.size() != 3) {
1796       return errors::FailedPrecondition("unexpected rhs cluster name: ",
1797                                         rhs_cluster_name);
1798     }
1799 
1800     if (lhs_cluster_name_parts[0] != rhs_cluster_name_parts[0] ||
1801         lhs_cluster_name_parts[1] != rhs_cluster_name_parts[1]) {
1802       return errors::FailedPrecondition(
1803           "Cluster names mismatch: lhs: ", lhs_cluster_name,
1804           " rhs: ", rhs_cluster_name);
1805     }
1806 
1807     if (lhs_cluster_name_parts[2] == rhs_cluster_name_parts[2]) {
1808       return errors::FailedPrecondition(
1809           "cluster sequence numbers are the same: lhs: ", lhs_cluster_name,
1810           " rhs: ", rhs_cluster_name);
1811     }
1812 
1813     return OkStatus();
1814   };
1815 
1816   testing::ResetClusterSequenceNumber();
1817   auto options = MarkForCompilationPassTestHelper::Options()
1818                      .WithDeterministicClusterNames();
1819 
1820   // Cluster the same graphs twice so we can observe that the prefix contains
1821   // the stable fingerprint.
1822   auto graph0 = create_graph("out");
1823   auto graph1 = create_graph("differs");
1824   auto graph2 = create_graph("out");      // same as graph0
1825   auto graph3 = create_graph("differs");  // same as graph1
1826 
1827   TF_ASSERT_OK(
1828       MarkForCompilationPassTestHelper::MarkForCompilation(&graph0, options));
1829   auto clusters0 = GetClusterNames(*graph0);
1830   ASSERT_EQ(clusters0.size(), 1);
1831 
1832   TF_ASSERT_OK(
1833       MarkForCompilationPassTestHelper::MarkForCompilation(&graph1, options));
1834   auto clusters1 = GetClusterNames(*graph1);
1835   ASSERT_EQ(clusters1.size(), 1);
1836 
1837   TF_ASSERT_OK(
1838       MarkForCompilationPassTestHelper::MarkForCompilation(&graph2, options));
1839   auto clusters2 = GetClusterNames(*graph2);
1840   ASSERT_EQ(clusters2.size(), 1);
1841 
1842   TF_ASSERT_OK(
1843       MarkForCompilationPassTestHelper::MarkForCompilation(&graph3, options));
1844   auto clusters3 = GetClusterNames(*graph3);
1845   ASSERT_EQ(clusters3.size(), 1);
1846 
1847   // clusters0 and clusters2 should be the same
1848   TF_EXPECT_OK(cluster_names_match(*clusters0.begin(), *clusters2.begin()));
1849 
1850   // clusters1 and clusters3 should also be the same
1851   TF_EXPECT_OK(cluster_names_match(*clusters1.begin(), *clusters3.begin()));
1852 
1853   // clusters0/2 should differ from clusters1/3
1854 }
1855 
1856 namespace {
MakeStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)1857 Node* MakeStageNode(GraphDefBuilder& builder, string name,
1858                     std::initializer_list<DataType> dtypes,
1859                     absl::Span<const ops::NodeOut> values) {
1860   auto opts = builder.opts()
1861                   .WithName(std::move(name))
1862                   .WithAttr("dtypes", std::move(dtypes));
1863   if (opts.HaveError()) {
1864     return nullptr;
1865   }
1866 
1867   NodeBuilder node_builder(name, "Stage", opts.op_registry());
1868   node_builder.Input(values);
1869   return opts.FinalizeBuilder(&node_builder);
1870 }
1871 }  // namespace
1872 
TEST(XlaCompilationTest,StagePipelinePreservedByClusterScopingPass)1873 TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
1874   auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
1875     // Construct a graph as below with two pipeline stages and test that nodes
1876     // in different stages will not be merged if ClusterScopingPass is on.
1877     //
1878     //       b
1879     //       |
1880     //       v
1881     // a -> add0 -> relu0 -> stage
1882     //
1883     //             b
1884     //             |
1885     //             v
1886     // unstage -> add1 -> relu1
1887     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1888     Node* a = ops::SourceOp("Const", builder.opts()
1889                                          .WithName("a")
1890                                          .WithAttr("dtype", DT_FLOAT)
1891                                          .WithAttr("value", Tensor()));
1892     Node* b = ops::SourceOp("Const", builder.opts()
1893                                          .WithName("b")
1894                                          .WithAttr("dtype", DT_FLOAT)
1895                                          .WithAttr("value", Tensor()));
1896     Node* unstage = ops::SourceOp(
1897         "Unstage",
1898         builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
1899 
1900     Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
1901     Node* add1 =
1902         ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
1903     Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
1904     ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
1905     MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
1906 
1907     return GraphDefBuilderToGraph(builder, graph->get());
1908   };
1909 
1910   // All nodes go into the same cluster if ClusterScopingPass is off.
1911   {
1912     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1913     TF_ASSERT_OK(build_staged_graph(&graph));
1914 
1915     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1916         &graph,
1917         MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
1918 
1919     std::unordered_map<string, string> clusters = GetClusters(*graph);
1920     EXPECT_EQ(clusters["add0"], clusters["add1"]);
1921     EXPECT_EQ(clusters["add0"], clusters["relu1"]);
1922     EXPECT_EQ(clusters["relu0"], clusters["add1"]);
1923     EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
1924   }
1925 
1926   // By default, ClusterScopingPass is on and different pipeline stages should
1927   // not be merged.
1928   {
1929     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1930     TF_ASSERT_OK(build_staged_graph(&graph));
1931 
1932     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1933 
1934     std::unordered_map<string, string> clusters = GetClusters(*graph);
1935     EXPECT_NE(clusters["add0"], clusters["add1"]);
1936     EXPECT_NE(clusters["add0"], clusters["relu1"]);
1937     EXPECT_NE(clusters["relu0"], clusters["add1"]);
1938     EXPECT_NE(clusters["relu0"], clusters["relu1"]);
1939   }
1940 }
TEST(XlaCompilationTest,XLALiteAllowlist)1941 TEST(XlaCompilationTest, XLALiteAllowlist) {
1942   auto* allowlist_table = tensorflow::GetAllowlistTable();
1943   absl::flat_hash_set<string> hallowlist;
1944   std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1945   absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1946 
1947   // Check that all the operations in the table are existing TF operations
1948   for (auto pair : *allowlist_table) {
1949     hallowlist.insert(pair.second.begin(), pair.second.end());
1950     for (auto op : pair.second) {
1951       ASSERT_TRUE(all_ops.contains(op));
1952     }
1953   }
1954 
1955   // Check that all registered XLA operation are in the allowlist
1956   // table or are known to not be in it.
1957 
1958   absl::flat_hash_set<string> known_not_in_list =
1959       tensorflow::testing::GetKnownXLAAllowlistOp();
1960   std::vector<string> unknow_op;
1961   for (string op : vall_ops) {
1962     if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
1963       unknow_op.push_back(op);
1964     }
1965   }
1966   EXPECT_TRUE(unknow_op.empty())
1967       << "Someone added support for a new TF operations inside XLA. They must "
1968          "be included in the XLALite allowlist or denylist:\n"
1969       << absl::StrJoin(unknow_op, "\n");
1970 }
1971 }  // namespace
1972 }  // namespace tensorflow
1973