1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api.h"
18 #include "tensorflow/c/eager/c_api_experimental.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_context_internal.h"
22 #include "tensorflow/c/eager/tfe_op_internal.h"
23 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
26 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/platform/blocking_counter.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/protobuf/cluster.pb.h"
35 #include "tensorflow/core/protobuf/coordination_config.pb.h"
36 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
37
38 namespace {
39
SendFunction(const std::string & send_device,const std::string & recv_device,const tensorflow::int64 send_device_incarnation)40 std::string SendFunction(const std::string& send_device,
41 const std::string& recv_device,
42 const tensorflow::int64 send_device_incarnation) {
43 tensorflow::FunctionDef def;
44 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
45 absl::StrCat(" signature {"
46 " name: 'SendFunction'"
47 " input_arg {"
48 " name: 'in'"
49 " type: DT_FLOAT"
50 " }"
51 " control_output: 'send_tensor'"
52 " }"
53 " node_def {"
54 " name: 'send'"
55 " op: '_Send'"
56 " input: 'in'"
57 " device: '",
58 send_device, "'",
59 " attr {"
60 " key: 'T'"
61 " value {"
62 " type: DT_FLOAT"
63 " }"
64 " }"
65 " attr {"
66 " key: 'tensor_name'"
67 " value {"
68 " s: 'dummy'"
69 " }"
70 " }"
71 " attr {"
72 " key: 'send_device'"
73 " value {"
74 " s: '",
75 send_device, "'",
76 " }"
77 " }"
78 " attr {"
79 " key: 'recv_device'"
80 " value {"
81 " s: '",
82 recv_device, "'",
83 " }"
84 " }"
85 " attr {"
86 " key: 'send_device_incarnation'"
87 " value {"
88 " i: ",
89 absl::StrCat(send_device_incarnation),
90 " }"
91 " }"
92 " }"
93 " control_ret {"
94 " key: 'send_tensor'"
95 " value: 'send'"
96 " }"),
97 &def));
98 return def.SerializeAsString();
99 }
100
RecvFunction(const std::string & send_device,const std::string & recv_device,const tensorflow::int64 send_device_incarnation)101 std::string RecvFunction(const std::string& send_device,
102 const std::string& recv_device,
103 const tensorflow::int64 send_device_incarnation) {
104 tensorflow::FunctionDef def;
105 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
106 absl::StrCat(" signature {"
107 " name: 'RecvFunction'"
108 " output_arg {"
109 " name: 'out'"
110 " type: DT_FLOAT"
111 " }"
112 " }"
113 " node_def {"
114 " name: 'recv'"
115 " op: '_Recv'"
116 " device: '",
117 recv_device, "'",
118 " attr {"
119 " key: 'tensor_type'"
120 " value {"
121 " type: DT_FLOAT"
122 " }"
123 " }"
124 " attr {"
125 " key: 'tensor_name'"
126 " value {"
127 " s: 'dummy'"
128 " }"
129 " }"
130 " attr {"
131 " key: 'send_device'"
132 " value {"
133 " s: '",
134 send_device, "'",
135 " }"
136 " }"
137 " attr {"
138 " key: 'recv_device'"
139 " value {"
140 " s: '",
141 recv_device, "'",
142 " }"
143 " }"
144 " attr {"
145 " key: 'send_device_incarnation'"
146 " value {"
147 " i: ",
148 absl::StrCat(send_device_incarnation),
149 " }"
150 " }"
151 " }"
152 " ret {"
153 " key: 'out'"
154 " value: 'recv:tensor'"
155 " }"),
156 &def));
157 return def.SerializeAsString();
158 }
159
DummyTensorHandleWithValue(TFE_Context * ctx,float v)160 TFE_TensorHandle* DummyTensorHandleWithValue(TFE_Context* ctx, float v) {
161 // Initialize matrix values.
162 int64_t dims[] = {2, 2};
163 float data[4];
164 for (int i = 0; i < 4; i++) {
165 data[i] = v * (i + 1);
166 }
167
168 return TestTensorHandleWithDimsFloat(ctx, data, &dims[0],
169 sizeof(dims) / sizeof(int64_t));
170 }
171
172 struct MultiClientSendRecvTestParams {
173 std::string test_name;
174 bool use_tfrt = false;
175 uint num_steps = 1;
176 uint delay_recv_sec = 0;
177 uint delay_send_sec = 0;
178 };
179
180 using MultiClientSendRecvTest =
181 testing::TestWithParam<MultiClientSendRecvTestParams>;
182
TEST_P(MultiClientSendRecvTest,TestMultiClientSendRecv)183 TEST_P(MultiClientSendRecvTest, TestMultiClientSendRecv) {
184 const MultiClientSendRecvTestParams& params = GetParam();
185 // Use a mutex to enforce a serialized operation between the two
186 // worker-threads since some of their operations involve updating the global
187 // singleton instances (in TFRT scenario), which otherwise would cause a data
188 // race.
189 tensorflow::mutex mu;
190
191 const int cluster_size = 2;
192 tensorflow::ServerDef server_def =
193 GetMultiClientServerDef("worker", cluster_size);
194
195 // Enable coordination service for propagating remote device attributess
196 auto* coord_config = server_def.mutable_default_session_config()
197 ->mutable_experimental()
198 ->mutable_coordination_config();
199 coord_config->set_service_type("standalone");
200 coord_config->set_service_leader("/job:worker/replica:0/task:0");
201
202 // The blocking counter makes sure that worker/0 thread (leader that starts
203 // the coordination service) does not exit early while other workers are still
204 // interacting with the coordination service.
205 tensorflow::BlockingCounter counter(cluster_size);
206
207 auto worker_thread_fn = [&](int worker_id) {
208 tensorflow::ServerDef server_def_copy = server_def;
209 server_def_copy.set_task_index(worker_id);
210 std::string serialized = server_def_copy.SerializeAsString();
211
212 TF_Status* status = TF_NewStatus();
213 TFE_ContextOptions* context_opts = TFE_NewContextOptions();
214 TFE_ContextOptionsSetAsync(context_opts,
215 static_cast<unsigned char>(/*enable=*/true));
216 TFE_ContextOptionsSetDevicePlacementPolicy(context_opts,
217 TFE_DEVICE_PLACEMENT_SILENT);
218 // use-tfrt flag.
219 context_opts->use_tfrt = params.use_tfrt;
220 tensorflow::SessionOptions session_opts;
221 session_opts.config = server_def_copy.default_session_config();
222 context_opts->session_options.options = session_opts;
223
224 TFE_Context* ctx;
225 {
226 tensorflow::mutex_lock l(mu);
227 ctx = TFE_NewContext(context_opts, status);
228 }
229 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
230 TFE_DeleteContextOptions(context_opts);
231
232 TFE_EnableCollectiveOps(ctx, serialized.data(), serialized.size(), status);
233 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
234
235 const std::string& send_device =
236 "/job:worker/replica:0/task:0/device:CPU:0";
237 const std::string& recv_device =
238 "/job:worker/replica:0/task:1/device:CPU:0";
239
240 std::vector<tensorflow::DeviceAttributes> device_attrs;
241 tensorflow::unwrap(ctx)->ListDevices(&device_attrs);
242 tensorflow::uint64 send_device_incarnation = 0;
243 for (const auto& device_attr : device_attrs) {
244 if (device_attr.name() == send_device) {
245 send_device_incarnation = device_attr.incarnation();
246 break;
247 }
248 }
249
250 if (worker_id == 0) {
251 // Sender worker.
252 tensorflow::Env::Default()->SleepForMicroseconds(params.delay_send_sec *
253 1000);
254
255 const std::string& fdef =
256 SendFunction(send_device, recv_device, send_device_incarnation);
257 TFE_ContextAddFunctionDef(ctx, fdef.data(), fdef.size(), status);
258 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
259
260 // Run multiple steps.
261 for (int s = 1; s <= params.num_steps; s++) {
262 TFE_Op* send_func = TFE_NewOp(ctx, "SendFunction", status);
263 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
264
265 if (params.use_tfrt) {
266 // TODO (@chienchunh): Add support for step id configuration in TFRT.
267 EXPECT_TRUE(tensorflow::unwrap(send_func)
268 ->Reset("SendFunction", send_device.c_str())
269 .ok());
270 } else {
271 tensorflow::EagerOperation* op =
272 tensorflow::OperationFromInterface(tensorflow::unwrap(send_func));
273 EXPECT_TRUE(op->Reset("SendFunction", send_device.c_str(),
274 /*remote=*/false, /*executor=*/nullptr,
275 tensorflow::EagerFunctionParams{
276 /*op_id=*/s, /*is_component_function=*/true,
277 /*step_id=*/s})
278 .ok());
279 }
280
281 TFE_TensorHandle* in = DummyTensorHandleWithValue(ctx, 1.0f * s);
282 TFE_OpAddInput(send_func, in, status);
283 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
284 int num_retvals = 0;
285 {
286 tensorflow::mutex_lock l(mu);
287 TFE_Execute(send_func, nullptr, &num_retvals, status);
288 }
289 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
290 TFE_DeleteOp(send_func);
291 TFE_DeleteTensorHandle(in);
292 }
293 } else {
294 // Receiver worker.
295 tensorflow::Env::Default()->SleepForMicroseconds(params.delay_recv_sec *
296 1000);
297
298 const std::string& fdef =
299 RecvFunction(send_device, recv_device, send_device_incarnation);
300 TFE_ContextAddFunctionDef(ctx, fdef.data(), fdef.size(), status);
301 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
302
303 // Run multiple steps.
304 for (int s = 1; s <= params.num_steps; s++) {
305 TFE_Op* recv_func = TFE_NewOp(ctx, "RecvFunction", status);
306 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
307
308 if (params.use_tfrt) {
309 // TODO (@chienchunh): Add support for step id configuration in TFRT.
310 EXPECT_TRUE(tensorflow::unwrap(recv_func)
311 ->Reset("RecvFunction", recv_device.c_str())
312 .ok());
313 } else {
314 tensorflow::EagerOperation* op =
315 tensorflow::OperationFromInterface(tensorflow::unwrap(recv_func));
316 EXPECT_TRUE(op->Reset("RecvFunction", recv_device.c_str(),
317 /*remote=*/false, /*executor=*/nullptr,
318 tensorflow::EagerFunctionParams{
319 /*op_id=*/s,
320 /*is_component_function=*/true,
321 /*step_id=*/s})
322 .ok());
323 }
324
325 TFE_TensorHandle* retvals[1];
326 int num_retvals = 1;
327 {
328 tensorflow::mutex_lock l(mu);
329 TFE_Execute(recv_func, &retvals[0], &num_retvals, status);
330 }
331 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
332 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
333 TFE_DeleteOp(recv_func);
334 TFE_DeleteTensorHandle(retvals[0]);
335
336 float result[4] = {0};
337 EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
338 memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
339 TF_DeleteTensor(t);
340 for (int i = 0; i < 4; i++) {
341 EXPECT_EQ(result[i], 1.0 * s * (i + 1));
342 }
343 }
344 }
345
346 // To make sure the sender won't delete the data it sent before the receiver
347 // retrieves it, we need to do the following steps:
348 // 1. Since we created async EagerContext, we need to force each worker to
349 // wait until all pening operations finish before deleting the context.
350 // 2. In addition, use the blocking counter to notify the 2 workers when
351 // it is safe to clean up all the data.
352 TFE_ContextAsyncWait(ctx, status);
353 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
354 counter.DecrementCount();
355 counter.Wait();
356
357 {
358 tensorflow::mutex_lock l(mu);
359 TFE_DeleteContext(ctx);
360 }
361 TF_DeleteStatus(status);
362 };
363
364 std::thread thread_worker1([&] { worker_thread_fn(0); });
365 std::thread thread_worker2([&] { worker_thread_fn(1); });
366
367 thread_worker1.join();
368 thread_worker2.join();
369 }
370
371 INSTANTIATE_TEST_SUITE_P(
372 MultiClientSendRecvTests, MultiClientSendRecvTest,
373 testing::ValuesIn<MultiClientSendRecvTestParams>({
374 {"MultiClientSingleStepFunction", false, 1, 0, 0},
375 {"MultiClientMultiStepFunction", false, 3, 0, 0},
376 {"MultiClientMultiStepFunctionWithRecvDelay", false, 5, 2, 0},
377 {"MultiClientMultiStepFunctionWithSendDelay", false, 5, 0, 2},
378 {"MultiClientSingleStepFunctionTfrt", true, 1, 0, 0},
379 {"MultiClientMultiStepFunctionTfrt", true, 3, 0, 0},
380 {"MultiClientMultiStepFunctionWithRecvDelayTfrt", true, 5, 2, 0},
381 {"MultiClientMultiStepFunctionWithSendDelayTfrt", true, 5, 0, 2},
382 }),
__anone408faa60502(const testing::TestParamInfo<MultiClientSendRecvTest::ParamType>& info) 383 [](const testing::TestParamInfo<MultiClientSendRecvTest::ParamType>& info) {
384 return info.param.test_name;
385 });
386
387 } // namespace
388