1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <unordered_map>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25
26 #include "absl/types/optional.h"
27 #include "absl/types/variant.h"
28 #include "tensorflow/c/tf_tensor.h"
29 #include "tensorflow/c/tf_tensor_internal.h"
30 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
32 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
33 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
34 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
35 #include "tensorflow/core/distributed_runtime/session_mgr.h"
36 #include "tensorflow/core/distributed_runtime/test_utils.h"
37 #include "tensorflow/core/distributed_runtime/worker_env.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/protobuf/eager_service.pb.h"
44 #include "tensorflow/core/protobuf/error_codes.pb.h"
45 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
46 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
47
48 namespace tensorflow {
49 namespace eager {
50 namespace {
51
52 class TestEagerServiceImpl : public EagerServiceImpl {
53 public:
TestEagerServiceImpl(const WorkerEnv * env)54 explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
GetEagerContext(const uint64 context_id,EagerContext ** ctx)55 Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
56 ServerContext* context = nullptr;
57 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
58 core::ScopedUnref context_unref(context);
59 *ctx = context->Context();
60 return OkStatus();
61 }
GetTensorHandle(const uint64 context_id,const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)62 Status GetTensorHandle(const uint64 context_id,
63 const RemoteTensorHandleInternal& remote_handle,
64 tensorflow::TensorHandle** handle) {
65 ServerContext* context = nullptr;
66 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
67 core::ScopedUnref context_unref(context);
68
69 return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle,
70 handle);
71 }
72 };
73
74 class FakeEagerClient : public EagerClient {
75 public:
FakeEagerClient()76 FakeEagerClient() {}
~FakeEagerClient()77 ~FakeEagerClient() override {}
78
SetServiceImpl(TestEagerServiceImpl * impl)79 void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
80
81 #define CLIENT_METHOD(method) \
82 void method##Async(const method##Request* request, \
83 method##Response* response, StatusCallback done) \
84 override { \
85 done(impl_->method(request, response)); \
86 }
87
88 CLIENT_METHOD(CreateContext);
89 CLIENT_METHOD(UpdateContext);
90 CLIENT_METHOD(WaitQueueDone);
91 CLIENT_METHOD(KeepAlive);
92 CLIENT_METHOD(CloseContext);
93 #undef CLIENT_METHOD
94
EnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)95 void EnqueueAsync(CallOptions* call_opts, const EnqueueRequest* request,
96 EnqueueResponse* response, StatusCallback done) override {
97 done(impl_->Enqueue(call_opts, request, response));
98 }
99
RunComponentFunctionAsync(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)100 void RunComponentFunctionAsync(CallOptions* call_opts,
101 const RunComponentFunctionRequest* request,
102 RunComponentFunctionResponse* response,
103 StatusCallback done) override {
104 impl_->RunComponentFunction(call_opts, request, response, std::move(done));
105 }
106
StreamingEnqueueAsync(bool enable_streaming_enqueue,CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)107 void StreamingEnqueueAsync(bool enable_streaming_enqueue,
108 CallOptions* call_opts,
109 const EnqueueRequest* request,
110 EnqueueResponse* response,
111 StatusCallback done) override {
112 done(impl_->Enqueue(nullptr, request, response));
113 }
114
allow_multiple_pending_requests() const115 bool allow_multiple_pending_requests() const override { return false; }
116
117 private:
118 TestEagerServiceImpl* impl_;
119 };
120
121 class DummyEagerClientCache : public EagerClientCache {
122 public:
DummyEagerClientCache()123 DummyEagerClientCache() : client_(new FakeEagerClient) {}
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)124 Status GetClient(const string& target,
125 core::RefCountPtr<EagerClient>* client) override {
126 client->reset(client_.get());
127 client_->Ref();
128 return OkStatus();
129 }
130
131 private:
132 core::RefCountPtr<EagerClient> client_;
133 };
134
135 class FakeCache : public TestWorkerCache {
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)136 Status GetEagerClientCache(
137 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
138 *eager_client_cache = std::make_unique<DummyEagerClientCache>();
139 return OkStatus();
140 }
141
ListWorkers(std::vector<string> * workers) const142 void ListWorkers(std::vector<string>* workers) const override {
143 workers->push_back("/job:localhost/replica:0/task:0");
144 }
145 };
146
147 class EagerServiceImplTest : public ::testing::Test {
148 public:
EagerServiceImplTest()149 EagerServiceImplTest()
150 : rendezvous_mgr_(&worker_env_),
151 session_mgr_(new SessionMgr(
152 &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
153 std::unique_ptr<WorkerCacheInterface>(new FakeCache),
154 [](const ServerDef& server_def,
155 WorkerCacheInterface** worker_cache) {
156 *worker_cache = new FakeCache;
157 return OkStatus();
158 })) {
159 worker_env_.env = Env::Default();
160
161 worker_env_.rendezvous_mgr = &rendezvous_mgr_;
162 worker_env_.session_mgr = session_mgr_.get();
163
164 device_mgr_ = std::make_unique<StaticDeviceMgr>(
165 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
166 worker_env_.local_devices = device_mgr_->ListDevices();
167 worker_env_.device_mgr = device_mgr_.get();
168 }
169
170 protected:
171 WorkerEnv worker_env_;
172 tensorflow::RpcRendezvousMgr rendezvous_mgr_;
173 std::unique_ptr<SessionMgr> session_mgr_;
174 std::unique_ptr<DeviceMgr> device_mgr_;
175 };
176
SetTensorProto(TensorProto * tensor_proto)177 void SetTensorProto(TensorProto* tensor_proto) {
178 int64_t dims[] = {2, 2};
179 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
180 TF_Tensor* t = TF_AllocateTensor(
181 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
182 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
183 tensorflow::Tensor tensor;
184 TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
185 tensor.AsProtoTensorContent(tensor_proto);
186 TF_DeleteTensor(t);
187 }
188
BuildOperation(Operation * operation,int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device)189 void BuildOperation(
190 Operation* operation, int64_t id, const string& name,
191 const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
192 inputs,
193 const std::unordered_map<string, AttrValue>& attrs, const string& device) {
194 operation->set_id(id);
195 operation->set_name(name);
196 operation->set_device(device);
197
198 for (const auto& input : inputs) {
199 if (input.index() == 0) {
200 *operation->add_op_inputs()->mutable_tensor() =
201 std::get<TensorProto>(input);
202 } else {
203 const auto& tensor_handle_pair =
204 std::get<std::pair<int64_t, int32>>(input);
205 auto* input = operation->add_op_inputs()->mutable_remote_handle();
206 input->set_op_id(tensor_handle_pair.first);
207 input->set_output_num(tensor_handle_pair.second);
208 input->set_op_device(device);
209 input->set_device(device);
210 }
211 }
212
213 for (const auto& attr_entry : attrs) {
214 (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second;
215 }
216 }
217
AddOperationToEnqueueRequest(int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,EnqueueRequest * request)218 void AddOperationToEnqueueRequest(
219 int64_t id, const string& name,
220 const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
221 inputs,
222 const std::unordered_map<string, AttrValue>& attrs, const string& device,
223 EnqueueRequest* request) {
224 auto* operation = request->add_queue()->mutable_operation();
225 BuildOperation(operation, id, name, inputs, attrs, device);
226 }
227
AddOperationToRunComponentFunctionRequest(int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,const int output_num,RunComponentFunctionRequest * request)228 void AddOperationToRunComponentFunctionRequest(
229 int64_t id, const string& name,
230 const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
231 inputs,
232 const std::unordered_map<string, AttrValue>& attrs, const string& device,
233 const int output_num, RunComponentFunctionRequest* request) {
234 auto* operation = request->mutable_operation();
235 operation->set_is_function(true);
236 operation->set_is_component_function(true);
237 request->add_output_num(output_num);
238 BuildOperation(operation, id, name, inputs, attrs, device);
239 }
240
MatMulFunctionNodeDef()241 tensorflow::NodeDef MatMulFunctionNodeDef() {
242 tensorflow::NodeDef def;
243 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
244 " name: 'matmul_func'"
245 " op: 'MatMulFunction'"
246 " input: 'a'"
247 " input: 'a'"
248 " attr {"
249 " key: 'T'"
250 " value {"
251 " type: DT_FLOAT"
252 " }"
253 " }",
254 &def));
255 return def;
256 }
257
MatMulFunction()258 tensorflow::FunctionDef MatMulFunction() {
259 tensorflow::FunctionDef def;
260 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
261 " signature {"
262 " name: 'MatMulFunction'"
263 " input_arg {"
264 " name: 'a'"
265 " type: DT_FLOAT"
266 " }"
267 " output_arg {"
268 " name: 'm'"
269 " type: DT_FLOAT"
270 " }"
271 " }"
272 " node_def {"
273 " name: 'matmul'"
274 " op: 'MatMul'"
275 " input: 'a'"
276 " input: 'a'"
277 " attr {"
278 " key: 'T'"
279 " value {"
280 " type: DT_FLOAT"
281 " }"
282 " }"
283 " attr {"
284 " key: 'transpose_a'"
285 " value {"
286 " b: false"
287 " }"
288 " }"
289 " }"
290 " ret {"
291 " key: 'm'"
292 " value: 'matmul:product'"
293 " }",
294 &def));
295 return def;
296 }
297
MatMulNestedFunction()298 tensorflow::FunctionDef MatMulNestedFunction() {
299 tensorflow::FunctionDef def;
300 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
301 " signature {"
302 " name: 'MatMulNestedFunction'"
303 " input_arg {"
304 " name: 'a'"
305 " type: DT_FLOAT"
306 " }"
307 " output_arg {"
308 " name: 'matmul_nested'"
309 " type: DT_FLOAT"
310 " }"
311 " }"
312 " node_def {"
313 " name: 'matmul_nested'"
314 " op: 'MatMulFunction'"
315 " input: 'a'"
316 " attr {"
317 " key: 'T'"
318 " value {"
319 " type: DT_FLOAT"
320 " }"
321 " }"
322 " }"
323 " ret {"
324 " key: 'matmul_nested'"
325 " value: 'matmul_nested:m:0'"
326 " }",
327 &def));
328 return def;
329 }
330
SingleRecvNodeFunction()331 tensorflow::FunctionDef SingleRecvNodeFunction() {
332 tensorflow::FunctionDef def;
333 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
334 " signature {"
335 " name: 'SingleRecvNodeFunction'"
336 " input_arg {"
337 " name: 'a'"
338 " type: DT_FLOAT"
339 " }"
340 " output_arg {"
341 " name: 'recv_tensor'"
342 " type: DT_FLOAT"
343 " }"
344 " }"
345 " node_def {"
346 " name: 'recv_node'"
347 " op: '_Recv'"
348 " device: '/job:localhost/replica:0/task:0/device:CPU:0'"
349 " attr {"
350 " key: 'client_terminated'"
351 " value {"
352 " b: true"
353 " }"
354 " }"
355 " attr {"
356 " key: 'recv_device'"
357 " value {"
358 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
359 " }"
360 " }"
361 " attr {"
362 " key: 'send_device'"
363 " value {"
364 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
365 " }"
366 " }"
367 " attr {"
368 " key: 'send_device_incarnation'"
369 " value {"
370 " i: 1"
371 " }"
372 " }"
373 " attr {"
374 " key: 'tensor_name'"
375 " value {"
376 " s: 't0'"
377 " }"
378 " }"
379 " attr {"
380 " key: 'tensor_type'"
381 " value {"
382 " type: DT_FLOAT"
383 " }"
384 " }"
385 " }"
386 " ret {"
387 " key: 'recv_tensor'"
388 " value: 'recv_node:tensor:0'"
389 " }",
390 &def));
391 return def;
392 }
393
394 // Test creates a context and attempts to execute some ops.
TEST_F(EagerServiceImplTest,BasicTest)395 TEST_F(EagerServiceImplTest, BasicTest) {
396 TestEagerServiceImpl eager_service_impl(&worker_env_);
397
398 uint64 context_id = random::New64();
399
400 CreateContextRequest request;
401 request.mutable_server_def()->set_job_name("localhost");
402 request.mutable_server_def()->set_task_index(0);
403 request.set_context_id(context_id);
404 CreateContextResponse response;
405
406 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
407
408 EnqueueRequest remote_enqueue_request;
409 remote_enqueue_request.set_context_id(context_id);
410 EnqueueResponse remote_enqueue_response;
411
412 std::unordered_map<string, AttrValue> const_attrs;
413 AttrValue val;
414 val.set_type(tensorflow::DataType::DT_FLOAT);
415 const_attrs.insert({"dtype", val});
416 val.Clear();
417 SetTensorProto(val.mutable_tensor());
418 const_attrs.insert({"value", val});
419
420 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
421 "/job:localhost/replica:0/task:0/device:CPU:0",
422 &remote_enqueue_request);
423
424 std::unordered_map<string, AttrValue> attrs;
425 val.Clear();
426 val.set_type(tensorflow::DataType::DT_FLOAT);
427 attrs.insert({"T", val});
428 val.Clear();
429 val.set_b(false);
430 attrs.insert({"transpose_a", val});
431 attrs.insert({"transpose_b", val});
432
433 AddOperationToEnqueueRequest(
434 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
435 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
436
437 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
438 &remote_enqueue_response));
439
440 auto& matmul_result_shape =
441 remote_enqueue_response.queue_response(1).shape(0);
442 EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
443 EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
444
445 tensorflow::TensorHandle* tensor_handle;
446 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
447 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
448
449 // This should be OK to do since we've placed all computation on the CPU
450 // device.
451 const tensorflow::Tensor* t = nullptr;
452 TF_ASSERT_OK(tensor_handle->Tensor(&t));
453
454 auto actual = t->flat<float>();
455
456 EXPECT_EQ(4, actual.size());
457
458 EXPECT_EQ(7, actual(0));
459 EXPECT_EQ(10, actual(1));
460 EXPECT_EQ(15, actual(2));
461 EXPECT_EQ(22, actual(3));
462
463 CloseContextRequest close_context_request;
464 close_context_request.set_context_id(context_id);
465 close_context_request.set_context_view_id(0);
466 CloseContextResponse close_context_response;
467 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
468 &close_context_response));
469 }
470
471 class EagerServiceImplFunctionTest : public EagerServiceImplTest {
472 public:
EagerServiceImplFunctionTest()473 EagerServiceImplFunctionTest() : EagerServiceImplTest() {}
474
475 // Creates a context and attempts to execute a function.
TestFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool local_inputs=false,const bool test_cancel=false)476 void TestFunction(const RegisterFunctionOp& register_op,
477 const string& function_name,
478 const bool local_inputs = false,
479 const bool test_cancel = false) {
480 TestEagerServiceImpl eager_service_impl(&worker_env_);
481
482 uint64 context_id = random::New64();
483
484 CreateContextRequest request;
485 request.mutable_server_def()->set_job_name("localhost");
486 request.mutable_server_def()->set_task_index(0);
487 request.set_context_id(context_id);
488 CreateContextResponse response;
489
490 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
491
492 EnqueueRequest enqueue_request;
493 enqueue_request.set_context_id(context_id);
494 *enqueue_request.add_queue()->mutable_register_function() = register_op;
495 EnqueueResponse enqueue_response;
496
497 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
498 &enqueue_response));
499
500 EnqueueRequest remote_enqueue_request;
501 remote_enqueue_request.set_context_id(context_id);
502 EnqueueResponse remote_enqueue_response;
503
504 if (local_inputs) {
505 TensorProto tensor_proto;
506 SetTensorProto(&tensor_proto);
507 AddOperationToEnqueueRequest(
508 2, function_name, {tensor_proto},
509 std::unordered_map<string, AttrValue>(),
510 "/job:localhost/replica:0/task:0/device:CPU:0",
511 &remote_enqueue_request);
512
513 } else {
514 std::unordered_map<string, AttrValue> const_attrs;
515 AttrValue val;
516 val.set_type(tensorflow::DataType::DT_FLOAT);
517 const_attrs.insert({"dtype", val});
518 val.Clear();
519
520 SetTensorProto(val.mutable_tensor());
521 const_attrs.insert({"value", val});
522
523 AddOperationToEnqueueRequest(
524 1, "Const", {}, const_attrs,
525 "/job:localhost/replica:0/task:0/device:CPU:0",
526 &remote_enqueue_request);
527 AddOperationToEnqueueRequest(
528 2, function_name, {std::make_pair(1, 0)},
529 std::unordered_map<string, AttrValue>(),
530 "/job:localhost/replica:0/task:0/device:CPU:0",
531 &remote_enqueue_request);
532 }
533
534 CallOptions call_opts;
535 Status status;
536 Notification n;
537 Env::Default()->SchedClosure([&] {
538 status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request,
539 &remote_enqueue_response);
540 n.Notify();
541 });
542
543 if (test_cancel) {
544 // Wait to let the Enqueue thread starts running
545 Env::Default()->SleepForMicroseconds(500000);
546 call_opts.StartCancel();
547 n.WaitForNotification();
548 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
549 } else {
550 n.WaitForNotification();
551 TF_ASSERT_OK(status);
552 const tensorflow::Tensor* t = nullptr;
553 tensorflow::TensorHandle* tensor_handle;
554 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
555 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
556 TF_ASSERT_OK(tensor_handle->Tensor(&t));
557
558 auto actual = t->flat<float>();
559 EXPECT_EQ(4, actual.size());
560
561 EXPECT_EQ(7, actual(0));
562 EXPECT_EQ(10, actual(1));
563 EXPECT_EQ(15, actual(2));
564 EXPECT_EQ(22, actual(3));
565 }
566
567 CloseContextRequest close_context_request;
568 close_context_request.set_context_id(context_id);
569 close_context_request.set_context_view_id(0);
570 CloseContextResponse close_context_response;
571 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
572 &close_context_response));
573 }
574
575 // Creates a context and attempts to execute a component function.
TestComponentFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool test_cancel)576 void TestComponentFunction(const RegisterFunctionOp& register_op,
577 const string& function_name,
578 const bool test_cancel) {
579 TestEagerServiceImpl eager_service_impl(&worker_env_);
580 uint64 context_id = random::New64();
581
582 // Create context.
583 CreateContextRequest request;
584 request.mutable_server_def()->set_job_name("localhost");
585 request.mutable_server_def()->set_task_index(0);
586 request.set_context_id(context_id);
587 CreateContextResponse response;
588 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
589
590 // Register function.
591 EnqueueRequest enqueue_request;
592 enqueue_request.set_context_id(context_id);
593 *enqueue_request.add_queue()->mutable_register_function() = register_op;
594 EnqueueResponse enqueue_response;
595 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
596 &enqueue_response));
597
598 // First run an op to generate input for function.
599 EnqueueRequest remote_enqueue_request;
600 remote_enqueue_request.set_context_id(context_id);
601 EnqueueResponse remote_enqueue_response;
602
603 std::unordered_map<string, AttrValue> const_attrs;
604 AttrValue val;
605 val.set_type(tensorflow::DataType::DT_FLOAT);
606 const_attrs.insert({"dtype", val});
607 val.Clear();
608 SetTensorProto(val.mutable_tensor());
609 const_attrs.insert({"value", val});
610 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
611 "/job:localhost/replica:0/task:0/device:CPU:0",
612 &remote_enqueue_request);
613 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
614 &remote_enqueue_response));
615
616 // Run function with input from the previous op.
617 RunComponentFunctionRequest run_comp_func_request;
618 run_comp_func_request.set_context_id(context_id);
619 RunComponentFunctionResponse run_comp_func_response;
620 const int output_num = 5;
621 AddOperationToRunComponentFunctionRequest(
622 2, function_name, {std::make_pair(1, 0)},
623 std::unordered_map<string, AttrValue>(),
624 "/job:localhost/replica:0/task:0/device:CPU:0", output_num,
625 &run_comp_func_request);
626
627 CallOptions call_opts;
628 Notification n;
629 Status status;
630 eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
631 &run_comp_func_response,
632 [&status, &n](const Status& s) {
633 status.Update(s);
634 n.Notify();
635 });
636 if (test_cancel) {
637 call_opts.StartCancel();
638 }
639 n.WaitForNotification();
640 if (test_cancel) {
641 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
642 } else {
643 TF_ASSERT_OK(status);
644 // Retrieve the output.
645 const tensorflow::Tensor* t = nullptr;
646 tensorflow::TensorHandle* tensor_handle;
647 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
648 context_id, RemoteTensorHandleInternal(2, output_num),
649 &tensor_handle));
650 TF_ASSERT_OK(tensor_handle->Tensor(&t));
651
652 auto actual = t->flat<float>();
653 EXPECT_EQ(4, actual.size());
654
655 EXPECT_EQ(7, actual(0));
656 EXPECT_EQ(10, actual(1));
657 EXPECT_EQ(15, actual(2));
658 EXPECT_EQ(22, actual(3));
659 }
660
661 CloseContextRequest close_context_request;
662 close_context_request.set_context_id(context_id);
663 close_context_request.set_context_view_id(0);
664 CloseContextResponse close_context_response;
665 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
666 &close_context_response));
667 }
668 };
669
TEST_F(EagerServiceImplFunctionTest,BasicFunctionTest)670 TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
671 RegisterFunctionOp register_op;
672 *register_op.mutable_function_def() = MatMulFunction();
673 TestFunction(register_op, "MatMulFunction");
674 }
675
TEST_F(EagerServiceImplFunctionTest,FunctionWithLocalInputsTest)676 TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
677 RegisterFunctionOp register_op;
678 *register_op.mutable_function_def() = MatMulFunction();
679 TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
680 }
681
TEST_F(EagerServiceImplFunctionTest,NestedFunctionTest)682 TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
683 RegisterFunctionOp register_op;
684 *register_op.mutable_function_def() = MatMulNestedFunction();
685 *register_op.mutable_library()->add_function() = MatMulFunction();
686 TestFunction(register_op, "MatMulNestedFunction");
687 }
688
TEST_F(EagerServiceImplFunctionTest,FunctionCancellationTest)689 TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) {
690 RegisterFunctionOp register_op;
691 *register_op.mutable_function_def() = SingleRecvNodeFunction();
692 TestFunction(register_op, "SingleRecvNodeFunction", /*local_inputs=*/false,
693 /*test_cancel=*/true);
694 }
695
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionTest)696 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
697 RegisterFunctionOp register_op;
698 *register_op.mutable_function_def() = MatMulFunction();
699 TestComponentFunction(register_op, "MatMulFunction", false);
700 }
701
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionCancellationTest)702 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
703 RegisterFunctionOp register_op;
704 *register_op.mutable_function_def() = SingleRecvNodeFunction();
705 TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
706 }
707
708 class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
709 public:
FunctionWithRemoteInputsTest()710 FunctionWithRemoteInputsTest()
711 : EagerServiceImplTest(), eager_service_impl_(&worker_env_) {
712 remote_device_mgr_ = std::make_unique<StaticDeviceMgr>(
713 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
714 context_id_ = random::New64();
715 }
716
717 class TestExecuteNodeArgs : public EagerKernelArgs {
718 public:
TestExecuteNodeArgs(gtl::InlinedVector<TensorValue,4> && tensor_args,std::function<Status (const int,eager::RemoteTensorHandle *)> serialize_remote_handle)719 TestExecuteNodeArgs(
720 gtl::InlinedVector<TensorValue, 4>&& tensor_args,
721 std::function<Status(const int, eager::RemoteTensorHandle*)>
722 serialize_remote_handle)
723 : EagerKernelArgs(std::move(tensor_args)),
724 serialize_remote_handle_(std::move(serialize_remote_handle)) {}
725
HasRemoteOrPackedInputs() const726 bool HasRemoteOrPackedInputs() const override { return true; }
727
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val) const728 Status GetRemoteArg(const FunctionArgIndex& index,
729 eager::RemoteTensorHandle* val) const override {
730 return serialize_remote_handle_(index.index, val);
731 }
732
733 private:
734 std::function<Status(const int, eager::RemoteTensorHandle*)>
735 serialize_remote_handle_;
736 };
737
MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef & fdef)738 bool MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef& fdef) {
739 for (const auto& node : fdef.node_def()) {
740 if (node.op() == "MatMul") {
741 return node.attr().find("transpose_a") != node.attr().end();
742 }
743 }
744 return false;
745 }
746
Init()747 void Init() {
748 CreateContextRequest request;
749 request.mutable_server_def()->set_job_name("localhost");
750 request.mutable_server_def()->set_task_index(0);
751 request.set_context_id(context_id_);
752 CreateContextResponse response;
753 TF_ASSERT_OK(eager_service_impl_.CreateContext(&request, &response));
754
755 // Make the fake EagerClient use the local eager_service_impl.
756 EagerContext* ctx = nullptr;
757 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
758 Device* device;
759 TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
760 core::RefCountPtr<EagerClient> client;
761 TF_ASSERT_OK(ctx->GetClient(device, &client));
762 FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
763 fake_client->SetServiceImpl(&eager_service_impl_);
764
765 // Create an input on local_device for MatMulFunction.
766 EnqueueRequest remote_enqueue_request;
767 remote_enqueue_request.set_context_id(context_id_);
768 EnqueueResponse remote_enqueue_response;
769 std::unordered_map<string, AttrValue> const_attrs;
770 AttrValue val;
771 val.set_type(tensorflow::DataType::DT_FLOAT);
772 const_attrs.insert({"dtype", val});
773 val.Clear();
774 SetTensorProto(val.mutable_tensor());
775 const_attrs.insert({"value", val});
776 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device_,
777 &remote_enqueue_request);
778 TF_EXPECT_OK(eager_service_impl_.Enqueue(nullptr, &remote_enqueue_request,
779 &remote_enqueue_response));
780 eager_cluster_flr_ = std::make_unique<EagerClusterFunctionLibraryRuntime>(
781 context_id_, ctx, device_mgr_.get());
782
783 fdef_ = MatMulFunction();
784 TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
785 eager_pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
786 remote_device_mgr_.get(), Env::Default(), /*config=*/
787 nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
788 /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
789 /*session_metadata=*/nullptr,
790 Rendezvous::Factory{[this](const int64_t step_id,
791 const DeviceMgr* device_mgr,
792 Rendezvous** r) {
793 *r = worker_env_.rendezvous_mgr->Find(step_id);
794 return OkStatus();
795 }});
796 }
797
CheckOutputTensorAndClose(const Tensor & tensor)798 void CheckOutputTensorAndClose(const Tensor& tensor) {
799 auto actual = tensor.flat<float>();
800 EXPECT_EQ(4, actual.size());
801 EXPECT_EQ(7, actual(0));
802 EXPECT_EQ(10, actual(1));
803 EXPECT_EQ(15, actual(2));
804 EXPECT_EQ(22, actual(3));
805
806 CloseContextRequest close_context_request;
807 close_context_request.set_context_id(context_id_);
808 close_context_request.set_context_view_id(0);
809 CloseContextResponse close_context_response;
810 TF_ASSERT_OK(eager_service_impl_.CloseContext(&close_context_request,
811 &close_context_response));
812 }
813
CheckOutputsAndClose(const std::vector<FunctionRet> & outputs,const int64_t op_id)814 void CheckOutputsAndClose(const std::vector<FunctionRet>& outputs,
815 const int64_t op_id) {
816 const tensorflow::Tensor* t = nullptr;
817 tensorflow::TensorHandle* tensor_handle;
818 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
819 context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
820 TF_ASSERT_OK(tensor_handle->Tensor(&t));
821 EXPECT_EQ(outputs.size(), 1);
822 EXPECT_EQ(outputs.at(0).index(), 1);
823 const TensorShape& shape = std::get<TensorShape>(outputs.at(0));
824 EXPECT_EQ(shape, t->shape());
825 CheckOutputTensorAndClose(*t);
826 }
827
828 protected:
829 const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
830 const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
831 TestEagerServiceImpl eager_service_impl_;
832 std::unique_ptr<DeviceMgr> remote_device_mgr_;
833 uint64 context_id_;
834 tensorflow::FunctionDef fdef_;
835 std::unique_ptr<ProcessFunctionLibraryRuntime> eager_pflr_;
836 std::unique_ptr<EagerClusterFunctionLibraryRuntime> eager_cluster_flr_;
837 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
838 };
839
840 // Test executes a remote function through
841 // ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest,EagerPFLRTest)842 TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
843 Init();
844 // Instantiate MatMulFunction on remote_device.
845 FunctionLibraryRuntime::InstantiateOptions options;
846 options.target = remote_device_;
847 options.is_multi_device_function = true;
848 options.input_devices.push_back(local_device_);
849 FunctionLibraryRuntime::Handle handle;
850 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
851 TF_ASSERT_OK(eager_pflr_->Instantiate(
852 fdef_.signature().name(), AttrSlice(&fdef_.attr()), options, &handle));
853 EagerContext* ctx = nullptr;
854 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
855 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
856 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
857 EXPECT_TRUE(fdef != nullptr);
858 if (absl::StartsWith(func_name, "MatMulFunction")) {
859 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
860 }
861 }
862 bool is_cross_process = false;
863 TF_CHECK_OK(eager_pflr_->IsCrossProcess(handle, &is_cross_process));
864 EXPECT_TRUE(is_cross_process);
865
866 // Run MatMulFunction on remote_device.
867 FunctionLibraryRuntime::Options opts;
868 const uint64 op_id = 2;
869 opts.op_id = op_id;
870 Notification done;
871 Status status;
872 RemoteTensorHandle input;
873 input.set_op_id(1);
874 input.set_output_num(0);
875 input.set_op_device(local_device_);
876 input.set_device(local_device_);
877 std::vector<RemoteTensorHandle> inputs = {input};
878 std::vector<FunctionRet> outputs;
879 gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
880 TestExecuteNodeArgs args(
881 std::move(tensor_args),
882 [&inputs](const int i, RemoteTensorHandle* handle) -> Status {
883 *handle = inputs.at(i);
884 return OkStatus();
885 });
886 eager_pflr_->Run(opts, handle, args, &outputs,
887 [&status, &done](const Status& s) {
888 status = s;
889 done.Notify();
890 });
891 done.WaitForNotification();
892 TF_ASSERT_OK(status);
893 CheckOutputsAndClose(outputs, op_id);
894 }
895
896 // Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,EagerClusterFLRTestWithLocalInputAndOutput)897 TEST_F(FunctionWithRemoteInputsTest,
898 EagerClusterFLRTestWithLocalInputAndOutput) {
899 Init();
900 // Instantiate MatMulFunction on remote_device.
901 FunctionLibraryRuntime::Handle handle;
902 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
903 Status status;
904 Notification instantiate_done;
905 eager_cluster_flr_->Instantiate(
906 fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()),
907 FunctionLibraryRuntime::InstantiateOptions(), &handle,
908 [&status, &instantiate_done](const Status& s) {
909 status = s;
910 instantiate_done.Notify();
911 });
912 instantiate_done.WaitForNotification();
913 TF_ASSERT_OK(status);
914 EagerContext* ctx = nullptr;
915 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
916 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
917 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
918 EXPECT_TRUE(fdef != nullptr);
919 if (absl::StartsWith(func_name, "MatMulFunction")) {
920 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
921 }
922 }
923 const tensorflow::Tensor* input_tensor = nullptr;
924 tensorflow::TensorHandle* tensor_handle;
925 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
926 context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
927 TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
928
929 // Send input_tensor to the remote device, execute MatMulFunction on the
930 // remote device, and send the output back.
931 FunctionLibraryRuntime::Options opts;
932 Notification execute_done;
933 std::vector<Tensor> inputs = {*input_tensor};
934 std::vector<Tensor> outputs;
935 eager_cluster_flr_->Run(opts, handle, inputs, &outputs,
936 [&status, &execute_done](const Status& s) {
937 status = s;
938 execute_done.Notify();
939 });
940 execute_done.WaitForNotification();
941 TF_ASSERT_OK(status);
942 EXPECT_EQ(outputs.size(), 1);
943 CheckOutputTensorAndClose(outputs.at(0));
944 }
945
946 // Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncTest)947 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
948 Init();
949 Device* local_device;
950 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
951 std::vector<Device*> input_dev_ptrs;
952 input_dev_ptrs.push_back(local_device);
953 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
954 EagerContext* ctx = nullptr;
955 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
956 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
957 const int64_t op_id = 2;
958 kernel.reset(new KernelAndDeviceFunc(
959 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
960 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
961 /*runner=*/nullptr,
962 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
963 /*outputs_on_op_device=*/false,
964 /*allow_small_function_optimizations=*/false,
965 /*allow_control_flow_sync_execution=*/false,
966 /*shape_inference_on_tfe_dialect_import=*/true,
967 /*int_args_and_retvals_on_device=*/false,
968 /*xla_compile_device_type=*/std::nullopt, ctx->RendezvousCreator(),
969 [=]() { return op_id; }));
970
971 // Instantiate MatMulFunction on remote_device.
972 const NodeDef node_def = MatMulFunctionNodeDef();
973 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
974
975 // Run MatMulFunction on remote_device.
976 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
977 RemoteTensorHandle input;
978 input.set_op_id(1);
979 input.set_output_num(0);
980 input.set_op_device(local_device_);
981 input.set_device(local_device_);
982 std::vector<RemoteTensorHandle> remote_handles = {input};
983 TestExecuteNodeArgs inputs(
984 std::move(input_tensors),
985 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
986 *handle = remote_handles.at(index);
987 return OkStatus();
988 });
989 std::vector<FunctionRet> outputs;
990
991 TF_ASSERT_OK(kernel->Run(/*step_container=*/nullptr, inputs, &outputs,
992 /*cancellation_manager=*/nullptr,
993 /*eager_func_params=*/std::nullopt,
994 /*stack_trace=*/std::nullopt,
995 /*coordination_service_agent=*/nullptr));
996
997 CheckOutputsAndClose(outputs, op_id);
998 }
999
1000 // Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncAsyncTest)1001 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
1002 Init();
1003 Device* local_device;
1004 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
1005 std::vector<Device*> input_dev_ptrs;
1006 input_dev_ptrs.push_back(local_device);
1007 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
1008 EagerContext* ctx = nullptr;
1009 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
1010 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
1011 const int64_t op_id = 2;
1012 kernel.reset(new KernelAndDeviceFunc(
1013 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
1014 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
1015 /*runner=*/nullptr,
1016 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
1017 /*outputs_on_op_device=*/false,
1018 /*allow_small_function_optimizations=*/false,
1019 /*allow_control_flow_sync_execution=*/false,
1020 /*shape_inference_on_tfe_dialect_import=*/true,
1021 /*int_args_and_retvals_on_device=*/false,
1022 /*xla_compile_device_type=*/std::nullopt, ctx->RendezvousCreator(),
1023 [=]() { return op_id; }));
1024
1025 // Instantiate MatMulFunction on remote_device.
1026 const NodeDef node_def = MatMulFunctionNodeDef();
1027 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
1028
1029 // Run MatMulFunction on remote_device.
1030 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
1031 RemoteTensorHandle input;
1032 input.set_op_id(1);
1033 input.set_output_num(0);
1034 input.set_op_device(local_device_);
1035 input.set_device(local_device_);
1036 std::vector<RemoteTensorHandle> remote_handles = {input};
1037 TestExecuteNodeArgs inputs(
1038 std::move(input_tensors),
1039 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
1040 *handle = remote_handles.at(index);
1041 return OkStatus();
1042 });
1043 std::vector<FunctionRet> outputs;
1044
1045 Status status;
1046 Notification n;
1047 kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
1048 /*cancellation_manager=*/nullptr,
1049 /*eager_func_params=*/std::nullopt,
1050 /*coordination_service_agent=*/nullptr,
1051 [&status, &n](const Status& s) {
1052 status = s;
1053 n.Notify();
1054 });
1055 n.WaitForNotification();
1056 TF_ASSERT_OK(status);
1057 CheckOutputsAndClose(outputs, op_id);
1058 }
1059
1060 // Test creates a context and attempts to send a tensor (using the RPC), and
1061 // then use the tensor.
TEST_F(EagerServiceImplTest,SendTensorTest)1062 TEST_F(EagerServiceImplTest, SendTensorTest) {
1063 TestEagerServiceImpl eager_service_impl(&worker_env_);
1064
1065 uint64 context_id = random::New64();
1066
1067 CreateContextRequest request;
1068 request.mutable_server_def()->set_job_name("localhost");
1069 request.mutable_server_def()->set_task_index(0);
1070 request.set_context_id(context_id);
1071 CreateContextResponse response;
1072
1073 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1074
1075 EnqueueRequest remote_enqueue_request;
1076 remote_enqueue_request.set_context_id(context_id);
1077 EnqueueResponse remote_enqueue_response;
1078
1079 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1080 send_tensor->set_op_id(1);
1081 SetTensorProto(send_tensor->add_tensors());
1082
1083 std::unordered_map<string, AttrValue> attrs;
1084 AttrValue val;
1085 val.Clear();
1086 val.set_type(tensorflow::DataType::DT_FLOAT);
1087 attrs.insert({"T", val});
1088 val.Clear();
1089 val.set_b(false);
1090 attrs.insert({"transpose_a", val});
1091 attrs.insert({"transpose_b", val});
1092
1093 AddOperationToEnqueueRequest(
1094 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
1095 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
1096
1097 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1098 &remote_enqueue_response));
1099
1100 const tensorflow::Tensor* t = nullptr;
1101 tensorflow::TensorHandle* tensor_handle;
1102 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1103 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
1104 TF_ASSERT_OK(tensor_handle->Tensor(&t));
1105
1106 EXPECT_EQ(tensor_handle->device(), nullptr);
1107
1108 auto actual = t->flat<float>();
1109 EXPECT_EQ(4, actual.size());
1110
1111 EXPECT_EQ(7, actual(0));
1112 EXPECT_EQ(10, actual(1));
1113 EXPECT_EQ(15, actual(2));
1114 EXPECT_EQ(22, actual(3));
1115
1116 CloseContextRequest close_context_request;
1117 close_context_request.set_context_id(context_id);
1118 close_context_request.set_context_view_id(0);
1119 CloseContextResponse close_context_response;
1120 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1121 &close_context_response));
1122 }
1123
1124 // Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest,SendPackedHandleTest)1125 TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
1126 TestEagerServiceImpl eager_service_impl(&worker_env_);
1127
1128 const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
1129 const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
1130 const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
1131 const string composite_device =
1132 "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
1133
1134 uint64 context_id = random::New64();
1135 CreateContextRequest request;
1136 auto* server_def = request.mutable_server_def();
1137 server_def->set_job_name("localhost");
1138 server_def->set_task_index(0);
1139 request.add_cluster_device_attributes()->set_name(device0);
1140 request.add_cluster_device_attributes()->set_name(device1);
1141 request.add_cluster_device_attributes()->set_name(device2);
1142 request.set_context_id(context_id);
1143 CreateContextResponse response;
1144
1145 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1146
1147 EnqueueRequest remote_enqueue_request;
1148 remote_enqueue_request.set_context_id(context_id);
1149 EnqueueResponse remote_enqueue_response;
1150
1151 // Copy a tensor to device0
1152 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1153 send_tensor->set_op_id(1);
1154 SetTensorProto(send_tensor->add_tensors());
1155
1156 // Copy a packed handle to device0
1157 auto* send_packed_handle =
1158 remote_enqueue_request.add_queue()->mutable_send_packed_handle();
1159 send_packed_handle->set_op_id(3);
1160 RemoteTensorHandle* remote_handle =
1161 send_packed_handle->add_handles()->mutable_remote_handle();
1162 remote_handle->set_op_id(send_tensor->op_id());
1163 remote_handle->set_output_num(0);
1164 remote_handle->set_op_device(device0);
1165 remote_handle->set_device(device0);
1166
1167 SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
1168 send_packed_handle->add_handles()->mutable_local_handle();
1169 SetTensorProto(lcoal_handle->mutable_tensor());
1170 lcoal_handle->set_device(device1);
1171
1172 remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
1173 remote_handle->set_op_id(2);
1174 remote_handle->set_output_num(5);
1175 remote_handle->set_op_device(device2);
1176 remote_handle->set_device(device2);
1177
1178 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1179 &remote_enqueue_response));
1180
1181 tensorflow::TensorHandle* packed_handle;
1182 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1183 context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
1184
1185 EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
1186 EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
1187 EXPECT_EQ(packed_handle->device()->name(), composite_device);
1188
1189 TensorHandle* handle0 = nullptr;
1190 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
1191 EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
1192 EXPECT_EQ(handle0->op_device()->name(), device0);
1193 const Tensor* t0 = nullptr;
1194 TF_ASSERT_OK(handle0->Tensor(&t0));
1195 auto actual = t0->flat<float>();
1196 EXPECT_EQ(4, actual.size());
1197 EXPECT_EQ(1.0, actual(0));
1198 EXPECT_EQ(2.0, actual(1));
1199 EXPECT_EQ(3.0, actual(2));
1200 EXPECT_EQ(4.0, actual(3));
1201
1202 TensorHandle* handle1 = nullptr;
1203 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
1204 EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
1205 EXPECT_EQ(handle1->op_device()->name(), device1);
1206 const Tensor* t1 = nullptr;
1207 TF_ASSERT_OK(handle0->Tensor(&t1));
1208 EXPECT_EQ(t1, t0);
1209
1210 TensorHandle* handle2 = nullptr;
1211 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
1212 EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
1213 EXPECT_EQ(handle2->op_device()->name(), device2);
1214 int64_t op_id;
1215 int32_t output_num;
1216 TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
1217 /*wait_until_ready=*/true, &op_id,
1218 &output_num));
1219 EXPECT_EQ(op_id, 2);
1220 EXPECT_EQ(output_num, 5);
1221
1222 CloseContextRequest close_context_request;
1223 close_context_request.set_context_id(context_id);
1224 close_context_request.set_context_view_id(0);
1225 CloseContextResponse close_context_response;
1226 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1227 &close_context_response));
1228 }
1229
1230 // Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest,RequestsToMasterTest)1231 TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
1232 tensorflow::Rendezvous* rendezvous =
1233 new tensorflow::IntraProcessRendezvous(device_mgr_.get());
1234 // Create a master eager context.
1235 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
1236 SessionOptions(),
1237 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
1238 /*async=*/false, device_mgr_.get(), false, rendezvous);
1239 const uint64 context_id = random::New64();
1240
1241 // Set RemoteMgr to ctx.
1242 auto remote_mgr =
1243 std::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
1244 TF_ASSERT_OK(ctx->InitializeRemoteWorker(
1245 /*remote_eager_workers=*/nullptr, /*remote_device_mgr=*/nullptr,
1246 /*remote_contexts=*/{}, context_id, /*context_view_id=*/0,
1247 /*rendezvous_creator=*/nullptr,
1248 /*cluster_flr=*/nullptr, std::move(remote_mgr),
1249 /*resource_deallocator=*/nullptr));
1250
1251 TestEagerServiceImpl eager_service_impl(&worker_env_);
1252
1253 EnqueueRequest remote_enqueue_request;
1254 remote_enqueue_request.set_context_id(context_id);
1255 EnqueueResponse remote_enqueue_response;
1256
1257 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1258 send_tensor->set_op_id(1);
1259 SetTensorProto(send_tensor->add_tensors());
1260
1261 // Unable to handle the request since there is no eager context.
1262 Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1263 &remote_enqueue_response);
1264 EXPECT_EQ(error::ABORTED, status.code());
1265 EXPECT_TRUE(absl::StrContains(
1266 status.error_message(),
1267 "Unable to find a context_id matching the specified one"));
1268
1269 // The request can be handled after adding the master eager context to
1270 // service.
1271 TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
1272 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1273 &remote_enqueue_response));
1274 ctx->Unref();
1275 }
1276
TEST_F(EagerServiceImplTest,KeepAliveTest)1277 TEST_F(EagerServiceImplTest, KeepAliveTest) {
1278 TestEagerServiceImpl eager_service_impl(&worker_env_);
1279
1280 uint64 context_id = random::New64();
1281 CreateContextRequest request;
1282 request.mutable_server_def()->set_job_name("localhost");
1283 request.mutable_server_def()->set_task_index(0);
1284 request.set_context_id(context_id);
1285 request.set_keep_alive_secs(3);
1286 CreateContextResponse response;
1287
1288 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1289
1290 worker_env_.env->SleepForMicroseconds(5 *
1291 tensorflow::EnvTime::kSecondsToMicros);
1292
1293 KeepAliveRequest keep_alive_request;
1294 KeepAliveResponse keep_alive_response;
1295
1296 keep_alive_request.set_context_id(context_id);
1297
1298 Status status =
1299 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
1300
1301 EXPECT_EQ(status.code(), error::ABORTED);
1302 EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
1303 status.error_message());
1304
1305 uint64 new_context_id = random::New64();
1306 // Create a new context.
1307 request.set_context_id(new_context_id);
1308 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1309
1310 // The context should not be GC'd.
1311 worker_env_.env->SleepForMicroseconds(1 *
1312 tensorflow::EnvTime::kSecondsToMicros);
1313
1314 keep_alive_request.set_context_id(new_context_id);
1315
1316 TF_ASSERT_OK(
1317 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
1318 }
1319
1320 } // namespace
1321 } // namespace eager
1322 } // namespace tensorflow
1323