1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_test_util.h"
17
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/c/eager/c_api_experimental.h"
20 #include "tensorflow/c/tf_datatype.h"
21 #include "tensorflow/c/tf_tensor.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/tstring.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/config.pb.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29 #include "tensorflow/core/util/port.h"
30
31 using tensorflow::string;
32 using tensorflow::tstring;
33
TestScalarTensorHandle(TFE_Context * ctx,float value)34 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
35 float data[] = {value};
36 TF_Status* status = TF_NewStatus();
37 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
38 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
39 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
40 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
41 TF_DeleteTensor(t);
42 TF_DeleteStatus(status);
43 return th;
44 }
45
TestScalarTensorHandle(TFE_Context * ctx,const tensorflow::tstring & value)46 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
47 const tensorflow::tstring& value) {
48 TF_Status* status = TF_NewStatus();
49 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
50 tstring* data = static_cast<tstring*>(TF_TensorData(t));
51 *data = value;
52 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
53 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
54 TF_DeleteTensor(t);
55 TF_DeleteStatus(status);
56 return th;
57 }
58
TestScalarTensorHandle(TFE_Context * ctx,int value)59 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
60 int data[] = {value};
61 TF_Status* status = TF_NewStatus();
62 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
63 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
64 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
65 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
66 TF_DeleteTensor(t);
67 TF_DeleteStatus(status);
68 return th;
69 }
70
TestScalarTensorHandle(TFE_Context * ctx,bool value)71 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
72 bool data[] = {value};
73 TF_Status* status = TF_NewStatus();
74 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
75 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
76 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
77 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
78 TF_DeleteTensor(t);
79 TF_DeleteStatus(status);
80 return th;
81 }
82
DoubleTestMatrixTensorHandle(TFE_Context * ctx)83 TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
84 int64_t dims[] = {2, 2};
85 double data[] = {1.0, 2.0, 3.0, 4.0};
86 TF_Status* status = TF_NewStatus();
87 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
88 sizeof(dims) / sizeof(int64_t), status);
89 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
90 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
91 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
92 TF_DeleteTensor(t);
93 TF_DeleteStatus(status);
94 return th;
95 }
96
TestMatrixTensorHandle(TFE_Context * ctx)97 TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
98 int64_t dims[] = {2, 2};
99 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
100 TF_Status* status = TF_NewStatus();
101 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
102 sizeof(dims) / sizeof(int64_t), status);
103 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
104 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
105 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
106 TF_DeleteTensor(t);
107 TF_DeleteStatus(status);
108 return th;
109 }
110
TestMatrixTensorHandleWithInput(TFE_Context * ctx,float data[],int64_t dims[],int num_dims)111 TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
112 float data[], int64_t dims[],
113 int num_dims) {
114 TF_Status* status = TF_NewStatus();
115 TF_Tensor* t =
116 TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
117 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
118 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
119 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
120 TF_DeleteTensor(t);
121 TF_DeleteStatus(status);
122 return th;
123 }
124
TestTensorHandleWithDimsFloat(TFE_Context * ctx,float data[],int64_t dims[],int num_dims)125 TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
126 int64_t dims[], int num_dims) {
127 TF_Status* status = TF_NewStatus();
128 TF_Tensor* t =
129 TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
130 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
131 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
132 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
133 TF_DeleteTensor(t);
134 TF_DeleteStatus(status);
135 return th;
136 }
137
TestTensorHandleWithDimsInt(TFE_Context * ctx,int data[],int64_t dims[],int num_dims)138 TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
139 int64_t dims[], int num_dims) {
140 TF_Status* status = TF_NewStatus();
141 TF_Tensor* t =
142 TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
143 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
144 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
145 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
146 TF_DeleteTensor(t);
147 TF_DeleteStatus(status);
148 return th;
149 }
150
TestMatrixTensorHandle100x100(TFE_Context * ctx)151 TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
152 constexpr int64_t dims[] = {100, 100};
153 constexpr int num_elements = dims[0] * dims[1];
154 float data[num_elements];
155 for (int i = 0; i < num_elements; ++i) {
156 data[i] = 1.0f;
157 }
158 TF_Status* status = TF_NewStatus();
159 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
160 sizeof(dims) / sizeof(int64_t), status);
161 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
162 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
163 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
164 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
165 TF_DeleteTensor(t);
166 TF_DeleteStatus(status);
167 return th;
168 }
169
DoubleTestMatrixTensorHandle3X2(TFE_Context * ctx)170 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
171 int64_t dims[] = {3, 2};
172 double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
173 TF_Status* status = TF_NewStatus();
174 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
175 sizeof(dims) / sizeof(int64_t), status);
176 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
177 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
178 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
179 TF_DeleteTensor(t);
180 TF_DeleteStatus(status);
181 return th;
182 }
183
TestMatrixTensorHandle3X2(TFE_Context * ctx)184 TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
185 int64_t dims[] = {3, 2};
186 float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
187 TF_Status* status = TF_NewStatus();
188 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
189 sizeof(dims) / sizeof(int64_t), status);
190 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
191 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
192 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
193 TF_DeleteTensor(t);
194 TF_DeleteStatus(status);
195 return th;
196 }
197
TestVariable(TFE_Context * ctx,float value,const tensorflow::string & device_name)198 TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
199 const tensorflow::string& device_name) {
200 TF_Status* status = TF_NewStatus();
201 // Create the variable handle.
202 TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
203 if (TF_GetCode(status) != TF_OK) return nullptr;
204 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
205 TFE_OpSetAttrShape(op, "shape", {}, 0, status);
206 TFE_OpSetAttrString(op, "container", "localhost", 0);
207 TFE_OpSetAttrString(op, "shared_name", "", 0);
208 if (!device_name.empty()) {
209 TFE_OpSetDevice(op, device_name.c_str(), status);
210 }
211 if (TF_GetCode(status) != TF_OK) return nullptr;
212 TFE_TensorHandle* var_handle = nullptr;
213 int num_retvals = 1;
214 TFE_Execute(op, &var_handle, &num_retvals, status);
215 if (TF_GetCode(status) != TF_OK) return nullptr;
216 TFE_DeleteOp(op);
217 if (TF_GetCode(status) != TF_OK) return nullptr;
218 CHECK_EQ(1, num_retvals);
219
220 // Assign 'value' to it.
221 op = TFE_NewOp(ctx, "AssignVariableOp", status);
222 if (TF_GetCode(status) != TF_OK) return nullptr;
223 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
224 TFE_OpAddInput(op, var_handle, status);
225
226 // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
227 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
228 TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
229 memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
230
231 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
232 value_handle(TFE_NewTensorHandle(t.get(), status),
233 TFE_DeleteTensorHandle);
234 if (TF_GetCode(status) != TF_OK) return nullptr;
235
236 TFE_OpAddInput(op, value_handle.get(), status);
237 if (TF_GetCode(status) != TF_OK) return nullptr;
238
239 num_retvals = 0;
240 TFE_Execute(op, nullptr, &num_retvals, status);
241 TFE_DeleteOp(op);
242 if (TF_GetCode(status) != TF_OK) return nullptr;
243 CHECK_EQ(0, num_retvals);
244
245 TF_DeleteStatus(status);
246
247 return var_handle;
248 }
249
AddOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)250 TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
251 TF_Status* status = TF_NewStatus();
252
253 TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
254 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
255 TFE_OpAddInput(op, a, status);
256 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
257 TFE_OpAddInput(op, b, status);
258 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
259 TF_DeleteStatus(status);
260 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
261
262 return op;
263 }
264
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)265 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
266 TF_Status* status = TF_NewStatus();
267
268 TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
269 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
270 TFE_OpAddInput(op, a, status);
271 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
272 TFE_OpAddInput(op, b, status);
273 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
274 TF_DeleteStatus(status);
275 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
276
277 return op;
278 }
279
IdentityOp(TFE_Context * ctx,TFE_TensorHandle * a)280 TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
281 TF_Status* status = TF_NewStatus();
282
283 TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
284 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
285 TFE_OpAddInput(op, a, status);
286 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
287 TF_DeleteStatus(status);
288 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
289
290 return op;
291 }
292
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)293 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
294 TF_Status* status = TF_NewStatus();
295
296 TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
297 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
298 TFE_OpAddInput(op, a, status);
299 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
300 TF_DeleteStatus(status);
301 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
302
303 return op;
304 }
305
TestAxisTensorHandle(TFE_Context * ctx)306 TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
307 int64_t dims[] = {1};
308 int data[] = {1};
309 TF_Status* status = TF_NewStatus();
310 TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
311 sizeof(dims) / sizeof(int64_t), status);
312 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
313 TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
314 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
315 TF_DeleteTensor(t);
316 TF_DeleteStatus(status);
317 return th;
318 }
319
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)320 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
321 TFE_TensorHandle* axis) {
322 TF_Status* status = TF_NewStatus();
323
324 TFE_Op* op = TFE_NewOp(ctx, "Min", status);
325 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
326 TFE_OpAddInput(op, input, status);
327 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
328 TFE_OpAddInput(op, axis, status);
329 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
330 TFE_OpSetAttrBool(op, "keep_dims", 1);
331 TFE_OpSetAttrType(op, "Tidx", TF_INT32);
332 TF_DeleteStatus(status);
333 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
334
335 return op;
336 }
337
AllReduceOp(TFE_Context * ctx,TFE_TensorHandle * in,int group_size)338 TFE_Op* AllReduceOp(TFE_Context* ctx, TFE_TensorHandle* in, int group_size) {
339 TF_Status* status = TF_NewStatus();
340
341 TFE_Op* op = TFE_NewOp(ctx, "CollectiveReduce", status);
342 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
343 TFE_OpAddInput(op, in, status);
344 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
345 TF_DeleteStatus(status);
346
347 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(in));
348 TFE_OpSetAttrInt(op, "group_size", group_size);
349 TFE_OpSetAttrInt(op, "group_key", 123);
350 TFE_OpSetAttrInt(op, "instance_key", 456);
351 TFE_OpSetAttrString(op, "merge_op", "Add", 3);
352 TFE_OpSetAttrString(op, "final_op", "Id", 2);
353 std::vector<int64_t> subdiv_offsets;
354 TFE_OpSetAttrIntList(op, "subdiv_offsets", subdiv_offsets.data(),
355 subdiv_offsets.size());
356
357 return op;
358 }
359
SendOp(TFE_Context * ctx,TFE_TensorHandle * in,const std::string & op_name,const std::string & send_device,const std::string & recv_device,tensorflow::uint64 send_device_incarnation)360 TFE_Op* SendOp(TFE_Context* ctx, TFE_TensorHandle* in,
361 const std::string& op_name, const std::string& send_device,
362 const std::string& recv_device,
363 tensorflow::uint64 send_device_incarnation) {
364 TF_Status* status = TF_NewStatus();
365 TFE_Op* op = TFE_NewOp(ctx, op_name.c_str(), status);
366 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
367 TFE_OpAddInput(op, in, status);
368 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
369 TF_DeleteStatus(status);
370
371 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(in));
372 TFE_OpSetAttrString(op, "tensor_name", "dummy", 5);
373 TFE_OpSetAttrString(op, "send_device", send_device.c_str(),
374 send_device.size());
375 TFE_OpSetAttrString(op, "recv_device", recv_device.c_str(),
376 recv_device.size());
377 TFE_OpSetAttrInt(op, "send_device_incarnation", send_device_incarnation);
378
379 return op;
380 }
381
RecvOp(TFE_Context * ctx,const std::string & op_name,const std::string & send_device,const std::string & recv_device,tensorflow::uint64 send_device_incarnation)382 TFE_Op* RecvOp(TFE_Context* ctx, const std::string& op_name,
383 const std::string& send_device, const std::string& recv_device,
384 tensorflow::uint64 send_device_incarnation) {
385 TF_Status* status = TF_NewStatus();
386 TFE_Op* op = TFE_NewOp(ctx, op_name.c_str(), status);
387 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
388 TF_DeleteStatus(status);
389
390 TFE_OpSetAttrType(op, "tensor_type", TF_INT32);
391 TFE_OpSetAttrString(op, "tensor_name", "dummy", 5);
392 TFE_OpSetAttrString(op, "send_device", send_device.c_str(),
393 send_device.size());
394 TFE_OpSetAttrString(op, "recv_device", recv_device.c_str(),
395 recv_device.size());
396 TFE_OpSetAttrInt(op, "send_device_incarnation", send_device_incarnation);
397
398 return op;
399 }
400
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)401 bool GetDeviceName(TFE_Context* ctx, string* device_name,
402 const char* device_type) {
403 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
404 TF_NewStatus(), TF_DeleteStatus);
405 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
406 CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
407
408 const int num_devices = TF_DeviceListCount(devices);
409 for (int i = 0; i < num_devices; ++i) {
410 const string dev_type(TF_DeviceListType(devices, i, status.get()));
411 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
412 const string dev_name(TF_DeviceListName(devices, i, status.get()));
413 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
414 if (dev_type == device_type) {
415 *device_name = dev_name;
416 LOG(INFO) << "Found " << device_type << " device " << *device_name;
417 TF_DeleteDeviceList(devices);
418 return true;
419 }
420 }
421 TF_DeleteDeviceList(devices);
422 return false;
423 }
424
GetServerDef(const string & job_name,int num_tasks)425 tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
426 tensorflow::ServerDef server_def;
427 server_def.set_protocol("grpc");
428 server_def.set_job_name(job_name);
429 server_def.set_task_index(0);
430 tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
431 tensorflow::JobDef* job_def = cluster_def->add_job();
432 job_def->set_name(job_name);
433 for (int i = 0; i < num_tasks; i++) {
434 int port = tensorflow::testing::PickUnusedPortOrDie();
435 job_def->mutable_tasks()->insert(
436 {i, tensorflow::strings::StrCat("localhost:", port)});
437 }
438 return server_def;
439 }
440
GetServerDef(int num_tasks)441 tensorflow::ServerDef GetServerDef(int num_tasks) {
442 return GetServerDef("localhost", num_tasks);
443 }
444
GetMultiClientServerDef(const std::string & job_name,int num_tasks,int num_virtual_gpus)445 tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name,
446 int num_tasks,
447 int num_virtual_gpus) {
448 tensorflow::ServerDef server_def;
449 server_def.set_protocol("grpc");
450 server_def.set_job_name(job_name);
451 server_def.set_task_index(0);
452 tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
453 tensorflow::JobDef* job_def = cluster_def->add_job();
454 job_def->set_name(job_name);
455 for (int i = 0; i < num_tasks; i++) {
456 int port = tensorflow::testing::PickUnusedPortOrDie();
457 job_def->mutable_tasks()->insert(
458 {i, tensorflow::strings::StrCat("localhost:", port)});
459 }
460 auto* config = server_def.mutable_default_session_config();
461 config->mutable_experimental()->set_collective_group_leader(
462 tensorflow::strings::StrCat("/job:", job_name, "/replica:0/task:", 0));
463 auto* rewrite_options =
464 config->mutable_graph_options()->mutable_rewrite_options();
465 rewrite_options->set_scoped_allocator_optimization(
466 tensorflow::RewriterConfig::ON);
467 rewrite_options->mutable_scoped_allocator_opts()->add_enable_op(
468 "CollectiveReduce");
469
470 if ((tensorflow::IsGoogleCudaEnabled() || tensorflow::IsBuiltWithROCm()) &&
471 num_virtual_gpus > 0) {
472 tensorflow::GPUOptions* gpu_options =
473 server_def.mutable_default_session_config()->mutable_gpu_options();
474 auto virtual_devices =
475 gpu_options->mutable_experimental()->add_virtual_devices();
476 for (int i = 0; i < num_virtual_gpus; ++i) {
477 virtual_devices->add_memory_limit_mb(200);
478 }
479 }
480 return server_def;
481 }
482