1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18 #include <string>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/cc/framework/ops.h"
27 #include "tensorflow/cc/ops/array_ops.h"
28 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
29 #include "tensorflow/cc/ops/function_ops.h"
30 #include "tensorflow/cc/ops/functional_ops.h"
31 #include "tensorflow/cc/ops/list_ops.h"
32 #include "tensorflow/cc/ops/resource_variable_ops.h"
33 #include "tensorflow/cc/ops/sendrecv_ops.h"
34 #include "tensorflow/cc/ops/standard_ops.h"
35 #include "tensorflow/compiler/jit/defs.h"
36 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
37 #include "tensorflow/compiler/jit/node_matchers.h"
38 #include "tensorflow/compiler/jit/xla_cluster_util.h"
39 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
40 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
41 #include "tensorflow/core/common_runtime/graph_constructor.h"
42 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/graph_def_builder.h"
47 #include "tensorflow/core/lib/core/status_test_util.h"
48 #include "tensorflow/core/platform/errors.h"
49 #include "tensorflow/core/platform/test.h"
50
51 using ::tensorflow::testing::FindNodeByName;
52
53 namespace tensorflow {
54 namespace {
55
__anon5a56cdc30202null56 static bool Initialized = [] {
57 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
58 return true;
59 }();
60
61 REGISTER_OP("UncompilableNullary").Output("o: float");
62 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
63
GetClusters(const Graph & graph)64 std::unordered_map<string, string> GetClusters(const Graph& graph) {
65 std::unordered_map<string, string> ids;
66 for (Node* node : graph.nodes()) {
67 string cluster;
68 if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
69 CHECK(!cluster.empty());
70 ids[node->name()] = cluster;
71 }
72 }
73
74 if (VLOG_IS_ON(2)) {
75 VLOG(2) << "Clusters:";
76 for (const auto& p : ids) {
77 VLOG(2) << " " << p.first << " -> " << p.second;
78 }
79 }
80 return ids;
81 }
82
GetClusterNames(const Graph & graph)83 std::set<string> GetClusterNames(const Graph& graph) {
84 std::set<string> names;
85 for (Node* node : graph.nodes()) {
86 string cluster;
87 if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
88 CHECK(!cluster.empty());
89 names.insert(cluster);
90 }
91 }
92 return names;
93 }
94
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)95 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
96 const Graph& g, std::vector<string>* cluster_names = nullptr) {
97 CHECK(cluster_names == nullptr || cluster_names->empty());
98 absl::flat_hash_map<string, std::vector<string>> cluster_sets;
99 for (const auto& p : GetClusters(g)) {
100 cluster_sets[p.second].push_back(p.first);
101 }
102 for (auto& p : cluster_sets) {
103 if (cluster_names != nullptr) {
104 cluster_names->push_back(p.first);
105 }
106 std::sort(p.second.begin(), p.second.end());
107 }
108 if (cluster_names != nullptr) {
109 std::sort(cluster_names->begin(), cluster_names->end());
110 }
111 return cluster_sets;
112 }
113
TEST(XlaCompilationTest,Chains)114 TEST(XlaCompilationTest, Chains) {
115 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
116 {
117 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
118 Node* a =
119 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
120 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
121 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
122 Node* d =
123 ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
124 Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
125 ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
126 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
127 }
128
129 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
130 auto clusters = GetClusters(*graph);
131 EXPECT_EQ(4, clusters.size());
132 EXPECT_EQ(clusters["B"], clusters["C"]);
133 EXPECT_EQ(clusters["E"], clusters["F"]);
134 EXPECT_NE(clusters["B"], clusters["E"]);
135 EXPECT_TRUE(clusters.find("A") == clusters.cend());
136 EXPECT_TRUE(clusters.find("D") == clusters.cend());
137 }
138
TEST(XlaCompilationTest,UncompilableCycles)139 TEST(XlaCompilationTest, UncompilableCycles) {
140 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
141 {
142 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
143 Node* a = ops::SourceOp("Const", builder.opts()
144 .WithName("A")
145 .WithAttr("dtype", DT_FLOAT)
146 .WithAttr("value", Tensor()));
147 Node* b =
148 ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
149 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
150 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
151 }
152
153 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
154 auto clusters = GetClusters(*graph);
155
156 EXPECT_TRUE(clusters.empty());
157 }
158
TEST(XlaCompilationTest,CompilableCycles)159 TEST(XlaCompilationTest, CompilableCycles) {
160 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
161 {
162 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
163 Node* a = ops::SourceOp("Const", builder.opts()
164 .WithName("A")
165 .WithAttr("dtype", DT_FLOAT)
166 .WithAttr("value", Tensor()));
167 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
168 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
169 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
170 }
171
172 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
173 auto clusters = GetClusters(*graph);
174
175 EXPECT_EQ(3, clusters.size());
176 EXPECT_EQ(clusters["A"], clusters["B"]);
177 EXPECT_EQ(clusters["A"], clusters["C"]);
178 }
179
TEST(XlaCompilationTest,StringUnsupported)180 TEST(XlaCompilationTest, StringUnsupported) {
181 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
182 {
183 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
184 Node* a = ops::SourceOp(
185 "Const", builder.opts()
186 .WithName("A")
187 .WithAttr("dtype", DT_STRING)
188 .WithAttr("value", Tensor(DT_STRING, TensorShape())));
189 Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
190 ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
191 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
192 }
193
194 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
195 auto clusters = GetClusters(*graph);
196 EXPECT_TRUE(clusters.empty());
197 }
198
TEST(XlaCompilationTest,WhereUnsupported)199 TEST(XlaCompilationTest, WhereUnsupported) {
200 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
201 {
202 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
203 Node* a = ops::SourceOp("Const", builder.opts()
204 .WithName("A")
205 .WithAttr("dtype", DT_INT32)
206 .WithAttr("value", Tensor()));
207 Node* b = ops::UnaryOp("Where", a, builder.opts().WithName("B"));
208 ops::BinaryOp("Gather", b, a, builder.opts().WithName("C"));
209 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
210 }
211
212 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
213 auto clusters = GetClusters(*graph);
214 EXPECT_TRUE(!clusters.empty());
215 }
216
TEST(XlaCompilationTest,HalfSupported)217 TEST(XlaCompilationTest, HalfSupported) {
218 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
219 {
220 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
221 Tensor t(DT_HALF, TensorShape());
222 t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
223 Node* a = ops::SourceOp("Const", builder.opts()
224 .WithName("A")
225 .WithAttr("dtype", DT_HALF)
226 .WithAttr("value", t));
227 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
228 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
229 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
230 }
231
232 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
233 auto clusters = GetClusters(*graph);
234 EXPECT_FALSE(clusters.empty());
235 }
236
237 // Tests that PartitionedCalls are only marked for compilation if every node
238 // inside the function can be compiled.
TEST(XlaCompilationTest,PartitionedCallUnsupported)239 TEST(XlaCompilationTest, PartitionedCallUnsupported) {
240 FunctionDef compilable = FunctionDefHelper::Define(
241 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
242 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
243 FunctionDef uncompilable =
244 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
245 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
246
247 FunctionDefLibrary flib;
248 *flib.add_function() = compilable;
249 *flib.add_function() = uncompilable;
250 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
251
252 std::unique_ptr<Graph> graph(new Graph(&flib_def));
253 Scope root = Scope::NewRootScope().ExitOnError();
254 Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
255
256 NameAttrList b_name_attr;
257 b_name_attr.set_name("CompilableFn");
258 ops::PartitionedCall b(root.WithOpName("B"), {a, a}, {DT_FLOAT}, b_name_attr);
259 NameAttrList c_name_attr;
260 c_name_attr.set_name("UncompilableFn");
261
262 ops::PartitionedCall c(root.WithOpName("C"), {a}, {DT_FLOAT}, c_name_attr);
263 Output d = ops::Add(root.WithOpName("D"), b.output.front(), c.output.front());
264
265 TF_ASSERT_OK(root.ToGraph(graph.get()));
266 TF_ASSERT_OK(
267 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
268 auto clusters = GetClusters(*graph);
269
270 EXPECT_EQ(2, clusters.size());
271 EXPECT_FALSE(clusters["B"].empty());
272 EXPECT_TRUE(clusters["C"].empty());
273 EXPECT_EQ(clusters["B"], clusters["D"]);
274 }
275
TEST(XlaCompilationTest,FunctionCalls)276 TEST(XlaCompilationTest, FunctionCalls) {
277 FunctionDef compilable = FunctionDefHelper::Define(
278 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
279 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
280 FunctionDef uncompilable =
281 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
282 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
283 FunctionDef noinline = compilable;
284 noinline.mutable_signature()->set_name("NoInlineFn");
285 AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
286
287 FunctionDefLibrary flib;
288 *flib.add_function() = compilable;
289 *flib.add_function() = uncompilable;
290 *flib.add_function() = noinline;
291 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
292
293 std::unique_ptr<Graph> graph(new Graph(&flib_def));
294 {
295 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
296 Node* a =
297 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
298 Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
299 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
300 ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
301 ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
302 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
303 }
304
305 TF_ASSERT_OK(
306 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
307 auto clusters = GetClusters(*graph);
308
309 EXPECT_EQ(2, clusters.size());
310 EXPECT_FALSE(clusters["C"].empty());
311 EXPECT_EQ(clusters["C"], clusters["E"]);
312 EXPECT_TRUE(clusters.find("A") == clusters.cend());
313 EXPECT_TRUE(clusters.find("B") == clusters.cend());
314 EXPECT_TRUE(clusters.find("D") == clusters.cend());
315 }
316
TEST(XlaCompilationTest,CallXlaDeviceFuncWithResourceOp)317 TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
318 FunctionDef compilable = FunctionDefHelper::Define(
319 "FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {},
320 {{{"assign_op"},
321 "AssignVariableOp",
322 {"var", "val"},
323 {{"dtype", DT_FLOAT}}},
324 {{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}});
325
326 FunctionDefLibrary flib;
327 *flib.add_function() = compilable;
328 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
329
330 std::unique_ptr<Graph> graph(new Graph(&flib_def));
331 {
332 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
333 Node* resource =
334 ops::SourceOp("VarHandleOp", builder.opts()
335 .WithName("varhandle")
336 .WithAttr("dtype", DT_FLOAT)
337 .WithAttr("shape", TensorShape({})));
338
339 Tensor const_tensor(DT_FLOAT, TensorShape({}));
340 const_tensor.scalar<float>()() = 42.0f;
341 Node* value = ops::SourceOp("Const", builder.opts()
342 .WithName("const")
343 .WithAttr("value", const_tensor)
344 .WithAttr("dtype", DT_FLOAT));
345
346 Node* call = ops::BinaryOp("FnWithResourceOp", resource, value,
347 builder.opts().WithName("A"));
348 Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0"));
349 Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1"));
350 ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2"));
351 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
352 }
353
354 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
355 testing::FindNodeByName(graph.get(), "A")
356 ->set_assigned_device_name(xla_cpu_device);
357 testing::FindNodeByName(graph.get(), "tanh0")
358 ->set_assigned_device_name(xla_cpu_device);
359 testing::FindNodeByName(graph.get(), "tanh1")
360 ->set_assigned_device_name(xla_cpu_device);
361 testing::FindNodeByName(graph.get(), "tanh2")
362 ->set_assigned_device_name(xla_cpu_device);
363
364 TF_ASSERT_OK(
365 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
366 auto clusters = GetClusters(*graph);
367
368 EXPECT_NE(clusters["A"], "");
369 }
370
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)371 static Status GradForUnaryCwise(FunctionDef* g,
372 std::vector<FunctionDefHelper::Node> nodes) {
373 for (auto& n : nodes) {
374 if (n.attr.empty()) {
375 n.attr = {{"T", DT_FLOAT}};
376 }
377 }
378 *g = FunctionDefHelper::Define(
379 // Arg defs
380 {"x: float", "dy: float"},
381 // Ret val defs
382 {"dx: float"},
383 // Attr defs
384 {},
385 // Nodes
386 nodes);
387 return OkStatus();
388 }
389
390 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)391 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
392 // clang-format off
393 return GradForUnaryCwise(g, {
394 {{"y"}, "Tanh", {"x"}},
395 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
396 FunctionDefHelper::Const("one", 1.0f),
397 {{"a"}, "Sub", {"one", "y2"}},
398 {{"dx"}, "Mul", {"dy", "a"}},
399 });
400 // clang-format on
401 }
402 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
403
404 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)405 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
406 // clang-format off
407 return GradForUnaryCwise(g, {
408 {{"y"}, "Tanh", {"x"}},
409 {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
410 FunctionDefHelper::Const("one", 1.0f),
411 {{"a"}, "Sub", {"one", "y2"}},
412 {{"dx"}, "Mul", {"dy", "a"}},
413 });
414 // clang-format on
415 }
416 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
417
TEST(XlaCompilationTest,SymbolicGradients)418 TEST(XlaCompilationTest, SymbolicGradients) {
419 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
420 {
421 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
422 Node* a =
423 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
424
425 // Builds a Symbolic gradient for Supported
426 NodeBuilder b_builder("B", "SymbolicGradient",
427 builder.opts().op_registry());
428 NameAttrList b_name_attr;
429 b_name_attr.set_name("Supported");
430 b_builder.Attr("f", b_name_attr);
431 b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
432 b_builder.Attr("Tout", {DT_FLOAT});
433 b_builder.Input({a, a});
434 Node* b = builder.opts().FinalizeBuilder(&b_builder);
435
436 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
437
438 // Builds a Symbolic gradient for Unsupported
439 NodeBuilder d_builder("D", "SymbolicGradient",
440 builder.opts().op_registry());
441 NameAttrList d_name_attr;
442 d_name_attr.set_name("Unsupported");
443 d_builder.Attr("f", d_name_attr);
444 d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
445 d_builder.Attr("Tout", {DT_FLOAT});
446 d_builder.Input({c, c});
447 builder.opts().FinalizeBuilder(&d_builder);
448
449 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
450 }
451
452 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
453 auto clusters = GetClusters(*graph);
454
455 EXPECT_EQ(2, clusters.size());
456 EXPECT_FALSE(clusters["B"].empty());
457 EXPECT_EQ(clusters["B"], clusters["C"]);
458 EXPECT_TRUE(clusters.find("A") == clusters.cend());
459 EXPECT_TRUE(clusters.find("D") == clusters.cend());
460 }
461
TEST(XlaCompilationTest,Loops)462 TEST(XlaCompilationTest, Loops) {
463 // Regression test for b/32350199, where the autoclustering code introduced a
464 // deadlock in a graph containing a while loop.
465 Scope root = Scope::NewRootScope().ExitOnError();
466 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
467 auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
468 auto c = ops::Add(root.WithOpName("C"), a, b);
469 auto enter = ops::internal::Enter(root, c, "aframe");
470 auto next_iter = ops::NextIteration(root, enter);
471 auto exit = ops::internal::Exit(root, next_iter);
472 auto d = ops::Add(root.WithOpName("D"), c, exit);
473
474 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
475 TF_EXPECT_OK(root.ToGraph(graph.get()));
476
477 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
478 auto clusters = GetClusters(*graph);
479
480 // Nothing should be compiled. In particular, 'd' and 'c' must not be
481 // compiled.
482 EXPECT_EQ(0, clusters.size());
483 }
484
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)485 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
486 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
487 {
488 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
489 Node* a = ops::SourceOp("Const", builder.opts()
490 .WithName("A")
491 .WithAttr("dtype", DT_FLOAT)
492 .WithAttr("value", Tensor())
493 .WithAttr(kXlaScopeAttr, "ScopeA"));
494 Node* b = ops::UnaryOp(
495 "Relu", a,
496 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
497 ops::BinaryOp(
498 "MatMul", a, b,
499 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
500 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
501 }
502
503 FunctionDefLibrary flib;
504 FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
505 TF_ASSERT_OK(
506 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
507 auto clusters = GetClusters(*graph);
508
509 // The computation is: C = A + relu(A)
510 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
511 // In this case, the GlobalJitLevel overrides the scopes to cluster while
512 // ignoring scopes.
513 EXPECT_EQ(3, clusters.size());
514 EXPECT_EQ(clusters["A"], clusters["B"]);
515 EXPECT_EQ(clusters["A"], clusters["C"]);
516 }
517
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)518 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
519 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
520 {
521 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
522 Node* a = ops::SourceOp("Const", builder.opts()
523 .WithName("A")
524 .WithAttr("dtype", DT_FLOAT)
525 .WithAttr("value", Tensor())
526 .WithAttr(kXlaScopeAttr, "ScopeA"));
527 Node* b = ops::UnaryOp(
528 "Relu", a,
529 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
530 ops::BinaryOp(
531 "MatMul", a, b,
532 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
533 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
534 }
535
536 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
537 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
538 auto clusters = GetClusters(*graph);
539
540 // The computation is: C = A + relu(A)
541 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
542 // In this case, we cannot fuse anything, and there are no clusters.
543 EXPECT_EQ(0, clusters.size());
544 }
545
TEST(XlaCompilationTest,CyclesWithSplittingScopes)546 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
547 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
548 {
549 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
550 Node* a = ops::SourceOp("Const", builder.opts()
551 .WithName("A")
552 .WithAttr("dtype", DT_FLOAT)
553 .WithAttr("value", Tensor())
554 .WithAttr(kXlaCompileAttr, true)
555 .WithAttr(kXlaScopeAttr, "Scope1"));
556 Node* b = ops::UnaryOp("Relu", a,
557 builder.opts()
558 .WithName("B")
559 .WithAttr(kXlaCompileAttr, true)
560 .WithAttr(kXlaScopeAttr, "Scope1"));
561 Node* c = ops::BinaryOp("MatMul", a, b,
562 builder.opts()
563 .WithName("C")
564 .WithAttr(kXlaCompileAttr, true)
565 .WithAttr(kXlaScopeAttr, "Scope2"));
566 ops::BinaryOp("Add", b, c,
567 builder.opts()
568 .WithName("D")
569 .WithAttr(kXlaCompileAttr, true)
570 .WithAttr(kXlaScopeAttr, "Scope2"));
571 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
572 }
573
574 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
575 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
576 auto clusters = GetClusters(*graph);
577
578 // The computation is: D = relu(A) + (A @ relu(A))
579 // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
580 // In this case, we can fuse the A and relu(A), and we can fuse the
581 // second half of the operations; there are two clusters.
582 EXPECT_EQ(4, clusters.size());
583 EXPECT_EQ(clusters["A"], clusters["B"]);
584 EXPECT_NE(clusters["A"], clusters["C"]);
585 EXPECT_EQ(clusters["C"], clusters["D"]);
586 }
587
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)588 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
589 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
590 {
591 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
592 Node* a = ops::SourceOp("Const", builder.opts()
593 .WithName("A")
594 .WithAttr("dtype", DT_FLOAT)
595 .WithAttr("value", Tensor())
596 .WithAttr(kXlaCompileAttr, true)
597 .WithAttr(kXlaScopeAttr, "ScopeA"));
598 Node* b = ops::UnaryOp("Relu", a,
599 builder.opts()
600 .WithName("B")
601 .WithAttr(kXlaCompileAttr, true)
602 .WithAttr(kXlaScopeAttr, "ScopeB"));
603 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
604 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
605 }
606
607 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
608 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
609 auto clusters = GetClusters(*graph);
610
611 // The computation is: C = A @ relu(A)
612 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
613 // In this case, we cannot fuse anything.
614 EXPECT_EQ(3, clusters.size());
615 EXPECT_NE(clusters["A"], clusters["B"]);
616 EXPECT_EQ(clusters["B"], clusters["C"]);
617 }
618
TEST(XlaCompilationTest,DontClusterNodesWithMismatchingDeadness)619 TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
620 Scope root = Scope::NewRootScope().ExitOnError();
621
622 Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
623 Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
624
625 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
626
627 ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
628 ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
629
630 Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
631 Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
632
633 Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
634 Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
635
636 Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
637
638 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
639 TF_EXPECT_OK(root.ToGraph(graph.get()));
640
641 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
642 &graph,
643 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
644 auto clusters = GetClusters(*graph);
645
646 EXPECT_NE(clusters["tan_a0"], "");
647 EXPECT_NE(clusters["tan_a1"], "");
648 EXPECT_NE(clusters["tan_b0"], "");
649 EXPECT_NE(clusters["tan_b1"], "");
650
651 EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
652 EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
653
654 EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
655 }
656
TEST(XlaCompilationTest,ClusterNodesWithMismatchingInputDeadness)657 TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
658 Scope root = Scope::NewRootScope().ExitOnError();
659
660 Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
661 Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
662
663 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
664
665 ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
666 ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
667
668 Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
669 switch_b.output_true);
670 Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
671 switch_b.output_true);
672 Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
673
674 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
675 TF_EXPECT_OK(root.ToGraph(graph.get()));
676
677 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
678 &graph,
679 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
680 auto clusters = GetClusters(*graph);
681
682 EXPECT_NE(clusters["add_a"], "");
683 EXPECT_NE(clusters["add_b"], "");
684 EXPECT_NE(clusters["add_c"], "");
685
686 EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
687 EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
688 }
689
690 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)691 Node* MakeRead(const Scope& scope, const string& id,
692 Node** var_handle_op = nullptr) {
693 Output var_handle =
694 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
695 Output read =
696 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
697 if (var_handle_op) {
698 *var_handle_op = var_handle.node();
699 }
700 return read.node();
701 }
702
MakeWrite(const Scope & scope,const string & id)703 Node* MakeWrite(const Scope& scope, const string& id) {
704 Output var_handle =
705 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
706 Output value_to_write =
707 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
708 ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
709 var_handle, value_to_write);
710 return assign_op.operation.node();
711 }
712
MakeNeutral(const Scope & scope,const string & id)713 Node* MakeNeutral(const Scope& scope, const string& id) {
714 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
715 }
716 } // namespace
717
TEST(XlaCompilationTest,ResourcesClusteringAllowed)718 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
719 Scope root = Scope::NewRootScope().ExitOnError();
720
721 Node* read = MakeRead(root, "R");
722 Node* write = MakeWrite(root, "W");
723
724 root.graph()->AddControlEdge(read, write);
725
726 FixupSourceAndSinkEdges(root.graph());
727 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
728 TF_EXPECT_OK(root.ToGraph(graph.get()));
729 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
730 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
731 GetClusterSets(*graph);
732 ASSERT_EQ(cluster_sets.size(), 1);
733 std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
734 "ValueToAssignW"};
735 ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
736 }
737
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)738 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
739 Scope root = Scope::NewRootScope().ExitOnError();
740
741 Node* read = MakeRead(root, "R");
742 Node* write = MakeWrite(root, "W");
743
744 root.graph()->AddControlEdge(write, read);
745
746 FixupSourceAndSinkEdges(root.graph());
747 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
748 TF_EXPECT_OK(root.ToGraph(graph.get()));
749 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
750 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
751 GetClusterSets(*graph);
752 ASSERT_EQ(cluster_sets.size(), 0);
753 }
754
TEST(XlaCompilationTest,ChainOfOps)755 TEST(XlaCompilationTest, ChainOfOps) {
756 Scope root = Scope::NewRootScope().ExitOnError();
757
758 Node* write_0 = MakeWrite(root, "W0");
759 Node* neutral_0 = MakeNeutral(root, "N0");
760 Node* read_0 = MakeRead(root, "R0");
761 Node* write_1 = MakeWrite(root, "W1");
762 Node* neutral_1 = MakeNeutral(root, "N1");
763 Node* read_1 = MakeRead(root, "R1");
764
765 root.graph()->AddControlEdge(write_0, neutral_0);
766 root.graph()->AddControlEdge(neutral_0, read_0);
767 root.graph()->AddControlEdge(read_0, write_1);
768 root.graph()->AddControlEdge(write_1, neutral_1);
769 root.graph()->AddControlEdge(neutral_1, read_1);
770
771 FixupSourceAndSinkEdges(root.graph());
772 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
773 TF_EXPECT_OK(root.ToGraph(graph.get()));
774 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
775
776 std::vector<string> cluster_names;
777 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
778 GetClusterSets(*graph, &cluster_names);
779
780 ASSERT_EQ(cluster_sets.size(), 1);
781
782 std::vector<string> expected_clustered_nodes_a = {
783 "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
784 ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
785 }
786
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)787 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
788 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
789 Scope root = Scope::NewRootScope().ExitOnError();
790 {
791 auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
792 NodeDefBuilder builder(name, "NoOp");
793 NodeDef def;
794 TF_CHECK_OK(builder.Finalize(&def));
795
796 Status status;
797 Node* node = graph->AddNode(def, &status);
798 TF_CHECK_OK(status);
799 return node;
800 };
801
802 Node* a = BuildNoopNode("a", graph.get());
803 Node* b = BuildNoopNode("b", graph.get());
804 Node* c = BuildNoopNode("c", graph.get());
805 graph->AddControlEdge(a, b);
806 graph->AddControlEdge(b, c);
807 graph->AddControlEdge(c, a);
808 }
809
810 TF_EXPECT_OK(root.ToGraph(graph.get()));
811
812 Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
813 EXPECT_FALSE(status.ok());
814 EXPECT_TRUE(absl::StrContains(status.ToString(),
815 "Edge from c to a would create a cycle.\n"
816 "+-> a\n"
817 "| b\n"
818 "+-- c\n"));
819 }
820
TEST(XlaCompilationTest,Retval)821 TEST(XlaCompilationTest, Retval) {
822 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
823 {
824 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
825 Node* a = ops::SourceOp("Const", builder.opts()
826 .WithName("A")
827 .WithAttr("dtype", DT_FLOAT)
828 .WithAttr("value", Tensor()));
829 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
830 ops::UnaryOp("_Retval", b,
831 builder.opts()
832 .WithName("R")
833 .WithAttr("T", DT_FLOAT)
834 .WithAttr("index", 0));
835
836 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
837 }
838
839 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
840 auto clusters = GetClusters(*graph);
841
842 EXPECT_TRUE(clusters.empty());
843 }
844
TEST(XlaCompilationTest,DontCountIdentityOps)845 TEST(XlaCompilationTest, DontCountIdentityOps) {
846 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
847 Scope root = Scope::NewRootScope().ExitOnError();
848 {
849 auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
850 auto b = ops::Identity(root.WithOpName("B"), a);
851 auto c = ops::Identity(root.WithOpName("C"), b);
852 auto r = ops::_Retval(root.WithOpName("R"), c, 0);
853 }
854 TF_ASSERT_OK(root.ToGraph(graph.get()));
855 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
856 auto clusters = GetClusters(*graph);
857
858 EXPECT_TRUE(clusters.empty());
859 }
860
TEST(XlaCompilationTest,ConstOp)861 TEST(XlaCompilationTest, ConstOp) {
862 // valid data type
863 {
864 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
865 Scope root = Scope::NewRootScope().ExitOnError();
866 auto c = ops::Const(root.WithOpName("const"), 0.5f);
867 c.node()->AddAttr(kXlaCompileAttr, true);
868 TF_ASSERT_OK(root.ToGraph(graph.get()));
869 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
870 EXPECT_EQ(1, GetClusters(*graph).size());
871 }
872
873 // invalid data type
874 {
875 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
876 Scope root = Scope::NewRootScope().ExitOnError();
877 auto c = ops::Const(root.WithOpName("const"), string("string"));
878 c.node()->AddAttr(kXlaCompileAttr, true);
879 TF_ASSERT_OK(root.ToGraph(graph.get()));
880 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
881 EXPECT_TRUE(GetClusters(*graph).empty());
882 }
883 }
884
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)885 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
886 Scope root = Scope::NewRootScope().ExitOnError();
887 Output variable = ops::Variable(root.WithOpName("variable"),
888 PartialTensorShape{}, DT_FLOAT);
889 Output read = ops::Identity(root.WithOpName("read"), variable);
890 Output neg = ops::Negate(root.WithOpName("negate"), read);
891 Output add = ops::Add(root.WithOpName("add"), neg, neg);
892 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
893
894 TF_ASSERT_OK(root.ToGraph(graph.get()));
895 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
896
897 std::unordered_map<string, string> clusters = GetClusters(*graph);
898
899 ASSERT_FALSE(clusters.empty());
900 string cluster_name = clusters.begin()->second;
901
902 std::unordered_map<string, string> expected_clusters(
903 {{"negate", cluster_name}, {"add", cluster_name}});
904 EXPECT_EQ(clusters, expected_clusters);
905 }
906
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)907 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
908 Scope root = Scope::NewRootScope().ExitOnError();
909 Output variable = ops::Variable(root.WithOpName("variable"),
910 PartialTensorShape{}, DT_FLOAT);
911 Output read = ops::Identity(root.WithOpName("read"), variable);
912 Output neg = ops::Negate(root.WithOpName("negate"), read);
913 Output identity = ops::Negate(root.WithOpName("identity"), neg);
914 Output add = ops::Add(root.WithOpName("add"), identity, neg);
915 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
916
917 TF_ASSERT_OK(root.ToGraph(graph.get()));
918 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
919
920 std::unordered_map<string, string> clusters = GetClusters(*graph);
921
922 ASSERT_FALSE(clusters.empty());
923 string cluster_name = clusters.begin()->second;
924
925 std::unordered_map<string, string> expected_clusters(
926 {{"negate", cluster_name},
927 {"identity", cluster_name},
928 {"add", cluster_name}});
929 EXPECT_EQ(clusters, expected_clusters);
930 }
931
TEST(XlaCompilationTest,ClusterControlTrigger)932 TEST(XlaCompilationTest, ClusterControlTrigger) {
933 Scope root = Scope::NewRootScope().ExitOnError();
934
935 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
936 "sender", 0, "receiver");
937 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
938 "sender", 0, "receiver");
939 Output const_a = ops::Const(root.WithOpName("const_a"), 42);
940
941 ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
942 ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
943 root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
944 root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
945 root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
946
947 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
948
949 TF_ASSERT_OK(root.ToGraph(graph.get()));
950 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
951
952 std::unordered_map<string, string> clusters = GetClusters(*graph);
953
954 // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
955 // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
956 // cluster it because of b/118970344.
957 EXPECT_TRUE(clusters.empty());
958 }
959
TEST(XlaCompilationTest,RandomShape)960 TEST(XlaCompilationTest, RandomShape) {
961 Scope root = Scope::NewRootScope().ExitOnError();
962 Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
963 Output shape =
964 ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
965 ops::Const(root.WithOpName("minval"), 1),
966 ops::Const(root.WithOpName("maxval"), 20));
967 Output reshape_input =
968 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
969 ops::Placeholder::Shape(TensorShape({500, 500})));
970 Output reshape =
971 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
972
973 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
974
975 TF_ASSERT_OK(root.ToGraph(graph.get()));
976 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
977
978 std::unordered_map<string, string> clusters = GetClusters(*graph);
979 EXPECT_EQ(clusters["shape"], "");
980 }
981
TEST(XlaCompilationTest,RandomShapeWithFunc)982 TEST(XlaCompilationTest, RandomShapeWithFunc) {
983 Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
984
985 FunctionDefLibrary flib_def;
986 FunctionDef func = FunctionDefHelper::Create(
987 /*function_name=*/"Stateful_func", /*in_def=*/{},
988 /*out_def=*/{"out: int32"},
989 /*attr_def*/
990 {}, /*node_def=*/
991 {FunctionDefHelper::Const("shape_shape", 2),
992 FunctionDefHelper::Const("minval", 1),
993 FunctionDefHelper::Const("maxval", 20),
994 {{"shape"},
995 "RandomUniformInt",
996 {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
997 {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
998 /*ret_def=*/{{"out", "shape:output:0"}});
999
1000 func.mutable_signature()->set_is_stateful(true);
1001 *flib_def.add_function() = std::move(func);
1002 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
1003 NodeDef call_node;
1004 call_node.set_name("fn_call");
1005 call_node.set_op("Stateful_func");
1006 Status status;
1007 Node* call = root.graph()->AddNode(call_node, &status);
1008 TF_ASSERT_OK(status);
1009
1010 Output shape = Output(call, 0);
1011 Output reshape_input =
1012 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
1013 ops::Placeholder::Shape(TensorShape({500, 500})));
1014 Output reshape =
1015 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
1016
1017 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1018 TF_ASSERT_OK(root.ToGraph(graph.get()));
1019 auto fld = std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
1020 flib_def);
1021 TF_ASSERT_OK(
1022 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
1023
1024 std::unordered_map<string, string> clusters = GetClusters(*graph);
1025 EXPECT_EQ(clusters["fn_call"], "");
1026 }
1027
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)1028 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
1029 absl::string_view xla_gpu_device =
1030 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1031
1032 Scope root = Scope::NewRootScope().ExitOnError();
1033 Output shape_shape =
1034 ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
1035 Output shape =
1036 ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
1037 ops::Const(root.WithOpName("test/minval"), 1),
1038 ops::Const(root.WithOpName("test/maxval"), 20));
1039 Output reshape_input =
1040 ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
1041 ops::Placeholder::Shape(TensorShape({500, 500})));
1042 Output reshape =
1043 ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
1044
1045 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1046 TF_ASSERT_OK(root.ToGraph(graph.get()));
1047
1048 for (Node* n : graph->nodes()) {
1049 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1050 n->set_assigned_device_name(string(xla_gpu_device));
1051 }
1052 }
1053 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1054
1055 std::unordered_map<string, string> clusters = GetClusters(*graph);
1056 EXPECT_EQ(clusters["test/shape_rng"], "");
1057 EXPECT_EQ(clusters["test/reshape"], "");
1058 }
1059
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)1060 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
1061 absl::string_view xla_gpu_device =
1062 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1063 Scope root = Scope::NewRootScope().ExitOnError();
1064 ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
1065 DT_INT32);
1066 Output zero = ops::Const(root.WithOpName("test/zero"), 0);
1067 ops::TensorArrayWrite tensor_array_write(
1068 root.WithOpName("test/write"), tensor_array.handle, zero,
1069 ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
1070 Output tensor_array_read =
1071 ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
1072 zero, tensor_array_write.flow_out, DT_INT32);
1073 Output reshape =
1074 ops::Reshape(root.WithOpName("test/reshape"),
1075 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
1076 tensor_array_read);
1077
1078 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1079 TF_ASSERT_OK(root.ToGraph(graph.get()));
1080
1081 for (Node* n : graph->nodes()) {
1082 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1083 n->set_assigned_device_name(string(xla_gpu_device));
1084 }
1085 }
1086 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1087
1088 std::unordered_map<string, string> clusters = GetClusters(*graph);
1089 EXPECT_NE(clusters["test/read"], "");
1090 EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
1091 }
1092
TEST(XlaCompilationTest,DontClusterMergingNodes)1093 TEST(XlaCompilationTest, DontClusterMergingNodes) {
1094 // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
1095 // on GPU1. However, it should not be clustered with the previous node on
1096 // GPU1, because that will serialize production of its inputs that should be
1097 // done in parallel.
1098 //
1099 // This graph is:
1100 // (Const0, Const0) -> MatMul0
1101 // (Const1, Const1) -> MatMul1
1102 // (MatMul0, MatMul1) -> MatMulCombined
1103 //
1104 // Device0: [Const0, Const0, MatMul0]
1105 // Device1: [Const1, Const1, MatMul1, MatMulCombined]
1106 //
1107 // Cluster0: [Const0, Const0, MatMul0]
1108 // Cluster1: [Const1, Const1, MatMul1]
1109 // Cluster2: [MatMulCombined]
1110 Scope root = Scope::NewRootScope().ExitOnError();
1111 absl::string_view xla_gpu_dev0 =
1112 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1113 absl::string_view xla_gpu_dev1 =
1114 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1115 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1116 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1117 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1118 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1119 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1120 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1121 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1122
1123 Output combined =
1124 ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
1125 TF_ASSERT_OK(root.ToGraph(graph.get()));
1126
1127 for (Node* n : graph->nodes()) {
1128 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1129 n->set_assigned_device_name(string(xla_gpu_dev0));
1130 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1131 n->set_assigned_device_name(string(xla_gpu_dev1));
1132 }
1133 }
1134 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1135
1136 // Each of the MatMuls should be in a separate cluster.
1137 std::unordered_map<string, string> clusters = GetClusters(*graph);
1138 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1139 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
1140 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
1141 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1142 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1143 }
1144
TEST(XlaCompilationTest,DontClusterMergingNodesOnCPU)1145 TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
1146 // This is similar to the 'DontClusterMergingNodes' above, except
1147 // MatMulCombined is placed on the CPU.
1148 Scope root = Scope::NewRootScope().ExitOnError();
1149 absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
1150 absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
1151 absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
1152 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1153 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1154 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1155 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1156 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1157 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1158 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1159
1160 Output combined =
1161 ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
1162 TF_ASSERT_OK(root.ToGraph(graph.get()));
1163
1164 for (Node* n : graph->nodes()) {
1165 if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
1166 n->set_assigned_device_name(string(xla_cpu_dev0));
1167 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1168 n->set_assigned_device_name(string(xla_gpu_dev0));
1169 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1170 n->set_assigned_device_name(string(xla_gpu_dev1));
1171 }
1172 }
1173 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1174
1175 // Each of the MatMuls should be in a separate cluster.
1176 std::unordered_map<string, string> clusters = GetClusters(*graph);
1177 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1178 EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
1179 EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
1180 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1181 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1182 }
1183
1184 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)1185 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
1186 // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
1187 // on GPU0. However, it should not be clustered with the next node on
1188 // GPU0, because that will prevent the node on GPU1 from beginning its work as
1189 // soon as the data has been produced.
1190 //
1191 // This graph is:
1192 // (Const0, Const0) -> MatMulSource
1193 // MatMulSource -> (MatMul0, MatMul1)
1194 //
1195 // Device0: [Const0, Const1, MatMulSource, MatMul0]
1196 // Device1: [MatMul1]
1197 //
1198 // Cluster0: [Const0, Const1, MatMulSource]
1199 // Cluster1: [MatMul0]
1200 // Cluster2: [MatMul1]
1201 Scope root = Scope::NewRootScope().ExitOnError();
1202 absl::string_view xla_gpu_dev0 =
1203 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1204 absl::string_view xla_gpu_dev1 =
1205 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1206 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1207 Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
1208 Output matmul_source =
1209 ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1210
1211 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1212 matmul_source);
1213 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1214 matmul_source);
1215
1216 TF_ASSERT_OK(root.ToGraph(graph.get()));
1217 for (Node* n : graph->nodes()) {
1218 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1219 n->set_assigned_device_name(string(xla_gpu_dev0));
1220 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1221 n->set_assigned_device_name(string(xla_gpu_dev1));
1222 }
1223 }
1224 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1225
1226 std::unordered_map<string, string> clusters = GetClusters(*graph);
1227 EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1228 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1229 EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1230
1231 // Improved Heuristics should prevent this probably.
1232 EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1233 }
1234
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1235 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1236 absl::string_view xla_cpu_device =
1237 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1238
1239 Scope root = Scope::NewRootScope().ExitOnError();
1240 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1241 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1242 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1243 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1244
1245 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1246 TF_ASSERT_OK(root.ToGraph(graph.get()));
1247
1248 for (Node* n : graph->nodes()) {
1249 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1250 n->set_assigned_device_name(string(xla_cpu_device));
1251 }
1252 }
1253 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1254
1255 std::unordered_map<string, string> clusters = GetClusters(*graph);
1256 EXPECT_NE(clusters["test/a"], "");
1257 EXPECT_NE(clusters["test/b"], "");
1258 EXPECT_NE(clusters["test/c"], "");
1259 }
1260
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1261 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1262 Scope root = Scope::NewRootScope().ExitOnError();
1263 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1264 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1265 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1266 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1267
1268 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1269 TF_ASSERT_OK(root.ToGraph(graph.get()));
1270
1271 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1272
1273 std::unordered_map<string, string> clusters = GetClusters(*graph);
1274 EXPECT_EQ(clusters["test/a"], "");
1275 EXPECT_EQ(clusters["test/b"], "");
1276 }
1277
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1278 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1279 absl::string_view xla_cpu_device =
1280 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1281
1282 Scope root = Scope::NewRootScope().ExitOnError();
1283 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1284 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1285 Output check =
1286 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1287 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1288 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1289
1290 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1291 TF_ASSERT_OK(root.ToGraph(graph.get()));
1292
1293 for (Node* n : graph->nodes()) {
1294 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1295 n->set_assigned_device_name(string(xla_cpu_device));
1296 }
1297 }
1298 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1299
1300 std::unordered_map<string, string> clusters = GetClusters(*graph);
1301 EXPECT_NE(clusters["test/check"], "");
1302 EXPECT_NE(clusters["test/greaterequal"], "");
1303 EXPECT_NE(clusters["test/assert"], "");
1304 }
1305
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1306 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1307 Scope root = Scope::NewRootScope().ExitOnError();
1308 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1309 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1310 Output check =
1311 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1312 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1313 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1314
1315 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1316 TF_ASSERT_OK(root.ToGraph(graph.get()));
1317
1318 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1319
1320 std::unordered_map<string, string> clusters = GetClusters(*graph);
1321 EXPECT_EQ(clusters["test/assert"], "");
1322 EXPECT_EQ(clusters["test/check"], "");
1323 }
1324
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1325 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1326 Scope root = Scope::NewRootScope().ExitOnError();
1327 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1328 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1329
1330 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1331 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1332
1333 Output tensor_list_reserve = ops::TensorListReserve(
1334 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1335
1336 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1337 TF_ASSERT_OK(root.ToGraph(graph.get()));
1338
1339 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1340
1341 std::unordered_map<string, string> clusters = GetClusters(*graph);
1342 EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1343 }
1344
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1345 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1346 Scope root = Scope::NewRootScope().ExitOnError();
1347 Output dummy_input =
1348 ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1349 Output variant_input =
1350 ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1351
1352 // Create one more node so that we don't avoid creating a cluster solely
1353 // because it would be trivial.
1354 Output dummy_cast =
1355 ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1356
1357 Output tensor_list_element_shape = ops::TensorListElementShape(
1358 root.WithOpName("test/tensor_list_element_shape"), variant_input,
1359 DT_INT32);
1360
1361 root.graph()->AddControlEdge(dummy_cast.node(),
1362 tensor_list_element_shape.node());
1363
1364 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1365 TF_ASSERT_OK(root.ToGraph(graph.get()));
1366
1367 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1368
1369 std::unordered_map<string, string> clusters = GetClusters(*graph);
1370 EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1371 }
1372
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1373 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1374 Scope root = Scope::NewRootScope().ExitOnError();
1375 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1376 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1377
1378 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1379 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1380
1381 Output tensor_list_reserve = ops::TensorListReserve(
1382 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1383
1384 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1385 TF_ASSERT_OK(root.ToGraph(graph.get()));
1386
1387 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1388 for (Node* n : graph->nodes()) {
1389 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1390 n->set_assigned_device_name(xla_cpu_device);
1391 }
1392 }
1393
1394 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1395
1396 std::unordered_map<string, string> clusters = GetClusters(*graph);
1397 EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1398 }
1399
1400 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1401 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1402 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1403 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1404
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1405 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1406 Scope root = Scope::NewRootScope().ExitOnError();
1407 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1408 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1409
1410 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1411 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1412 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1413
1414 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1415 TF_ASSERT_OK(root.ToGraph(graph.get()));
1416
1417 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1418 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1419 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1420
1421 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1422
1423 std::unordered_map<string, string> clusters = GetClusters(*graph);
1424
1425 EXPECT_NE(clusters["test/x"], "");
1426
1427 EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1428 EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1429 }
1430
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1431 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1432 Scope root = Scope::NewRootScope().ExitOnError();
1433 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1434 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1435
1436 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1437 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1438
1439 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1440 TF_ASSERT_OK(root.ToGraph(graph.get()));
1441
1442 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1443 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1444
1445 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1446
1447 std::unordered_map<string, string> clusters = GetClusters(*graph);
1448
1449 EXPECT_EQ(clusters["test/x"], "");
1450 EXPECT_EQ(clusters["test/y"], "");
1451 }
1452
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1453 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1454 Scope root = Scope::NewRootScope().ExitOnError();
1455 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1456 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1457
1458 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1459 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1460
1461 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1462 TF_ASSERT_OK(root.ToGraph(graph.get()));
1463
1464 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1465 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1466
1467 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1468
1469 std::unordered_map<string, string> clusters = GetClusters(*graph);
1470
1471 EXPECT_EQ(clusters["test/x"], "");
1472 EXPECT_EQ(clusters["test/y"], "");
1473 }
1474
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1475 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1476 Scope root = Scope::NewRootScope().ExitOnError();
1477 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1478 Node* var_handle;
1479 Node* resource_read = MakeRead(root, "read", &var_handle);
1480 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1481
1482 string resource_read_name = resource_read->name();
1483 string var_handle_name = var_handle->name();
1484
1485 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1486 TF_ASSERT_OK(root.ToGraph(graph.get()));
1487
1488 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1489 FindNodeByName(graph.get(), resource_read_name)
1490 ->set_assigned_device_name(kGPU0);
1491 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1492
1493 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1494
1495 std::unordered_map<string, string> clusters = GetClusters(*graph);
1496
1497 EXPECT_NE(clusters["test/b"], "");
1498 EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1499 }
1500
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1501 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1502 Scope root = Scope::NewRootScope().ExitOnError();
1503 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1504 Node* var_handle;
1505 Node* resource_read = MakeRead(root, "read", &var_handle);
1506 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1507
1508 string resource_read_name = resource_read->name();
1509 string var_handle_name = var_handle->name();
1510
1511 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1512 TF_ASSERT_OK(root.ToGraph(graph.get()));
1513
1514 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1515 FindNodeByName(graph.get(), resource_read_name)
1516 ->set_assigned_device_name(kCPU0);
1517 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1518
1519 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1520
1521 std::unordered_map<string, string> clusters = GetClusters(*graph);
1522
1523 EXPECT_EQ(clusters["test/b"], "");
1524 EXPECT_EQ(clusters[resource_read_name], "");
1525 }
1526
TEST(XlaCompilationTest,DontClusterNodesWithScopedAllocatorAttr)1527 TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
1528 Scope root = Scope::NewRootScope().ExitOnError();
1529 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1530 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1531
1532 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1533 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1534 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1535
1536 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1537 TF_ASSERT_OK(root.ToGraph(graph.get()));
1538
1539 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1540 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1541 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1542
1543 std::vector<int> scoped_allocator_value;
1544 scoped_allocator_value.push_back(0);
1545 scoped_allocator_value.push_back(155);
1546 FindNodeByName(graph.get(), "test/z")
1547 ->AddAttr("_scoped_allocator", scoped_allocator_value);
1548
1549 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1550
1551 std::unordered_map<string, string> clusters = GetClusters(*graph);
1552
1553 EXPECT_EQ(clusters["test/z"], "");
1554 }
1555
TEST(XlaCompilationTest,DontClusterNodesWithForwardFromAttr)1556 TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
1557 Scope root = Scope::NewRootScope().ExitOnError();
1558 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1559 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1560
1561 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1562 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1563 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1564
1565 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1566 TF_ASSERT_OK(root.ToGraph(graph.get()));
1567
1568 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1569 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1570 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1571
1572 FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
1573
1574 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1575
1576 std::unordered_map<string, string> clusters = GetClusters(*graph);
1577
1578 EXPECT_EQ(clusters["test/z"], "");
1579 }
1580
1581 // Note, this relies on other implementation details to test the
1582 // specific heuristic we care about here, so other changes might be at fault if
1583 // this CL breaks. What we care about is that if a ShapeConsumingOp can be
1584 // connected with a producer or consumer and cannot be clustered with both, it
1585 // should be clustered with the producer.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducer)1586 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
1587 Scope root = Scope::NewRootScope().ExitOnError();
1588 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1589 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1590
1591 Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1592 Output y = ops::Size(root.WithOpName("test/y"), x);
1593 Output z = ops::Add(root.WithOpName("test/z"), y, y);
1594
1595 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1596 TF_ASSERT_OK(root.ToGraph(graph.get()));
1597
1598 // Ensure that the "Size" op can only be clustered with either the producer or
1599 // consumer by putting them on different devices.
1600 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1601 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1602 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
1603
1604 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1605
1606 std::unordered_map<string, string> clusters = GetClusters(*graph);
1607
1608 EXPECT_NE(clusters["test/y"], "");
1609 EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1610 EXPECT_NE(clusters["test/z"], clusters["test/y"]);
1611 }
1612
1613 // Test that ShapeConsuming ops are still fully clustered whenever possible.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducerAndConsumer)1614 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
1615 Scope root = Scope::NewRootScope().ExitOnError();
1616 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1617 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1618
1619 Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1620 Output y = ops::Size(root.WithOpName("test/y"), x);
1621 Output z = ops::Add(root.WithOpName("test/z"), y, y);
1622
1623 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1624 TF_ASSERT_OK(root.ToGraph(graph.get()));
1625
1626 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1627
1628 std::unordered_map<string, string> clusters = GetClusters(*graph);
1629
1630 EXPECT_NE(clusters["test/y"], "");
1631 EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
1632 EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1633 }
1634
AddCtrlEdge(const Scope & scope,Operation a,Operation b)1635 void AddCtrlEdge(const Scope& scope, Operation a, Operation b) {
1636 scope.graph()->AddControlEdge(a.node(), b.node());
1637 }
1638
AddCtrlEdge(const Scope & scope,Output a,Operation b)1639 void AddCtrlEdge(const Scope& scope, Output a, Operation b) {
1640 AddCtrlEdge(scope, a.op(), b);
1641 }
1642
AddCtrlEdge(const Scope & scope,Operation a,Output b)1643 void AddCtrlEdge(const Scope& scope, Operation a, Output b) {
1644 AddCtrlEdge(scope, a, b.op());
1645 }
1646
1647 // Tests that we pick a good clustering for graphs that have an integer
1648 // increment operation control dependent on gradient update operations.
TEST(XlaCompilationTest,IterationIncrementAndGroupDeps)1649 TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
1650 Scope scope = Scope::NewRootScope().ExitOnError();
1651
1652 Output iter =
1653 ops::VarHandleOp(scope.WithOpName("iter"), DT_INT64, TensorShape({}));
1654 Output weights_0 = ops::VarHandleOp(scope.WithOpName("weights_0"), DT_FLOAT,
1655 TensorShape({1000}));
1656 Output weights_1 = ops::VarHandleOp(scope.WithOpName("weights_1"), DT_FLOAT,
1657 TensorShape({1000}));
1658
1659 // We update the weights by adding delta to them (to "simulate" a
1660 // ResourceApplyGradientDescent and similar things).
1661 Output delta = ops::Placeholder(scope.WithOpName("delta"), DT_FLOAT);
1662
1663 ops::AssignAddVariableOp increment_op(
1664 scope.WithOpName("IncrementIteration"), iter,
1665 ops::Const(scope.WithOpName("one"), static_cast<int64_t>(1)));
1666
1667 ops::AssignAddVariableOp weights_0_update_op(
1668 scope.WithOpName("weights_0_update"), weights_0, delta);
1669 ops::AssignAddVariableOp weights_1_update_op(
1670 scope.WithOpName("weights_1_update"), weights_1, delta);
1671
1672 ops::NoOp group_deps(scope.WithOpName("group_deps"));
1673
1674 ops::NoOp some_ctrl_input(scope.WithOpName("some_ctrl_input"));
1675
1676 Output matmul_input =
1677 ops::Placeholder(scope.WithOpName("matmul_input"), DT_FLOAT);
1678 Output matmul_0 =
1679 ops::MatMul(scope.WithOpName("matmul_0"), matmul_input, matmul_input);
1680 Output matmul_1 =
1681 ops::MatMul(scope.WithOpName("matmul_1"), matmul_input, matmul_input);
1682
1683 AddCtrlEdge(scope, increment_op, group_deps);
1684 AddCtrlEdge(scope, weights_0_update_op, increment_op);
1685 AddCtrlEdge(scope, weights_1_update_op, increment_op);
1686
1687 AddCtrlEdge(scope, some_ctrl_input, weights_0_update_op);
1688 AddCtrlEdge(scope, some_ctrl_input, weights_1_update_op);
1689
1690 AddCtrlEdge(scope, matmul_0, group_deps);
1691 AddCtrlEdge(scope, matmul_1, group_deps);
1692
1693 AddCtrlEdge(scope, weights_0_update_op, matmul_0);
1694 AddCtrlEdge(scope, weights_1_update_op, matmul_1);
1695
1696 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1697 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1698
1699 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1700
1701 std::unordered_map<string, string> clusters = GetClusters(*graph);
1702
1703 EXPECT_NE(clusters["some_ctrl_input"], "");
1704 EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]);
1705 EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_1_update"]);
1706 EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1707 EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1708 }
1709
1710 // Test a pattern where a special Identity node is driving consts in a loop.
1711 // Expect that the Identity node will not go into any clusters. Note that we
1712 // create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
1713 // etc.) just enough to test the pattern, as a complete graph may be too
1714 // cumbersome and unnecessary.
TEST(XlaCompilationTest,DontClusterTheSpecialIdentityDrivingConstsInLoop)1715 TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
1716 Scope root = Scope::NewRootScope().ExitOnError();
1717
1718 Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
1719 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1720 Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
1721 ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
1722
1723 Output identity =
1724 ops::Identity(root.WithOpName("identity"), switch_node.output_true);
1725 Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
1726 root.graph()->AddControlEdge(identity.node(), const_node.node());
1727 Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
1728 Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
1729 Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
1730
1731 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1732 TF_EXPECT_OK(root.ToGraph(graph.get()));
1733
1734 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1735 &graph,
1736 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
1737 auto clusters = GetClusters(*graph);
1738
1739 EXPECT_EQ(clusters["identity"], "");
1740 }
1741
TEST(XlaCompilationTest,UnsupportedEnterExitPattern)1742 TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
1743 // Regression test for b/32350199, where the autoclustering code introduced a
1744 // deadlock in a graph containing a while loop.
1745 Scope root = Scope::NewRootScope().ExitOnError();
1746 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
1747 auto enter_0 = ops::internal::Enter(root.WithOpName("enter_a"), a, "frame");
1748 auto exit_0 = ops::internal::Exit(root.WithOpName("exit_a"), enter_0);
1749 auto tanh = ops::Tanh(root.WithOpName("tanh"), exit_0);
1750 auto enter_1 =
1751 ops::internal::Enter(root.WithOpName("enter_1"), tanh, "frame");
1752 auto exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
1753
1754 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1755 TF_EXPECT_OK(root.ToGraph(graph.get()));
1756
1757 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1758 auto clusters = GetClusters(*graph);
1759
1760 // Nothing should be compiled.
1761 EXPECT_EQ(0, clusters.size());
1762 }
1763
TEST(XlaCompilationTest,DeterministicClusterNames)1764 TEST(XlaCompilationTest, DeterministicClusterNames) {
1765 auto create_graph =
1766 [](absl::string_view output_name) -> std::unique_ptr<Graph> {
1767 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1768 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1769 Tensor t(DT_FLOAT, TensorShape());
1770 t.scalar<float>()() = 0.0f;
1771 Node* a = ops::SourceOp("Const", builder.opts()
1772 .WithName("A")
1773 .WithAttr("dtype", DT_FLOAT)
1774 .WithAttr("value", t));
1775 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
1776 ops::BinaryOp("MatMul", a, b, builder.opts().WithName(output_name));
1777 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
1778 return graph;
1779 };
1780
1781 // Checks if two cluster names match for all parts except their sequence
1782 // number. Names are expected as: cluster_fp_seq#
1783 auto cluster_names_match = [](absl::string_view lhs_cluster_name,
1784 absl::string_view rhs_cluster_name) {
1785 std::vector<absl::string_view> lhs_cluster_name_parts =
1786 absl::StrSplit(lhs_cluster_name, '_');
1787 std::vector<absl::string_view> rhs_cluster_name_parts =
1788 absl::StrSplit(rhs_cluster_name, '_');
1789
1790 if (lhs_cluster_name_parts.size() != 3) {
1791 return errors::FailedPrecondition("unexpected lhs cluster name: ",
1792 lhs_cluster_name);
1793 }
1794
1795 if (rhs_cluster_name_parts.size() != 3) {
1796 return errors::FailedPrecondition("unexpected rhs cluster name: ",
1797 rhs_cluster_name);
1798 }
1799
1800 if (lhs_cluster_name_parts[0] != rhs_cluster_name_parts[0] ||
1801 lhs_cluster_name_parts[1] != rhs_cluster_name_parts[1]) {
1802 return errors::FailedPrecondition(
1803 "Cluster names mismatch: lhs: ", lhs_cluster_name,
1804 " rhs: ", rhs_cluster_name);
1805 }
1806
1807 if (lhs_cluster_name_parts[2] == rhs_cluster_name_parts[2]) {
1808 return errors::FailedPrecondition(
1809 "cluster sequence numbers are the same: lhs: ", lhs_cluster_name,
1810 " rhs: ", rhs_cluster_name);
1811 }
1812
1813 return OkStatus();
1814 };
1815
1816 testing::ResetClusterSequenceNumber();
1817 auto options = MarkForCompilationPassTestHelper::Options()
1818 .WithDeterministicClusterNames();
1819
1820 // Cluster the same graphs twice so we can observe that the prefix contains
1821 // the stable fingerprint.
1822 auto graph0 = create_graph("out");
1823 auto graph1 = create_graph("differs");
1824 auto graph2 = create_graph("out"); // same as graph0
1825 auto graph3 = create_graph("differs"); // same as graph1
1826
1827 TF_ASSERT_OK(
1828 MarkForCompilationPassTestHelper::MarkForCompilation(&graph0, options));
1829 auto clusters0 = GetClusterNames(*graph0);
1830 ASSERT_EQ(clusters0.size(), 1);
1831
1832 TF_ASSERT_OK(
1833 MarkForCompilationPassTestHelper::MarkForCompilation(&graph1, options));
1834 auto clusters1 = GetClusterNames(*graph1);
1835 ASSERT_EQ(clusters1.size(), 1);
1836
1837 TF_ASSERT_OK(
1838 MarkForCompilationPassTestHelper::MarkForCompilation(&graph2, options));
1839 auto clusters2 = GetClusterNames(*graph2);
1840 ASSERT_EQ(clusters2.size(), 1);
1841
1842 TF_ASSERT_OK(
1843 MarkForCompilationPassTestHelper::MarkForCompilation(&graph3, options));
1844 auto clusters3 = GetClusterNames(*graph3);
1845 ASSERT_EQ(clusters3.size(), 1);
1846
1847 // clusters0 and clusters2 should be the same
1848 TF_EXPECT_OK(cluster_names_match(*clusters0.begin(), *clusters2.begin()));
1849
1850 // clusters1 and clusters3 should also be the same
1851 TF_EXPECT_OK(cluster_names_match(*clusters1.begin(), *clusters3.begin()));
1852
1853 // clusters0/2 should differ from clusters1/3
1854 }
1855
1856 namespace {
MakeStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)1857 Node* MakeStageNode(GraphDefBuilder& builder, string name,
1858 std::initializer_list<DataType> dtypes,
1859 absl::Span<const ops::NodeOut> values) {
1860 auto opts = builder.opts()
1861 .WithName(std::move(name))
1862 .WithAttr("dtypes", std::move(dtypes));
1863 if (opts.HaveError()) {
1864 return nullptr;
1865 }
1866
1867 NodeBuilder node_builder(name, "Stage", opts.op_registry());
1868 node_builder.Input(values);
1869 return opts.FinalizeBuilder(&node_builder);
1870 }
1871 } // namespace
1872
TEST(XlaCompilationTest,StagePipelinePreservedByClusterScopingPass)1873 TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
1874 auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
1875 // Construct a graph as below with two pipeline stages and test that nodes
1876 // in different stages will not be merged if ClusterScopingPass is on.
1877 //
1878 // b
1879 // |
1880 // v
1881 // a -> add0 -> relu0 -> stage
1882 //
1883 // b
1884 // |
1885 // v
1886 // unstage -> add1 -> relu1
1887 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1888 Node* a = ops::SourceOp("Const", builder.opts()
1889 .WithName("a")
1890 .WithAttr("dtype", DT_FLOAT)
1891 .WithAttr("value", Tensor()));
1892 Node* b = ops::SourceOp("Const", builder.opts()
1893 .WithName("b")
1894 .WithAttr("dtype", DT_FLOAT)
1895 .WithAttr("value", Tensor()));
1896 Node* unstage = ops::SourceOp(
1897 "Unstage",
1898 builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
1899
1900 Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
1901 Node* add1 =
1902 ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
1903 Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
1904 ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
1905 MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
1906
1907 return GraphDefBuilderToGraph(builder, graph->get());
1908 };
1909
1910 // All nodes go into the same cluster if ClusterScopingPass is off.
1911 {
1912 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1913 TF_ASSERT_OK(build_staged_graph(&graph));
1914
1915 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1916 &graph,
1917 MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
1918
1919 std::unordered_map<string, string> clusters = GetClusters(*graph);
1920 EXPECT_EQ(clusters["add0"], clusters["add1"]);
1921 EXPECT_EQ(clusters["add0"], clusters["relu1"]);
1922 EXPECT_EQ(clusters["relu0"], clusters["add1"]);
1923 EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
1924 }
1925
1926 // By default, ClusterScopingPass is on and different pipeline stages should
1927 // not be merged.
1928 {
1929 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1930 TF_ASSERT_OK(build_staged_graph(&graph));
1931
1932 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1933
1934 std::unordered_map<string, string> clusters = GetClusters(*graph);
1935 EXPECT_NE(clusters["add0"], clusters["add1"]);
1936 EXPECT_NE(clusters["add0"], clusters["relu1"]);
1937 EXPECT_NE(clusters["relu0"], clusters["add1"]);
1938 EXPECT_NE(clusters["relu0"], clusters["relu1"]);
1939 }
1940 }
TEST(XlaCompilationTest,XLALiteAllowlist)1941 TEST(XlaCompilationTest, XLALiteAllowlist) {
1942 auto* allowlist_table = tensorflow::GetAllowlistTable();
1943 absl::flat_hash_set<string> hallowlist;
1944 std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1945 absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1946
1947 // Check that all the operations in the table are existing TF operations
1948 for (auto pair : *allowlist_table) {
1949 hallowlist.insert(pair.second.begin(), pair.second.end());
1950 for (auto op : pair.second) {
1951 ASSERT_TRUE(all_ops.contains(op));
1952 }
1953 }
1954
1955 // Check that all registered XLA operation are in the allowlist
1956 // table or are known to not be in it.
1957
1958 absl::flat_hash_set<string> known_not_in_list =
1959 tensorflow::testing::GetKnownXLAAllowlistOp();
1960 std::vector<string> unknow_op;
1961 for (string op : vall_ops) {
1962 if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
1963 unknow_op.push_back(op);
1964 }
1965 }
1966 EXPECT_TRUE(unknow_op.empty())
1967 << "Someone added support for a new TF operations inside XLA. They must "
1968 "be included in the XLALite allowlist or denylist:\n"
1969 << absl::StrJoin(unknow_op, "\n");
1970 }
1971 } // namespace
1972 } // namespace tensorflow
1973