1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_test_util.h"
24 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
25 #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
26 #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
27 #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
28 #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
29 #include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
30 #include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
31 #include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
32 #include "tensorflow/c/tf_datatype.h"
33 #include "tensorflow/c/tf_shape.h"
34 #include "tensorflow/c/tf_status.h"
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/core/lib/io/path.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringpiece.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/platform/tstring.h"
41
42 namespace {
43
44 using tensorflow::tstring;
45
46 constexpr char kTestData[] = "cc/saved_model/testdata";
47 const char* kServeTag[] = {"serve"};
48
SavedModelPath(tensorflow::StringPiece saved_model_dir)49 std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
50 return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
51 kTestData, saved_model_dir);
52 }
53
54 // This value parameterized test allows us to test both TFRT
55 // and non TFRT runtimes.
56 // https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
57 class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
58
TEST_P(CSavedModelAPITest,LoadsSavedModelWithTags)59 TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
60 TF_Status* status = TF_NewStatus();
61 TFE_ContextOptions* opts = TFE_NewContextOptions();
62 bool use_tfrt = GetParam();
63 if (use_tfrt) {
64 TFE_DeleteContextOptions(opts);
65 TF_DeleteStatus(status);
66 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
67 }
68
69 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
70
71 TFE_Context* ctx = TFE_NewContext(opts, status);
72 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
73 TFE_DeleteContextOptions(opts);
74
75 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
76
77 TF_SavedModel* saved_model =
78 TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
79
80 // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
81 // That unblocks writing other tests that require a TF_SavedModel*,
82 // like loading a ConcreteFunction. This test at least checks that the
83 // C API builds and can be minimally run.
84 EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
85
86 TF_DeleteSavedModel(saved_model);
87 TF_DeleteStatus(status);
88 TFE_DeleteContext(ctx);
89 }
90
TEST_P(CSavedModelAPITest,LoadsSavedModel)91 TEST_P(CSavedModelAPITest, LoadsSavedModel) {
92 TF_Status* status = TF_NewStatus();
93 TFE_ContextOptions* opts = TFE_NewContextOptions();
94 bool use_tfrt = GetParam();
95 if (use_tfrt) {
96 TFE_DeleteContextOptions(opts);
97 TF_DeleteStatus(status);
98 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
99 }
100
101 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
102
103 TFE_Context* ctx = TFE_NewContext(opts, status);
104 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
105 TFE_DeleteContextOptions(opts);
106
107 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
108
109 TF_SavedModel* saved_model =
110 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
111
112 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
113 TF_ConcreteFunction* compute_fn =
114 TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
115 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
116
117 std::vector<TFE_TensorHandle*> compute_fn_inputs;
118 TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
119 TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
120 compute_fn_inputs.push_back(input_a);
121 compute_fn_inputs.push_back(input_b);
122
123 TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
124 compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
125 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
126
127 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
128 // inputs + outputs a function has.
129 TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
130 int num_retvals = 1;
131
132 TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
133 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
134
135 TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
136 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
137
138 EXPECT_EQ(TF_NumDims(result), 0);
139 float output_value = *static_cast<float*>(TF_TensorData(result));
140 // (1 + 2) * (2 + 1) / 3 + 5 should be 8
141 EXPECT_FLOAT_EQ(output_value, 8.0);
142
143 TF_DeleteTensor(result);
144 TFE_DeleteTensorHandle(compute_fn_outputs[0]);
145 TFE_DeleteTensorHandle(input_a);
146 TFE_DeleteTensorHandle(input_b);
147 TFE_DeleteOp(compute_fn_op);
148 TF_DeleteSavedModel(saved_model);
149 TF_DeleteStatus(status);
150 TFE_DeleteContext(ctx);
151 }
152
153 // This tests running the "serving_default" SignatureDefFunction from the
154 // VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
155 // protobuf in the metagraph looks like:
156 // signature_def: {
157 // key : "serving_default"
158 // value: {
159 // inputs: {
160 // key : "a"
161 // value: {
162 // name : "serving_default_a:0"
163 // dtype: DT_FLOAT
164 // tensor_shape: {
165 // }
166 // }
167 // }
168 // inputs: {
169 // key : "b"
170 // value: {
171 // name : "serving_default_b:0"
172 // dtype: DT_FLOAT
173 // tensor_shape: {
174 // }
175 // }
176 // }
177 // outputs: {
178 // key : "output_0"
179 // value: {
180 // name : "StatefulPartitionedCall:0"
181 // dtype: DT_FLOAT
182 // tensor_shape: {
183 // }
184 // }
185 // }
186 // method_name: "tensorflow/serving/predict"
187 // }
188 // }
TEST_P(CSavedModelAPITest,RunsSignatureDefFunction)189 TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
190 TF_Status* status = TF_NewStatus();
191 TFE_ContextOptions* opts = TFE_NewContextOptions();
192 bool use_tfrt = GetParam();
193 if (use_tfrt) {
194 TFE_DeleteContextOptions(opts);
195 TF_DeleteStatus(status);
196 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
197 }
198
199 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
200
201 TFE_Context* ctx = TFE_NewContext(opts, status);
202 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
203 TFE_DeleteContextOptions(opts);
204
205 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
206
207 TF_SavedModel* saved_model =
208 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
209
210 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
211 TF_SignatureDefFunction* serving_default =
212 TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
213 status);
214 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
215
216 TF_SignatureDefFunctionMetadata* metadata =
217 TF_SignatureDefFunctionGetMetadata(serving_default);
218
219 const TF_SignatureDefParamList* args =
220 TF_SignatureDefFunctionMetadataArgs(metadata);
221 const TF_SignatureDefParamList* returns =
222 TF_SignatureDefFunctionMetadataReturns(metadata);
223
224 EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
225 const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
226 const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
227 const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
228
229 // Input "a" is a scalar, float32 tensor
230 EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
231 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
232 EXPECT_EQ(0, TF_ShapeDims(shape_a));
233
234 const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
235 const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
236 const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
237
238 // Input "b" is a scalar, float32 tensor
239 EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
240 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
241 EXPECT_EQ(0, TF_ShapeDims(shape_b));
242
243 EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
244
245 const TF_SignatureDefParam* param_out =
246 TF_SignatureDefParamListGet(returns, 0);
247 const TF_TensorSpec* tensor_spec_out =
248 TF_SignatureDefParamTensorSpec(param_out);
249 const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
250
251 // Output "output_0" is a scalar, float32 tensor
252 EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
253 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
254 EXPECT_EQ(0, TF_ShapeDims(shape_out));
255
256 std::vector<TFE_TensorHandle*> compute_fn_inputs;
257 TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
258 TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
259 compute_fn_inputs.push_back(input_a);
260 compute_fn_inputs.push_back(input_b);
261
262 TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
263 serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
264 status);
265 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
266
267 std::vector<TFE_TensorHandle*> compute_fn_outputs(
268 TF_SignatureDefParamListSize(returns));
269 int num_retvals = TF_SignatureDefParamListSize(returns);
270
271 TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
272 status);
273 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
274
275 TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
276 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
277
278 EXPECT_EQ(TF_NumDims(result), 0);
279 float output_value = *static_cast<float*>(TF_TensorData(result));
280 // (1 + 2) * (2 + 1) / 3 + 5 should be 8
281 EXPECT_FLOAT_EQ(output_value, 8.0);
282
283 TF_DeleteTensor(result);
284 TFE_DeleteTensorHandle(compute_fn_outputs[0]);
285 TFE_DeleteTensorHandle(input_a);
286 TFE_DeleteTensorHandle(input_b);
287 TFE_DeleteOp(serving_default_op);
288 TF_DeleteSavedModel(saved_model);
289 TF_DeleteStatus(status);
290 TFE_DeleteContext(ctx);
291 }
292
TEST_P(CSavedModelAPITest,LoadsAssetSavedModel)293 TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
294 TF_Status* status = TF_NewStatus();
295 TFE_ContextOptions* opts = TFE_NewContextOptions();
296 bool use_tfrt = GetParam();
297 if (use_tfrt) {
298 TFE_DeleteContextOptions(opts);
299 TF_DeleteStatus(status);
300 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
301 }
302
303 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
304
305 TFE_Context* ctx = TFE_NewContext(opts, status);
306 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
307 TFE_DeleteContextOptions(opts);
308
309 std::string model_dir = SavedModelPath("AssetModule");
310
311 TF_SavedModel* saved_model =
312 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
313
314 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
315 TF_ConcreteFunction* read_file_fn =
316 TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
317 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
318
319 TFE_Op* read_file_op =
320 TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
321 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
322
323 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
324 // inputs + outputs a function has.
325 TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
326 int num_retvals = 1;
327
328 TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
329 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
330
331 TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
332 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
333
334 EXPECT_EQ(TF_NumDims(result), 0);
335 tensorflow::tstring* output_value =
336 static_cast<tensorflow::tstring*>(TF_TensorData(result));
337 std::string file_contents(*output_value);
338 EXPECT_NE(file_contents.find("TEST ASSET FILE CONTENTS"), std::string::npos);
339
340 TF_DeleteTensor(result);
341 TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
342 TFE_DeleteOp(read_file_op);
343 TF_DeleteSavedModel(saved_model);
344 TF_DeleteStatus(status);
345 TFE_DeleteContext(ctx);
346 }
347
TEST_P(CSavedModelAPITest,LoadsStaticHashtableSavedModel)348 TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
349 TF_Status* status = TF_NewStatus();
350 TFE_ContextOptions* opts = TFE_NewContextOptions();
351 bool use_tfrt = GetParam();
352 if (use_tfrt) {
353 TFE_DeleteContextOptions(opts);
354 TF_DeleteStatus(status);
355 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
356 }
357
358 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
359
360 TFE_Context* ctx = TFE_NewContext(opts, status);
361 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
362 TFE_DeleteContextOptions(opts);
363
364 std::string model_dir = SavedModelPath("StaticHashTableModule");
365
366 TF_SavedModel* saved_model =
367 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
368
369 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
370 TF_ConcreteFunction* lookup_fn =
371 TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
372 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
373
374 // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
375 // mapping:
376 // "foo" -> 0
377 // "bar" -> 1
378 // "baz" -> 2
379 // "wombat" -> 3
380 // all other strings -> -1
381
382 // Call lookup function with input "foo", expecting an output of 0
383 {
384 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
385 TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
386 lookup_fn_inputs.push_back(input_foo);
387
388 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
389 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
390 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
391
392 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
393 // inputs + outputs a function has.
394 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
395 int num_retvals = 1;
396
397 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
398 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
399
400 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
401 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
402
403 EXPECT_EQ(TF_NumDims(result), 0);
404 int64_t* output_value = static_cast<int64_t*>(TF_TensorData(result));
405 EXPECT_EQ(*output_value, 0);
406
407 TF_DeleteTensor(result);
408 TFE_DeleteTensorHandle(input_foo);
409 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
410 TFE_DeleteOp(lookup_op);
411 }
412
413 // Call lookup function with input "baz", expecting an output of 2
414 {
415 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
416 TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
417 lookup_fn_inputs.push_back(input_foo);
418
419 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
420 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
421 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
422
423 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
424 // inputs + outputs a function has.
425 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
426 int num_retvals = 1;
427
428 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
429 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
430
431 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
432 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
433
434 EXPECT_EQ(TF_NumDims(result), 0);
435 int64_t* output_value = static_cast<int64_t*>(TF_TensorData(result));
436 EXPECT_EQ(*output_value, 2);
437
438 TF_DeleteTensor(result);
439 TFE_DeleteTensorHandle(input_foo);
440 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
441 TFE_DeleteOp(lookup_op);
442 }
443
444 // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
445 {
446 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
447 TFE_TensorHandle* input_foo =
448 TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
449 lookup_fn_inputs.push_back(input_foo);
450
451 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
452 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
453 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
454
455 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
456 // inputs + outputs a function has.
457 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
458 int num_retvals = 1;
459
460 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
461 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
462
463 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
464 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
465
466 EXPECT_EQ(TF_NumDims(result), 0);
467 int64_t* output_value = static_cast<int64_t*>(TF_TensorData(result));
468 EXPECT_EQ(*output_value, -1);
469
470 TF_DeleteTensor(result);
471 TFE_DeleteTensorHandle(input_foo);
472 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
473 TFE_DeleteOp(lookup_op);
474 }
475
476 TF_DeleteSavedModel(saved_model);
477 TF_DeleteStatus(status);
478 TFE_DeleteContext(ctx);
479 }
480
TEST_P(CSavedModelAPITest,LoadSavedModelWithUninitializedVariable)481 TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
482 TF_Status* status = TF_NewStatus();
483 TFE_ContextOptions* opts = TFE_NewContextOptions();
484 bool use_tfrt = GetParam();
485 if (use_tfrt) {
486 TFE_DeleteContextOptions(opts);
487 TF_DeleteStatus(status);
488 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
489 }
490
491 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
492
493 TFE_Context* ctx = TFE_NewContext(opts, status);
494 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
495 TFE_DeleteContextOptions(opts);
496
497 std::string model_dir = tensorflow::io::JoinPath(
498 tensorflow::testing::TensorFlowSrcRoot(),
499 "c/experimental/saved_model/internal/testdata/UninitializedVariable");
500
501 TF_SavedModel* saved_model =
502 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
503 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
504
505 tensorflow::TFSavedModelAPI* model_api =
506 tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
507 tensorflow::unwrap(saved_model));
508 tensorflow::Variable* uninitialized_variable;
509 ASSERT_EQ(::tensorflow::OkStatus(),
510 model_api->GetVariable("uninitialized_variable",
511 &uninitialized_variable));
512 ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
513
514 ASSERT_EQ(::tensorflow::OkStatus(),
515 model_api->GetVariable("sub_module.uninitialized_variable",
516 &uninitialized_variable));
517 ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
518
519 TF_DeleteSavedModel(saved_model);
520 TF_DeleteStatus(status);
521 TFE_DeleteContext(ctx);
522 }
523
TEST_P(CSavedModelAPITest,LoadSavedModelWithWhileLoop)524 TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) {
525 TF_Status* status = TF_NewStatus();
526 TFE_ContextOptions* opts = TFE_NewContextOptions();
527 bool use_tfrt = GetParam();
528 if (use_tfrt) {
529 TFE_DeleteContextOptions(opts);
530 TF_DeleteStatus(status);
531 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
532 }
533
534 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
535
536 TFE_Context* ctx = TFE_NewContext(opts, status);
537 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
538 TFE_DeleteContextOptions(opts);
539
540 std::string model_dir = tensorflow::io::JoinPath(
541 tensorflow::testing::TensorFlowSrcRoot(),
542 "c/experimental/saved_model/internal/testdata/SimpleWhileLoop");
543
544 TF_SavedModel* saved_model =
545 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
546 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
547
548 TF_ConcreteFunction* while_fn =
549 TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
550 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
551
552 std::vector<TFE_TensorHandle*> while_fn_inputs;
553 while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f));
554
555 TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp(
556 while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status);
557 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
558
559 TFE_TensorHandle* while_fn_outputs[1] = {nullptr};
560 int num_retvals = 1;
561
562 TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status);
563 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
564
565 TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status);
566 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
567 ASSERT_EQ(TF_NumDims(result), 0);
568 float output_value = *static_cast<float*>(TF_TensorData(result));
569 ASSERT_FLOAT_EQ(output_value, 55); // 10+9+...+1
570
571 TF_DeleteTensor(result);
572 TFE_DeleteTensorHandle(while_fn_outputs[0]);
573 TFE_DeleteOp(while_fn_op);
574 TFE_DeleteTensorHandle(while_fn_inputs[0]);
575 TF_DeleteSavedModel(saved_model);
576 TF_DeleteStatus(status);
577 TFE_DeleteContext(ctx);
578 }
579
580 INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
581 ::testing::Bool());
582
583 } // namespace
584