xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/experimental/libtf/tests/function_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/function.h"
16 
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/abstract_function.h"
19 #include "tensorflow/c/eager/abstract_tensor_handle.h"
20 #include "tensorflow/c/eager/graph_function.h"
21 #include "tensorflow/c/eager/unified_api_testutil.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/cc/experimental/libtf/object.h"
24 #include "tensorflow/cc/experimental/libtf/value.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace tf {
32 namespace libtf {
33 using tensorflow::AbstractContext;
34 using tensorflow::AbstractContextPtr;
35 using tensorflow::AbstractFunctionPtr;
36 using tensorflow::AbstractTensorHandle;
37 using tensorflow::DT_FLOAT;
38 using tensorflow::FunctionDef;
39 using tensorflow::FunctionDefHelper;
40 using tensorflow::PartialTensorShape;
41 using tensorflow::Status;
42 using tensorflow::StatusOr;
43 using tensorflow::TF_StatusPtr;
44 using tensorflow::tracing::graph::GraphFunction;
45 
46 class FunctionTest
47     : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
48  public:
49   template <class T, TF_DataType datatype>
CreateScalarTensor(T val)50   impl::TaggedValueTensor CreateScalarTensor(T val) {
51     AbstractTensorHandle* raw = nullptr;
52     Status s = TestScalarTensorHandle<T, datatype>(ctx_.get(), val, &raw);
53     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
54     return impl::TaggedValueTensor(raw, /*add_ref=*/false);
55   }
56 
UseTfrt()57   bool UseTfrt() { return std::get<1>(GetParam()); }
58 
59   AbstractContextPtr ctx_;
60 
61  protected:
SetUp()62   void SetUp() override {
63     // Set the tracing impl, GraphDef vs MLIR.
64     TF_StatusPtr status(TF_NewStatus());
65     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
66     Status s = tensorflow::StatusFromTF_Status(status.get());
67     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
68 
69     // Set the runtime impl, Core RT vs TFRT.
70     AbstractContext* ctx_raw = nullptr;
71     s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw);
72     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
73     ctx_.reset(ctx_raw);
74   }
75 };
76 
77 // TODO(b/191361582): Use Abstract* APIs for building functions so that we can
78 // test with MLIR.
SquareFunc()79 FunctionDef SquareFunc() {
80   return FunctionDefHelper::Define(
81       // Function Name
82       "SquareFunc",
83       // Args
84       {"x: float"},
85       // Returns
86       {"y: float"},
87       // Attr def
88       {},
89       // Nodes
90       {{/*ret=*/{"y"},
91         /*op=*/"Square",
92         /*arg=*/{"x"},
93         /*attr=*/{{"T", DT_FLOAT}},
94         /*dep=*/{},
95         /*device=*/"",
96         /*name=*/"square"}});
97 }
98 
AddFunc()99 FunctionDef AddFunc() {
100   return FunctionDefHelper::Define(
101       // Function Name
102       "AddFunc",
103       // Args
104       {"x: float", "y: float"},
105       // Returns
106       {"z: float"},
107       // Attr def
108       {},
109       // Nodes
110       {{/*ret=*/{"z"},
111         /*op=*/"Add",
112         /*arg=*/{"x", "y"},
113         /*attr=*/{{"T", DT_FLOAT}},
114         /*dep=*/{},
115         /*device=*/"",
116         /*name=*/"add"}});
117 }
118 
IdentityNFunc()119 FunctionDef IdentityNFunc() {
120   return FunctionDefHelper::Define(
121       // Function Name
122       "IdentityNFunc",
123       // Args
124       {"x: float", "y: float"},
125       // Returns
126       {"u: float", "v: float"},
127       // Attr def
128       {},
129       // Nodes
130       {{/*ret=*/{"u", "v"},
131         /*op=*/"IdentityN",
132         /*arg=*/{"x", "y"},
133         /*attr=*/{{"T", tensorflow::DataTypeSlice({DT_FLOAT, DT_FLOAT})}},
134         /*dep=*/{},
135         /*device=*/""}});
136 }
137 
138 template <typename T>
ExpectEquals(AbstractTensorHandle * t,T expected)139 void ExpectEquals(AbstractTensorHandle* t, T expected) {
140   TF_Tensor* result_t;
141   Status s = tensorflow::GetValue(t, &result_t);
142   ASSERT_TRUE(s.ok()) << s.error_message();
143   auto value = static_cast<T*>(TF_TensorData(result_t));
144   EXPECT_EQ(*value, expected);
145   TF_DeleteTensor(result_t);
146 }
147 
148 // TODO(srbs): Add tests for captures.
149 // TODO(srbs): Add tests for polymorphism (different shapes and dtypes).
TEST_P(FunctionTest,Square)150 TEST_P(FunctionTest, Square) {
151   // Construct a scalar.
152   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
153   FunctionDef fdef = SquareFunc();
154   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
155   Function tf_function;
156   PartialTensorShape unknown_shape;
157   TaggedValue signature(unknown_shape, DT_FLOAT);
158   Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
159   ASSERT_TRUE(s.ok()) << s.error_message();
160   TaggedValue args(std::move(x));
161   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
162   ASSERT_TRUE(v.ok()) << v.status().error_message();
163   const TaggedValue& result = v.ValueOrDie();
164   AbstractTensorHandle* t = result.tensor().get();
165   ExpectEquals(t, 4.0f);
166 }
167 
TEST_P(FunctionTest,Add)168 TEST_P(FunctionTest, Add) {
169   // Construct a scalar.
170   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
171   FunctionDef fdef = AddFunc();
172   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
173   Function tf_function;
174   PartialTensorShape unknown_shape;
175   TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
176   TaggedValue input_signature = TaggedValue::Tuple();
177   input_signature.tuple().emplace_back(tensor_spec);
178   input_signature.tuple().emplace_back(tensor_spec);
179   Status s =
180       tf_function.RegisterTrace(std::move(trace), input_signature, tensor_spec);
181   ASSERT_TRUE(s.ok()) << s.error_message();
182   TaggedValue args = TaggedValue::Tuple();
183   args.tuple().emplace_back(TaggedValue(x));
184   args.tuple().emplace_back(TaggedValue(x));
185   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
186   ASSERT_TRUE(v.ok()) << v.status().error_message();
187   const TaggedValue& result = v.ValueOrDie();
188   ExpectEquals(result.tensor().get(), 4.0f);
189 }
190 
TEST_P(FunctionTest,IdentityN)191 TEST_P(FunctionTest, IdentityN) {
192   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
193   impl::TaggedValueTensor y = CreateScalarTensor<float, TF_FLOAT>(4.0f);
194   FunctionDef fdef = IdentityNFunc();
195   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
196   Function tf_function;
197   PartialTensorShape unknown_shape;
198   TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
199   TaggedValue signature = TaggedValue::Tuple();
200   signature.tuple().emplace_back(tensor_spec);
201   signature.tuple().emplace_back(tensor_spec);
202   Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
203   ASSERT_TRUE(s.ok()) << s.error_message();
204   TaggedValue args = TaggedValue::Tuple();
205   args.tuple().emplace_back(TaggedValue(x));
206   args.tuple().emplace_back(TaggedValue(y));
207   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
208   ASSERT_TRUE(v.ok()) << v.status().error_message();
209   const TaggedValue& result = v.ValueOrDie();
210   ExpectEquals(result.tuple()[0].tensor().get(), 2.0f);
211   ExpectEquals(result.tuple()[1].tensor().get(), 4.0f);
212 }
213 
TEST_P(FunctionTest,UnaryFuncCalledWithMultipleArgsFails)214 TEST_P(FunctionTest, UnaryFuncCalledWithMultipleArgsFails) {
215   // Construct a scalar.
216   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
217   FunctionDef fdef = SquareFunc();
218   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
219   Function tf_function;
220   PartialTensorShape unknown_shape;
221   TaggedValue signature(unknown_shape, DT_FLOAT);
222   Status s = tf_function.RegisterTrace(std::move(trace), signature, signature);
223   ASSERT_TRUE(s.ok()) << s.error_message();
224   TaggedValue args = TaggedValue::Tuple();
225   args.tuple().emplace_back(TaggedValue(x));
226   args.tuple().emplace_back(TaggedValue(x));
227   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
228   ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status()));
229   ASSERT_TRUE(absl::StrContains(v.status().error_message(), "No match"));
230 }
231 
TEST_P(FunctionTest,IncorrectArityOfOutputSignatureFails)232 TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) {
233   if (UseTfrt()) {
234     GTEST_SKIP() << "TFRT crashes if expected number of output tensors does not"
235                     " match actual.";
236   }
237   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
238   impl::TaggedValueTensor y = CreateScalarTensor<float, TF_FLOAT>(4.0f);
239   FunctionDef fdef = IdentityNFunc();
240   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
241   Function tf_function;
242   PartialTensorShape unknown_shape;
243   TaggedValue tensor_spec(unknown_shape, DT_FLOAT);
244   TaggedValue input_signature = TaggedValue::Tuple();
245   input_signature.tuple().emplace_back(tensor_spec);
246   input_signature.tuple().emplace_back(tensor_spec);
247   // This is wrong!
248   TaggedValue output_signature(unknown_shape, DT_FLOAT);
249   Status s = tf_function.RegisterTrace(std::move(trace), input_signature,
250                                        output_signature);
251   ASSERT_TRUE(s.ok()) << s.error_message();
252   TaggedValue args = TaggedValue::Tuple();
253   args.tuple().emplace_back(TaggedValue(x));
254   args.tuple().emplace_back(TaggedValue(y));
255   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
256   ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())) << v.status();
257   ASSERT_TRUE(absl::StrContains(v.status().error_message(),
258                                 "Expecting 2 outputs, but *num_retvals is 1"));
259 }
260 
TEST_P(FunctionTest,IncorrectDtypeInOutputSignatureFails)261 TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) {
262   // Construct a scalar.
263   impl::TaggedValueTensor x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
264   FunctionDef fdef = AddFunc();
265   AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false);
266   Function tf_function;
267   PartialTensorShape unknown_shape;
268   TaggedValue input_tensor_spec(unknown_shape, tensorflow::DT_FLOAT);
269   TaggedValue input_signature = TaggedValue::Tuple();
270   input_signature.tuple().emplace_back(input_tensor_spec);
271   input_signature.tuple().emplace_back(input_tensor_spec);
272   // Incorrect type.
273   TaggedValue output_tensor_spec(unknown_shape, tensorflow::DT_INT64);
274   Status s = tf_function.RegisterTrace(std::move(trace), input_signature,
275                                        output_tensor_spec);
276   ASSERT_TRUE(s.ok()) << s.error_message();
277   TaggedValue args = TaggedValue::Tuple();
278   args.tuple().emplace_back(TaggedValue(x));
279   args.tuple().emplace_back(TaggedValue(x));
280   StatusOr<TaggedValue> v = tf_function.Execute(ctx_.get(), args);
281   ASSERT_TRUE(tensorflow::errors::IsInternal(v.status())) << v.status();
282   ASSERT_TRUE(absl::StrContains(v.status().error_message(),
283                                 "Shape and dtype of tensor"));
284   ASSERT_TRUE(absl::StrContains(v.status().error_message(),
285                                 "does not match that in signature"));
286 }
287 
288 INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest,
289                          ::testing::Combine(::testing::Values("graphdef",
290                                                               "mlir"),
291                                             ::testing::Values(false, true)));
292 
293 }  // namespace libtf
294 }  // namespace tf
295