xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/placer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/placer.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
29 #include "tensorflow/core/common_runtime/optimization_registry.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/function_testlib.h"
33 #include "tensorflow/core/framework/kernel_def_builder.h"
34 #include "tensorflow/core/framework/op.h"
35 #include "tensorflow/core/framework/op_def_builder.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/test.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #include "tensorflow/core/protobuf/error_codes.pb.h"
47 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
48 
49 namespace tensorflow {
50 
51 using ::tensorflow::test::function::GDef;
52 using ::tensorflow::test::function::NDef;
53 using FDH = ::tensorflow::FunctionDefHelper;
54 
55 constexpr char kCPU[] = "/device:FakeCPU:0";
56 constexpr char kGPU[] = "/device:FakeGPU:0";
57 
58 constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:FakeCPU:0";
59 constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:FakeGPU:0";
60 
61 namespace {
62 
63 ////////////////////////////////////////////////////////////////////////////////
64 //
65 // Op, kernel, and device registrations to set up the environment.
66 //
67 // The Placer uses information about the op (input types),
68 // kernel (device constraints), and available devices to make
69 // placement decisions. To avoid depending on the full runtime, we
70 // define dummy implementations of these, and register them with the
71 // runtime.
72 //
73 ////////////////////////////////////////////////////////////////////////////////
74 
75 // A dummy OpKernel that is used to register ops on different devices.
76 class DummyOp : public OpKernel {
77  public:
DummyOp(OpKernelConstruction * context)78   explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)79   void Compute(OpKernelContext* context) override {}
80 };
81 
82 // A fake device that has specific device attributes, used to simulate
83 // the presence of a CPU or a GPU (without depending on that part of
84 // the runtime.
85 class FakeDevice : public Device {
86  private:
FakeDevice(const DeviceAttributes & device_attributes)87   explicit FakeDevice(const DeviceAttributes& device_attributes)
88       : Device(nullptr, device_attributes) {}
89 
90  public:
Sync()91   Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
92 
GetAllocator(AllocatorAttributes attr)93   Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
94 
MakeDevice(const string & name,const string & device_type)95   static std::unique_ptr<Device> MakeDevice(const string& name,
96                                             const string& device_type) {
97     DeviceAttributes device_attributes;
98     device_attributes.set_name(name);
99     device_attributes.set_device_type(device_type);
100     return std::unique_ptr<Device>(new FakeDevice(device_attributes));
101   }
102 
MakeCPU(const string & name)103   static std::unique_ptr<Device> MakeCPU(const string& name) {
104     return MakeDevice(name, "FakeCPU");
105   }
106 
MakeGPU(const string & name)107   static std::unique_ptr<Device> MakeGPU(const string& name) {
108     return MakeDevice(name, "FakeGPU");
109   }
110 };
111 
112 class DummyFactory : public DeviceFactory {
113  public:
ListPhysicalDevices(std::vector<string> * devices)114   Status ListPhysicalDevices(std::vector<string>* devices) override {
115     return OkStatus();
116   }
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)117   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
118                        std::vector<std::unique_ptr<Device>>* devices) override {
119     return OkStatus();
120   }
121 };
122 
123 // Device order now depends on the registration of devices, not a fixed
124 // value in device_set.cc.  To avoid the need to link in the real CPU and GPU
125 // devices into this test, we create fake devices and registrations that
126 // can stand-in for the real devices for the purposes of testing placement
127 // and ordering.
128 REGISTER_LOCAL_DEVICE_FACTORY("FakeCPU", DummyFactory);
129 REGISTER_LOCAL_DEVICE_FACTORY("FakeGPU", DummyFactory, 51);
130 
131 // Register the following ops so they can be added to a Graph, and
132 // kernels so that they can be placed on particular device types.
133 REGISTER_OP("TestVariable").Output("o: Ref(float)");
134 REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeCPU"), DummyOp);
135 REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeGPU"), DummyOp);
136 
137 REGISTER_OP("VariableCPU").Output("o: Ref(float)");
138 REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device("FakeCPU"), DummyOp);
139 
140 REGISTER_OP("VariableGPU").Output("o: Ref(float)");
141 REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device("FakeGPU"), DummyOp);
142 
143 REGISTER_OP("VariableNoKernels").Output("o: Ref(float)");
144 
145 REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float");
146 REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeCPU"), DummyOp);
147 REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeGPU"), DummyOp);
148 
149 REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
150 REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeCPU"), DummyOp);
151 REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeGPU"), DummyOp);
152 
153 REGISTER_OP("ReluCPU").Input("i: float").Output("o: float");
154 REGISTER_KERNEL_BUILDER(Name("ReluCPU").Device("FakeCPU"), DummyOp);
155 
156 REGISTER_OP("ReluGPU").Input("i: float").Output("o: float");
157 REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device("FakeGPU"), DummyOp);
158 
159 REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float");
160 REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeCPU"), DummyOp);
161 REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeGPU"), DummyOp);
162 
163 REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float");
164 REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device("FakeCPU"), DummyOp);
165 
166 REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float");
167 REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device("FakeGPU"), DummyOp);
168 
169 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
170 REGISTER_KERNEL_BUILDER(Name("TestInput").Device("FakeCPU"), DummyOp);
171 
172 // Op producing an output that can be placed on CPU or GPU.
173 REGISTER_OP("TestCPUGPUOutput").Output("a: float");
174 REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeCPU"), DummyOp);
175 REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeGPU"), DummyOp);
176 
177 REGISTER_OP("TestGPUOutput").Output("a: float");
178 REGISTER_KERNEL_BUILDER(Name("TestGPUOutput").Device("FakeGPU"), DummyOp);
179 
180 REGISTER_OP("TestDevice").Output("a: float").Output("b: float");
181 REGISTER_KERNEL_BUILDER(Name("TestDevice").Device("FakeGPU"), DummyOp);
182 
183 REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float");
184 REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeCPU"), DummyOp);
185 REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp);
186 
187 REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp);
188 REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp);
189 
190 // Op that has kernels with device priorities specified.
191 REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float");
192 REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2),
193                         DummyOp);
194 REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1),
195                         DummyOp);
196 
197 // Op that has kernels with XLA device priority higher than FakeCPU.
198 REGISTER_OP("TestXlaOp").Input("a: float").Output("b: float");
199 REGISTER_KERNEL_BUILDER(Name("TestXlaOp").Device("XLA_CPU").Priority(2),
200                         DummyOp);
201 REGISTER_KERNEL_BUILDER(Name("TestXlaOp").Device("FakeCPU").Priority(1),
202                         DummyOp);
203 
204 // Op with no-copy type definition.
205 REGISTER_OP("TestUncopiableTypeGeneratorCPU")
206     .Output("d: variant")
207     .SetTypeConstructor(full_type::UnaryGeneric(TFT_DATASET));
208 REGISTER_KERNEL_BUILDER(
209     Name("TestUncopiableTypeGeneratorCPU").Device("FakeCPU"), DummyOp);
210 
211 // Op consuming a typed input.
212 REGISTER_OP("TestTypedConsumer").Input("i: variant");
213 REGISTER_KERNEL_BUILDER(Name("TestTypedConsumer").Device("FakeCPU"), DummyOp);
214 REGISTER_KERNEL_BUILDER(Name("TestTypedConsumer").Device("FakeGPU"), DummyOp);
215 
216 ////////////////////////////////////////////////////////////////////////////////
217 //
218 // A PlacerTest method has three phases:
219 //
220 // 1. Build a TensorFlow graph, with no (or partial) device assignments.
221 // 2. Attempt to compute a placement using the Placer.
222 // 3. EITHER: test that the constraints implied by the graph are respected;
223 //    or that an appropriate error was reported.
224 //
225 ////////////////////////////////////////////////////////////////////////////////
226 class PlacerTest : public ::testing::Test {
227  protected:
PlacerTest()228   PlacerTest() : PlacerTest(10) {}
229 
PlacerTest(int num_devices)230   explicit PlacerTest(int num_devices) {
231     // Build a set of num_devices GPU, num_devices CPU devices, and one XLA_CPU
232     // device.
233     // NOTE: this->local_devices_ owns the device objects;
234     // this->devices_ contains borrowed pointers to the device
235     // objects.
236     for (int i = 0; i < num_devices; ++i) {
237       local_devices_.emplace_back(FakeDevice::MakeCPU(
238           strings::StrCat("/job:a/replica:0/task:0/device:FakeCPU:", i)));
239       devices_.AddDevice(local_devices_.back().get());
240       // Insert the GPUs in reverse order.
241       local_devices_.emplace_back(FakeDevice::MakeGPU(strings::StrCat(
242           "/job:a/replica:0/task:0/device:FakeGPU:", num_devices - 1 - i)));
243       devices_.AddDevice(local_devices_.back().get());
244     }
245     local_devices_.emplace_back(FakeDevice::MakeDevice(
246         "/job:a/replica:0/task:0/device:XLA_CPU:0", "XLA_CPU"));
247     devices_.AddDevice(local_devices_.back().get());
248     local_devices_.emplace_back(FakeDevice::MakeDevice(
249         "/job:a/replica:0/task:0/device:COMPOSITE:0", "COMPOSITE"));
250     devices_.AddDevice(local_devices_.back().get());
251   }
252 
253   // Builds the given graph, and (if successful) indexes the node
254   // names for use in placement, and later lookup.
BuildGraph(const GraphDefBuilder & builder,Graph * out_graph)255   Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
256     TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph));
257     RebuildNodeNameMap(*out_graph);
258     return OkStatus();
259   }
260 
BuildGraph(const GraphDef & graph_def,Graph * out_graph)261   Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) {
262     GraphConstructorOptions opts;
263     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph));
264     RebuildNodeNameMap(*out_graph);
265     return OkStatus();
266   }
267 
268   // Invokes the Placer on "graph". If no DeviceSet is specified, the
269   // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
270   //
271   // REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
Place(Graph * graph,DeviceSet * devices,Device * default_local_device,bool allow_soft_placement,bool log_device_placement)272   Status Place(Graph* graph, DeviceSet* devices, Device* default_local_device,
273                bool allow_soft_placement, bool log_device_placement) {
274     Placer placer(graph, "", &graph->flib_def(), devices, default_local_device,
275                   allow_soft_placement, log_device_placement);
276     return placer.Run();
277   }
278 
CallOptPassesAndPlace(Graph * graph,DeviceSet * devices,bool allow_soft_placement,bool log_device_placement)279   Status CallOptPassesAndPlace(Graph* graph, DeviceSet* devices,
280                                bool allow_soft_placement,
281                                bool log_device_placement) {
282     // Disable all real optimizations (i.e. Grappler and GraphOptimizer)
283     // to make sure functions are not inlined and not constant folded
284     SessionOptions session_options;
285     GraphOptions* graph_opts = session_options.config.mutable_graph_options();
286     OptimizerOptions* optimizer_opts = graph_opts->mutable_optimizer_options();
287     optimizer_opts->set_opt_level(OptimizerOptions::L0);
288     optimizer_opts->set_global_jit_level(OptimizerOptions::OFF);
289     RewriterConfig* rewriter_config = graph_opts->mutable_rewrite_options();
290     rewriter_config->set_disable_meta_optimizer(true);
291 
292     // Placing nested functions requires go through some PRE_PLACEMENT passes.
293     // Currently, just the IsolateDeepOpsPass.
294     GraphOptimizationPassOptions optimization_options;
295     std::unique_ptr<Graph> graph_ptr(graph);
296     optimization_options.graph = &graph_ptr;
297     FunctionLibraryDefinition flib_def(graph->flib_def());
298     optimization_options.flib_def = &flib_def;
299     optimization_options.device_set = &devices_;
300     optimization_options.session_options = &session_options;
301     Status s = OptimizationPassRegistry::Global()->RunGrouping(
302         OptimizationPassRegistry::PRE_PLACEMENT, optimization_options);
303     if (!s.ok()) {
304       graph_ptr.release();
305       return s;
306     }
307     graph = graph_ptr.release();
308 
309     RebuildNodeNameMap(*graph);
310 
311     Placer placer(graph, "", &graph->flib_def(), devices, nullptr,
312                   allow_soft_placement, log_device_placement);
313     return placer.Run();
314   }
315 
Place(Graph * graph,DeviceSet * devices)316   Status Place(Graph* graph, DeviceSet* devices) {
317     return Place(graph, devices, nullptr, true, false);
318   }
319 
Place(Graph * graph,bool allow_soft_placement,bool log_device_placement)320   Status Place(Graph* graph, bool allow_soft_placement,
321                bool log_device_placement) {
322     return Place(graph, &devices_, nullptr, allow_soft_placement,
323                  log_device_placement);
324   }
325 
Place(Graph * graph)326   Status Place(Graph* graph) {
327     return Place(graph, &devices_, nullptr, true, false);
328   }
329 
CallOptPassesAndPlace(Graph * graph,bool allow_soft_placement,bool log_device_placement)330   Status CallOptPassesAndPlace(Graph* graph, bool allow_soft_placement,
331                                bool log_device_placement) {
332     return CallOptPassesAndPlace(graph, &devices_, allow_soft_placement,
333                                  log_device_placement);
334   }
335 
CallOptPassesAndPlace(Graph * graph)336   Status CallOptPassesAndPlace(Graph* graph) {
337     return CallOptPassesAndPlace(graph, &devices_, true, false);
338   }
339 
340   // Returns the node in "graph" with the given name.
341   //
342   // REQUIRES: "graph" was produced by the most recent call to BuildGraph.
GetNodeByName(const Graph & graph,const string & name)343   Node* GetNodeByName(const Graph& graph, const string& name) {
344     const auto search = nodes_by_name_.find(name);
345     CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name;
346     return graph.FindNodeId(search->second);
347   }
348 
349  protected:
350   std::vector<std::unique_ptr<Device>> local_devices_;
351   DeviceSet devices_;
352   std::unordered_map<string, int> nodes_by_name_;
353 
354   Status ReferenceTestHelper(const string& variable_op_type,
355                              const string& assign_op_type,
356                              const DeviceType& expected_device_type);
357 
358  private:
RebuildNodeNameMap(const Graph & graph)359   void RebuildNodeNameMap(const Graph& graph) {
360     nodes_by_name_.clear();
361     for (Node* node : graph.nodes()) {
362       nodes_by_name_[node->name()] = node->id();
363     }
364   }
365 };
366 
367 // Fixture that add a parameter for allow_soft_placement.
368 // Test cases that want to test behavior with and without soft placement
369 // can use this fixture instead of PlacerTest.
370 class SoftPlacementPlacerTest : public PlacerTest,
371                                 public ::testing::WithParamInterface<bool> {};
372 
373 INSTANTIATE_TEST_SUITE_P(All, SoftPlacementPlacerTest,
374                          ::testing::Values(false, true),
375                          ::testing::PrintToStringParamName());
376 
377 #define EXPECT_COLOCATED(g, name_a, name_b)                         \
378   do {                                                              \
379     Graph& g_ = (g);                                                \
380     EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(),  \
381               GetNodeByName(g_, (name_b))->assigned_device_name()); \
382   } while (0)
383 
384 #define EXPECT_NOT_COLOCATED(g, name_a, name_b)                     \
385   do {                                                              \
386     Graph& g_ = (g);                                                \
387     EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(),  \
388               GetNodeByName(g_, (name_b))->assigned_device_name()); \
389   } while (0)
390 
391 #define EXPECT_DEVICE_TYPE(g, name, expected_device_type)               \
392   EXPECT_EQ(DeviceType(expected_device_type).type(),                    \
393             devices_                                                    \
394                 .FindDeviceByName(                                      \
395                     GetNodeByName((g), (name))->assigned_device_name()) \
396                 ->attributes()                                          \
397                 .device_type())
398 
399 #define EXPECT_SAME_TYPE(g, node1, node2)                                \
400   EXPECT_EQ(devices_                                                     \
401                 .FindDeviceByName(                                       \
402                     GetNodeByName((g), (node1))->assigned_device_name()) \
403                 ->attributes()                                           \
404                 .device_type(),                                          \
405             devices_                                                     \
406                 .FindDeviceByName(                                       \
407                     GetNodeByName((g), (node2))->assigned_device_name()) \
408                 ->attributes()                                           \
409                 .device_type())
410 
411 #define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
412   EXPECT_TRUE(absl::StrContains(                       \
413       GetNodeByName((g), (name))->assigned_device_name(), device_substr))
414 
415 // Test that a graph with no constraints will successfully assign nodes to the
416 // "best available" device (i.e. prefer GPU over CPU).
TEST_F(PlacerTest,TestNoConstraints)417 TEST_F(PlacerTest, TestNoConstraints) {
418   Graph g(OpRegistry::Global());
419   {  // Scope for temporary variables used to construct g.
420     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
421     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
422     ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1"));
423     ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2"));
424     TF_EXPECT_OK(BuildGraph(b, &g));
425   }
426 
427   TF_EXPECT_OK(Place(&g));
428   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
429   EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
430   EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU");
431 }
432 
433 // Test that a graph with no constraints but using kernels that have a specified
434 // device priority will successfully assign nodes to the device with higher
435 // priority
TEST_F(PlacerTest,TestNoConstraintsWithPrioritizedKernels)436 TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) {
437   Graph g(OpRegistry::Global());
438   {  // Scope for temporary variables used to construct g.
439     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
440     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
441     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
442                  b.opts().WithName("n1"));
443     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1),
444                  b.opts().WithName("n2"));
445     TF_EXPECT_OK(BuildGraph(b, &g));
446   }
447 
448   TF_EXPECT_OK(Place(&g));
449   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
450   EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
451   EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
452 }
453 
454 // Test that if the node supports XLA_CPU and FakeCPU, it will be placed on
455 // XLA_CPU if and only if the node is assigned to the XLA_CPU device.
TEST_F(PlacerTest,TestXlaOpPlacement)456 TEST_F(PlacerTest, TestXlaOpPlacement) {
457   Graph g(OpRegistry::Global());
458   {  // Scope for temporary variables used to construct g.
459     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
460     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
461     ops::UnaryOp("TestXlaOp", ops::NodeOut(input, 0), b.opts().WithName("n1"));
462     ops::UnaryOp("TestXlaOp", ops::NodeOut(input, 1), b.opts().WithName("n2"));
463     TF_EXPECT_OK(BuildGraph(b, &g));
464   }
465 
466   GetNodeByName(g, "n2")->set_assigned_device_name(
467       "/job:a/replica:0/task:0/device:XLA_CPU:0");
468 
469   TF_EXPECT_OK(Place(&g));
470   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
471   // n1 should be placed on FakeCPU even if the op supports XLA_CPU with higher
472   // priority than FakeCPU.
473   EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
474   // n2 should be placed on XLA_CPU because it supports XLA_CPU and it is
475   // assigned to a XLA_CPU device.
476   EXPECT_DEVICE_TYPE(g, "n2", "XLA_CPU");
477 }
478 
TEST_F(PlacerTest,TestGPUInputIntoPrioritizedKernel)479 TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) {
480   Graph g(OpRegistry::Global());
481   {
482     // Scope for temp variables used to construct g.
483     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
484     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
485     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
486                  b.opts().WithName("n1"));
487     TF_EXPECT_OK(BuildGraph(b, &g));
488   }
489 
490   TF_EXPECT_OK(Place(&g));
491   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
492   EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
493 }
494 
495 // Tests that a GPU kernel colocated with prioritized kernel respects it.
TEST_F(PlacerTest,TestGPUInputColocatedWithPrioritizedKernel)496 TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) {
497   Graph g(OpRegistry::Global());
498   {
499     // Scope for temp variables used to construct g.
500     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
501     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
502     // We colocate n1 with in.
503     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
504                  b.opts().WithName("n1").WithAttr("_class", {"loc:@in"}));
505     // We don't colocate n2 with in.
506     ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
507                  b.opts().WithName("n2"));
508     TF_EXPECT_OK(BuildGraph(b, &g));
509   }
510 
511   TF_EXPECT_OK(Place(&g));
512   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
513   EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
514   EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
515 }
516 
517 REGISTER_OP("CreateDatasetCPU").Output("o: resource");
518 REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp);
519 REGISTER_OP("CreateDatasetGPU").Output("o: resource");
520 REGISTER_KERNEL_BUILDER(Name("CreateDatasetGPU").Device("FakeGPU"), DummyOp);
521 
522 REGISTER_OP("CreateDatasetSP").Output("o: resource");
523 REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2),
524                         DummyOp);
525 REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1),
526                         DummyOp);
527 
528 REGISTER_OP("CreateDatasetRP").Output("o: resource");
529 REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1),
530                         DummyOp);
531 REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2),
532                         DummyOp);
533 
534 REGISTER_OP("CreateDatasetNP").Output("o: resource");
535 REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp);
536 REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp);
537 
538 REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float");
539 REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp);
540 REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp);
541 
542 REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float");
543 REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2),
544                         DummyOp);
545 REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1),
546                         DummyOp);
547 
548 REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float");
549 REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1),
550                         DummyOp);
551 REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2),
552                         DummyOp);
553 
554 REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float");
555 REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp);
556 
557 // Test reference edges with one node having prioritized kernels and the other
558 // has no preference. We should respect priority here.
TEST_F(PlacerTest,TestDSWithPriority)559 TEST_F(PlacerTest, TestDSWithPriority) {
560   Graph g(OpRegistry::Global());
561   {
562     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
563     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
564     ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
565     TF_EXPECT_OK(BuildGraph(b, &g));
566   }
567   TF_EXPECT_OK(Place(&g));
568   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
569   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
570 }
571 
572 // Test reference edges with one node having kernels with regular priority and
573 // the other has no preference. We should respect priority here.
TEST_F(PlacerTest,TestDSWithGPUPriority)574 TEST_F(PlacerTest, TestDSWithGPUPriority) {
575   Graph g(OpRegistry::Global());
576   {
577     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
578     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
579     ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
580     TF_EXPECT_OK(BuildGraph(b, &g));
581   }
582   TF_EXPECT_OK(Place(&g));
583   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
584   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
585 }
586 
587 // Test reference edges with one node having prioritized kernels and the other
588 // has no preference. We should respect priority here.
TEST_F(PlacerTest,TestITWithPriority)589 TEST_F(PlacerTest, TestITWithPriority) {
590   Graph g(OpRegistry::Global());
591   {
592     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
593     Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
594     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
595     TF_EXPECT_OK(BuildGraph(b, &g));
596   }
597   TF_EXPECT_OK(Place(&g));
598   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
599   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
600 }
601 
602 // Test reference edges with one node having kernels with regular priority and
603 // the other has no preference. We should respect priority here.
TEST_F(PlacerTest,TestITWithGPUPriority)604 TEST_F(PlacerTest, TestITWithGPUPriority) {
605   Graph g(OpRegistry::Global());
606   {
607     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
608     Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
609     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
610     TF_EXPECT_OK(BuildGraph(b, &g));
611   }
612   TF_EXPECT_OK(Place(&g));
613   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
614   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
615 }
616 
617 // Test reference edges with one node having prioritized kernels and other node
618 // can only be placed on GPU. We should respect the constraint then.
TEST_F(PlacerTest,TestITGPU)619 TEST_F(PlacerTest, TestITGPU) {
620   Graph g(OpRegistry::Global());
621   {
622     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
623     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
624     ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it"));
625     TF_EXPECT_OK(BuildGraph(b, &g));
626   }
627   TF_EXPECT_OK(Place(&g));
628   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
629   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
630 }
631 
632 // Test reference edges with one node having prioritized kernels and other node
633 // can only be placed on CPU. We should respect the constraint then.
TEST_F(PlacerTest,TestSimpleIteratorOnlyGPU)634 TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) {
635   Graph g(OpRegistry::Global());
636   {
637     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
638     Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds"));
639     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
640     TF_EXPECT_OK(BuildGraph(b, &g));
641   }
642   TF_EXPECT_OK(Place(&g));
643   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
644   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
645 }
646 
647 // Test constraints with agreeing priorities.
TEST_F(PlacerTest,TestAgreeingPriorities)648 TEST_F(PlacerTest, TestAgreeingPriorities) {
649   Graph g(OpRegistry::Global());
650   {
651     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
652     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
653     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
654     TF_EXPECT_OK(BuildGraph(b, &g));
655   }
656   TF_EXPECT_OK(Place(&g));
657   EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
658   EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
659 }
660 
661 // Test constraints with agreeing regular priorities.
TEST_F(PlacerTest,TestAgreeingRegularPriorities)662 TEST_F(PlacerTest, TestAgreeingRegularPriorities) {
663   Graph g(OpRegistry::Global());
664   {
665     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
666     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
667     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
668     TF_EXPECT_OK(BuildGraph(b, &g));
669   }
670   TF_EXPECT_OK(Place(&g));
671   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
672   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
673 }
674 
675 // Test constraints with different priorities. In this case, we should bail
676 // and just revert to default.
TEST_F(PlacerTest,TestConflictingPriorities)677 TEST_F(PlacerTest, TestConflictingPriorities) {
678   Graph g(OpRegistry::Global());
679   {
680     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
681     Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
682     ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
683     TF_EXPECT_OK(BuildGraph(b, &g));
684   }
685   TF_EXPECT_OK(Place(&g));
686   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
687   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
688 }
689 
690 // Test constraints with different priorities. In this case, we should bail
691 // and just revert to default.
TEST_F(PlacerTest,TestConflictingPrioritiesReversed)692 TEST_F(PlacerTest, TestConflictingPrioritiesReversed) {
693   Graph g(OpRegistry::Global());
694   {
695     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
696     Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
697     ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
698     TF_EXPECT_OK(BuildGraph(b, &g));
699   }
700   TF_EXPECT_OK(Place(&g));
701   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
702   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
703 }
704 
705 // Test that a graph with device type and reference constraints on
706 // some of the ops will successfully assign nodes to the constrained
707 // device, and colocate nodes with reference connections.
TEST_F(PlacerTest,TestDeviceTypeConstraints)708 TEST_F(PlacerTest, TestDeviceTypeConstraints) {
709   Graph g(OpRegistry::Global());
710   {  // Scope for temporary variables used to construct g.
711     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
712     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
713     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
714     ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu"));
715     Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
716     ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu"));
717     TF_EXPECT_OK(BuildGraph(b, &g));
718   }
719 
720   TF_EXPECT_OK(Place(&g));
721   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
722   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
723   EXPECT_DEVICE_TYPE(g, "assign_cpu", "FakeCPU");
724   EXPECT_COLOCATED(g, "var_cpu", "assign_cpu");
725   EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU");
726   EXPECT_DEVICE_TYPE(g, "assign_gpu", "FakeGPU");
727   EXPECT_COLOCATED(g, "var_gpu", "assign_gpu");
728 }
729 
TEST_F(PlacerTest,TestMetadataColocatedWithInput)730 TEST_F(PlacerTest, TestMetadataColocatedWithInput) {
731   Graph g(OpRegistry::Global());
732   {  // Scope for temporary variables used to construct g.
733     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
734     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
735 
736     // Normally, shape has a GPU implementation and would be placed
737     // on GPU.  However, because it is a metadata operation, it is
738     // placed on CPU to avoid transferring the data from CPU to GPU.
739     ops::UnaryOp("Shape", var_cpu, b.opts().WithName("shape_op"));
740     TF_EXPECT_OK(BuildGraph(b, &g));
741   }
742 
743   TF_EXPECT_OK(Place(&g));
744   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
745   EXPECT_DEVICE_TYPE(g, "shape_op", "FakeCPU");
746   EXPECT_COLOCATED(g, "var_cpu", "shape_op");
747 }
748 
749 // Heuristic A implements "Island fusing": if a node only generates
750 // an output and it has only one consumer, we place the node
751 // with its consumer.
TEST_F(PlacerTest,TestHeuristicGeneratorFollowsSingleConsumer)752 TEST_F(PlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) {
753   Graph g(OpRegistry::Global());
754   {  // Scope for temporary variables used to construct g.
755     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
756 
757     // A variable is only on CPU
758     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
759 
760     // The constant to be assigned can be on both GPU or CPU.
761     //
762     // Because of the heuristic, it gets placed on CPU to avoid a
763     // copy.
764     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
765 
766     // The assign is bound to CPU by the reference edge.
767     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
768 
769     TF_EXPECT_OK(BuildGraph(b, &g));
770   }
771 
772   TF_EXPECT_OK(Place(&g));
773   EXPECT_COLOCATED(g, "var_cpu", "in");
774   EXPECT_COLOCATED(g, "assign", "in");
775 }
776 
TEST_F(PlacerTest,TestIgnoreGeneratorHeuristicIfWrongDevice)777 TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) {
778   Graph g(OpRegistry::Global());
779   {  // Scope for temporary variables used to construct g.
780     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
781 
782     // A variable is only on CPU
783     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
784 
785     // The constant to be assigned can only be on GPU.
786     //
787     // The heuristic to place the generator with its consumer does
788     // not apply since the consumer's device is not in the list
789     // of valid devices for the generator.
790     Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
791 
792     // The assign is bound to CPU by the reference edge.
793     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
794 
795     TF_EXPECT_OK(BuildGraph(b, &g));
796   }
797 
798   TF_EXPECT_OK(Place(&g));
799   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
800   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
801   EXPECT_COLOCATED(g, "var_cpu", "assign");
802 }
803 
TEST_F(PlacerTest,TestIgnoreGeneratorHeuristicIfWrongPartialDevice)804 TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) {
805   Graph g(OpRegistry::Global());
806   {  // Scope for temporary variables used to construct g.
807     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
808 
809     // A variable is only on CPU
810     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
811 
812     // The constant to be assigned can be on CPU or GPU, but is explicitly
813     // placed on CPU:1.
814     //
815     // The heuristic to place the generator with its consumer does
816     // not apply since the consumer's device is not in the list
817     // of valid devices for the generator.
818     Node* input =
819         ops::SourceOp("TestCPUGPUOutput",
820                       b.opts().WithName("in").WithDevice("/device:FakeCPU:1"));
821 
822     // The assign is bound to CPU by the reference edge.
823     ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign"));
824 
825     TF_EXPECT_OK(BuildGraph(b, &g));
826   }
827 
828   TF_EXPECT_OK(Place(&g));
829   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
830   EXPECT_DEVICE_CONTAINS(g, "in", "/device:FakeCPU:1");
831   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
832   EXPECT_COLOCATED(g, "var_cpu", "assign");
833   EXPECT_DEVICE_CONTAINS(g, "var_cpu", "/device:FakeCPU:0");
834 }
835 
836 // Test that a graph with partial device specifications on the ops
837 // will successfully
TEST_F(PlacerTest,TestPartialSpec)838 TEST_F(PlacerTest, TestPartialSpec) {
839   Graph g(OpRegistry::Global());
840   {  // Scope for temporary variables used to construct g.
841     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
842     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a"));
843     ops::SourceOp("TestVariable",
844                   b.opts().WithName("var").WithDevice("/job:a"));
845     TF_EXPECT_OK(BuildGraph(b, &g));
846   }
847 
848   TF_EXPECT_OK(Place(&g));
849   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
850   EXPECT_DEVICE_CONTAINS(g, "in", "/job:a");
851   EXPECT_DEVICE_TYPE(g, "var", "FakeGPU");
852   EXPECT_DEVICE_CONTAINS(g, "var", "/job:a");
853 }
854 
855 // Test that a node with a pre-assigned device is not relocated.
TEST_F(PlacerTest,TestAssignedDevicePreserved)856 TEST_F(PlacerTest, TestAssignedDevicePreserved) {
857   Graph g(OpRegistry::Global());
858   {  // Scope for temporary variables used to construct g.
859     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
860     ops::SourceOp("TestInput", b.opts().WithName("in"));
861     TF_EXPECT_OK(BuildGraph(b, &g));
862   }
863 
864   GetNodeByName(g, "in")->set_assigned_device_name(
865       "/job:a/replica:0/task:0/device:FakeCPU:7");
866 
867   TF_EXPECT_OK(Place(&g));
868   EXPECT_EQ("/job:a/replica:0/task:0/device:FakeCPU:7",
869             GetNodeByName(g, "in")->assigned_device_name());
870 }
871 
872 // Test that a graph with partial device specifications for CPU-only ops
873 // will be relocated to CPU.
TEST_F(PlacerTest,TestPartialSpecGpuToCpu)874 TEST_F(PlacerTest, TestPartialSpecGpuToCpu) {
875   Graph g(OpRegistry::Global());
876   {  // Scope for temporary variables used to construct g.
877     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
878     ops::SourceOp("TestInput",
879                   b.opts().WithName("in").WithDevice("/device:FakeGPU:0"));
880     ops::SourceOp("TestVariable",
881                   b.opts().WithName("var").WithDevice("/device:FakeGPU:0"));
882     TF_EXPECT_OK(BuildGraph(b, &g));
883   }
884 
885   TF_EXPECT_OK(Place(&g, true, false));
886   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
887   EXPECT_DEVICE_CONTAINS(g, "in", "/device:FakeCPU");
888   EXPECT_DEVICE_TYPE(g, "var", "FakeGPU");
889   EXPECT_DEVICE_CONTAINS(g, "var", "/device:FakeGPU:0");
890 }
891 
892 // Test that a resource with requested device will be moved to another
893 // device if it is processed by an op that is not supported on requested device.
TEST_F(PlacerTest,TestResourceMove)894 TEST_F(PlacerTest, TestResourceMove) {
895   Graph g(OpRegistry::Global());
896   {
897     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
898     Node* ds =
899         ops::SourceOp("CreateDatasetSP",
900                       b.opts().WithName("ds").WithDevice("/device:FakeCPU:0"));
901     ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it"));
902     TF_EXPECT_OK(BuildGraph(b, &g));
903   }
904   TF_EXPECT_OK(Place(&g));
905   EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
906   EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
907 }
908 
909 // Test that a node with an assigned GPU device but has not registered
910 // OpKernel will fail.
TEST_F(PlacerTest,TestAssignedGpuDeviceToCpuDevice)911 TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) {
912   Graph g(OpRegistry::Global());
913   {  // Scope for temporary variables used to construct g.
914     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
915     ops::SourceOp("TestInput", b.opts().WithName("in"));
916     TF_EXPECT_OK(BuildGraph(b, &g));
917   }
918 
919   GetNodeByName(g, "in")->set_assigned_device_name(
920       "/job:a/replica:0/task:0/device:FakeGPU:0");
921 
922   Status s = Place(&g);
923   EXPECT_EQ(error::INTERNAL, s.code()) << s.ToString();
924   EXPECT_TRUE(absl::StrContains(
925       s.error_message(),
926       "Assigned device '/job:a/replica:0/task:0/device:FakeGPU:0' "
927       "does not have registered OpKernel support for TestInput"))
928       << s.ToString();
929 }
930 
931 // Test that graphs with reference connections are correctly placed.
932 
933 // Build a graph containing a Variable op of "variable_op_type" and an
934 // Assign op of "assign_op_type", and expect all of the ops to be
935 // placed on a device of type "expected_device_type".
ReferenceTestHelper(const string & variable_op_type,const string & assign_op_type,const DeviceType & expected_device_type)936 Status PlacerTest::ReferenceTestHelper(const string& variable_op_type,
937                                        const string& assign_op_type,
938                                        const DeviceType& expected_device_type) {
939   Graph g(OpRegistry::Global());
940   {  // Scope for temporary variables used to construct g.
941     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
942     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
943     // Build ten variable-and-assignment pairs.
944     for (int i = 0; i < 10; ++i) {
945       Node* var = ops::SourceOp(variable_op_type,
946                                 b.opts().WithName(strings::StrCat("var_", i)));
947       ops::BinaryOp(assign_op_type, var, input,
948                     b.opts().WithName(strings::StrCat("assign_", i)));
949     }
950     TF_EXPECT_OK(BuildGraph(b, &g));
951   }
952 
953   TF_RETURN_IF_ERROR(Place(&g));
954 
955   for (int i = 0; i < 10; ++i) {
956     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
957                      strings::StrCat("assign_", i));
958     EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type);
959     EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type);
960   }
961 
962   return OkStatus();
963 }
964 
965 // Test all 2^3 combinations of Variable and Assignment op types
966 // (unconstrained, CPU-only, and GPU-only).
TEST_F(PlacerTest,TestReferenceConnection)967 TEST_F(PlacerTest, TestReferenceConnection) {
968   Status s;
969   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU"));
970   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU"));
971   TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", "FakeGPU"));
972   TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", "FakeCPU"));
973   TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", "FakeCPU"));
974   {
975     Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU");
976     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
977     EXPECT_TRUE(absl::StrContains(
978         s.error_message(), "no device type supports both of those nodes"));
979   }
980   TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU"));
981   {
982     Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU");
983     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
984     EXPECT_TRUE(absl::StrContains(
985         s.error_message(), "no device type supports both of those nodes"));
986   }
987   TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU"));
988 }
989 
990 // Handle-using dummy variable ops.
991 REGISTER_OP("TestHandleVariable").Output("o: resource");
992 REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeCPU"), DummyOp);
993 REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeGPU"), DummyOp);
994 
995 REGISTER_OP("HandleVariableCPU").Output("o: resource");
996 REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device("FakeCPU"), DummyOp);
997 
998 REGISTER_OP("HandleVariableGPU").Output("o: resource");
999 REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device("FakeGPU"), DummyOp);
1000 
1001 REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float");
1002 REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeCPU"), DummyOp);
1003 REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeGPU"), DummyOp);
1004 
1005 REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float");
1006 REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp);
1007 
1008 REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
1009 REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp);
1010 
1011 REGISTER_OP("TestTwoHandlesIn").Input("i: resource").Input("j: resource");
1012 REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeCPU"), DummyOp);
1013 REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeGPU"), DummyOp);
1014 
1015 // Tests all combinations of resource handles and ops using them.
TEST_F(PlacerTest,TestResourceHandle)1016 TEST_F(PlacerTest, TestResourceHandle) {
1017   auto handle_test = [this](const string& var_op_name,
1018                             const string& use_op_name, DeviceType device) {
1019     Graph g(OpRegistry::Global());
1020     {  // Scope for temporary variables used to construct g.
1021       GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1022       Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1023       Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var"));
1024       ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign"));
1025       TF_EXPECT_OK(BuildGraph(b, &g));
1026     }
1027 
1028     TF_RETURN_IF_ERROR(Place(&g));
1029 
1030     EXPECT_COLOCATED(g, "var", "assign");
1031     EXPECT_DEVICE_TYPE(g, "var", device);
1032     EXPECT_DEVICE_TYPE(g, "assign", device);
1033     return OkStatus();
1034   };
1035   TF_EXPECT_OK(
1036       handle_test("TestHandleVariable", "TestHandleAssign", "FakeGPU"));
1037   TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignCPU", "FakeCPU"));
1038   TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignGPU", "FakeGPU"));
1039   TF_EXPECT_OK(handle_test("HandleVariableCPU", "TestHandleAssign", "FakeCPU"));
1040   TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", "FakeCPU"));
1041   TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", "FakeGPU"));
1042   TF_EXPECT_OK(handle_test("HandleVariableGPU", "TestHandleAssign", "FakeGPU"));
1043   EXPECT_FALSE(
1044       handle_test("HandleVariableGPU", "HandleAssignCPU", "FakeCPU").ok());
1045   EXPECT_FALSE(
1046       handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok());
1047 }
1048 
TEST_F(PlacerTest,TestResourceHandlesOnDifferentDevicesFails)1049 TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
1050   auto handle_test = [this](bool allow_soft_placement, bool set_assigned) {
1051     Graph g(OpRegistry::Global());
1052     {  // Scope for temporary variables used to construct g.
1053       GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1054       Node* var_cpu =
1055           ops::SourceOp("TestHandleVariable", b.opts().WithName("var_cpu"));
1056       Node* var_gpu =
1057           ops::SourceOp("TestHandleVariable", b.opts().WithName("var_gpu"));
1058       ops::BinaryOp("TestTwoHandlesIn", var_cpu, var_gpu,
1059                     b.opts().WithName("two_handles_in"));
1060       TF_EXPECT_OK(BuildGraph(b, &g));
1061 
1062       if (set_assigned) {
1063         GetNodeByName(g, "var_cpu")
1064             ->set_assigned_device_name(
1065                 "/job:a/replica:0/task:0/device:FakeCPU:0");
1066         GetNodeByName(g, "var_gpu")
1067             ->set_assigned_device_name(
1068                 "/job:a/replica:0/task:0/device:FakeGPU:0");
1069       } else {
1070         GetNodeByName(g, "var_cpu")
1071             ->set_requested_device("/job:a/replica:0/task:0/device:FakeCPU:0");
1072         GetNodeByName(g, "var_gpu")
1073             ->set_requested_device("/job:a/replica:0/task:0/device:FakeGPU:0");
1074       }
1075     }
1076 
1077     Status s = Place(&g, allow_soft_placement, true);
1078     EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
1079     if (set_assigned) {
1080       EXPECT_TRUE(absl::StrContains(
1081           s.error_message(),
1082           "Cannot place the graph because a reference or resource edge "
1083           "connects "
1084           "colocation groups with incompatible assigned devices: "
1085           "/job:a/replica:0/task:0/device:FakeGPU:0 vs "
1086           "/job:a/replica:0/task:0/device:FakeCPU:0"))
1087           << s.ToString();
1088     } else {
1089       EXPECT_TRUE(absl::StrContains(
1090           s.error_message(),
1091           "Cannot place the graph because a reference or resource edge "
1092           "connects "
1093           "colocation groups with incompatible resource devices: "
1094           "/job:a/replica:0/task:0/device:FakeGPU:0 vs "
1095           "/job:a/replica:0/task:0/device:FakeCPU:0"))
1096           << s.ToString();
1097     }
1098 
1099     return OkStatus();
1100   };
1101 
1102   TF_EXPECT_OK(handle_test(false, false));
1103   TF_EXPECT_OK(handle_test(false, true));
1104   TF_EXPECT_OK(handle_test(true, false));
1105   TF_EXPECT_OK(handle_test(true, true));
1106 }
1107 
1108 // Test that an assignment of an operator to the wrong device
1109 // is ignored when it could never be satisfied (due to reference
1110 // edges, for example).
TEST_F(PlacerTest,TestReferenceConnectionIgnoreInfeasible)1111 TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) {
1112   Status s;
1113   Graph g(OpRegistry::Global());
1114   {
1115     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1116     Node* input = ops::SourceOp(
1117         "TestDevice",
1118         b.opts().WithName("in").WithDevice("/job:a/task:0/device:FakeGPU:0"));
1119     Node* var =
1120         ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice(
1121                                           "/job:a/task:0/device:FakeGPU:0"));
1122 
1123     // This op is specified on CPU, but in practice will be ignored,
1124     // because the reference edges forces it on GPU.
1125     ops::BinaryOp("TestAssign", var, input,
1126                   b.opts().WithName("assign").WithDevice(
1127                       "/job:a/task:0/device:FakeCPU:0"));
1128     TF_EXPECT_OK(BuildGraph(b, &g));
1129   }
1130 
1131   s = Place(&g, false, false);
1132   TF_EXPECT_OK(s);
1133   EXPECT_DEVICE_TYPE(g, "var_0", "FakeGPU");
1134   EXPECT_DEVICE_TYPE(g, "assign", "FakeGPU");
1135 }
1136 
1137 // Test that an assignment of an operator to the a more specified device
1138 // causes the device to maintain its more specific placement.
TEST_F(PlacerTest,TestReferenceConnectionMoreSpecificDestinationSourceWins)1139 TEST_F(PlacerTest, TestReferenceConnectionMoreSpecificDestinationSourceWins) {
1140   Status s;
1141   Graph g(OpRegistry::Global());
1142   {
1143     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1144     // Input can be on either device
1145     Node* input =
1146         ops::SourceOp("TestCPUGPUOutput",
1147                       b.opts().WithName("in").WithDevice("/job:a/task:0"));
1148 
1149     // Variable can be on either device
1150     Node* var = ops::SourceOp(
1151         "TestVariable", b.opts().WithName("var_0").WithDevice("/job:a/task:0"));
1152 
1153     // This op is specified on CPU and is more specific than the variable.
1154     // Because the variable is less specified, the variable will be
1155     // assigned to CPU.
1156     ops::BinaryOp("TestAssign", var, input,
1157                   b.opts().WithName("assign").WithDevice(
1158                       "/job:a/task:0/device:FakeCPU:0"));
1159     TF_EXPECT_OK(BuildGraph(b, &g));
1160   }
1161 
1162   s = Place(&g, false, false);
1163   TF_EXPECT_OK(s);
1164   EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU");
1165   EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
1166 }
1167 
1168 // A reference connection exists between a variable and an assign,
1169 // where the assign has a device but the variable does not.  In this
1170 // case, the variable gets placed on the location of the assign
1171 // operation.
TEST_F(PlacerTest,TestReferenceConnectionNoSourceDevice)1172 TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) {
1173   Status s;
1174   Graph g(OpRegistry::Global());
1175   {
1176     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1177     Node* input = ops::SourceOp(
1178         "TestDevice",
1179         b.opts().WithName("in").WithDevice("/job:a/task:0/device:FakeGPU:0"));
1180     Node* var = ops::SourceOp("TestVariable", b.opts().WithName("var_0"));
1181     ops::BinaryOp("TestAssign", var, input,
1182                   b.opts().WithName("assign").WithDevice(
1183                       "/job:a/task:0/device:FakeCPU:0"));
1184     TF_EXPECT_OK(BuildGraph(b, &g));
1185   }
1186 
1187   s = Place(&g, false, false);
1188   TF_EXPECT_OK(s);
1189   EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU");
1190   EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
1191 }
1192 
TEST_F(PlacerTest,TestResourceHandleOnCompositeDevice)1193 TEST_F(PlacerTest, TestResourceHandleOnCompositeDevice) {
1194   auto build_graph = [this](Graph* g) -> Status {
1195     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1196     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1197     // Build ten variable-and-assignment pairs.
1198     Node* var = ops::SourceOp("HandleVariableCPU", b.opts().WithName("var"));
1199     ops::BinaryOp("TestHandleAssign", var, input, b.opts().WithName("assign"));
1200     TF_RETURN_IF_ERROR(BuildGraph(b, g));
1201     // `var` is assigned to COMPOSITE.
1202     GetNodeByName(*g, "var")->set_assigned_device_name(
1203         "/job:a/replica:0/task:0/device:COMPOSITE:0");
1204     return OkStatus();
1205   };
1206 
1207   {
1208     // `assign` is not assigned to any device.
1209     Graph g(OpRegistry::Global());
1210     TF_ASSERT_OK(build_graph(&g));
1211     TF_ASSERT_OK(Place(&g));
1212     EXPECT_DEVICE_TYPE(g, "var", "COMPOSITE");
1213     EXPECT_DEVICE_TYPE(g, "assign", "COMPOSITE");
1214   }
1215   {
1216     // `assign` is assigned to FakeCPU.
1217     Graph g(OpRegistry::Global());
1218     TF_ASSERT_OK(build_graph(&g));
1219     GetNodeByName(g, "assign")
1220         ->set_assigned_device_name("/job:a/replica:0/task:0/device:FakeCPU:0");
1221     TF_ASSERT_OK(Place(&g));
1222     EXPECT_DEVICE_TYPE(g, "var", "COMPOSITE");
1223     EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
1224   }
1225 }
1226 
TEST_F(PlacerTest,TestColocationGroup)1227 TEST_F(PlacerTest, TestColocationGroup) {
1228   Graph g(OpRegistry::Global());
1229   {  // Scope for temporary variables used to construct g.
1230     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1231     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1232     Node* colocated_with_input = ops::UnaryOp(
1233         "TestRelu", input,
1234         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1235 
1236     // This will not be colocated with the input because TestInput is
1237     // only available on CPU and TestRelu will default to GPU.
1238     Node* not_colocated_with_input =
1239         ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
1240     CHECK(colocated_with_input);
1241     CHECK(not_colocated_with_input);
1242     TF_EXPECT_OK(BuildGraph(b, &g));
1243   }
1244 
1245   TF_EXPECT_OK(Place(&g));
1246   EXPECT_COLOCATED(g, "in", "colocated_1");
1247   EXPECT_NOT_COLOCATED(g, "in", "foo");
1248 }
1249 
TEST_F(PlacerTest,TestMultipleColocationGroups)1250 TEST_F(PlacerTest, TestMultipleColocationGroups) {
1251   Graph g(OpRegistry::Global());
1252   {  // Scope for temporary variables used to construct g.
1253     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1254     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1255     Node* colocated_with_input = ops::UnaryOp(
1256         "TestRelu", input,
1257         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1258     Node* colocated_with_input_and_other =
1259         ops::UnaryOp("TestRelu", input,
1260                      b.opts().WithName("foo").WithAttr(
1261                          "_class", {"loc:@in", "loc:@colocated_1"}));
1262     CHECK(colocated_with_input);
1263     CHECK(colocated_with_input_and_other);
1264     TF_EXPECT_OK(BuildGraph(b, &g));
1265   }
1266 
1267   TF_EXPECT_OK(Place(&g));
1268   EXPECT_COLOCATED(g, "in", "colocated_1");
1269   EXPECT_COLOCATED(g, "in", "foo");
1270 }
1271 
TEST_F(PlacerTest,TestChainColocation)1272 TEST_F(PlacerTest, TestChainColocation) {
1273   Graph g(OpRegistry::Global());
1274   {  // Scope for temporary variables used to construct g.
1275     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1276     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1277     Node* colocated_with_input = ops::UnaryOp(
1278         "TestRelu", input,
1279         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1280     Node* colocated_with_input_and_other = ops::UnaryOp(
1281         "TestRelu", input,
1282         b.opts().WithName("foo").WithAttr("_class", {"loc:@colocated_1"}));
1283     CHECK(colocated_with_input);
1284     CHECK(colocated_with_input_and_other);
1285     TF_EXPECT_OK(BuildGraph(b, &g));
1286   }
1287 
1288   TF_EXPECT_OK(Place(&g));
1289   EXPECT_COLOCATED(g, "in", "colocated_1");
1290   EXPECT_COLOCATED(g, "in", "foo");
1291 }
1292 
TEST_P(SoftPlacementPlacerTest,TestInvalidMultipleColocationGroups)1293 TEST_P(SoftPlacementPlacerTest, TestInvalidMultipleColocationGroups) {
1294   Graph g(OpRegistry::Global());
1295   {  // Scope for temporary variables used to construct g.
1296     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1297     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1298     Node* colocated_with_input = ops::UnaryOp(
1299         "ReluCPU", input,
1300         b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"}));
1301     Node* colocated_with_input_and_other =
1302         ops::UnaryOp("ReluGPU", input,
1303                      b.opts().WithName("foo").WithAttr(
1304                          "_class", {"loc:@in", "loc:@colocated_1"}));
1305     CHECK(colocated_with_input);
1306     CHECK(colocated_with_input_and_other);
1307     TF_EXPECT_OK(BuildGraph(b, &g));
1308   }
1309 
1310   bool allow_soft_placement = GetParam();
1311   Status s = Place(&g, allow_soft_placement, true);
1312   if (allow_soft_placement) {
1313     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1314     EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
1315     EXPECT_DEVICE_TYPE(g, "colocated_1", "FakeCPU");
1316     EXPECT_DEVICE_TYPE(g, "foo", "FakeGPU");
1317   } else {
1318     EXPECT_TRUE(absl::StrContains(
1319         s.error_message(),
1320         "Cannot colocate nodes {{colocation_node foo}} and "
1321         "{{colocation_node in}} because no device type supports both of those "
1322         "nodes and the other nodes colocated with them"))
1323         << s.ToString();
1324   }
1325 }
1326 
TEST_F(PlacerTest,TestColocationGroupWithReferenceConnections)1327 TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
1328   Graph g(OpRegistry::Global());
1329   {  // Scope for temporary variables used to construct g.
1330     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1331     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1332     Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
1333     Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
1334     Node* var3 = ops::SourceOp(
1335         "VariableCPU",
1336         b.opts().WithName("var3").WithDevice("/device:COMPOSITE:0"));
1337 
1338     // Two assigns (reference connections) with two different
1339     // colocation groups. Because their colocation groups all map to the
1340     // same device, this is a valid assignment.
1341     ops::BinaryOp(
1342         "TestAssign", var1, input,
1343         b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"}));
1344     ops::BinaryOp(
1345         "TestAssign", var2, input,
1346         b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
1347     ops::BinaryOp(
1348         "TestAssign", var3, input,
1349         b.opts().WithName("assign3").WithAttr("_class", {"loc:@var3"}));
1350     TF_EXPECT_OK(BuildGraph(b, &g));
1351   }
1352 
1353   TF_EXPECT_OK(Place(&g));
1354   EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
1355   EXPECT_COLOCATED(g, "in", "var1");
1356   EXPECT_COLOCATED(g, "in", "var2");
1357   EXPECT_COLOCATED(g, "var1", "assign2");
1358   EXPECT_COLOCATED(g, "var2", "assign1");
1359   EXPECT_DEVICE_TYPE(g, "var3", "COMPOSITE");
1360   EXPECT_COLOCATED(g, "var3", "assign3");
1361 }
1362 
TEST_P(SoftPlacementPlacerTest,TestColocationGroupWithUnsatisfiableReferenceConnections)1363 TEST_P(SoftPlacementPlacerTest,
1364        TestColocationGroupWithUnsatisfiableReferenceConnections) {
1365   Graph g(OpRegistry::Global());
1366   {  // Scope for temporary variables used to construct g.
1367     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1368     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1369 
1370     Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
1371     Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
1372     // Var 3 is on GPU
1373     Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
1374 
1375     // Two assigns (reference connections) with two different
1376     // colocation groups. Because their colocation groups all map to the
1377     // same device, this is a valid assignment.
1378     ops::BinaryOp(
1379         "TestAssign", var1, input,
1380         b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"}));
1381     ops::BinaryOp(
1382         "TestAssign", var2, input,
1383         b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
1384     // Assign to var3, but try to use a colocation group that matches
1385     // the assign of var2.  This should fail because assign2 must be on CPU
1386     // (it has a reference edge on var2), and assign3 must be on GPU,
1387     // hence the conflict.
1388     ops::BinaryOp(
1389         "TestAssign", var3, input,
1390         b.opts().WithName("assign3").WithAttr("_class", {"loc:@var2"}));
1391     TF_EXPECT_OK(BuildGraph(b, &g));
1392   }
1393 
1394   bool allow_soft_placement = GetParam();
1395   Status s = Place(&g, allow_soft_placement, true);
1396   if (allow_soft_placement) {
1397     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1398   } else {
1399     EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
1400     EXPECT_TRUE(absl::StrContains(
1401         s.error_message(),
1402         "Cannot colocate nodes {{colocation_node assign3}} and "
1403         "{{colocation_node var2}} because no device type supports both of "
1404         "those nodes and the other nodes colocated with them."))
1405         << s.ToString();
1406   }
1407 }
1408 
TEST_F(PlacerTest,TestColocationAndReferenceConnections)1409 TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
1410   Graph g(OpRegistry::Global());
1411   {  // Scope for temporary variables used to construct g.
1412     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1413     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1414     for (int i = 0; i < 10; ++i) {
1415       // Declare ten variable and assignment pairs.
1416       Node* var = ops::SourceOp("TestVariable",
1417                                 b.opts().WithName(strings::StrCat("var_", i)));
1418       ops::BinaryOp("TestAssign", var, input,
1419                     b.opts().WithName(strings::StrCat("assign_", i)));
1420     }
1421     for (int i = 10; i < 100; ++i) {
1422       // Create a variable colocated with some existing variable, and
1423       // an assignment colocated with a possibly-different variable.
1424       Node* var = ops::SourceOp(
1425           "TestVariable",
1426           b.opts()
1427               .WithName(strings::StrCat("var_", i))
1428               .WithAttr("_class", {strings::StrCat("loc:@var_", i % 6)}));
1429       ops::BinaryOp(
1430           "TestAssign", var, input,
1431           b.opts()
1432               .WithName(strings::StrCat("assign_", i))
1433               .WithAttr("_class", {strings::StrCat("loc:@assign_", i % 3)}));
1434     }
1435     TF_EXPECT_OK(BuildGraph(b, &g));
1436   }
1437 
1438   TF_EXPECT_OK(Place(&g));
1439   for (int i = 0; i < 10; ++i) {
1440     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1441                      strings::StrCat("assign_", i));
1442   }
1443   for (int i = 10; i < 100; ++i) {
1444     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1445                      strings::StrCat("assign_", i));
1446     EXPECT_COLOCATED(g, strings::StrCat("var_", i),
1447                      strings::StrCat("var_", i % 6));
1448     EXPECT_COLOCATED(g, strings::StrCat("assign_", i),
1449                      strings::StrCat("assign_", i % 3));
1450   }
1451 }
1452 
1453 // Test that placement fails when no devices are registered.
TEST_F(PlacerTest,TestEmptyDeviceSet)1454 TEST_F(PlacerTest, TestEmptyDeviceSet) {
1455   Graph g(OpRegistry::Global());
1456   {  // Scope for temporary variables used to construct g.
1457     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1458     ops::SourceOp("TestInput", b.opts().WithName("in"));
1459     TF_EXPECT_OK(BuildGraph(b, &g));
1460   }
1461 
1462   DeviceSet empty;
1463 
1464   Status s = Place(&g, &empty);
1465   EXPECT_TRUE(
1466       absl::StrContains(s.error_message(), "No devices are registered"));
1467 }
1468 
1469 // Test that placement fails when the requested device forces an
1470 // indirect constraint to be violated.
TEST_F(PlacerTest,TestHeterogeneousDeviceSetFailure)1471 TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) {
1472   Graph g(OpRegistry::Global());
1473   {  // Scope for temporary variables used to construct g.
1474     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1475     Node* in = ops::SourceOp("TestInput", b.opts().WithName("in"));
1476     Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1477     ops::BinaryOp("TestAssign", var, in,
1478                   b.opts().WithName("assign").WithDevice("/job:b/task:1"));
1479     TF_EXPECT_OK(BuildGraph(b, &g));
1480   }
1481 
1482   DeviceSet heterogeneous;
1483   std::unique_ptr<Device> gpu(
1484       FakeDevice::MakeGPU("/job:b/replica:0/task:0/device:FakeGPU:0"));
1485   heterogeneous.AddDevice(gpu.get());
1486   std::unique_ptr<Device> cpu(
1487       FakeDevice::MakeCPU("/job:b/replica:0/task:1/device:FakeCPU:0"));
1488   heterogeneous.AddDevice(cpu.get());
1489   Status s = Place(&g, &heterogeneous);
1490   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1491   EXPECT_TRUE(absl::StrContains(s.error_message(),
1492                                 "colocated with a group of nodes that required "
1493                                 "incompatible device"));
1494 
1495   // The error message should contain information that indicates which
1496   // op types have which registered device types.
1497   EXPECT_TRUE(absl::StrContains(s.error_message(), "VariableGPU: FakeGPU"))
1498       << s;
1499   EXPECT_TRUE(
1500       absl::StrContains(s.error_message(), "TestAssign: FakeGPU FakeCPU"))
1501       << s;
1502 }
1503 
1504 // Test that placement fails when an unknown device is requested.
TEST_F(PlacerTest,TestUnknownDevice)1505 TEST_F(PlacerTest, TestUnknownDevice) {
1506   Graph g(OpRegistry::Global());
1507   {  // Scope for temporary variables used to construct g.
1508     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1509     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
1510     TF_EXPECT_OK(BuildGraph(b, &g));
1511   }
1512 
1513   Status s = Place(&g);
1514   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1515   EXPECT_TRUE(absl::StrContains(s.error_message(), "/job:foo"));
1516 }
1517 
1518 // Test that placement fails when the combination of partial
1519 // constraints leads to an unknown device.
TEST_F(PlacerTest,TestUnknownMergedDevice)1520 TEST_F(PlacerTest, TestUnknownMergedDevice) {
1521   Graph g(OpRegistry::Global());
1522   {  // Scope for temporary variables used to construct g.
1523     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1524     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
1525     TF_EXPECT_OK(BuildGraph(b, &g));
1526   }
1527 
1528   Status s = Place(&g);
1529   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1530   EXPECT_TRUE(absl::StrContains(s.error_message(), "/job:foo"));
1531 }
1532 
1533 // Test that placement fails when the previously-assigned device for a
1534 // node is unknown.
TEST_F(PlacerTest,TestUnknownAssignedDevice)1535 TEST_F(PlacerTest, TestUnknownAssignedDevice) {
1536   Graph g(OpRegistry::Global());
1537   {  // Scope for temporary variables used to construct g.
1538     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1539     ops::SourceOp("TestInput", b.opts().WithName("in"));
1540     TF_EXPECT_OK(BuildGraph(b, &g));
1541   }
1542 
1543   GetNodeByName(g, "in")->set_assigned_device_name("/job:foo");
1544 
1545   Status s = Place(&g);
1546   EXPECT_EQ(error::INTERNAL, s.code());
1547   EXPECT_TRUE(absl::StrContains(
1548       s.error_message(),
1549       "Assigned device '/job:foo' does not match any device"));
1550 }
1551 
1552 // Test that placement fails when an op with no registered kernels is
1553 // requested and no device is requested for the node
TEST_F(PlacerTest,TestNoKernelsRegisteredWithNoRequestedDevice)1554 TEST_F(PlacerTest, TestNoKernelsRegisteredWithNoRequestedDevice) {
1555   Graph g(OpRegistry::Global());
1556   {  // Scope for temporary variables used to construct g.
1557     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1558     ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
1559     TF_EXPECT_OK(BuildGraph(b, &g));
1560   }
1561 
1562   Status s = Place(&g);
1563   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1564   EXPECT_TRUE(absl::StrContains(s.error_message(),
1565                                 "No OpKernel was registered to support Op "
1566                                 "'VariableNoKernels' used by {{node var}}"));
1567   EXPECT_TRUE(absl::StrContains(s.error_message(), "<no registered kernels>"));
1568 }
1569 
1570 // Test that placement fails when an op does not have registered kernel
1571 // and the requested device has the same (job, replica, task) as the placer's
1572 // local device
TEST_F(PlacerTest,TestNoKernelsRegisteredWithRequestedDeviceLocal)1573 TEST_F(PlacerTest, TestNoKernelsRegisteredWithRequestedDeviceLocal) {
1574   const string cpu_device = "/job:b/replica:0/task:0/device:FakeCPU:0";
1575   const string gpu_device = "/job:b/replica:0/task:0/device:FakeGPU:0";
1576 
1577   Graph g(OpRegistry::Global());
1578   {  // Scope for temporary variables used to construct g.
1579     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1580     ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
1581     TF_EXPECT_OK(BuildGraph(b, &g));
1582   }
1583   GetNodeByName(g, "var")->set_requested_device(gpu_device);
1584 
1585   DeviceSet devices;
1586   std::unique_ptr<Device> gpu(FakeDevice::MakeGPU(gpu_device));
1587   devices.AddDevice(gpu.get());
1588   std::unique_ptr<Device> cpu(FakeDevice::MakeCPU(cpu_device));
1589   devices.AddDevice(cpu.get());
1590   Status s = Place(&g, &devices, cpu.get(), false, false);
1591   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1592   EXPECT_TRUE(absl::StrContains(s.error_message(),
1593                                 "No OpKernel was registered to support Op "
1594                                 "'VariableNoKernels' used by {{node var}}"));
1595   EXPECT_TRUE(absl::StrContains(s.error_message(), "<no registered kernels>"));
1596 }
1597 
1598 // Test that placement succeeds when an op does not have registered kernel
1599 // and the requested device has different (job, replica, task) than the placer's
1600 // local device
TEST_F(PlacerTest,TestNoKernelsRegisteredWithRequestedDeviceRemote)1601 TEST_F(PlacerTest, TestNoKernelsRegisteredWithRequestedDeviceRemote) {
1602   const string local_device = "/job:b/replica:0/task:0/device:FakeCPU:0";
1603   const string remote_device = "/job:b/replica:0/task:1/device:FakeGPU:0";
1604 
1605   Graph g(OpRegistry::Global());
1606   {  // Scope for temporary variables used to construct g.
1607     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1608     ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
1609     TF_EXPECT_OK(BuildGraph(b, &g));
1610   }
1611   GetNodeByName(g, "var")->set_requested_device(remote_device);
1612 
1613   DeviceSet heterogeneous;
1614   std::unique_ptr<Device> gpu(FakeDevice::MakeGPU(remote_device));
1615   heterogeneous.AddDevice(gpu.get());
1616   std::unique_ptr<Device> cpu(FakeDevice::MakeCPU(local_device));
1617   heterogeneous.AddDevice(cpu.get());
1618   TF_EXPECT_OK(Place(&g, &heterogeneous, cpu.get(), false, false));
1619   EXPECT_DEVICE_CONTAINS(g, "var", remote_device);
1620 }
1621 
1622 // Test that placement fails when a kernel is registered but no known
1623 // device supports it.
TEST_F(PlacerTest,TestNoDevicesRegistered)1624 TEST_F(PlacerTest, TestNoDevicesRegistered) {
1625   Graph g(OpRegistry::Global());
1626   {  // Scope for temporary variables used to construct g.
1627     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1628     ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1629     TF_EXPECT_OK(BuildGraph(b, &g));
1630   }
1631 
1632   DeviceSet cpu_only;
1633   std::unique_ptr<Device> cpu(
1634       FakeDevice::MakeCPU("/job:a/replica:0/task:0/device:FakeCPU:0"));
1635   cpu_only.AddDevice(cpu.get());
1636 
1637   Status s = Place(&g, &cpu_only);
1638   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1639   EXPECT_TRUE(absl::StrContains(s.error_message(),
1640                                 "No OpKernel was registered to support Op "
1641                                 "'VariableGPU' used by {{node var}}"));
1642   EXPECT_TRUE(absl::StrContains(s.error_message(), "device='FakeGPU'"));
1643 }
1644 
1645 // Test that placement fails when a requested device is malformed.
TEST_F(PlacerTest,TestMalformedDeviceSpecification)1646 TEST_F(PlacerTest, TestMalformedDeviceSpecification) {
1647   Graph g(OpRegistry::Global());
1648   {  // Scope for temporary variables used to construct g.
1649     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1650     ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar"));
1651     TF_EXPECT_OK(BuildGraph(b, &g));
1652   }
1653 
1654   Status s = Place(&g);
1655   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1656   EXPECT_TRUE(absl::StrContains(s.error_message(),
1657                                 "Malformed device specification '/foo:bar'"));
1658 }
1659 
1660 // Test that placement fails when a previously-assigned device is malformed.
TEST_F(PlacerTest,TestMalformedAssignedDevice)1661 TEST_F(PlacerTest, TestMalformedAssignedDevice) {
1662   Graph g(OpRegistry::Global());
1663   {  // Scope for temporary variables used to construct g.
1664     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1665     ops::SourceOp("TestInput", b.opts().WithName("in"));
1666     TF_EXPECT_OK(BuildGraph(b, &g));
1667   }
1668 
1669   GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar");
1670 
1671   Status s = Place(&g);
1672   EXPECT_EQ(error::INTERNAL, s.code());
1673   EXPECT_TRUE(absl::StrContains(s.error_message(),
1674                                 "Malformed assigned device '/foo:bar'"));
1675 }
1676 
1677 // Test that placement fails when a device was previously assigned to
1678 // a node, but it does not uniquely identify a particular device.
TEST_F(PlacerTest,TestNonUniqueAssignedDevice)1679 TEST_F(PlacerTest, TestNonUniqueAssignedDevice) {
1680   Graph g(OpRegistry::Global());
1681   {  // Scope for temporary variables used to construct g.
1682     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1683     ops::SourceOp("TestInput", b.opts().WithName("in"));
1684     TF_EXPECT_OK(BuildGraph(b, &g));
1685   }
1686 
1687   GetNodeByName(g, "in")->set_assigned_device_name("/job:a");
1688 
1689   Status s = Place(&g);
1690   EXPECT_EQ(error::INTERNAL, s.code());
1691   EXPECT_TRUE(absl::StrContains(
1692       s.error_message(), "Assigned device '/job:a' does not match any device"));
1693 }
1694 
1695 // Test that ops request to be placed on non-existent devices will be relocated
1696 // to existing device of the same type if allow_soft_placement is set.
TEST_F(PlacerTest,TestNonexistentGpuAllowSoftPlacement)1697 TEST_F(PlacerTest, TestNonexistentGpuAllowSoftPlacement) {
1698   Graph g(OpRegistry::Global());
1699   {  // Scope for temporary variables used to construct g.
1700     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1701     ops::SourceOp("TestDevice",
1702                   b.opts().WithName("in").WithDevice("/device:FakeGPU:11"));
1703     TF_EXPECT_OK(BuildGraph(b, &g));
1704   }
1705 
1706   TF_EXPECT_OK(Place(&g, true, false));
1707   EXPECT_DEVICE_CONTAINS(g, "in", "/device:FakeGPU:0");
1708 }
1709 
1710 // Test that ops request to be placed on non-existent devices will fail if
1711 // allow_soft_placement is not set.
TEST_F(PlacerTest,TestNonexistentGpuNoAllowSoftPlacement)1712 TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
1713   Graph g(OpRegistry::Global());
1714   {  // Scope for temporary variables used to construct g.
1715     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1716     ops::SourceOp("TestDevice",
1717                   b.opts().WithName("in").WithDevice("/device:FakeGPU:11"));
1718     TF_EXPECT_OK(BuildGraph(b, &g));
1719   }
1720 
1721   Status s = Place(&g, false, false);
1722   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1723   EXPECT_TRUE(absl::StrContains(s.error_message(), "/device:FakeGPU:11"));
1724 }
1725 
1726 // Test that the "Cannot assign a device" error message contains a format tag
1727 // when requested.
TEST_F(PlacerTest,TestNonexistentGpuNoAllowSoftPlacementFormatTag)1728 TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
1729   Graph g(OpRegistry::Global());
1730   {  // Scope for temporary variables used to construct g.
1731     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1732     ops::SourceOp("TestDevice",
1733                   b.opts().WithName("in").WithDevice("/device:FakeGPU:11"));
1734     TF_EXPECT_OK(BuildGraph(b, &g));
1735   }
1736 
1737   Status s = Place(&g, false, false);
1738   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1739   LOG(WARNING) << s.error_message();
1740   EXPECT_TRUE(absl::StrContains(s.error_message(),
1741                                 "Cannot assign a device for operation in"));
1742   EXPECT_TRUE(absl::StrContains(s.error_message(), "{{node in}}"));
1743 }
1744 
1745 // Test that placement fails when a node requests an explicit device that is not
1746 // supported by the registered kernels if allow_soft_placement is no set.
TEST_F(PlacerTest,TestUnsupportedDeviceNoAllowSoftPlacement)1747 TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
1748   Graph g(OpRegistry::Global());
1749   {  // Scope for temporary variables used to construct g.
1750     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1751     ops::SourceOp("VariableGPU",
1752                   b.opts().WithName("var").WithDevice("/device:FakeCPU:0"));
1753     TF_EXPECT_OK(BuildGraph(b, &g));
1754   }
1755 
1756   Status s = Place(&g, false, false);
1757   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
1758   EXPECT_TRUE(absl::StrContains(s.error_message(), "/device:FakeCPU:0"))
1759       << s.ToString();
1760   EXPECT_TRUE(
1761       absl::StrContains(s.error_message(),
1762                         "no supported kernel for FakeCPU devices is available"))
1763       << s.ToString();
1764 }
1765 
1766 // Test that placement fails when a node requests an explicit device that is not
1767 // supported by the registered kernels if allow_soft_placement is no set.
TEST_F(PlacerTest,TestNonExistentDevice)1768 TEST_F(PlacerTest, TestNonExistentDevice) {
1769   Graph g(OpRegistry::Global());
1770   {  // Scope for temporary variables used to construct g.
1771     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1772     ops::SourceOp("VariableGPU",
1773                   b.opts().WithName("var").WithDevice("/job:foo/replica:17"));
1774     TF_EXPECT_OK(BuildGraph(b, &g));
1775   }
1776 
1777   Status s = Place(&g, false, false);
1778   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1779   LOG(WARNING) << s.error_message();
1780   EXPECT_TRUE(absl::StrContains(
1781       s.error_message(), "was explicitly assigned to /job:foo/replica:17"));
1782   EXPECT_TRUE(absl::StrContains(s.error_message(), "but available devices"));
1783 }
1784 
1785 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
1786 // Test that we inform the user if they appear to be explicitly placing nodes
1787 // on a GPU when CUDA is not available
TEST_F(PlacerTest,TestUseGpuWithNoCuda)1788 TEST_F(PlacerTest, TestUseGpuWithNoCuda) {
1789   Graph g(OpRegistry::Global());
1790   {  // Scope for temporary variables used to construct g.
1791     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1792     ops::SourceOp("VariableGPU",
1793                   b.opts().WithName("var").WithDevice("/device:gpu:0"));
1794     TF_EXPECT_OK(BuildGraph(b, &g));
1795   }
1796 
1797   Status s = Place(&g, false, false);
1798   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1799   LOG(WARNING) << s.error_message();
1800   EXPECT_TRUE(absl::StrContains(
1801       s.error_message(),
1802       "The requested device appears to be a GPU, but CUDA is not enabled."));
1803 }
1804 #endif
1805 
TEST_F(PlacerTest,TestUnsupportedDeviceAllowSoftPlacement)1806 TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
1807   Graph g(OpRegistry::Global());
1808   {  // Scope for temporary variables used to construct g.
1809     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1810     ops::SourceOp("TestInput",  // has only CPU kernel
1811                   b.opts().WithName("a").WithDevice("/device:FakeGPU:0"));
1812     TF_EXPECT_OK(BuildGraph(b, &g));
1813   }
1814 
1815   TF_EXPECT_OK(Place(&g, true, false));
1816 }
1817 
1818 // Test that a graph with device type and reference constraints on
1819 // some of the ops will successfully assign nodes to the constrained
1820 // device, and colocate nodes with reference connections.
TEST_F(PlacerTest,TestDeviceTypeConstraintsAllowSoftPlacement)1821 TEST_F(PlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
1822   Graph g(OpRegistry::Global());
1823   {  // Scope for temporary variables used to construct g.
1824     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1825     // var_gpu has ref output and runs on GPU.
1826     // force_gpu takes var_gpu and requested CPU.
1827     // Verify that both are placed on GPU.
1828     Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
1829     ops::UnaryOp(
1830         "TestDeviceEnforce", var_gpu,
1831         b.opts().WithName("force_gpu").WithDevice("/device:FakeCPU:0"));
1832     // var_cpu has ref output and runs on CPU.
1833     // force_cpu takes var_cpu and requested GPU.
1834     // Verify that both are placed on CPU.
1835     Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
1836     ops::UnaryOp(
1837         "TestDeviceEnforce", var_cpu,
1838         b.opts().WithName("force_cpu").WithDevice("/device:FakeGPU:0"));
1839     TF_EXPECT_OK(BuildGraph(b, &g));
1840   }
1841 
1842   TF_EXPECT_OK(Place(&g, true, false));
1843   EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU");
1844   EXPECT_DEVICE_TYPE(g, "force_gpu", "FakeGPU");
1845   EXPECT_COLOCATED(g, "var_gpu", "force_gpu");
1846   EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU");
1847   EXPECT_DEVICE_TYPE(g, "force_cpu", "FakeCPU");
1848   EXPECT_COLOCATED(g, "var_cpu", "force_cpu");
1849 }
1850 
1851 // Test that placement fails when two nodes have a reference connection
1852 // constraint, and each node requires a mutually incompatible device.
TEST_F(PlacerTest,TestUnsatisfiableConstraintWithReferenceConnections)1853 TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
1854   Graph g(OpRegistry::Global());
1855   {  // Scope for temporary variables used to construct g.
1856     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1857     Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
1858     Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
1859     ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign"));
1860     TF_EXPECT_OK(BuildGraph(b, &g));
1861   }
1862 
1863   Status s = Place(&g);
1864   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1865   EXPECT_TRUE(absl::StrContains(s.error_message(),
1866                                 "Cannot colocate nodes {{colocation_node "
1867                                 "var}} and {{colocation_node assign}}"));
1868 }
1869 
1870 // Test that a generator node follows its consumers (where there are several
1871 // consumer nodes on the same devices).
TEST_F(PlacerTest,TestGeneratorNodeFollowsConsumerNode)1872 TEST_F(PlacerTest, TestGeneratorNodeFollowsConsumerNode) {
1873   Graph g(OpRegistry::Global());
1874   {  // Scope for temporary variables used to construct g.
1875     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1876 
1877     // A variable is only on CPU
1878     Node* var1_cpu =
1879         ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
1880     Node* var2_cpu =
1881         ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
1882 
1883     // The constant to be assigned can be on both GPU or CPU.
1884     //
1885     // Because of the heuristic, it gets placed on CPU to avoid a
1886     // copy.
1887     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
1888 
1889     // The assigns are bound to CPU by the reference edge.
1890     ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
1891     ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
1892 
1893     TF_EXPECT_OK(BuildGraph(b, &g));
1894   }
1895 
1896   TF_EXPECT_OK(Place(&g));
1897   EXPECT_COLOCATED(g, "var1_cpu", "in");
1898   EXPECT_COLOCATED(g, "assign1", "in");
1899   EXPECT_COLOCATED(g, "var2_cpu", "in");
1900   EXPECT_COLOCATED(g, "assign2", "in");
1901 }
1902 
1903 // Test that a generator node does not follow its consumers (where there are
1904 // several consumers on different devices).
TEST_F(PlacerTest,TestGeneratorNodeDoesntFollowNonColocatedConsumers)1905 TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
1906   Graph g(OpRegistry::Global());
1907   {  // Scope for temporary variables used to construct g.
1908     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
1909 
1910     // A variable is only on CPU
1911     Node* var1_cpu =
1912         ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
1913     Node* var2_cpu =
1914         ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
1915 
1916     // The constant to be assigned can be on both GPU or CPU.
1917     //
1918     // Because of the heuristic, it ought to be on the GPU (cannot be
1919     // co-located with both consumers, so goes to the 'standard' place)
1920     Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
1921 
1922     // The assigns are bound to CPU by the reference edge.
1923     ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
1924     ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
1925 
1926     TF_EXPECT_OK(BuildGraph(b, &g));
1927 
1928     GetNodeByName(g, "var1_cpu")
1929         ->set_assigned_device_name("/job:a/replica:0/task:0/device:FakeCPU:1");
1930 
1931     GetNodeByName(g, "var2_cpu")
1932         ->set_assigned_device_name("/job:a/replica:0/task:0/device:FakeCPU:2");
1933   }
1934 
1935   TF_EXPECT_OK(Place(&g));
1936   EXPECT_COLOCATED(g, "assign1", "var1_cpu");
1937   EXPECT_COLOCATED(g, "assign2", "var2_cpu");
1938   EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
1939 }
1940 
1941 REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeCPU"), DummyOp);
1942 REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeGPU"), DummyOp);
1943 REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeCPU"), DummyOp);
1944 REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeGPU"), DummyOp);
1945 REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeCPU"), DummyOp);
1946 REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeGPU"), DummyOp);
1947 REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeCPU"), DummyOp);
1948 REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeGPU"), DummyOp);
1949 REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeCPU"), DummyOp);
1950 REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeGPU"), DummyOp);
1951 REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeCPU"), DummyOp);
1952 REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeGPU"), DummyOp);
1953 REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device("FakeCPU"), DummyOp);
1954 REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device("FakeGPU"), DummyOp);
1955 
TEST_P(SoftPlacementPlacerTest,RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned)1956 TEST_P(SoftPlacementPlacerTest,
1957        RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned) {
1958   /*
1959    *    a:RES:GPU  b:RES:CPU
1960    *       |         |
1961    *       |         |
1962    *       v         v
1963    *      id1       id2
1964    *     @loc:id2
1965    */
1966   FunctionDef func = test::function::ResourceOutput();
1967   GraphDef graph = GDef(
1968       {
1969           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
1970           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
1971           NDef("id1", "Identity", {"a"},
1972                {{"T", DT_RESOURCE},
1973                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
1974           NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
1975       },
1976       // FunctionLib
1977       {func});
1978 
1979   Graph g(OpRegistry::Global());
1980   TF_ASSERT_OK(BuildGraph(graph, &g));
1981 
1982   bool allow_soft_placement = GetParam();
1983   Status s = Place(&g, allow_soft_placement, true);
1984   if (allow_soft_placement) {
1985     EXPECT_EQ(error::OK, s.code()) << s.ToString();
1986     EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
1987     EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU");
1988     EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
1989     EXPECT_DEVICE_TYPE(g, "id2", "FakeCPU");
1990   } else {
1991     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
1992     EXPECT_TRUE(absl::StrContains(
1993         s.error_message(),
1994         "Cannot colocate nodes {{colocation_node id2}} and {{colocation_node "
1995         "id1}}: Cannot merge devices with incompatible types: "
1996         "'/device:FakeCPU:0' and '/device:FakeGPU:0'"))
1997         << s.ToString();
1998   }
1999 }
2000 
TEST_F(PlacerTest,RequestedDeviceCanBeOverridden)2001 TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
2002   /*
2003    *     a:RES      b:RES
2004    *       |         |
2005    *     id_a:GPU   id_b:CPU
2006    *       |         |
2007    *       v         v
2008    *      id1       id2
2009    *     @loc:id2
2010    */
2011   FunctionDef func = test::function::ResourceOutput();
2012   GraphDef graph = GDef(
2013       {
2014           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2015           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
2016           NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}, kGPU),
2017           NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
2018           NDef("id1", "Identity", {"id_a"},
2019                {{"T", DT_RESOURCE},
2020                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
2021           NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
2022       },
2023       // FunctionLib
2024       {func});
2025 
2026   Graph g(OpRegistry::Global());
2027   TF_ASSERT_OK(BuildGraph(graph, &g));
2028   TF_ASSERT_OK(Place(&g));
2029 
2030   // All should be colocated
2031   EXPECT_COLOCATED(g, "a", "b");
2032   EXPECT_COLOCATED(g, "id_a", "id_b");
2033   EXPECT_COLOCATED(g, "id1", "id2");
2034   EXPECT_COLOCATED(g, "a", "id_a");
2035   EXPECT_COLOCATED(g, "a", "id1");
2036 }
2037 
TEST_F(PlacerTest,AssignedDeviceOfColocatedNodeIsRespected)2038 TEST_F(PlacerTest, AssignedDeviceOfColocatedNodeIsRespected) {
2039   /*
2040    *     a:float (assigned to CPU)
2041    *       |
2042    *       v
2043    *     iter (has only GPU kernel)
2044    */
2045   GraphDef graph = GDef({
2046       NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2047       NDef("iter", "IteratorGPU", {"a"}),
2048   });
2049 
2050   Graph g(OpRegistry::Global());
2051   TF_ASSERT_OK(BuildGraph(graph, &g));
2052   GetNodeByName(g, "a")->set_assigned_device_name(kFullCPU);
2053   Status s = Place(&g);
2054   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2055   EXPECT_TRUE(
2056       absl::StrContains(s.error_message(),
2057                         "{{colocation_node iter}} was colocated with a "
2058                         "group of nodes that required incompatible device "
2059                         "'/job:a/replica:0/task:0/device:FakeCPU:0'"))
2060       << s.ToString();
2061 }
2062 
TEST_P(SoftPlacementPlacerTest,AssignedDevicesAreNotOverriddenDueToResourcesAndColocation)2063 TEST_P(SoftPlacementPlacerTest,
2064        AssignedDevicesAreNotOverriddenDueToResourcesAndColocation) {
2065   /*
2066    *     a:RES      b:RES
2067    *       |         |
2068    *     id_a:GPU   id_b:CPU
2069    *       |         |
2070    *       v         v
2071    *      id1       id2
2072    *     @loc:id2
2073    */
2074   FunctionDef func = test::function::ResourceOutput();
2075   GraphDef graph = GDef(
2076       {
2077           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2078           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
2079           NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}),
2080           NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
2081           NDef("id1", "Identity", {"id_a"},
2082                {{"T", DT_RESOURCE},
2083                 {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
2084           NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
2085       },
2086       // FunctionLib
2087       {func});
2088 
2089   Graph g(OpRegistry::Global());
2090   TF_ASSERT_OK(BuildGraph(graph, &g));
2091   GetNodeByName(g, "id_a")->set_assigned_device_name(kFullGPU);
2092   GetNodeByName(g, "id_b")->set_assigned_device_name(kFullCPU);
2093 
2094   bool allow_soft_placement = GetParam();
2095 
2096   Status s = Place(&g, allow_soft_placement, false);
2097   if (allow_soft_placement) {
2098     EXPECT_EQ(error::OK, s.code()) << s.ToString();
2099     EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
2100     EXPECT_DEVICE_TYPE(g, "id_a", "FakeGPU");
2101     EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU");
2102     EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
2103     EXPECT_DEVICE_TYPE(g, "id_b", "FakeCPU");
2104     EXPECT_DEVICE_TYPE(g, "id2", "FakeCPU");
2105   } else {
2106     EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
2107     EXPECT_TRUE(absl::StrContains(
2108         s.error_message(),
2109         "Cannot colocate nodes {{colocation_node id2}} and {{colocation_node "
2110         "id1}}: Cannot merge devices with incompatible types: "
2111         "'/job:a/replica:0/task:0/device:FakeCPU:0' and "
2112         "'/job:a/replica:0/task:0/device:FakeGPU:0'"))
2113         << s.ToString();
2114   }
2115 }
2116 
2117 // Fixture for tests that place graphs containing function calls.
2118 // Particularly the case where internal functions return resources.
2119 class NestedPlacerTest : public PlacerTest {
2120  public:
2121   // Create one FakeCPU and one FakeGPU. These tests don't need multiple devices
2122   // of the same type.
NestedPlacerTest()2123   NestedPlacerTest() : PlacerTest(1) {}
2124 };
2125 
TEST_F(NestedPlacerTest,OutputOneResource)2126 TEST_F(NestedPlacerTest, OutputOneResource) {
2127   /*
2128    *                a:FLOAT:GPU
2129    *                 |  b:RESOURCE:CPU
2130    *                 |   |
2131    *                 v   v
2132    *                  PCO
2133    *                 |   \
2134    *                 |   v
2135    *                 v   r2:FLOAT
2136    *                 r1:RESOURCE
2137    *
2138    * PartitionedCallOp (PCO) should be placed on GPU even through it
2139    * takes a CPU resource as input. The resource output should be placed
2140    * on CPU since it is the same resource as the input one.
2141    */
2142   FunctionDef func = test::function::ResourceOutput();
2143   GraphDef graph = GDef(
2144       {
2145           NDef("a", "_Arg", {}, {{"T", DT_FLOAT}}, kGPU),
2146           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2147           NDef("y", "PartitionedCall", {"a", "b"},
2148                {{"Tin", DataTypeSlice{DT_FLOAT, DT_RESOURCE}},
2149                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2150                 {"f", FDH::FunctionRef("ResourceOutput", {})}}),
2151           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2152           NDef("r2", "Identity", {"y:1"}, {{"T", DT_FLOAT}}),
2153       },
2154       // FunctionLib
2155       {func});
2156 
2157   Graph g(OpRegistry::Global());
2158   TF_ASSERT_OK(BuildGraph(graph, &g));
2159   TF_ASSERT_OK(CallOptPassesAndPlace(&g));
2160 
2161   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2162   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2163   EXPECT_DEVICE_TYPE(g, "r2", "FakeGPU");
2164 }
2165 
TEST_F(NestedPlacerTest,OutputOneResource_ExtraIdentities)2166 TEST_F(NestedPlacerTest, OutputOneResource_ExtraIdentities) {
2167   /*
2168    *                a:FLOAT
2169    *                 |  b:RESOURCE
2170    *                 |   |
2171    *              ai:GPU |
2172    *                 |  bi:CPU
2173    *                 |   |
2174    *                 v   v
2175    *                  PCO
2176    *                 |   \
2177    *                 |   v
2178    *                 v   r2:FLOAT
2179    *                 r1:RESOURCE
2180    *
2181    * Same as above except that devices are requested on identities, not on
2182    * resource generating ops.
2183    */
2184   FunctionDef func = test::function::ResourceOutput();
2185   GraphDef graph = GDef(
2186       {
2187           NDef("a", "_Arg", {}, {{"T", DT_FLOAT}}, kGPU),
2188           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2189           NDef("ai", "Identity", {"a"}, {{"T", DT_FLOAT}}),
2190           NDef("bi", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
2191           NDef("y", "PartitionedCall", {"ai", "bi"},
2192                {{"Tin", DataTypeSlice{DT_FLOAT, DT_RESOURCE}},
2193                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2194                 {"f", FDH::FunctionRef("ResourceOutput", {})}}),
2195           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2196           NDef("r2", "Identity", {"y:1"}, {{"T", DT_FLOAT}}),
2197       },
2198       // FunctionLib
2199       {func});
2200 
2201   Graph g(OpRegistry::Global());
2202   TF_ASSERT_OK(BuildGraph(graph, &g));
2203   TF_ASSERT_OK(CallOptPassesAndPlace(&g));
2204 
2205   EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
2206   EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
2207   EXPECT_DEVICE_TYPE(g, "ai", "FakeGPU");
2208   EXPECT_DEVICE_TYPE(g, "bi", "FakeCPU");
2209   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2210   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2211   EXPECT_DEVICE_TYPE(g, "r2", "FakeGPU");
2212 }
2213 
TEST_F(NestedPlacerTest,OutputOneResource_OverrideOutputResourceDevice)2214 TEST_F(NestedPlacerTest, OutputOneResource_OverrideOutputResourceDevice) {
2215   /*
2216    *                a:FLOAT:GPU
2217    *                 |  b:RESOURCE:CPU
2218    *                 |   |
2219    *                 v   v
2220    *                  PCO
2221    *                 |   \
2222    *                 |   v
2223    *                 v   r2:FLOAT
2224    *                 r1:RESOURCE:GPU
2225    *
2226    * Same as above except r1 is wrongly assigned on GPU. Placer will override
2227    * this device assignment.
2228    */
2229   FunctionDef func = test::function::ResourceOutput();
2230   GraphDef graph = GDef(
2231       {
2232           NDef("a", "_Arg", {}, {{"T", DT_FLOAT}}, kGPU),
2233           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2234           NDef("y", "PartitionedCall", {"a", "b"},
2235                {{"Tin", DataTypeSlice{DT_FLOAT, DT_RESOURCE}},
2236                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2237                 {"f", FDH::FunctionRef("ResourceOutput", {})}}),
2238           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}, kGPU),
2239           NDef("r2", "Identity", {"y:1"}, {{"T", DT_FLOAT}}),
2240       },
2241       // FunctionLib
2242       {func});
2243 
2244   Graph g(OpRegistry::Global());
2245   TF_ASSERT_OK(BuildGraph(graph, &g));
2246   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2247 
2248   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2249   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2250   EXPECT_DEVICE_TYPE(g, "r2", "FakeGPU");
2251 }
2252 
TEST_F(NestedPlacerTest,OutputTwoResources)2253 TEST_F(NestedPlacerTest, OutputTwoResources) {
2254   /*
2255    *                a:RESOURCE:CPU
2256    *                 |  b:RESOURCE:GPU
2257    *                 |   |
2258    *                 v   v
2259    *                  PCO (simple swap)
2260    *                 |   \
2261    *                 |   v
2262    *                 v   r2:RESOURCE
2263    *                 r1:RESOURCE
2264    *
2265    * Ops consuming output resources should be placed on correct devices.
2266    */
2267   FunctionDef func = test::function::Swap();
2268   GraphDef graph = GDef(
2269       {
2270           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2271           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2272           NDef("y", "PartitionedCall", {"a", "b"},
2273                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2274                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2275                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}),
2276           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2277           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}),
2278       },
2279       // FunctionLib
2280       {func});
2281 
2282   Graph g(OpRegistry::Global());
2283   TF_EXPECT_OK(BuildGraph(graph, &g));
2284   TF_EXPECT_OK(CallOptPassesAndPlace(&g));
2285 
2286   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2287   EXPECT_DEVICE_TYPE(g, "r1", "FakeGPU");
2288   EXPECT_DEVICE_TYPE(g, "r2", "FakeCPU");
2289 }
2290 
TEST_F(NestedPlacerTest,OutputTwoResources_PCOOnCPU)2291 TEST_F(NestedPlacerTest, OutputTwoResources_PCOOnCPU) {
2292   /*
2293    *                a:RESOURCE:CPU
2294    *                 |  b:RESOURCE:GPU
2295    *                 |   |
2296    *                 v   v
2297    *                  PCO:CPU (simple swap)
2298    *                 |   \
2299    *                 |   v
2300    *                 v   r2:RESOURCE
2301    *                 r1:RESOURCE
2302    *
2303    * Ops consuming output resources should be placed on correct devices, even
2304    * when PCO is explicitly placed.
2305    */
2306   FunctionDef func = test::function::Swap();
2307   GraphDef graph = GDef(
2308       {
2309           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2310           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2311           NDef("y", "PartitionedCall", {"a", "b"},
2312                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2313                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2314                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2315                kCPU),
2316           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2317           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}),
2318       },
2319       // FunctionLib
2320       {func});
2321 
2322   Graph g(OpRegistry::Global());
2323   TF_EXPECT_OK(BuildGraph(graph, &g));
2324   TF_EXPECT_OK(CallOptPassesAndPlace(&g));
2325 
2326   EXPECT_DEVICE_TYPE(g, "y", "FakeCPU");
2327   EXPECT_DEVICE_TYPE(g, "r1", "FakeGPU");
2328   EXPECT_DEVICE_TYPE(g, "r2", "FakeCPU");
2329 }
2330 
TEST_F(NestedPlacerTest,OutputTwoResources_UnassignedResource)2331 TEST_F(NestedPlacerTest, OutputTwoResources_UnassignedResource) {
2332   /*
2333    *                a:RESOURCE
2334    *                 |  b:RESOURCE:GPU
2335    *                 |   |
2336    *                 v   v
2337    *                  PCO:CPU (simple swap)
2338    *                 |   \
2339    *                 |   v
2340    *                 v   r2:RESOURCE
2341    *                 r1:RESOURCE
2342    *
2343    * Resource input `a` is not explicitly assigned. Placer leaves `a` and `b` to
2344    * the "second pass" as they are "sources". It assigns `r1` to GPU because it
2345    * is in the same group as `b`. It assigns `r2` to GPU because GPU has a
2346    * higher device preference. Finally, `a` is assigned to GPU because `r2` is
2347    * on GPU - this test that the "second pass" heuristics respect colocation
2348    * groups (even when the consumer of the source, i.e. PCO is on a different
2349    * device).
2350    */
2351   FunctionDef func = test::function::Swap();
2352   GraphDef graph = GDef(
2353       {
2354           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2355           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2356           NDef("y", "PartitionedCall", {"a", "b"},
2357                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2358                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2359                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2360                kCPU),
2361           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2362           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}),
2363       },
2364       // FunctionLib
2365       {func});
2366 
2367   Graph g(OpRegistry::Global());
2368   TF_EXPECT_OK(BuildGraph(graph, &g));
2369   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2370 
2371   EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
2372   EXPECT_DEVICE_TYPE(g, "b", "FakeGPU");
2373   EXPECT_DEVICE_TYPE(g, "y", "FakeCPU");
2374   EXPECT_DEVICE_TYPE(g, "r1", "FakeGPU");
2375   EXPECT_DEVICE_TYPE(g, "r2", "FakeGPU");
2376 }
2377 
TEST_F(NestedPlacerTest,OutputTwoResources_UnassignedResource_CPU)2378 TEST_F(NestedPlacerTest, OutputTwoResources_UnassignedResource_CPU) {
2379   /*
2380    *                a:RESOURCE
2381    *                 |  b:RESOURCE:CPU
2382    *                 |   |
2383    *                 v   v
2384    *                  PCO:CPU (simple swap)
2385    *                 |   \
2386    *                 |   v
2387    *                 v   r2:RESOURCE
2388    *                 r1:RESOURCE
2389    *
2390    * Same as above except `b` is on CPU.
2391    */
2392   FunctionDef func = test::function::Swap();
2393   GraphDef graph = GDef(
2394       {
2395           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2396           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2397           NDef("y", "PartitionedCall", {"a", "b"},
2398                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2399                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2400                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2401                kCPU),
2402           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2403           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}),
2404       },
2405       // FunctionLib
2406       {func});
2407 
2408   Graph g(OpRegistry::Global());
2409   TF_EXPECT_OK(BuildGraph(graph, &g));
2410   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2411 
2412   EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
2413   EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
2414   EXPECT_DEVICE_TYPE(g, "y", "FakeCPU");
2415   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2416   EXPECT_DEVICE_TYPE(g, "r2", "FakeGPU");
2417 }
2418 
TEST_F(NestedPlacerTest,OutputResourceConsumedByMultipleOps)2419 TEST_F(NestedPlacerTest, OutputResourceConsumedByMultipleOps) {
2420   /*
2421    *                a:RESOURCE
2422    *                 |  b:RESOURCE:CPU
2423    *                 |   |
2424    *                 v   v
2425    *                  PCO:CPU (simple swap)
2426    *                 |   \
2427    *                 |   v
2428    *                 |  r3:RESOURCE:GPU
2429    *                 |
2430    *              ---+---
2431    *             |       |
2432    *             |   r2:RESOURCE
2433    *         r1:RESOURCE
2434    */
2435   FunctionDef func = test::function::Swap();
2436   GraphDef graph = GDef(
2437       {
2438           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2439           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2440           NDef("y", "PartitionedCall", {"a", "b"},
2441                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2442                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2443                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}),
2444           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2445           NDef("r2", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2446           NDef("r3", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}, kGPU),
2447       },
2448       // FunctionLib
2449       {func});
2450 
2451   Graph g(OpRegistry::Global());
2452   TF_EXPECT_OK(BuildGraph(graph, &g));
2453   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2454 
2455   EXPECT_DEVICE_TYPE(g, "a", "FakeGPU");
2456   EXPECT_DEVICE_TYPE(g, "b", "FakeCPU");
2457   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2458   EXPECT_DEVICE_TYPE(g, "r2", "FakeCPU");
2459   EXPECT_DEVICE_TYPE(g, "r3", "FakeGPU");
2460 }
2461 
TEST_F(NestedPlacerTest,DuplicateInputResource)2462 TEST_F(NestedPlacerTest, DuplicateInputResource) {
2463   /*
2464    *                a:RESOURCE
2465    *                  / \
2466    *                 |   |
2467    *                 v   v
2468    *                  PCO:GPU (simple swap)
2469    *                 |   \
2470    *                 |   v
2471    *                 v   r2:RESOURCE:CPU
2472    *                 r1:RESOURCE
2473    */
2474   FunctionDef func = test::function::Swap();
2475   GraphDef graph = GDef(
2476       {
2477           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2478           NDef("y", "PartitionedCall", {"a", "a"},
2479                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2480                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2481                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2482                kGPU),
2483           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2484           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}, kCPU),
2485       },
2486       // FunctionLib
2487       {func});
2488 
2489   Graph g(OpRegistry::Global());
2490   TF_EXPECT_OK(BuildGraph(graph, &g));
2491   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2492 
2493   EXPECT_DEVICE_TYPE(g, "a", "FakeCPU");
2494   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2495   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2496   EXPECT_DEVICE_TYPE(g, "r2", "FakeCPU");
2497 }
2498 
TEST_F(NestedPlacerTest,DuplicateInputs_OutputResourceConsumedByMultipleOps)2499 TEST_F(NestedPlacerTest, DuplicateInputs_OutputResourceConsumedByMultipleOps) {
2500   /*
2501    *                a:RESOURCE
2502    *                  /  \
2503    *                 |   |
2504    *                 v   v
2505    *                  PCO:GPU (simple swap)
2506    *                 |   \
2507    *                 |   v
2508    *                 |  r3:RESOURCE
2509    *                 |
2510    *              ---+---
2511    *             |       |
2512    *             |   r2:RESOURCE:CPU
2513    *         r1:RESOURCE
2514    */
2515   FunctionDef func = test::function::Swap();
2516   GraphDef graph = GDef(
2517       {
2518           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2519           NDef("y", "PartitionedCall", {"a", "a"},
2520                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2521                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2522                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2523                kGPU),
2524           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2525           NDef("r2", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}, kCPU),
2526           NDef("r3", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}),
2527       },
2528       // FunctionLib
2529       {func});
2530 
2531   Graph g(OpRegistry::Global());
2532   TF_EXPECT_OK(BuildGraph(graph, &g));
2533   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2534 
2535   EXPECT_DEVICE_TYPE(g, "a", "FakeCPU");
2536   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2537   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2538   EXPECT_DEVICE_TYPE(g, "r2", "FakeCPU");
2539   EXPECT_DEVICE_TYPE(g, "r3", "FakeCPU");
2540 }
2541 
TEST_F(NestedPlacerTest,DuplicateInputResource_Conflict)2542 TEST_F(NestedPlacerTest, DuplicateInputResource_Conflict) {
2543   /*
2544    *                a:RESOURCE
2545    *                  / \
2546    *                 |   |
2547    *                 v   v
2548    *                  PCO:GPU (simple swap)
2549    *                 |   \
2550    *                 |   v
2551    *                 v   r2:RESOURCE:CPU
2552    *                 r1:RESOURCE:GPU
2553    *
2554    * There is a conflict but Placer always overrides requested devices
2555    * when they result in conflict due to resource edges. Which device
2556    * is picked for a/r1/r2 is indeterministic.
2557    */
2558   FunctionDef func = test::function::Swap();
2559   GraphDef graph = GDef(
2560       {
2561           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
2562           NDef("y", "PartitionedCall", {"a", "a"},
2563                {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2564                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}},
2565                 {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}},
2566                kGPU),
2567           NDef("r1", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}, kGPU),
2568           NDef("r2", "Identity", {"y:1"}, {{"T", DT_RESOURCE}}, kCPU),
2569       },
2570       // FunctionLib
2571       {func});
2572 
2573   Graph g(OpRegistry::Global());
2574   TF_EXPECT_OK(BuildGraph(graph, &g));
2575   TF_ASSERT_OK(CallOptPassesAndPlace(&g, false, true));
2576 
2577   EXPECT_SAME_TYPE(g, "a", "r1");
2578   EXPECT_SAME_TYPE(g, "a", "r2");
2579 }
2580 
TEST_F(NestedPlacerTest,TestDstDeviceIsIgnoredWhenConstrainedByResourceEdge)2581 TEST_F(NestedPlacerTest, TestDstDeviceIsIgnoredWhenConstrainedByResourceEdge) {
2582   /*
2583    *                a:RESOURCE:CPU
2584    *                   |
2585    *                   |
2586    *                   v
2587    *                  PCO (identity)
2588    *                   |
2589    *                   |
2590    *                   v
2591    *                r1:RESOURCE:GPU
2592    *
2593    * r1'th device will be overridden.
2594    */
2595   FunctionDef func = test::function::ResourceIdentity();
2596   GraphDef graph = GDef(
2597       {
2598           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2599           NDef("y", "PartitionedCall", {"a"},
2600                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2601                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2602                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2603           NDef("r1", "_Retval", {"y:0"}, {{"T", DT_RESOURCE}},
2604                kGPU  // This device specification will be overridden
2605                ),
2606       },
2607       // FunctionLib
2608       {func});
2609 
2610   Graph g(OpRegistry::Global());
2611   TF_EXPECT_OK(BuildGraph(graph, &g));
2612   TF_EXPECT_OK(CallOptPassesAndPlace(&g));
2613 
2614   EXPECT_DEVICE_TYPE(g, "a", "FakeCPU");
2615   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2616 }
2617 
TEST_F(NestedPlacerTest,TestDstDeviceIsIgnoredWhenConstrainedByResourceEdge_EvenWhenPCOIsPlaced)2618 TEST_F(
2619     NestedPlacerTest,
2620     TestDstDeviceIsIgnoredWhenConstrainedByResourceEdge_EvenWhenPCOIsPlaced) {
2621   /*
2622    *                a:RESOURCE:CPU
2623    *                   |
2624    *                   |
2625    *                   v
2626    *                  PCO:GPU (identity)
2627    *                   |
2628    *                   |
2629    *                   v
2630    *                r1:RESOURCE:GPU
2631    *
2632    * r1'th device will be overridden.
2633    */
2634   FunctionDef func = test::function::ResourceIdentity();
2635   GraphDef graph = GDef(
2636       {
2637           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2638           NDef("y", "PartitionedCall", {"a"},
2639                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2640                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2641                 {"f", FDH::FunctionRef("ResourceIdentity", {})}},
2642                kGPU),
2643           NDef("r1", "_Retval", {"y:0"}, {{"T", DT_RESOURCE}},
2644                kGPU  // This device specification will be overridden
2645                ),
2646       },
2647       // FunctionLib
2648       {func});
2649 
2650   Graph g(OpRegistry::Global());
2651   TF_EXPECT_OK(BuildGraph(graph, &g));
2652   TF_EXPECT_OK(CallOptPassesAndPlace(&g));
2653 
2654   EXPECT_DEVICE_TYPE(g, "r1", "FakeCPU");
2655   EXPECT_DEVICE_TYPE(g, "y", "FakeGPU");
2656 }
2657 
TEST_F(NestedPlacerTest,ResourceConflictInvolvingPCO)2658 TEST_F(NestedPlacerTest, ResourceConflictInvolvingPCO) {
2659   /*
2660    *                a:RESOURCE:CPU
2661    *                   |
2662    *                   |
2663    *                   v
2664    *                  PCO (identity)
2665    *                   |
2666    *                   |   b:RESOURCE:GPU
2667    *                   |    |
2668    *                   v    v
2669    *                Add:RESOURCE
2670    *
2671    * Add op cannot be placed because the requested devices are on
2672    * resource generating ops and they conflict.
2673    */
2674   FunctionDef func = test::function::ResourceIdentity();
2675   GraphDef graph = GDef(
2676       {
2677           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2678           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2679           NDef("y", "PartitionedCall", {"a"},
2680                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2681                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2682                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2683           NDef("add", "Add", {"y:0", "b"}, {{"T", DT_RESOURCE}}),
2684       },
2685       // FunctionLib
2686       {func});
2687 
2688   Graph g(OpRegistry::Global());
2689   TF_EXPECT_OK(BuildGraph(graph, &g));
2690   Status s = CallOptPassesAndPlace(&g);
2691   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2692   EXPECT_TRUE(absl::StrContains(
2693       s.error_message(),
2694       "Cannot place the graph because a reference or resource edge connects "
2695       "colocation groups with incompatible resource devices: /device:FakeCPU:0 "
2696       "vs /device:FakeGPU:0"))
2697       << s.ToString();
2698 }
2699 
TEST_F(NestedPlacerTest,ResourceConflictInvolvingTwoPCOs)2700 TEST_F(NestedPlacerTest, ResourceConflictInvolvingTwoPCOs) {
2701   /*
2702    *            a:RESOURCE:CPU
2703    *               |
2704    *               |          b:RESOURCE:GPU
2705    *               |              |
2706    *               v              |
2707    *            y:PCO (identity)  |
2708    *               |              v
2709    *                \          z:PCO (identity)
2710    *                 \           /
2711    *                  \         /
2712    *                   v       v
2713    *                 Add:RESOURCE
2714    *
2715    * Add op cannot be placed.
2716    */
2717   FunctionDef func = test::function::ResourceIdentity();
2718   GraphDef graph = GDef(
2719       {
2720           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2721           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2722           NDef("y", "PartitionedCall", {"a"},
2723                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2724                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2725                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2726           NDef("z", "PartitionedCall", {"b"},
2727                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2728                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2729                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2730           NDef("add", "Add", {"y:0", "z:0"}, {{"T", DT_RESOURCE}}),
2731       },
2732       // FunctionLib
2733       {func});
2734 
2735   Graph g(OpRegistry::Global());
2736   TF_EXPECT_OK(BuildGraph(graph, &g));
2737 
2738   Status s = CallOptPassesAndPlace(&g);
2739   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2740   EXPECT_TRUE(absl::StrContains(
2741       s.error_message(),
2742       "Cannot place the graph because a reference or resource edge connects "
2743       "colocation groups with incompatible resource devices: /device:FakeCPU:0 "
2744       "vs /device:FakeGPU:0"))
2745       << s.ToString();
2746 }
2747 
2748 // Function that returns a resource that can be produced on CPU only.
CPUResourceOutput()2749 FunctionDef CPUResourceOutput() {
2750   return FDH::Create(
2751       // Name
2752       "CPUResourceOutput",
2753       // Args
2754       {"x: float"},
2755       // Return values
2756       {"ds: resource", "x_out: float"},
2757       // Attr def
2758       {},
2759       // Nodes
2760       {
2761           {{"make_ds"}, "CreateDatasetCPU", {}},
2762       },
2763       {{"ds", "make_ds:o:0"}, {"x_out", "x"}});
2764 }
2765 
TEST_F(NestedPlacerTest,DeepDeviceConstraintsPropagated)2766 TEST_F(NestedPlacerTest, DeepDeviceConstraintsPropagated) {
2767   /*
2768    *            a:FLOAT
2769    *               |
2770    *               v
2771    *          PCO (CPUResourceOutput)
2772    *               |    |
2773    *               |    v
2774    *               |  (ignored)
2775    *               |
2776    *               v
2777    *          id:Identity:GPU (assigned)
2778    *
2779    * The graph cannot be placed because the PCO can produce the resource
2780    * on CPU only.
2781    */
2782   FunctionDef func = CPUResourceOutput();
2783   GraphDef graph = GDef(
2784       {
2785           NDef("a", "_Arg", {}, {{"T", DT_FLOAT}}),
2786           NDef("y", "PartitionedCall", {"a"},
2787                {{"Tin", DataTypeSlice{DT_FLOAT}},
2788                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2789                 {"f", FDH::FunctionRef("CPUResourceOutput", {})}}),
2790           NDef("id", "Identity", {"y:0"}, {{"T", DT_RESOURCE}}),
2791       },
2792       // FunctionLib
2793       {func});
2794 
2795   Graph g(OpRegistry::Global());
2796   TF_EXPECT_OK(BuildGraph(graph, &g));
2797   GetNodeByName(g, "id")->set_assigned_device_name(kFullGPU);
2798 
2799   Status s = CallOptPassesAndPlace(&g);
2800   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2801   // TODO(b/129057603): When better error messages are implemented, this should
2802   // change.
2803   EXPECT_TRUE(absl::StrContains(
2804       s.error_message(), "Could not satisfy explicit device specification"))
2805       << s.ToString();
2806 }
2807 
NestedCPUResourceOutput()2808 FunctionDef NestedCPUResourceOutput() {
2809   return FDH::Create(
2810       // Name
2811       "NestedCPUResourceOutput",
2812       // Args
2813       {"x: float"},
2814       // Return values
2815       {"ds: resource", "x_out: float"},
2816       // Attr def
2817       {},
2818       // Nodes
2819       {
2820           {{"y"},
2821            "PartitionedCall",
2822            {"x"},
2823            {{"Tin", DataTypeSlice{DT_FLOAT}},
2824             {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2825             {"f", FDH::FunctionRef("CPUResourceOutput", {})}}},
2826       },
2827       {{"ds", "y:output:0"}, {"x_out", "y:output:1"}});
2828 }
2829 
TEST_F(NestedPlacerTest,NestedDeepDeviceConstraintsPropagated)2830 TEST_F(NestedPlacerTest, NestedDeepDeviceConstraintsPropagated) {
2831   /*
2832    *            a:FLOAT
2833    *               |
2834    *               v
2835    *          PCO (NestedCPUResourceOutput)
2836    *               |    |
2837    *               |    v
2838    *               |  (ignored)
2839    *               |
2840    *               v
2841    *          id:_Retval:GPU (assigned)
2842    *
2843    * The graph cannot be placed because the PCO can produce the resource
2844    * on CPU only.
2845    */
2846   GraphDef graph = GDef(
2847       {
2848           NDef("a", "_Arg", {}, {{"T", DT_FLOAT}}),
2849           NDef("y", "PartitionedCall", {"a"},
2850                {{"Tin", DataTypeSlice{DT_FLOAT}},
2851                 {"Tout", DataTypeSlice{DT_RESOURCE, DT_FLOAT}},
2852                 {"f", FDH::FunctionRef("NestedCPUResourceOutput", {})}}),
2853           NDef("id", "_Retval", {"y:0"}, {{"T", DT_RESOURCE}}),
2854       },
2855       // FunctionLib
2856       {CPUResourceOutput(), NestedCPUResourceOutput()});
2857 
2858   Graph g(OpRegistry::Global());
2859   TF_EXPECT_OK(BuildGraph(graph, &g));
2860   GetNodeByName(g, "id")->set_assigned_device_name(kFullGPU);
2861 
2862   Status s = CallOptPassesAndPlace(&g);
2863   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2864   // TODO(b/129057603): When better error messages are implemented, this should
2865   // change.
2866   EXPECT_TRUE(absl::StrContains(
2867       s.error_message(), "Could not satisfy explicit device specification"))
2868       << s.ToString();
2869 }
2870 
TEST_F(NestedPlacerTest,TwoFunctionsBackToBack)2871 TEST_F(NestedPlacerTest, TwoFunctionsBackToBack) {
2872   /*
2873    *            a:RESOURCE:CPU
2874    *               |
2875    *               |          b:RESOURCE:GPU
2876    *               v              |
2877    *            y:PCO (identity)  |
2878    *               |              |
2879    *            w:PCO (identity)  |
2880    *               |              v
2881    *                \          z:PCO (identity)
2882    *                 \           /
2883    *                  \         /
2884    *                   v       v
2885    *                 Add:RESOURCE
2886    *
2887    * Add op cannot be placed.
2888    * Two PCOs back to back is a challenging case that required adding
2889    * IsolateDeepOpsPass.
2890    */
2891   FunctionDef func = test::function::ResourceIdentity();
2892   GraphDef graph = GDef(
2893       {
2894           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
2895           NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
2896           NDef("y", "PartitionedCall", {"a"},
2897                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2898                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2899                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2900           NDef("w", "PartitionedCall", {"y:0"},
2901                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2902                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2903                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2904           NDef("z", "PartitionedCall", {"b"},
2905                {{"Tin", DataTypeSlice{DT_RESOURCE}},
2906                 {"Tout", DataTypeSlice{DT_RESOURCE}},
2907                 {"f", FDH::FunctionRef("ResourceIdentity", {})}}),
2908           NDef("add", "Add", {"w:0", "z:0"}, {{"T", DT_RESOURCE}}),
2909       },
2910       // FunctionLib
2911       {func});
2912 
2913   Graph g(OpRegistry::Global());
2914   TF_EXPECT_OK(BuildGraph(graph, &g));
2915 
2916   Status s = CallOptPassesAndPlace(&g);
2917   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2918   EXPECT_TRUE(absl::StrContains(
2919       s.error_message(),
2920       "Cannot place the graph because a reference or resource edge connects "
2921       "colocation groups with incompatible resource devices: /device:FakeCPU:0 "
2922       "vs /device:FakeGPU:0"))
2923       << s.ToString();
2924 }
2925 
NestedCallFunctionsBackToBack()2926 FunctionDef NestedCallFunctionsBackToBack() {
2927   return FDH::Create(
2928       // Name
2929       "NestedCallFunctionsBackToBack",
2930       // Args
2931       {},
2932       // Return values
2933       {"output: resource"},
2934       // Attr def
2935       {},
2936       // Nodes
2937       {
2938           {{"cpu_ds"}, "CreateDatasetCPU", {}},
2939           {{"y"},
2940            "PartitionedCall",
2941            {"cpu_ds:o:0"},
2942            {{"Tin", DataTypeSlice{DT_RESOURCE}},
2943             {"Tout", DataTypeSlice{DT_RESOURCE}},
2944             {"f", FDH::FunctionRef("ResourceIdentity", {})}}},
2945           {{"w"},
2946            "PartitionedCall",
2947            {"y:output:0"},
2948            {{"Tin", DataTypeSlice{DT_RESOURCE}},
2949             {"Tout", DataTypeSlice{DT_RESOURCE}},
2950             {"f", FDH::FunctionRef("ResourceIdentity", {})}}},
2951           {{"gpu_ds"}, "CreateDatasetGPU", {}},
2952           {{"z"},
2953            "PartitionedCall",
2954            {"gpu_ds:o:0"},
2955            {{"Tin", DataTypeSlice{DT_RESOURCE}},
2956             {"Tout", DataTypeSlice{DT_RESOURCE}},
2957             {"f", FDH::FunctionRef("ResourceIdentity", {})}}},
2958           {{"add"}, "Add", {"w:output:0", "z:output:0"}, {{"T", DT_RESOURCE}}},
2959       },
2960       {{"output", "add:z:0"}});
2961 }
2962 
TEST_F(NestedPlacerTest,NestedTwoFunctionsBackToBack)2963 TEST_F(NestedPlacerTest, NestedTwoFunctionsBackToBack) {
2964   /*
2965    * Same as TwoFunctionsBackToBack above but the functions are invoked in
2966    * another function instead of the top level graph. This tests that Placer
2967    * isolates deep ops in nested function bodies.
2968    */
2969   FunctionDef func = NestedCallFunctionsBackToBack();
2970   GraphDef graph = GDef(
2971       {
2972           NDef("y", "PartitionedCall", {},
2973                {{"Tin", {}},
2974                 {"Tout", DataTypeSlice{DT_FLOAT}},
2975                 {"f", FDH::FunctionRef("NestedCallFunctionsBackToBack", {})}}),
2976       },
2977       // FunctionLib
2978       {NestedCallFunctionsBackToBack(), test::function::ResourceIdentity()});
2979 
2980   Graph g(OpRegistry::Global());
2981   TF_EXPECT_OK(BuildGraph(graph, &g));
2982 
2983   Status s = CallOptPassesAndPlace(&g);
2984   EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
2985   EXPECT_TRUE(absl::StrContains(
2986       s.error_message(),
2987       "Nodes were connected by a reference or resource connection (requiring "
2988       "them to be on the same device), but the two nodes were assigned two "
2989       "different devices"))
2990       << s.ToString();
2991 }
2992 
RecursiveResourceIdentity()2993 FunctionDef RecursiveResourceIdentity() {
2994   return FDH::Create(
2995       // Name
2996       "RecursiveResourceIdentity",
2997       // Args
2998       {"x: resource"},
2999       // Return values
3000       {"y: resource"},
3001       // Attr def
3002       {},
3003       // Nodes
3004       {
3005           {{"out"},
3006            "PartitionedCall",
3007            {"x"},
3008            {{"Tin", DataTypeSlice{DT_RESOURCE}},
3009             {"Tout", DataTypeSlice{DT_RESOURCE}},
3010             {"f", FDH::FunctionRef("RecursiveResourceIdentity", {})}}},
3011       },
3012       // Output mapping
3013       {{"y", "out:output:0"}});
3014 }
3015 
TEST_F(NestedPlacerTest,DirectRecursion)3016 TEST_F(NestedPlacerTest, DirectRecursion) {
3017   GraphDef graph = GDef(
3018       {
3019           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
3020           NDef("y", "PartitionedCall", {"a"},
3021                {{"Tin", DataTypeSlice{DT_RESOURCE}},
3022                 {"Tout", DataTypeSlice{DT_RESOURCE}},
3023                 {"f", FDH::FunctionRef("RecursiveResourceIdentity", {})}}),
3024           NDef("r1", "_Retval", {"y:0"}, {{"T", DT_RESOURCE}}),
3025       },
3026       // FunctionLib
3027       {RecursiveResourceIdentity()});
3028 
3029   Graph g(OpRegistry::Global());
3030   TF_EXPECT_OK(BuildGraph(graph, &g));
3031 
3032   Status s = CallOptPassesAndPlace(&g);
3033   EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s.ToString();
3034   EXPECT_TRUE(absl::StrContains(
3035       s.error_message(),
3036       "Recursive function calls are not supported. Node {{node out}} inside "
3037       "the body of {{function_node RecursiveResourceIdentity}} calls function "
3038       "{{function_node RecursiveResourceIdentity}}"))
3039       << s.ToString();
3040 }
3041 
RecursiveF1()3042 FunctionDef RecursiveF1() {
3043   return FDH::Create(
3044       // Name
3045       "RecursiveF1",
3046       // Args
3047       {"x: resource"},
3048       // Return values
3049       {"y: resource"},
3050       // Attr def
3051       {},
3052       // Nodes
3053       {
3054           {{"out"},
3055            "PartitionedCall",
3056            {"x"},
3057            {{"Tin", DataTypeSlice{DT_RESOURCE}},
3058             {"Tout", DataTypeSlice{DT_RESOURCE}},
3059             {"f", FDH::FunctionRef("RecursiveF2", {})}}},
3060       },
3061       // Output mapping
3062       {{"y", "out:output:0"}});
3063 }
3064 
RecursiveF2()3065 FunctionDef RecursiveF2() {
3066   return FDH::Create(
3067       // Name
3068       "RecursiveF2",
3069       // Args
3070       {"x: resource"},
3071       // Return values
3072       {"y: resource"},
3073       // Attr def
3074       {},
3075       // Nodes
3076       {
3077           {{"out"},
3078            "PartitionedCall",
3079            {"x"},
3080            {{"Tin", DataTypeSlice{DT_RESOURCE}},
3081             {"Tout", DataTypeSlice{DT_RESOURCE}},
3082             {"f", FDH::FunctionRef("RecursiveF1", {})}}},
3083       },
3084       // Output mapping
3085       {{"y", "out:output:0"}});
3086 }
3087 
TEST_F(NestedPlacerTest,IndirectRecursion)3088 TEST_F(NestedPlacerTest, IndirectRecursion) {
3089   GraphDef graph = GDef(
3090       {
3091           NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
3092           NDef("y", "PartitionedCall", {"a"},
3093                {{"Tin", DataTypeSlice{DT_RESOURCE}},
3094                 {"Tout", DataTypeSlice{DT_RESOURCE}},
3095                 {"f", FDH::FunctionRef("RecursiveF1", {})}}),
3096           NDef("r1", "_Retval", {"y:0"}, {{"T", DT_RESOURCE}}),
3097       },
3098       // FunctionLib
3099       {RecursiveF1(), RecursiveF2()});
3100 
3101   Graph g(OpRegistry::Global());
3102   TF_EXPECT_OK(BuildGraph(graph, &g));
3103 
3104   Status s = CallOptPassesAndPlace(&g);
3105   EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s.ToString();
3106   EXPECT_TRUE(absl::StrContains(
3107       s.error_message(),
3108       "Recursive function calls are not supported. Node {{node out}} inside "
3109       "the body of {{function_node RecursiveF2}} calls function "
3110       "{{function_node RecursiveF1}} which is already present in the call "
3111       "stack"))
3112       << s.ToString();
3113 }
3114 
3115 }  // namespace
3116 }  // namespace tensorflow
3117