1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/function.h"
16
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/abstract_function.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/graph_function.h"
21 #include "tensorflow/c/eager/unified_api_testutil.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/cc/experimental/libtf/object.h"
24 #include "tensorflow/cc/experimental/libtf/value.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace tf {
32 namespace libtf {
33 using tensorflow::AbstractContext;
34 using tensorflow::AbstractContextPtr;
35 using tensorflow::AbstractFunctionPtr;
36 using tensorflow::AbstractTensorHandle;
37 using tensorflow::DT_FLOAT;
38 using tensorflow::FunctionDef;
39 using tensorflow::FunctionDefHelper;
40 using tensorflow::PartialTensorShape;
41 using tensorflow::Status;
42 using tensorflow::StatusOr;
43 using tensorflow::TF_StatusPtr;
44 using tensorflow::tracing::graph::GraphFunction;
45
46 class FunctionTest
47 : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
48 public:
49 template <class T, TF_DataType datatype>
CreateScalarTensor(T val)50 impl::TaggedValueTensor CreateScalarTensor(T val) {
51 AbstractTensorHandle* raw = nullptr;
52 Status s = TestScalarTensorHandle<T, datatype>(ctx_.get(), val, &raw);
53 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
54 return impl::TaggedValueTensor(raw, /*add_ref=*/false);
55 }
56
UseTfrt()57 bool UseTfrt() { return std::get<1>(GetParam()); }
58
59 AbstractContextPtr ctx_;
60
61 protected:
SetUp()62 void SetUp() override {
63 // Set the tracing impl, GraphDef vs MLIR.
64 TF_StatusPtr status(TF_NewStatus());
65 TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
66 Status s = tensorflow::StatusFromTF_Status(status.get());
67 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
68
69 // Set the runtime impl, Core RT vs TFRT.
70 AbstractContext* ctx_raw = nullptr;
71 s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw);
72 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
73 ctx_.reset(ctx_raw);
74 }
75 };
76
77 // TODO(b/191361582): Use Abstract* APIs for building functions so that we can
78 // test with MLIR.
SquareFunc()79 FunctionDef SquareFunc() {
80 return FunctionDefHelper::Define(
81 // Function Name
82 "SquareFunc",
83 // Args
84 {"x: float"},
85 // Returns
86 {"y: float"},
87 // Attr def
88 {},
89 // Nodes
90 {{/*ret=*/{"y"},
91 /*op=*/"Square",
92 /*arg=*/{"x"},
93 /*attr=*/{{"T", DT_FLOAT}},
94 /*dep=*/{},
95 /*device=*/"",
96 /*name=*/"square"}});
97 }
98
AddFunc()99 FunctionDef AddFunc() {
100 return FunctionDefHelper::Define(
101 // Function Name
102 "AddFunc",
103 // Args
104 {"x: float", "y: float"},
105 // Returns
106 {"z: float"},
107 // Attr def
108 {},
109 // Nodes
110 {{/*ret=*/{"z"},
111 /*op=*/"Add",
112 /*arg=*/{"x", "y"},
113 /*attr=*/{{"T", DT_FLOAT}},
114 /*dep=*/{},
115 /*device=*/"",
116 /*name=*/"add"}});
117 }
118
IdentityNFunc()119 FunctionDef IdentityNFunc() {
120 return FunctionDefHelper::Define(
121 // Function Name
122 "IdentityNFunc",
123 // Args
124 {"x: float", "y: float"},
125 // Returns
126 {"u: float", "v: float"},
127 // Attr def
128 {},
129 // Nodes
130 {{/*ret=*/{"u", "v"},
131 /*op=*/"IdentityN",
132 /*arg=*/{"x", "y"},
133 /*attr=*/{{"T", tensorflow::DataTypeSlice({DT_FLOAT, DT_FLOAT})}},
134 /*dep=*/{},
135 /*device=*/""}});
136 }
137
138 template <typename T>
ExpectEquals(AbstractTensorHandle * t,T expected)139 void ExpectEquals(AbstractTensorHandle* t, T expected) {
140 TF_Tensor* result_t;
141 Status s = tensorflow::GetValue(t, &result_t);
142 ASSERT_TRUE(s.ok()) << s.error_message();
143 auto value = static_cast<T*>(TF_TensorData(result_t));
144 EXPECT_EQ(*value, expected);
145 TF_DeleteTensor(result_t);
146 }
147
148 // TODO(srbs): Add tests for captures.
149 // TODO(srbs): Add tests for polymorphism (different shapes and dtypes).
TEST_P(FunctionTest,Square)150 TEST_P(FunctionTest, Square) {
151 // Construct a scalar.
152 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
153 FunctionDef fdef = SquareFunc();
154 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
155 Function tf_function;
156 PartialTensorShape unknown_shape;
157 TaggedValue signature(unknown_shape, DT_FLOAT);
158 Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
159 ASSERT_TRUE(s.ok()) << s.error_message();
160 TaggedValue args(std::move(x));
161 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
162 ASSERT_TRUE(v.ok()) << v.status().error_message();
163 const TaggedValue& result = v.ValueOrDie();
164 AbstractTensorHandle* t = result.tensor().get();
165 ExpectEquals(t, 4.0f);
166 }
167
TEST_P(FunctionTest,Add)168 TEST_P(FunctionTest, Add) {
169 // Construct a scalar.
170 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
171 FunctionDef fdef = AddFunc();
172 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
173 Function tf_function;
174 PartialTensorShape unknown_shape;
175 TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
176 TaggedValue input_signature = TaggedValue::Tuple();
177 input_signature.tuple().emplace_back(tensor_spec);
178 input_signature.tuple().emplace_back(tensor_spec);
179 Status s =
180 tf_function.RegisterTrace(std::move(trace), input_signature, tensor_spec);
181 ASSERT_TRUE(s.ok()) << s.error_message();
182 TaggedValue args = TaggedValue::Tuple();
183 args.tuple().emplace_back(TaggedValue(x));
184 args.tuple().emplace_back(TaggedValue(x));
185 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
186 ASSERT_TRUE(v.ok()) << v.status().error_message();
187 const TaggedValue& result = v.ValueOrDie();
188 ExpectEquals(result.tensor().get(), 4.0f);
189 }
190
TEST_P(FunctionTest,IdentityN)191 TEST_P(FunctionTest, IdentityN) {
192 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
193 impl::TaggedValueTensor y = CreateScalarTensor<float, TF_FLOAT>(4.0f);
194 FunctionDef fdef = IdentityNFunc();
195 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
196 Function tf_function;
197 PartialTensorShape unknown_shape;
198 TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
199 TaggedValue signature = TaggedValue::Tuple();
200 signature.tuple().emplace_back(tensor_spec);
201 signature.tuple().emplace_back(tensor_spec);
202 Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
203 ASSERT_TRUE(s.ok()) << s.error_message();
204 TaggedValue args = TaggedValue::Tuple();
205 args.tuple().emplace_back(TaggedValue(x));
206 args.tuple().emplace_back(TaggedValue(y));
207 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
208 ASSERT_TRUE(v.ok()) << v.status().error_message();
209 const TaggedValue& result = v.ValueOrDie();
210 ExpectEquals(result.tuple()[0].tensor().get(), 2.0f);
211 ExpectEquals(result.tuple()[1].tensor().get(), 4.0f);
212 }
213
TEST_P(FunctionTest,UnaryFuncCalledWithMultipleArgsFails)214 TEST_P(FunctionTest, UnaryFuncCalledWithMultipleArgsFails) {
215 // Construct a scalar.
216 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
217 FunctionDef fdef = SquareFunc();
218 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
219 Function tf_function;
220 PartialTensorShape unknown_shape;
221 TaggedValue signature(unknown_shape, DT_FLOAT);
222 Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
223 ASSERT_TRUE(s.ok()) << s.error_message();
224 TaggedValue args = TaggedValue::Tuple();
225 args.tuple().emplace_back(TaggedValue(x));
226 args.tuple().emplace_back(TaggedValue(x));
227 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
228 ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status()));
229 ASSERT_TRUE(absl::StrContains(v.status().error_message(), "No match"));
230 }
231
TEST_P(FunctionTest,IncorrectArityOfOutputSignatureFails)232 TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) {
233 if (UseTfrt()) {
234 GTEST_SKIP() << "TFRT crashes if expected number of output tensors does not"
235 " match actual.";
236 }
237 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
238 impl::TaggedValueTensor y = CreateScalarTensor<float, TF_FLOAT>(4.0f);
239 FunctionDef fdef = IdentityNFunc();
240 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
241 Function tf_function;
242 PartialTensorShape unknown_shape;
243 TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
244 TaggedValue input_signature = TaggedValue::Tuple();
245 input_signature.tuple().emplace_back(tensor_spec);
246 input_signature.tuple().emplace_back(tensor_spec);
247 // This is wrong!
248 TaggedValue output_signature(unknown_shape, DT_FLOAT);
249 Status s = tf_function.RegisterTrace(std::move(trace), input_signature,
250 output_signature);
251 ASSERT_TRUE(s.ok()) << s.error_message();
252 TaggedValue args = TaggedValue::Tuple();
253 args.tuple().emplace_back(TaggedValue(x));
254 args.tuple().emplace_back(TaggedValue(y));
255 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
256 ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())) << v.status();
257 ASSERT_TRUE(absl::StrContains(v.status().error_message(),
258 "Expecting 2 outputs, but *num_retvals is 1"));
259 }
260
TEST_P(FunctionTest,IncorrectDtypeInOutputSignatureFails)261 TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) {
262 // Construct a scalar.
263 impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
264 FunctionDef fdef = AddFunc();
265 AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
266 Function tf_function;
267 PartialTensorShape unknown_shape;
268 TaggedValue input_tensor_spec(unknown_shape, tensorflow::DT_FLOAT);
269 TaggedValue input_signature = TaggedValue::Tuple();
270 input_signature.tuple().emplace_back(input_tensor_spec);
271 input_signature.tuple().emplace_back(input_tensor_spec);
272 // Incorrect type.
273 TaggedValue output_tensor_spec(unknown_shape, tensorflow::DT_INT64);
274 Status s = tf_function.RegisterTrace(std::move(trace), input_signature,
275 output_tensor_spec);
276 ASSERT_TRUE(s.ok()) << s.error_message();
277 TaggedValue args = TaggedValue::Tuple();
278 args.tuple().emplace_back(TaggedValue(x));
279 args.tuple().emplace_back(TaggedValue(x));
280 StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
281 ASSERT_TRUE(tensorflow::errors::IsInternal(v.status())) << v.status();
282 ASSERT_TRUE(absl::StrContains(v.status().error_message(),
283 "Shape and dtype of tensor"));
284 ASSERT_TRUE(absl::StrContains(v.status().error_message(),
285 "does not match that in signature"));
286 }
287
288 INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest,
289 ::testing::Combine(::testing::Values("graphdef",
290 "mlir"),
291 ::testing::Values(false, true)));
292
293 } // namespace libtf
294 } // namespace tf
295