1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api.h"
18 #include "tensorflow/c/eager/c_api_experimental.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_context_internal.h"
22 #include "tensorflow/c/eager/tfe_op_internal.h"
23 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
26 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/platform/blocking_counter.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/protobuf/cluster.pb.h"
35 #include "tensorflow/core/protobuf/coordination_config.pb.h"
36 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
37 
38 namespace {
39 
SendFunction(const std::string & send_device,const std::string & recv_device,const tensorflow::int64 send_device_incarnation)40 std::string SendFunction(const std::string& send_device,
41                          const std::string& recv_device,
42                          const tensorflow::int64 send_device_incarnation) {
43   tensorflow::FunctionDef def;
44   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
45       absl::StrCat("    signature {"
46                    "      name: 'SendFunction'"
47                    "      input_arg {"
48                    "        name: 'in'"
49                    "        type: DT_FLOAT"
50                    "      }"
51                    "      control_output: 'send_tensor'"
52                    "    }"
53                    "    node_def {"
54                    "      name: 'send'"
55                    "      op: '_Send'"
56                    "      input: 'in'"
57                    "      device: '",
58                    send_device, "'",
59                    "      attr {"
60                    "        key: 'T'"
61                    "        value {"
62                    "          type: DT_FLOAT"
63                    "        }"
64                    "      }"
65                    "      attr {"
66                    "        key: 'tensor_name'"
67                    "        value {"
68                    "          s: 'dummy'"
69                    "        }"
70                    "      }"
71                    "      attr {"
72                    "        key: 'send_device'"
73                    "        value {"
74                    "          s: '",
75                    send_device, "'",
76                    "        }"
77                    "      }"
78                    "      attr {"
79                    "        key: 'recv_device'"
80                    "        value {"
81                    "          s: '",
82                    recv_device, "'",
83                    "        }"
84                    "      }"
85                    "      attr {"
86                    "        key: 'send_device_incarnation'"
87                    "        value {"
88                    "          i: ",
89                    absl::StrCat(send_device_incarnation),
90                    "        }"
91                    "      }"
92                    "    }"
93                    "    control_ret {"
94                    "      key: 'send_tensor'"
95                    "      value: 'send'"
96                    "    }"),
97       &def));
98   return def.SerializeAsString();
99 }
100 
RecvFunction(const std::string & send_device,const std::string & recv_device,const tensorflow::int64 send_device_incarnation)101 std::string RecvFunction(const std::string& send_device,
102                          const std::string& recv_device,
103                          const tensorflow::int64 send_device_incarnation) {
104   tensorflow::FunctionDef def;
105   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
106       absl::StrCat("    signature {"
107                    "      name: 'RecvFunction'"
108                    "      output_arg {"
109                    "        name: 'out'"
110                    "        type: DT_FLOAT"
111                    "      }"
112                    "    }"
113                    "    node_def {"
114                    "      name: 'recv'"
115                    "      op: '_Recv'"
116                    "      device: '",
117                    recv_device, "'",
118                    "      attr {"
119                    "        key: 'tensor_type'"
120                    "        value {"
121                    "          type: DT_FLOAT"
122                    "        }"
123                    "      }"
124                    "      attr {"
125                    "        key: 'tensor_name'"
126                    "        value {"
127                    "          s: 'dummy'"
128                    "        }"
129                    "      }"
130                    "      attr {"
131                    "        key: 'send_device'"
132                    "        value {"
133                    "          s: '",
134                    send_device, "'",
135                    "        }"
136                    "      }"
137                    "      attr {"
138                    "        key: 'recv_device'"
139                    "        value {"
140                    "          s: '",
141                    recv_device, "'",
142                    "        }"
143                    "      }"
144                    "      attr {"
145                    "        key: 'send_device_incarnation'"
146                    "        value {"
147                    "          i: ",
148                    absl::StrCat(send_device_incarnation),
149                    "        }"
150                    "      }"
151                    "    }"
152                    "    ret {"
153                    "      key: 'out'"
154                    "      value: 'recv:tensor'"
155                    "    }"),
156       &def));
157   return def.SerializeAsString();
158 }
159 
DummyTensorHandleWithValue(TFE_Context * ctx,float v)160 TFE_TensorHandle* DummyTensorHandleWithValue(TFE_Context* ctx, float v) {
161   // Initialize matrix values.
162   int64_t dims[] = {2, 2};
163   float data[4];
164   for (int i = 0; i < 4; i++) {
165     data[i] = v * (i + 1);
166   }
167 
168   return TestTensorHandleWithDimsFloat(ctx, data, &dims[0],
169                                        sizeof(dims) / sizeof(int64_t));
170 }
171 
172 struct MultiClientSendRecvTestParams {
173   std::string test_name;
174   bool use_tfrt = false;
175   uint num_steps = 1;
176   uint delay_recv_sec = 0;
177   uint delay_send_sec = 0;
178 };
179 
180 using MultiClientSendRecvTest =
181     testing::TestWithParam<MultiClientSendRecvTestParams>;
182 
TEST_P(MultiClientSendRecvTest,TestMultiClientSendRecv)183 TEST_P(MultiClientSendRecvTest, TestMultiClientSendRecv) {
184   const MultiClientSendRecvTestParams& params = GetParam();
185   // Use a mutex to enforce a serialized operation between the two
186   // worker-threads since some of their operations involve updating the global
187   // singleton instances (in TFRT scenario), which otherwise would cause a data
188   // race.
189   tensorflow::mutex mu;
190 
191   const int cluster_size = 2;
192   tensorflow::ServerDef server_def =
193       GetMultiClientServerDef("worker", cluster_size);
194 
195   // Enable coordination service for propagating remote device attributess
196   auto* coord_config = server_def.mutable_default_session_config()
197                            ->mutable_experimental()
198                            ->mutable_coordination_config();
199   coord_config->set_service_type("standalone");
200   coord_config->set_service_leader("/job:worker/replica:0/task:0");
201 
202   // The blocking counter makes sure that worker/0 thread (leader that starts
203   // the coordination service) does not exit early while other workers are still
204   // interacting with the coordination service.
205   tensorflow::BlockingCounter counter(cluster_size);
206 
207   auto worker_thread_fn = [&](int worker_id) {
208     tensorflow::ServerDef server_def_copy = server_def;
209     server_def_copy.set_task_index(worker_id);
210     std::string serialized = server_def_copy.SerializeAsString();
211 
212     TF_Status* status = TF_NewStatus();
213     TFE_ContextOptions* context_opts = TFE_NewContextOptions();
214     TFE_ContextOptionsSetAsync(context_opts,
215                                static_cast<unsigned char>(/*enable=*/true));
216     TFE_ContextOptionsSetDevicePlacementPolicy(context_opts,
217                                                TFE_DEVICE_PLACEMENT_SILENT);
218     // use-tfrt flag.
219     context_opts->use_tfrt = params.use_tfrt;
220     tensorflow::SessionOptions session_opts;
221     session_opts.config = server_def_copy.default_session_config();
222     context_opts->session_options.options = session_opts;
223 
224     TFE_Context* ctx;
225     {
226       tensorflow::mutex_lock l(mu);
227       ctx = TFE_NewContext(context_opts, status);
228     }
229     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
230     TFE_DeleteContextOptions(context_opts);
231 
232     TFE_EnableCollectiveOps(ctx, serialized.data(), serialized.size(), status);
233     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
234 
235     const std::string& send_device =
236         "/job:worker/replica:0/task:0/device:CPU:0";
237     const std::string& recv_device =
238         "/job:worker/replica:0/task:1/device:CPU:0";
239 
240     std::vector<tensorflow::DeviceAttributes> device_attrs;
241     tensorflow::unwrap(ctx)->ListDevices(&device_attrs);
242     tensorflow::uint64 send_device_incarnation = 0;
243     for (const auto& device_attr : device_attrs) {
244       if (device_attr.name() == send_device) {
245         send_device_incarnation = device_attr.incarnation();
246         break;
247       }
248     }
249 
250     if (worker_id == 0) {
251       // Sender worker.
252       tensorflow::Env::Default()->SleepForMicroseconds(params.delay_send_sec *
253                                                        1000);
254 
255       const std::string& fdef =
256           SendFunction(send_device, recv_device, send_device_incarnation);
257       TFE_ContextAddFunctionDef(ctx, fdef.data(), fdef.size(), status);
258       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
259 
260       // Run multiple steps.
261       for (int s = 1; s <= params.num_steps; s++) {
262         TFE_Op* send_func = TFE_NewOp(ctx, "SendFunction", status);
263         EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
264 
265         if (params.use_tfrt) {
266           // TODO (@chienchunh): Add support for step id configuration in TFRT.
267           EXPECT_TRUE(tensorflow::unwrap(send_func)
268                           ->Reset("SendFunction", send_device.c_str())
269                           .ok());
270         } else {
271           tensorflow::EagerOperation* op =
272               tensorflow::OperationFromInterface(tensorflow::unwrap(send_func));
273           EXPECT_TRUE(op->Reset("SendFunction", send_device.c_str(),
274                                 /*remote=*/false, /*executor=*/nullptr,
275                                 tensorflow::EagerFunctionParams{
276                                     /*op_id=*/s, /*is_component_function=*/true,
277                                     /*step_id=*/s})
278                           .ok());
279         }
280 
281         TFE_TensorHandle* in = DummyTensorHandleWithValue(ctx, 1.0f * s);
282         TFE_OpAddInput(send_func, in, status);
283         EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
284         int num_retvals = 0;
285         {
286           tensorflow::mutex_lock l(mu);
287           TFE_Execute(send_func, nullptr, &num_retvals, status);
288         }
289         EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
290         TFE_DeleteOp(send_func);
291         TFE_DeleteTensorHandle(in);
292       }
293     } else {
294       // Receiver worker.
295       tensorflow::Env::Default()->SleepForMicroseconds(params.delay_recv_sec *
296                                                        1000);
297 
298       const std::string& fdef =
299           RecvFunction(send_device, recv_device, send_device_incarnation);
300       TFE_ContextAddFunctionDef(ctx, fdef.data(), fdef.size(), status);
301       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
302 
303       // Run multiple steps.
304       for (int s = 1; s <= params.num_steps; s++) {
305         TFE_Op* recv_func = TFE_NewOp(ctx, "RecvFunction", status);
306         EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
307 
308         if (params.use_tfrt) {
309           // TODO (@chienchunh): Add support for step id configuration in TFRT.
310           EXPECT_TRUE(tensorflow::unwrap(recv_func)
311                           ->Reset("RecvFunction", recv_device.c_str())
312                           .ok());
313         } else {
314           tensorflow::EagerOperation* op =
315               tensorflow::OperationFromInterface(tensorflow::unwrap(recv_func));
316           EXPECT_TRUE(op->Reset("RecvFunction", recv_device.c_str(),
317                                 /*remote=*/false, /*executor=*/nullptr,
318                                 tensorflow::EagerFunctionParams{
319                                     /*op_id=*/s,
320                                     /*is_component_function=*/true,
321                                     /*step_id=*/s})
322                           .ok());
323         }
324 
325         TFE_TensorHandle* retvals[1];
326         int num_retvals = 1;
327         {
328           tensorflow::mutex_lock l(mu);
329           TFE_Execute(recv_func, &retvals[0], &num_retvals, status);
330         }
331         EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
332         TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
333         TFE_DeleteOp(recv_func);
334         TFE_DeleteTensorHandle(retvals[0]);
335 
336         float result[4] = {0};
337         EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
338         memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
339         TF_DeleteTensor(t);
340         for (int i = 0; i < 4; i++) {
341           EXPECT_EQ(result[i], 1.0 * s * (i + 1));
342         }
343       }
344     }
345 
346     // To make sure the sender won't delete the data it sent before the receiver
347     // retrieves it, we need to do the following steps:
348     // 1. Since we created async EagerContext, we need to force each worker to
349     //    wait until all pening operations finish before deleting the context.
350     // 2. In addition, use the blocking counter to notify the 2 workers when
351     //    it is safe to clean up all the data.
352     TFE_ContextAsyncWait(ctx, status);
353     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
354     counter.DecrementCount();
355     counter.Wait();
356 
357     {
358       tensorflow::mutex_lock l(mu);
359       TFE_DeleteContext(ctx);
360     }
361     TF_DeleteStatus(status);
362   };
363 
364   std::thread thread_worker1([&] { worker_thread_fn(0); });
365   std::thread thread_worker2([&] { worker_thread_fn(1); });
366 
367   thread_worker1.join();
368   thread_worker2.join();
369 }
370 
371 INSTANTIATE_TEST_SUITE_P(
372     MultiClientSendRecvTests, MultiClientSendRecvTest,
373     testing::ValuesIn<MultiClientSendRecvTestParams>({
374         {"MultiClientSingleStepFunction", false, 1, 0, 0},
375         {"MultiClientMultiStepFunction", false, 3, 0, 0},
376         {"MultiClientMultiStepFunctionWithRecvDelay", false, 5, 2, 0},
377         {"MultiClientMultiStepFunctionWithSendDelay", false, 5, 0, 2},
378         {"MultiClientSingleStepFunctionTfrt", true, 1, 0, 0},
379         {"MultiClientMultiStepFunctionTfrt", true, 3, 0, 0},
380         {"MultiClientMultiStepFunctionWithRecvDelayTfrt", true, 5, 2, 0},
381         {"MultiClientMultiStepFunctionWithSendDelayTfrt", true, 5, 0, 2},
382     }),
__anone408faa60502(const testing::TestParamInfo<MultiClientSendRecvTest::ParamType>& info) 383     [](const testing::TestParamInfo<MultiClientSendRecvTest::ParamType>& info) {
384       return info.param.test_name;
385     });
386 
387 }  // namespace
388