xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/c_api_test_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
16 #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
17 
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/c/eager/c_api_experimental.h"
20 #include "tensorflow/c/tf_datatype.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/tstring.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
25 
26 // Return a tensor handle containing a float scalar
27 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
28 
29 // Return a tensor handle containing a int scalar
30 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
31 
32 // Return a tensor handle containing a bool scalar
33 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
34 
35 // Return a tensor handle containing a tstring scalar
36 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
37                                          const tensorflow::tstring& value);
38 
39 // Return a tensor handle containing a 2x2 matrix of doubles
40 TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
41 
42 // Return a tensor handle containing a 2x2 matrix of floats
43 TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
44 
45 // Return a tensor handle containing 2D matrix containing given data and
46 // dimensions
47 TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
48                                                   float data[], int64_t dims[],
49                                                   int num_dims);
50 
51 // Get a Matrix TensorHandle with given float values and dimensions
52 TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
53                                                 int64_t dims[], int num_dims);
54 
55 // Get a Matrix TensorHandle with given int values and dimensions
56 TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
57                                               int64_t dims[], int num_dims);
58 
59 // Return a tensor handle with given type, values and dimensions.
60 template <class T, TF_DataType datatype>
TestTensorHandleWithDims(TFE_Context * ctx,const T * data,const int64_t * dims,int num_dims)61 TFE_TensorHandle* TestTensorHandleWithDims(TFE_Context* ctx, const T* data,
62                                            const int64_t* dims, int num_dims) {
63   TF_Status* status = TF_NewStatus();
64   TF_Tensor* t = TFE_AllocateHostTensor(ctx, datatype, dims, num_dims, status);
65   memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
66   TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
67   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
68   TF_DeleteTensor(t);
69   TF_DeleteStatus(status);
70   return th;
71 }
72 
73 // Return a scalar tensor handle with given values.
74 template <class T, TF_DataType datatype>
TestScalarTensorHandle(TFE_Context * ctx,const T value)75 TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, const T value) {
76   T data[] = {value};
77   return TestTensorHandleWithDims<T, datatype>(ctx, data, nullptr, 0);
78 }
79 
80 // Return a tensor handle containing a 100x100 matrix of floats
81 TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
82 
83 // Return a tensor handle containing a 3x2 matrix of doubles
84 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
85 
86 // Return a tensor handle containing a 3x2 matrix of floats
87 TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
88 
89 // Return a variable handle referring to a variable with the given initial value
90 // on the given device.
91 TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
92                                const tensorflow::string& device_name = "");
93 
94 // Return an add op multiplying `a` by `b`.
95 TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
96 
97 // Return a matmul op multiplying `a` by `b`.
98 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
99 
100 // Return an identity op.
101 TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
102 
103 // Return a shape op fetching the shape of `a`.
104 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
105 
106 // Return an allreduce op adding up input tensor `in` from `group_size` workers.
107 TFE_Op* AllReduceOp(TFE_Context* ctx, TFE_TensorHandle* in, int group_size);
108 
109 // Return a SendOp op `op_name` with send input tensor `in` and attributes
110 // `send_device`, `recv_device`, and `send_device_incarnation` set.
111 TFE_Op* SendOp(TFE_Context* ctx, TFE_TensorHandle* in,
112                const std::string& op_name, const std::string& send_device,
113                const std::string& recv_device,
114                tensorflow::uint64 send_device_incarnation);
115 
116 // Return a RecvOp op `op_name` with the attributes `send_device`,
117 // `recv_device`, and `send_device_incarnation` set.
118 TFE_Op* RecvOp(TFE_Context* ctx, const std::string& op_name,
119                const std::string& send_device, const std::string& recv_device,
120                tensorflow::uint64 send_device_incarnation);
121 
122 // Return a 1-D INT32 tensor containing a single value 1.
123 TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
124 
125 // Return an op taking minimum of `input` long `axis` dimension.
126 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
127               TFE_TensorHandle* axis);
128 
129 // If there is a device of type `device_type`, returns true
130 // and sets 'device_name' accordingly.
131 // `device_type` must be either "GPU" or "TPU".
132 bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
133                    const char* device_type);
134 
135 // Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
136 tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
137                                    int num_tasks);
138 
139 // Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
140 tensorflow::ServerDef GetServerDef(int num_tasks);
141 
142 // Create a multi-client ServerDef with the given `job_name`, add `num_tasks`
143 // tasks and `num_virtual_gpus` virtual GPUs in it.
144 tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name,
145                                               int num_tasks,
146                                               int num_virtual_gpus = 0);
147 
148 #endif  // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
149