xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/c_api_test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_test_util.h"
17 
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/c/eager/c_api_experimental.h"
20 #include "tensorflow/c/tf_datatype.h"
21 #include "tensorflow/c/tf_tensor.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/tstring.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/config.pb.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29 #include "tensorflow/core/util/port.h"
30 
31 using tensorflow::string;
32 using tensorflow::tstring;
33 
TestScalarTensorHandle(TFE_Context * ctx,float value)34 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
35   float data[] = {value};
36   TF_Status* status = TF_NewStatus();
37   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
38   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
39   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
40   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
41   TF_DeleteTensor(t);
42   TF_DeleteStatus(status);
43   return th;
44 }
45 
TestScalarTensorHandle(TFE_Context * ctx,const tensorflow::tstring & value)46 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
47                                          const tensorflow::tstring& value) {
48   TF_Status* status = TF_NewStatus();
49   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
50   tstring* data = static_cast<tstring*>(TF_TensorData(t));
51   *data = value;
52   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
53   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
54   TF_DeleteTensor(t);
55   TF_DeleteStatus(status);
56   return th;
57 }
58 
TestScalarTensorHandle(TFE_Context * ctx,int value)59 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
60   int data[] = {value};
61   TF_Status* status = TF_NewStatus();
62   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
63   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
64   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
65   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
66   TF_DeleteTensor(t);
67   TF_DeleteStatus(status);
68   return th;
69 }
70 
TestScalarTensorHandle(TFE_Context * ctx,bool value)71 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
72   bool data[] = {value};
73   TF_Status* status = TF_NewStatus();
74   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
75   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
76   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
77   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
78   TF_DeleteTensor(t);
79   TF_DeleteStatus(status);
80   return th;
81 }
82 
DoubleTestMatrixTensorHandle(TFE_Context * ctx)83 TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
84   int64_t dims[] = {2, 2};
85   double data[] = {1.0, 2.0, 3.0, 4.0};
86   TF_Status* status = TF_NewStatus();
87   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
88                                         sizeof(dims) / sizeof(int64_t), status);
89   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
90   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
91   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
92   TF_DeleteTensor(t);
93   TF_DeleteStatus(status);
94   return th;
95 }
96 
TestMatrixTensorHandle(TFE_Context * ctx)97 TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
98   int64_t dims[] = {2, 2};
99   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
100   TF_Status* status = TF_NewStatus();
101   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
102                                         sizeof(dims) / sizeof(int64_t), status);
103   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
104   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
105   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
106   TF_DeleteTensor(t);
107   TF_DeleteStatus(status);
108   return th;
109 }
110 
TestMatrixTensorHandleWithInput(TFE_Context * ctx,float data[],int64_t dims[],int num_dims)111 TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
112                                                   float data[], int64_t dims[],
113                                                   int num_dims) {
114   TF_Status* status = TF_NewStatus();
115   TF_Tensor* t =
116       TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
117   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
118   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
119   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
120   TF_DeleteTensor(t);
121   TF_DeleteStatus(status);
122   return th;
123 }
124 
TestTensorHandleWithDimsFloat(TFE_Context * ctx,float data[],int64_t dims[],int num_dims)125 TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
126                                                 int64_t dims[], int num_dims) {
127   TF_Status* status = TF_NewStatus();
128   TF_Tensor* t =
129       TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
130   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
131   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
132   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
133   TF_DeleteTensor(t);
134   TF_DeleteStatus(status);
135   return th;
136 }
137 
TestTensorHandleWithDimsInt(TFE_Context * ctx,int data[],int64_t dims[],int num_dims)138 TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
139                                               int64_t dims[], int num_dims) {
140   TF_Status* status = TF_NewStatus();
141   TF_Tensor* t =
142       TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
143   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
144   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
145   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
146   TF_DeleteTensor(t);
147   TF_DeleteStatus(status);
148   return th;
149 }
150 
TestMatrixTensorHandle100x100(TFE_Context * ctx)151 TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
152   constexpr int64_t dims[] = {100, 100};
153   constexpr int num_elements = dims[0] * dims[1];
154   float data[num_elements];
155   for (int i = 0; i < num_elements; ++i) {
156     data[i] = 1.0f;
157   }
158   TF_Status* status = TF_NewStatus();
159   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
160                                         sizeof(dims) / sizeof(int64_t), status);
161   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
162   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
163   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
164   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
165   TF_DeleteTensor(t);
166   TF_DeleteStatus(status);
167   return th;
168 }
169 
DoubleTestMatrixTensorHandle3X2(TFE_Context * ctx)170 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
171   int64_t dims[] = {3, 2};
172   double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
173   TF_Status* status = TF_NewStatus();
174   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
175                                         sizeof(dims) / sizeof(int64_t), status);
176   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
177   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
178   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
179   TF_DeleteTensor(t);
180   TF_DeleteStatus(status);
181   return th;
182 }
183 
TestMatrixTensorHandle3X2(TFE_Context * ctx)184 TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
185   int64_t dims[] = {3, 2};
186   float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
187   TF_Status* status = TF_NewStatus();
188   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
189                                         sizeof(dims) / sizeof(int64_t), status);
190   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
191   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
192   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
193   TF_DeleteTensor(t);
194   TF_DeleteStatus(status);
195   return th;
196 }
197 
TestVariable(TFE_Context * ctx,float value,const tensorflow::string & device_name)198 TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
199                                const tensorflow::string& device_name) {
200   TF_Status* status = TF_NewStatus();
201   // Create the variable handle.
202   TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
203   if (TF_GetCode(status) != TF_OK) return nullptr;
204   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
205   TFE_OpSetAttrShape(op, "shape", {}, 0, status);
206   TFE_OpSetAttrString(op, "container", "localhost", 0);
207   TFE_OpSetAttrString(op, "shared_name", "", 0);
208   if (!device_name.empty()) {
209     TFE_OpSetDevice(op, device_name.c_str(), status);
210   }
211   if (TF_GetCode(status) != TF_OK) return nullptr;
212   TFE_TensorHandle* var_handle = nullptr;
213   int num_retvals = 1;
214   TFE_Execute(op, &var_handle, &num_retvals, status);
215   if (TF_GetCode(status) != TF_OK) return nullptr;
216   TFE_DeleteOp(op);
217   if (TF_GetCode(status) != TF_OK) return nullptr;
218   CHECK_EQ(1, num_retvals);
219 
220   // Assign 'value' to it.
221   op = TFE_NewOp(ctx, "AssignVariableOp", status);
222   if (TF_GetCode(status) != TF_OK) return nullptr;
223   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
224   TFE_OpAddInput(op, var_handle, status);
225 
226   // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
227   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
228       TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
229   memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
230 
231   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
232       value_handle(TFE_NewTensorHandle(t.get(), status),
233                    TFE_DeleteTensorHandle);
234   if (TF_GetCode(status) != TF_OK) return nullptr;
235 
236   TFE_OpAddInput(op, value_handle.get(), status);
237   if (TF_GetCode(status) != TF_OK) return nullptr;
238 
239   num_retvals = 0;
240   TFE_Execute(op, nullptr, &num_retvals, status);
241   TFE_DeleteOp(op);
242   if (TF_GetCode(status) != TF_OK) return nullptr;
243   CHECK_EQ(0, num_retvals);
244 
245   TF_DeleteStatus(status);
246 
247   return var_handle;
248 }
249 
AddOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)250 TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
251   TF_Status* status = TF_NewStatus();
252 
253   TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
254   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
255   TFE_OpAddInput(op, a, status);
256   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
257   TFE_OpAddInput(op, b, status);
258   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
259   TF_DeleteStatus(status);
260   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
261 
262   return op;
263 }
264 
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)265 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
266   TF_Status* status = TF_NewStatus();
267 
268   TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
269   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
270   TFE_OpAddInput(op, a, status);
271   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
272   TFE_OpAddInput(op, b, status);
273   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
274   TF_DeleteStatus(status);
275   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
276 
277   return op;
278 }
279 
IdentityOp(TFE_Context * ctx,TFE_TensorHandle * a)280 TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
281   TF_Status* status = TF_NewStatus();
282 
283   TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
284   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
285   TFE_OpAddInput(op, a, status);
286   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
287   TF_DeleteStatus(status);
288   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
289 
290   return op;
291 }
292 
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)293 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
294   TF_Status* status = TF_NewStatus();
295 
296   TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
297   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
298   TFE_OpAddInput(op, a, status);
299   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
300   TF_DeleteStatus(status);
301   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
302 
303   return op;
304 }
305 
TestAxisTensorHandle(TFE_Context * ctx)306 TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
307   int64_t dims[] = {1};
308   int data[] = {1};
309   TF_Status* status = TF_NewStatus();
310   TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
311                                         sizeof(dims) / sizeof(int64_t), status);
312   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
313   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
314   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
315   TF_DeleteTensor(t);
316   TF_DeleteStatus(status);
317   return th;
318 }
319 
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)320 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
321               TFE_TensorHandle* axis) {
322   TF_Status* status = TF_NewStatus();
323 
324   TFE_Op* op = TFE_NewOp(ctx, "Min", status);
325   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
326   TFE_OpAddInput(op, input, status);
327   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
328   TFE_OpAddInput(op, axis, status);
329   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
330   TFE_OpSetAttrBool(op, "keep_dims", 1);
331   TFE_OpSetAttrType(op, "Tidx", TF_INT32);
332   TF_DeleteStatus(status);
333   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
334 
335   return op;
336 }
337 
AllReduceOp(TFE_Context * ctx,TFE_TensorHandle * in,int group_size)338 TFE_Op* AllReduceOp(TFE_Context* ctx, TFE_TensorHandle* in, int group_size) {
339   TF_Status* status = TF_NewStatus();
340 
341   TFE_Op* op = TFE_NewOp(ctx, "CollectiveReduce", status);
342   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
343   TFE_OpAddInput(op, in, status);
344   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
345   TF_DeleteStatus(status);
346 
347   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(in));
348   TFE_OpSetAttrInt(op, "group_size", group_size);
349   TFE_OpSetAttrInt(op, "group_key", 123);
350   TFE_OpSetAttrInt(op, "instance_key", 456);
351   TFE_OpSetAttrString(op, "merge_op", "Add", 3);
352   TFE_OpSetAttrString(op, "final_op", "Id", 2);
353   std::vector<int64_t> subdiv_offsets;
354   TFE_OpSetAttrIntList(op, "subdiv_offsets", subdiv_offsets.data(),
355                        subdiv_offsets.size());
356 
357   return op;
358 }
359 
SendOp(TFE_Context * ctx,TFE_TensorHandle * in,const std::string & op_name,const std::string & send_device,const std::string & recv_device,tensorflow::uint64 send_device_incarnation)360 TFE_Op* SendOp(TFE_Context* ctx, TFE_TensorHandle* in,
361                const std::string& op_name, const std::string& send_device,
362                const std::string& recv_device,
363                tensorflow::uint64 send_device_incarnation) {
364   TF_Status* status = TF_NewStatus();
365   TFE_Op* op = TFE_NewOp(ctx, op_name.c_str(), status);
366   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
367   TFE_OpAddInput(op, in, status);
368   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
369   TF_DeleteStatus(status);
370 
371   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(in));
372   TFE_OpSetAttrString(op, "tensor_name", "dummy", 5);
373   TFE_OpSetAttrString(op, "send_device", send_device.c_str(),
374                       send_device.size());
375   TFE_OpSetAttrString(op, "recv_device", recv_device.c_str(),
376                       recv_device.size());
377   TFE_OpSetAttrInt(op, "send_device_incarnation", send_device_incarnation);
378 
379   return op;
380 }
381 
RecvOp(TFE_Context * ctx,const std::string & op_name,const std::string & send_device,const std::string & recv_device,tensorflow::uint64 send_device_incarnation)382 TFE_Op* RecvOp(TFE_Context* ctx, const std::string& op_name,
383                const std::string& send_device, const std::string& recv_device,
384                tensorflow::uint64 send_device_incarnation) {
385   TF_Status* status = TF_NewStatus();
386   TFE_Op* op = TFE_NewOp(ctx, op_name.c_str(), status);
387   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
388   TF_DeleteStatus(status);
389 
390   TFE_OpSetAttrType(op, "tensor_type", TF_INT32);
391   TFE_OpSetAttrString(op, "tensor_name", "dummy", 5);
392   TFE_OpSetAttrString(op, "send_device", send_device.c_str(),
393                       send_device.size());
394   TFE_OpSetAttrString(op, "recv_device", recv_device.c_str(),
395                       recv_device.size());
396   TFE_OpSetAttrInt(op, "send_device_incarnation", send_device_incarnation);
397 
398   return op;
399 }
400 
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)401 bool GetDeviceName(TFE_Context* ctx, string* device_name,
402                    const char* device_type) {
403   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
404       TF_NewStatus(), TF_DeleteStatus);
405   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
406   CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
407 
408   const int num_devices = TF_DeviceListCount(devices);
409   for (int i = 0; i < num_devices; ++i) {
410     const string dev_type(TF_DeviceListType(devices, i, status.get()));
411     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
412     const string dev_name(TF_DeviceListName(devices, i, status.get()));
413     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
414     if (dev_type == device_type) {
415       *device_name = dev_name;
416       LOG(INFO) << "Found " << device_type << " device " << *device_name;
417       TF_DeleteDeviceList(devices);
418       return true;
419     }
420   }
421   TF_DeleteDeviceList(devices);
422   return false;
423 }
424 
GetServerDef(const string & job_name,int num_tasks)425 tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
426   tensorflow::ServerDef server_def;
427   server_def.set_protocol("grpc");
428   server_def.set_job_name(job_name);
429   server_def.set_task_index(0);
430   tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
431   tensorflow::JobDef* job_def = cluster_def->add_job();
432   job_def->set_name(job_name);
433   for (int i = 0; i < num_tasks; i++) {
434     int port = tensorflow::testing::PickUnusedPortOrDie();
435     job_def->mutable_tasks()->insert(
436         {i, tensorflow::strings::StrCat("localhost:", port)});
437   }
438   return server_def;
439 }
440 
GetServerDef(int num_tasks)441 tensorflow::ServerDef GetServerDef(int num_tasks) {
442   return GetServerDef("localhost", num_tasks);
443 }
444 
GetMultiClientServerDef(const std::string & job_name,int num_tasks,int num_virtual_gpus)445 tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name,
446                                               int num_tasks,
447                                               int num_virtual_gpus) {
448   tensorflow::ServerDef server_def;
449   server_def.set_protocol("grpc");
450   server_def.set_job_name(job_name);
451   server_def.set_task_index(0);
452   tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
453   tensorflow::JobDef* job_def = cluster_def->add_job();
454   job_def->set_name(job_name);
455   for (int i = 0; i < num_tasks; i++) {
456     int port = tensorflow::testing::PickUnusedPortOrDie();
457     job_def->mutable_tasks()->insert(
458         {i, tensorflow::strings::StrCat("localhost:", port)});
459   }
460   auto* config = server_def.mutable_default_session_config();
461   config->mutable_experimental()->set_collective_group_leader(
462       tensorflow::strings::StrCat("/job:", job_name, "/replica:0/task:", 0));
463   auto* rewrite_options =
464       config->mutable_graph_options()->mutable_rewrite_options();
465   rewrite_options->set_scoped_allocator_optimization(
466       tensorflow::RewriterConfig::ON);
467   rewrite_options->mutable_scoped_allocator_opts()->add_enable_op(
468       "CollectiveReduce");
469 
470   if ((tensorflow::IsGoogleCudaEnabled() || tensorflow::IsBuiltWithROCm()) &&
471       num_virtual_gpus > 0) {
472     tensorflow::GPUOptions* gpu_options =
473         server_def.mutable_default_session_config()->mutable_gpu_options();
474     auto virtual_devices =
475         gpu_options->mutable_experimental()->add_virtual_devices();
476     for (int i = 0; i < num_virtual_gpus; ++i) {
477       virtual_devices->add_memory_limit_mb(200);
478     }
479   }
480   return server_def;
481 }
482