1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/cc/framework/scope.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test.h"
33
34 namespace tensorflow {
35 namespace {
36
FuncListAttr(const absl::Span<const char * const> names)37 AttrValue FuncListAttr(const absl::Span<const char* const> names) {
38 AttrValue attr;
39 for (const char* name : names) {
40 attr.mutable_list()->add_func()->set_name(name);
41 }
42 return attr;
43 }
44
45 constexpr char kFunctionalIfNodeName[] = "If";
46 constexpr char kFunctionalCaseNodeName[] = "Case";
47 constexpr char kFunctionalWhileNodeName[] = "While";
48 constexpr char kCompilableFunctionName[] = "CompilableFn";
49 constexpr char kCompilableFunctionNodeName[] = "n_c";
50 constexpr char kUncompilableFunctionName[] = "UncompilableFn";
51 constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable";
52 constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo";
53 constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable";
54
55 // A dummy OpKernel for testing.
56 class DummyCompilableOp : public XlaOpKernel {
57 public:
DummyCompilableOp(OpKernelConstruction * ctx)58 explicit DummyCompilableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)59 void Compile(XlaOpKernelContext* ctx) override {
60 ctx->SetOutput(0, ctx->Input(0));
61 }
62 };
63
64 // Register the DummyCompilableOp kernel for CPU.
65 REGISTER_OP("InputFloatOp").Output("o: float");
66 REGISTER_OP("CompilableOp").Input("i: float").Output("o: float");
67 REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT),
68 DummyCompilableOp);
69
70 // Dummy op that is uncompilable in CPU.
71 REGISTER_OP("MissingKernel").Input("i: float").Output("o: float");
72
73 class CompilabilityCheckUtilTest : public ::testing::Test {
74 protected:
SetUp()75 void SetUp() override {
76 XlaOpRegistry::RegisterCompilationKernels();
77
78 op_filter_.allow_resource_ops_in_called_functions = false;
79 op_filter_.allow_stack_ops = false;
80 op_filter_.allow_tensor_array_ops = false;
81 op_filter_.allow_stateful_rng_ops = false;
82 op_filter_.allow_control_trigger = false;
83 op_filter_.allow_eliding_assert_and_checknumerics_ops = false;
84 op_filter_.allow_ops_producing_or_consuming_variant = false;
85 op_filter_.allow_inaccurate_ops = false;
86 op_filter_.allow_slow_ops = false;
87 op_filter_.allow_outside_compiled = false;
88
89 checker_ = CreateCompilabilityChecker();
90 }
91
CreateCompilabilityChecker()92 std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
93 return std::make_unique<RecursiveCompilabilityChecker>(op_filter_,
94 device_type_);
95 }
96
GetFunctionLibraryRuntime()97 FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
98 OptimizerOptions opts;
99 pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
100 nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
101 flib_def_.get(), opts);
102
103 return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
104 }
105
106 RecursiveCompilabilityChecker::OperationFilter op_filter_;
107 DeviceType device_type_ = DeviceType(DEVICE_CPU_XLA_JIT);
108 std::unique_ptr<FunctionDefLibrary> func_library_ =
109 std::make_unique<FunctionDefLibrary>();
110 std::unique_ptr<FunctionLibraryDefinition> flib_def_ =
111 std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
112 *func_library_);
113 std::unique_ptr<RecursiveCompilabilityChecker> checker_;
114 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
115 };
116
TEST_F(CompilabilityCheckUtilTest,CheckNonFunctionalNodes)117 TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
118 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
119 auto opts = builder.opts();
120 Node* const0 = ops::SourceOp("InputFloatOp", opts);
121 Node* compilable_op = ops::UnaryOp("CompilableOp", const0, opts);
122 Node* uncompilable_op = ops::UnaryOp("MissingKernel", compilable_op, opts);
123 GraphDef graph_def;
124 TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
125
126 auto* flib_runtime = GetFunctionLibraryRuntime();
127 // Source node is not compilable.
128 EXPECT_FALSE(checker_->IsCompilableNode(*const0, flib_runtime));
129
130 EXPECT_TRUE(checker_->IsCompilableNode(*compilable_op, flib_runtime));
131
132 // Uncompilable as we are only checking compilability in CPU device type.
133 EXPECT_FALSE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
134
135 const auto uncompilable_nodes =
136 checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
137 ASSERT_EQ(1, uncompilable_nodes.size());
138 auto node_info_it =
139 uncompilable_nodes.find(NameAttrList().ShortDebugString());
140 ASSERT_NE(uncompilable_nodes.end(), node_info_it);
141 const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
142 ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
143 const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
144 EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
145 "unsupported op"));
146 ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
147 ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
148 }
149
TEST_F(CompilabilityCheckUtilTest,CheckOutsideCompiledNode)150 TEST_F(CompilabilityCheckUtilTest, CheckOutsideCompiledNode) {
151 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
152 auto opts = builder.opts();
153 Node* const0 = ops::SourceOp("InputFloatOp", opts);
154 Node* uncompilable_op = ops::UnaryOp("MissingKernel", const0, opts);
155 uncompilable_op->AddAttr("_xla_outside_compilation", "0");
156 GraphDef graph_def;
157 TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
158
159 auto* flib_runtime = GetFunctionLibraryRuntime();
160
161 // Outside compiled ops are considered by default..
162 EXPECT_FALSE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
163
164 const auto uncompilable_nodes =
165 checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
166 ASSERT_EQ(1, uncompilable_nodes.size());
167
168 op_filter_.allow_outside_compiled = true;
169 checker_ = CreateCompilabilityChecker();
170 // With filter option outside compiled ops are ignored and considered
171 // compilable.
172 EXPECT_TRUE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
173
174 const auto uncompilable_nodes2 =
175 checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
176 ASSERT_EQ(0, uncompilable_nodes2.size());
177 }
178
TEST_F(CompilabilityCheckUtilTest,CheckSimpleFunctionNode)179 TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
180 FunctionDefLibrary flib;
181 *flib.add_function() = FunctionDefHelper::Define(
182 /*Function*/ kUncompilableFunctionName,
183 /*Inputs*/ {"n_a:float"},
184 /*Outputs*/ {"n_c_uncompilable:float"},
185 /*Attributes*/ {},
186 // Node info
187 {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
188 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
189
190 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
191 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
192 Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
193 Node* functional_node = ops::UnaryOp(kUncompilableFunctionName, const0,
194 builder.opts().WithName("D"));
195 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
196
197 auto* flib_runtime = GetFunctionLibraryRuntime();
198 EXPECT_FALSE(checker_->IsCompilableNode(*functional_node, flib_runtime));
199 const auto uncompilable_nodes =
200 checker_->FindUncompilableNodes(*functional_node, flib_runtime);
201
202 EXPECT_EQ(1, uncompilable_nodes.size());
203 NameAttrList function;
204 function.set_name(kUncompilableFunctionName);
205 const auto node_info_it =
206 uncompilable_nodes.find(function.ShortDebugString());
207 ASSERT_NE(uncompilable_nodes.end(), node_info_it);
208 const auto& uncompilable_node_list = node_info_it->second.second;
209 ASSERT_EQ(1, uncompilable_node_list.size());
210 const auto& node_info = uncompilable_node_list.at(0);
211 const auto& node_stack = node_info.stack_trace;
212 ASSERT_EQ(2, node_stack.size());
213 EXPECT_EQ("D", node_stack.at(0).name);
214 EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
215 EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
216 EXPECT_TRUE(
217 absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
218 }
219
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalWhileNode)220 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
221 FunctionDefLibrary flib;
222 *flib.add_function() = FunctionDefHelper::Define(
223 /*Function*/ kCompilableFunctionName,
224 /*Inputs*/ {"n_a:float", "n_b:float"},
225 /*Outputs*/ {"n_c:float"},
226 /*Attribute*/ {},
227 // Node info
228 {{{kCompilableFunctionNodeName},
229 "Add",
230 {"n_a", "n_b"},
231 {{"T", DT_FLOAT}}}});
232 *flib.add_function() = FunctionDefHelper::Define(
233 /*Function*/ kUncompilableFunctionName,
234 /*Inputs*/ {"n_a:float"},
235 /*Outputs*/ {"n_c_uncompilable:float"},
236 /*Attributes*/ {},
237 // Node info
238 {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
239
240 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
241 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
242
243 Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
244 Node* input_node = ops::UnaryOp("CompilableOp", const0, builder.opts());
245
246 NameAttrList compilable;
247 compilable.set_name(kCompilableFunctionName);
248 NameAttrList uncompilable;
249 uncompilable.set_name(kUncompilableFunctionName);
250
251 NodeBuilder while_builder(kFunctionalWhileNodeName, "While",
252 builder.opts().op_registry());
253 while_builder.Input({input_node, input_node})
254 .Attr("cond", compilable)
255 .Attr("body", uncompilable);
256 builder.opts().FinalizeBuilder(&while_builder);
257
258 GraphDef graph_def;
259 TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
260 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
261 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
262
263 auto while_node_it = std::find_if(
264 graph->nodes().begin(), graph->nodes().end(),
265 [&](const Node* n) { return n->name() == kFunctionalWhileNodeName; });
266 EXPECT_NE(while_node_it, graph->nodes().end());
267
268 auto* flib_runtime = GetFunctionLibraryRuntime();
269
270 EXPECT_FALSE(checker_->IsCompilableNode(**while_node_it, flib_runtime));
271 const auto uncompilable_nodes =
272 checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
273 ASSERT_EQ(1, uncompilable_nodes.size());
274
275 NameAttrList function;
276 function.set_name(kUncompilableFunctionName);
277 const auto node_info_it =
278 uncompilable_nodes.find(function.ShortDebugString());
279 ASSERT_NE(uncompilable_nodes.end(), node_info_it);
280 const auto& uncompilable_node_list = node_info_it->second.second;
281 ASSERT_EQ(1, uncompilable_node_list.size());
282 const auto& node_info = uncompilable_node_list.at(0);
283
284 const auto& node_stack = node_info.stack_trace;
285 ASSERT_EQ(2, node_stack.size());
286 const auto& stacktrace_first_node_info = node_stack.at(0);
287 EXPECT_EQ(kFunctionalWhileNodeName, stacktrace_first_node_info.name);
288 EXPECT_EQ("", stacktrace_first_node_info.function_name);
289
290 const auto& stacktrace_second_node_info = node_stack.at(1);
291 EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
292 EXPECT_EQ(kUncompilableFunctionName,
293 stacktrace_second_node_info.function_name);
294
295 EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
296 EXPECT_TRUE(
297 absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
298 }
299
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalIfNode)300 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
301 FunctionDefLibrary flib;
302 *flib.add_function() = FunctionDefHelper::Define(
303 /*Function*/ kUncompilableFunctionName,
304 /*Inputs*/ {"n_a:float"},
305 /*Outputs*/ {"n_c_uncompilable:float"},
306 /*Attributes*/ {},
307 // Node info
308 {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
309 *flib.add_function() = FunctionDefHelper::Define(
310 /*Function*/ kUncompilableFunctionTwoName,
311 /*Inputs*/ {"n_a:float"},
312 /*Outputs*/ {"n_d_uncompilable:float"},
313 /*Attribute*/ {},
314 // Node info
315 {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
316 NameAttrList uncompilable_fn1_attr;
317 uncompilable_fn1_attr.set_name(kUncompilableFunctionName);
318 NameAttrList uncompilable_fn2_attr;
319 uncompilable_fn2_attr.set_name(kUncompilableFunctionTwoName);
320
321 Scope root = Scope::NewRootScope().ExitOnError();
322 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
323 auto predicate = ops::Placeholder(root.WithOpName("pred"), DT_BOOL);
324 auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
325 std::vector<NodeBuilder::NodeOut> if_inputs(
326 {NodeBuilder::NodeOut(placeholder.node())});
327 Node* if_node;
328 TF_ASSERT_OK(
329 NodeBuilder(kFunctionalIfNodeName, "If", &root.graph()->flib_def())
330 .Input(predicate.node())
331 .Input(if_inputs)
332 .Attr("then_branch", uncompilable_fn1_attr)
333 .Attr("else_branch", uncompilable_fn2_attr)
334 .Attr("Tout", {DT_INT32})
335 .Finalize(root.graph(), &if_node));
336 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
337 TF_ASSERT_OK(root.ToGraph(graph.get()));
338
339 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
340
341 auto if_node_it = std::find_if(
342 graph->nodes().begin(), graph->nodes().end(),
343 [&](const Node* n) { return n->name() == kFunctionalIfNodeName; });
344 EXPECT_NE(if_node_it, graph->nodes().end());
345 auto* flib_runtime = GetFunctionLibraryRuntime();
346
347 EXPECT_FALSE(checker_->IsCompilableNode(**if_node_it, flib_runtime));
348 const auto uncompilable_nodes =
349 checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
350 ASSERT_EQ(2, uncompilable_nodes.size());
351
352 NameAttrList function_one;
353 function_one.set_name(kUncompilableFunctionName);
354 auto it = uncompilable_nodes.find(function_one.ShortDebugString());
355 ASSERT_NE(uncompilable_nodes.end(), it);
356
357 const auto& uncompilable_node_list = it->second.second;
358 ASSERT_EQ(1, uncompilable_node_list.size());
359 const auto& uncompilable_node_one = uncompilable_node_list.at(0);
360 const auto& node_one_stack = uncompilable_node_one.stack_trace;
361
362 ASSERT_EQ(2, node_one_stack.size());
363 const auto& node_one_stacktrace_first_node = node_one_stack.at(0);
364 EXPECT_EQ(kFunctionalIfNodeName, node_one_stacktrace_first_node.name);
365 EXPECT_EQ("", node_one_stacktrace_first_node.function_name);
366
367 const auto& stacktrace_second_node_info = node_one_stack.at(1);
368 EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
369 EXPECT_EQ(kUncompilableFunctionName,
370 stacktrace_second_node_info.function_name);
371
372 EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
373 EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
374 "unsupported op"));
375
376 NameAttrList function_two;
377 function_two.set_name(kUncompilableFunctionTwoName);
378 it = uncompilable_nodes.find(function_two.ShortDebugString());
379 ASSERT_NE(uncompilable_nodes.end(), it);
380
381 const auto& uncompilable_node_two_list = it->second.second;
382 ASSERT_EQ(1, uncompilable_node_two_list.size());
383 const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
384 const auto& node_two_stack = uncompilable_node_two.stack_trace;
385 ASSERT_EQ(2, node_two_stack.size());
386 const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
387 EXPECT_EQ(kFunctionalIfNodeName, node_two_stacktrace_first_node.name);
388 EXPECT_EQ("", node_two_stacktrace_first_node.function_name);
389
390 const auto& node_two_stacktrace_second_node = node_two_stack.at(1);
391 EXPECT_EQ(kUncompilableFunctionNodeTwoName,
392 node_two_stacktrace_second_node.name);
393 EXPECT_EQ(kUncompilableFunctionTwoName,
394 node_two_stacktrace_second_node.function_name);
395
396 EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
397 EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
398 "unsupported op"));
399 }
400
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalCaseNode)401 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
402 FunctionDefLibrary flib;
403 *flib.add_function() = FunctionDefHelper::Define(
404 /*Function*/ kUncompilableFunctionName,
405 /*Inputs*/ {"n_a:float"},
406 /*Outputs*/ {"n_c_uncompilable:float"},
407 /*Attributes*/ {},
408 // Node info
409 {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
410 *flib.add_function() = FunctionDefHelper::Define(
411 /*Function*/ kUncompilableFunctionTwoName,
412 /*Inputs*/ {"n_a:float"},
413 /*Outputs*/ {"n_d_uncompilable:float"},
414 /*Attribute*/ {},
415 // Node info
416 {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
417
418 Scope root = Scope::NewRootScope().ExitOnError();
419 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
420 auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
421 auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
422 std::vector<NodeBuilder::NodeOut> inputes(
423 {NodeBuilder::NodeOut(placeholder.node())});
424 Node* case_node;
425 TF_ASSERT_OK(
426 NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
427 .Input(branch_index.node())
428 .Input(inputes)
429 .Attr("branches", FuncListAttr({kUncompilableFunctionName,
430 kUncompilableFunctionTwoName}))
431 .Attr("Tout", {DT_INT32})
432 .Finalize(root.graph(), &case_node));
433 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
434 TF_ASSERT_OK(root.ToGraph(graph.get()));
435
436 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
437
438 auto case_node_it = std::find_if(
439 graph->nodes().begin(), graph->nodes().end(),
440 [&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
441 EXPECT_NE(case_node_it, graph->nodes().end());
442 auto* flib_runtime = GetFunctionLibraryRuntime();
443
444 op_filter_.require_always_compilable = false;
445 checker_ = CreateCompilabilityChecker();
446 EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
447 op_filter_.require_always_compilable = true;
448 checker_ = CreateCompilabilityChecker();
449 EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
450 }
451
TEST_F(CompilabilityCheckUtilTest,TestCanNotTriggerXlaCompilation)452 TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
453 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
454 Scope root = Scope::NewRootScope().ExitOnError();
455 FunctionDefLibrary library;
456
457 FunctionDef identity_func = FunctionDefHelper::Create(
458 "IdentityFunc",
459 /*in_def=*/{"x:float"},
460 /*out_def=*/{"res:float"},
461 /*attr_def=*/{},
462 /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
463 /*ret_def*/ {{"res", "t0:output"}});
464
465 *library.add_function() = identity_func;
466
467 Output in = ops::Placeholder(root, DT_FLOAT);
468 NameAttrList b_name_attr;
469 b_name_attr.set_name("IdentityFunc");
470 ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
471 b_name_attr);
472
473 GraphDef graph_def;
474 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
475 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
476
477 EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
478 }
479
TEST_F(CompilabilityCheckUtilTest,TestXlaOpsCanTriggerXlaCompilation)480 TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
481 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
482 Scope root = Scope::NewRootScope().ExitOnError();
483 FunctionDefLibrary library;
484
485 FunctionDef sort_func = FunctionDefHelper::Create(
486 "SortFunc",
487 /*in_def=*/{"x:float"},
488 /*out_def=*/{"res:float"},
489 /*attr_def=*/{},
490 /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
491 /*ret_def*/ {{"res", "t0:output"}});
492
493 *library.add_function() = sort_func;
494
495 Output in = ops::Placeholder(root, DT_FLOAT);
496 NameAttrList b_name_attr;
497 b_name_attr.set_name("SortFunc");
498 ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
499 b_name_attr);
500
501 GraphDef graph_def;
502 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
503 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
504
505 EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
506 }
507
TEST_F(CompilabilityCheckUtilTest,TestCanTriggerXlaCompilation)508 TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
509 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
510 Scope root = Scope::NewRootScope().ExitOnError();
511 FunctionDefLibrary library;
512
513 AttrValue true_attribute;
514 true_attribute.set_b(true);
515
516 FunctionDef identity_func = FunctionDefHelper::Create(
517 "IdentityFunc",
518 /*in_def=*/{"x:float"},
519 /*out_def=*/{"res:float"},
520 /*attr_def=*/{},
521 /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
522 /*ret_def*/ {{"res", "t0:output"}});
523
524 (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
525
526 FunctionDef call_identity = FunctionDefHelper::Create(
527 "CallIdentity",
528 /*in_def=*/{"x:float"},
529 /*out_def=*/{"z:float"}, /*attr_def=*/{},
530 /*node_def=*/
531 {{{"func_call"},
532 "PartitionedCall",
533 {"x"},
534 {{"Tin", DataTypeSlice({DT_FLOAT})},
535 {"Tout", DataTypeSlice({DT_FLOAT})},
536 {"f",
537 FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
538 {kXlaMustCompileAttr, true}}}},
539 /*ret_def=*/{{"z", "func_call:output:0"}});
540
541 *library.add_function() = identity_func;
542 *library.add_function() = call_identity;
543
544 Output in = ops::Placeholder(root, DT_FLOAT);
545 NameAttrList b_name_attr;
546 b_name_attr.set_name("CallIdentity");
547 ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
548 b_name_attr);
549
550 GraphDef graph_def;
551 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
552 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
553
554 EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
555 }
556
557 } // namespace
558 } // namespace tensorflow
559