xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/compilability_check_util_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/cc/framework/scope.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test.h"
33 
34 namespace tensorflow {
35 namespace {
36 
FuncListAttr(const absl::Span<const char * const> names)37 AttrValue FuncListAttr(const absl::Span<const char* const> names) {
38   AttrValue attr;
39   for (const char* name : names) {
40     attr.mutable_list()->add_func()->set_name(name);
41   }
42   return attr;
43 }
44 
45 constexpr char kFunctionalIfNodeName[] = "If";
46 constexpr char kFunctionalCaseNodeName[] = "Case";
47 constexpr char kFunctionalWhileNodeName[] = "While";
48 constexpr char kCompilableFunctionName[] = "CompilableFn";
49 constexpr char kCompilableFunctionNodeName[] = "n_c";
50 constexpr char kUncompilableFunctionName[] = "UncompilableFn";
51 constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable";
52 constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo";
53 constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable";
54 
55 // A dummy OpKernel for testing.
56 class DummyCompilableOp : public XlaOpKernel {
57  public:
DummyCompilableOp(OpKernelConstruction * ctx)58   explicit DummyCompilableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)59   void Compile(XlaOpKernelContext* ctx) override {
60     ctx->SetOutput(0, ctx->Input(0));
61   }
62 };
63 
64 // Register the DummyCompilableOp kernel for CPU.
65 REGISTER_OP("InputFloatOp").Output("o: float");
66 REGISTER_OP("CompilableOp").Input("i: float").Output("o: float");
67 REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT),
68                 DummyCompilableOp);
69 
70 // Dummy op that is uncompilable in CPU.
71 REGISTER_OP("MissingKernel").Input("i: float").Output("o: float");
72 
73 class CompilabilityCheckUtilTest : public ::testing::Test {
74  protected:
SetUp()75   void SetUp() override {
76     XlaOpRegistry::RegisterCompilationKernels();
77 
78     op_filter_.allow_resource_ops_in_called_functions = false;
79     op_filter_.allow_stack_ops = false;
80     op_filter_.allow_tensor_array_ops = false;
81     op_filter_.allow_stateful_rng_ops = false;
82     op_filter_.allow_control_trigger = false;
83     op_filter_.allow_eliding_assert_and_checknumerics_ops = false;
84     op_filter_.allow_ops_producing_or_consuming_variant = false;
85     op_filter_.allow_inaccurate_ops = false;
86     op_filter_.allow_slow_ops = false;
87     op_filter_.allow_outside_compiled = false;
88 
89     checker_ = CreateCompilabilityChecker();
90   }
91 
CreateCompilabilityChecker()92   std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
93     return std::make_unique<RecursiveCompilabilityChecker>(op_filter_,
94                                                             device_type_);
95   }
96 
GetFunctionLibraryRuntime()97   FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
98     OptimizerOptions opts;
99     pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
100         nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
101         flib_def_.get(), opts);
102 
103     return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
104   }
105 
106   RecursiveCompilabilityChecker::OperationFilter op_filter_;
107   DeviceType device_type_ = DeviceType(DEVICE_CPU_XLA_JIT);
108   std::unique_ptr<FunctionDefLibrary> func_library_ =
109       std::make_unique<FunctionDefLibrary>();
110   std::unique_ptr<FunctionLibraryDefinition> flib_def_ =
111       std::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
112                                                    *func_library_);
113   std::unique_ptr<RecursiveCompilabilityChecker> checker_;
114   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
115 };
116 
TEST_F(CompilabilityCheckUtilTest,CheckNonFunctionalNodes)117 TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
118   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
119   auto opts = builder.opts();
120   Node* const0 = ops::SourceOp("InputFloatOp", opts);
121   Node* compilable_op = ops::UnaryOp("CompilableOp", const0, opts);
122   Node* uncompilable_op = ops::UnaryOp("MissingKernel", compilable_op, opts);
123   GraphDef graph_def;
124   TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
125 
126   auto* flib_runtime = GetFunctionLibraryRuntime();
127   // Source node is not compilable.
128   EXPECT_FALSE(checker_->IsCompilableNode(*const0, flib_runtime));
129 
130   EXPECT_TRUE(checker_->IsCompilableNode(*compilable_op, flib_runtime));
131 
132   // Uncompilable as we are only checking compilability in CPU device type.
133   EXPECT_FALSE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
134 
135   const auto uncompilable_nodes =
136       checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
137   ASSERT_EQ(1, uncompilable_nodes.size());
138   auto node_info_it =
139       uncompilable_nodes.find(NameAttrList().ShortDebugString());
140   ASSERT_NE(uncompilable_nodes.end(), node_info_it);
141   const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
142   ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
143   const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
144   EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
145                                 "unsupported op"));
146   ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
147   ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
148 }
149 
TEST_F(CompilabilityCheckUtilTest,CheckOutsideCompiledNode)150 TEST_F(CompilabilityCheckUtilTest, CheckOutsideCompiledNode) {
151   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
152   auto opts = builder.opts();
153   Node* const0 = ops::SourceOp("InputFloatOp", opts);
154   Node* uncompilable_op = ops::UnaryOp("MissingKernel", const0, opts);
155   uncompilable_op->AddAttr("_xla_outside_compilation", "0");
156   GraphDef graph_def;
157   TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
158 
159   auto* flib_runtime = GetFunctionLibraryRuntime();
160 
161   // Outside compiled ops are considered by default..
162   EXPECT_FALSE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
163 
164   const auto uncompilable_nodes =
165       checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
166   ASSERT_EQ(1, uncompilable_nodes.size());
167 
168   op_filter_.allow_outside_compiled = true;
169   checker_ = CreateCompilabilityChecker();
170   // With filter option outside compiled ops are ignored and considered
171   // compilable.
172   EXPECT_TRUE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
173 
174   const auto uncompilable_nodes2 =
175       checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
176   ASSERT_EQ(0, uncompilable_nodes2.size());
177 }
178 
TEST_F(CompilabilityCheckUtilTest,CheckSimpleFunctionNode)179 TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
180   FunctionDefLibrary flib;
181   *flib.add_function() = FunctionDefHelper::Define(
182       /*Function*/ kUncompilableFunctionName,
183       /*Inputs*/ {"n_a:float"},
184       /*Outputs*/ {"n_c_uncompilable:float"},
185       /*Attributes*/ {},
186       // Node info
187       {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
188   flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
189 
190   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
191   std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
192   Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
193   Node* functional_node = ops::UnaryOp(kUncompilableFunctionName, const0,
194                                        builder.opts().WithName("D"));
195   TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
196 
197   auto* flib_runtime = GetFunctionLibraryRuntime();
198   EXPECT_FALSE(checker_->IsCompilableNode(*functional_node, flib_runtime));
199   const auto uncompilable_nodes =
200       checker_->FindUncompilableNodes(*functional_node, flib_runtime);
201 
202   EXPECT_EQ(1, uncompilable_nodes.size());
203   NameAttrList function;
204   function.set_name(kUncompilableFunctionName);
205   const auto node_info_it =
206       uncompilable_nodes.find(function.ShortDebugString());
207   ASSERT_NE(uncompilable_nodes.end(), node_info_it);
208   const auto& uncompilable_node_list = node_info_it->second.second;
209   ASSERT_EQ(1, uncompilable_node_list.size());
210   const auto& node_info = uncompilable_node_list.at(0);
211   const auto& node_stack = node_info.stack_trace;
212   ASSERT_EQ(2, node_stack.size());
213   EXPECT_EQ("D", node_stack.at(0).name);
214   EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
215   EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
216   EXPECT_TRUE(
217       absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
218 }
219 
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalWhileNode)220 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
221   FunctionDefLibrary flib;
222   *flib.add_function() = FunctionDefHelper::Define(
223       /*Function*/ kCompilableFunctionName,
224       /*Inputs*/ {"n_a:float", "n_b:float"},
225       /*Outputs*/ {"n_c:float"},
226       /*Attribute*/ {},
227       // Node info
228       {{{kCompilableFunctionNodeName},
229         "Add",
230         {"n_a", "n_b"},
231         {{"T", DT_FLOAT}}}});
232   *flib.add_function() = FunctionDefHelper::Define(
233       /*Function*/ kUncompilableFunctionName,
234       /*Inputs*/ {"n_a:float"},
235       /*Outputs*/ {"n_c_uncompilable:float"},
236       /*Attributes*/ {},
237       // Node info
238       {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
239 
240   flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
241   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
242 
243   Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
244   Node* input_node = ops::UnaryOp("CompilableOp", const0, builder.opts());
245 
246   NameAttrList compilable;
247   compilable.set_name(kCompilableFunctionName);
248   NameAttrList uncompilable;
249   uncompilable.set_name(kUncompilableFunctionName);
250 
251   NodeBuilder while_builder(kFunctionalWhileNodeName, "While",
252                             builder.opts().op_registry());
253   while_builder.Input({input_node, input_node})
254       .Attr("cond", compilable)
255       .Attr("body", uncompilable);
256   builder.opts().FinalizeBuilder(&while_builder);
257 
258   GraphDef graph_def;
259   TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
260   std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
261   TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
262 
263   auto while_node_it = std::find_if(
264       graph->nodes().begin(), graph->nodes().end(),
265       [&](const Node* n) { return n->name() == kFunctionalWhileNodeName; });
266   EXPECT_NE(while_node_it, graph->nodes().end());
267 
268   auto* flib_runtime = GetFunctionLibraryRuntime();
269 
270   EXPECT_FALSE(checker_->IsCompilableNode(**while_node_it, flib_runtime));
271   const auto uncompilable_nodes =
272       checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
273   ASSERT_EQ(1, uncompilable_nodes.size());
274 
275   NameAttrList function;
276   function.set_name(kUncompilableFunctionName);
277   const auto node_info_it =
278       uncompilable_nodes.find(function.ShortDebugString());
279   ASSERT_NE(uncompilable_nodes.end(), node_info_it);
280   const auto& uncompilable_node_list = node_info_it->second.second;
281   ASSERT_EQ(1, uncompilable_node_list.size());
282   const auto& node_info = uncompilable_node_list.at(0);
283 
284   const auto& node_stack = node_info.stack_trace;
285   ASSERT_EQ(2, node_stack.size());
286   const auto& stacktrace_first_node_info = node_stack.at(0);
287   EXPECT_EQ(kFunctionalWhileNodeName, stacktrace_first_node_info.name);
288   EXPECT_EQ("", stacktrace_first_node_info.function_name);
289 
290   const auto& stacktrace_second_node_info = node_stack.at(1);
291   EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
292   EXPECT_EQ(kUncompilableFunctionName,
293             stacktrace_second_node_info.function_name);
294 
295   EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
296   EXPECT_TRUE(
297       absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
298 }
299 
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalIfNode)300 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
301   FunctionDefLibrary flib;
302   *flib.add_function() = FunctionDefHelper::Define(
303       /*Function*/ kUncompilableFunctionName,
304       /*Inputs*/ {"n_a:float"},
305       /*Outputs*/ {"n_c_uncompilable:float"},
306       /*Attributes*/ {},
307       // Node info
308       {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
309   *flib.add_function() = FunctionDefHelper::Define(
310       /*Function*/ kUncompilableFunctionTwoName,
311       /*Inputs*/ {"n_a:float"},
312       /*Outputs*/ {"n_d_uncompilable:float"},
313       /*Attribute*/ {},
314       // Node info
315       {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
316   NameAttrList uncompilable_fn1_attr;
317   uncompilable_fn1_attr.set_name(kUncompilableFunctionName);
318   NameAttrList uncompilable_fn2_attr;
319   uncompilable_fn2_attr.set_name(kUncompilableFunctionTwoName);
320 
321   Scope root = Scope::NewRootScope().ExitOnError();
322   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
323   auto predicate = ops::Placeholder(root.WithOpName("pred"), DT_BOOL);
324   auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
325   std::vector<NodeBuilder::NodeOut> if_inputs(
326       {NodeBuilder::NodeOut(placeholder.node())});
327   Node* if_node;
328   TF_ASSERT_OK(
329       NodeBuilder(kFunctionalIfNodeName, "If", &root.graph()->flib_def())
330           .Input(predicate.node())
331           .Input(if_inputs)
332           .Attr("then_branch", uncompilable_fn1_attr)
333           .Attr("else_branch", uncompilable_fn2_attr)
334           .Attr("Tout", {DT_INT32})
335           .Finalize(root.graph(), &if_node));
336   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
337   TF_ASSERT_OK(root.ToGraph(graph.get()));
338 
339   flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
340 
341   auto if_node_it = std::find_if(
342       graph->nodes().begin(), graph->nodes().end(),
343       [&](const Node* n) { return n->name() == kFunctionalIfNodeName; });
344   EXPECT_NE(if_node_it, graph->nodes().end());
345   auto* flib_runtime = GetFunctionLibraryRuntime();
346 
347   EXPECT_FALSE(checker_->IsCompilableNode(**if_node_it, flib_runtime));
348   const auto uncompilable_nodes =
349       checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
350   ASSERT_EQ(2, uncompilable_nodes.size());
351 
352   NameAttrList function_one;
353   function_one.set_name(kUncompilableFunctionName);
354   auto it = uncompilable_nodes.find(function_one.ShortDebugString());
355   ASSERT_NE(uncompilable_nodes.end(), it);
356 
357   const auto& uncompilable_node_list = it->second.second;
358   ASSERT_EQ(1, uncompilable_node_list.size());
359   const auto& uncompilable_node_one = uncompilable_node_list.at(0);
360   const auto& node_one_stack = uncompilable_node_one.stack_trace;
361 
362   ASSERT_EQ(2, node_one_stack.size());
363   const auto& node_one_stacktrace_first_node = node_one_stack.at(0);
364   EXPECT_EQ(kFunctionalIfNodeName, node_one_stacktrace_first_node.name);
365   EXPECT_EQ("", node_one_stacktrace_first_node.function_name);
366 
367   const auto& stacktrace_second_node_info = node_one_stack.at(1);
368   EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
369   EXPECT_EQ(kUncompilableFunctionName,
370             stacktrace_second_node_info.function_name);
371 
372   EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
373   EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
374                                 "unsupported op"));
375 
376   NameAttrList function_two;
377   function_two.set_name(kUncompilableFunctionTwoName);
378   it = uncompilable_nodes.find(function_two.ShortDebugString());
379   ASSERT_NE(uncompilable_nodes.end(), it);
380 
381   const auto& uncompilable_node_two_list = it->second.second;
382   ASSERT_EQ(1, uncompilable_node_two_list.size());
383   const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
384   const auto& node_two_stack = uncompilable_node_two.stack_trace;
385   ASSERT_EQ(2, node_two_stack.size());
386   const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
387   EXPECT_EQ(kFunctionalIfNodeName, node_two_stacktrace_first_node.name);
388   EXPECT_EQ("", node_two_stacktrace_first_node.function_name);
389 
390   const auto& node_two_stacktrace_second_node = node_two_stack.at(1);
391   EXPECT_EQ(kUncompilableFunctionNodeTwoName,
392             node_two_stacktrace_second_node.name);
393   EXPECT_EQ(kUncompilableFunctionTwoName,
394             node_two_stacktrace_second_node.function_name);
395 
396   EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
397   EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
398                                 "unsupported op"));
399 }
400 
TEST_F(CompilabilityCheckUtilTest,CheckFunctionalCaseNode)401 TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
402   FunctionDefLibrary flib;
403   *flib.add_function() = FunctionDefHelper::Define(
404       /*Function*/ kUncompilableFunctionName,
405       /*Inputs*/ {"n_a:float"},
406       /*Outputs*/ {"n_c_uncompilable:float"},
407       /*Attributes*/ {},
408       // Node info
409       {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
410   *flib.add_function() = FunctionDefHelper::Define(
411       /*Function*/ kUncompilableFunctionTwoName,
412       /*Inputs*/ {"n_a:float"},
413       /*Outputs*/ {"n_d_uncompilable:float"},
414       /*Attribute*/ {},
415       // Node info
416       {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
417 
418   Scope root = Scope::NewRootScope().ExitOnError();
419   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
420   auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
421   auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
422   std::vector<NodeBuilder::NodeOut> inputes(
423       {NodeBuilder::NodeOut(placeholder.node())});
424   Node* case_node;
425   TF_ASSERT_OK(
426       NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
427           .Input(branch_index.node())
428           .Input(inputes)
429           .Attr("branches", FuncListAttr({kUncompilableFunctionName,
430                                           kUncompilableFunctionTwoName}))
431           .Attr("Tout", {DT_INT32})
432           .Finalize(root.graph(), &case_node));
433   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
434   TF_ASSERT_OK(root.ToGraph(graph.get()));
435 
436   flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
437 
438   auto case_node_it = std::find_if(
439       graph->nodes().begin(), graph->nodes().end(),
440       [&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
441   EXPECT_NE(case_node_it, graph->nodes().end());
442   auto* flib_runtime = GetFunctionLibraryRuntime();
443 
444   op_filter_.require_always_compilable = false;
445   checker_ = CreateCompilabilityChecker();
446   EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
447   op_filter_.require_always_compilable = true;
448   checker_ = CreateCompilabilityChecker();
449   EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
450 }
451 
TEST_F(CompilabilityCheckUtilTest,TestCanNotTriggerXlaCompilation)452 TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
453   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
454   Scope root = Scope::NewRootScope().ExitOnError();
455   FunctionDefLibrary library;
456 
457   FunctionDef identity_func = FunctionDefHelper::Create(
458       "IdentityFunc",
459       /*in_def=*/{"x:float"},
460       /*out_def=*/{"res:float"},
461       /*attr_def=*/{},
462       /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
463       /*ret_def*/ {{"res", "t0:output"}});
464 
465   *library.add_function() = identity_func;
466 
467   Output in = ops::Placeholder(root, DT_FLOAT);
468   NameAttrList b_name_attr;
469   b_name_attr.set_name("IdentityFunc");
470   ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
471                             b_name_attr);
472 
473   GraphDef graph_def;
474   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
475   TF_ASSERT_OK(root.ToGraphDef(&graph_def));
476 
477   EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
478 }
479 
TEST_F(CompilabilityCheckUtilTest,TestXlaOpsCanTriggerXlaCompilation)480 TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
481   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
482   Scope root = Scope::NewRootScope().ExitOnError();
483   FunctionDefLibrary library;
484 
485   FunctionDef sort_func = FunctionDefHelper::Create(
486       "SortFunc",
487       /*in_def=*/{"x:float"},
488       /*out_def=*/{"res:float"},
489       /*attr_def=*/{},
490       /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
491       /*ret_def*/ {{"res", "t0:output"}});
492 
493   *library.add_function() = sort_func;
494 
495   Output in = ops::Placeholder(root, DT_FLOAT);
496   NameAttrList b_name_attr;
497   b_name_attr.set_name("SortFunc");
498   ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
499                             b_name_attr);
500 
501   GraphDef graph_def;
502   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
503   TF_ASSERT_OK(root.ToGraphDef(&graph_def));
504 
505   EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
506 }
507 
TEST_F(CompilabilityCheckUtilTest,TestCanTriggerXlaCompilation)508 TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
509   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
510   Scope root = Scope::NewRootScope().ExitOnError();
511   FunctionDefLibrary library;
512 
513   AttrValue true_attribute;
514   true_attribute.set_b(true);
515 
516   FunctionDef identity_func = FunctionDefHelper::Create(
517       "IdentityFunc",
518       /*in_def=*/{"x:float"},
519       /*out_def=*/{"res:float"},
520       /*attr_def=*/{},
521       /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
522       /*ret_def*/ {{"res", "t0:output"}});
523 
524   (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
525 
526   FunctionDef call_identity = FunctionDefHelper::Create(
527       "CallIdentity",
528       /*in_def=*/{"x:float"},
529       /*out_def=*/{"z:float"}, /*attr_def=*/{},
530       /*node_def=*/
531       {{{"func_call"},
532         "PartitionedCall",
533         {"x"},
534         {{"Tin", DataTypeSlice({DT_FLOAT})},
535          {"Tout", DataTypeSlice({DT_FLOAT})},
536          {"f",
537           FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
538          {kXlaMustCompileAttr, true}}}},
539       /*ret_def=*/{{"z", "func_call:output:0"}});
540 
541   *library.add_function() = identity_func;
542   *library.add_function() = call_identity;
543 
544   Output in = ops::Placeholder(root, DT_FLOAT);
545   NameAttrList b_name_attr;
546   b_name_attr.set_name("CallIdentity");
547   ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
548                             b_name_attr);
549 
550   GraphDef graph_def;
551   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
552   TF_ASSERT_OK(root.ToGraphDef(&graph_def));
553 
554   EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
555 }
556 
557 }  // namespace
558 }  // namespace tensorflow
559