xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tf_session_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/tensorflow/tf_session.h"
18 
19 #include <string>
20 
21 #include "gtest/gtest.h"
22 #include "fcp/base/tracing_schema.h"
23 #include "fcp/protos/plan.pb.h"
24 #include "fcp/tensorflow/testing/tf_helper.h"
25 #include "fcp/tensorflow/tracing_schema.h"
26 #include "fcp/testing/result_matchers.h"
27 #include "fcp/tracing/test_tracing_recorder.h"
28 #include "tensorflow/cc/framework/scope.h"
29 #include "tensorflow/cc/ops/math_ops.h"
30 #include "tensorflow/cc/ops/standard_ops.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/protobuf/saver.pb.h"
33 
34 namespace fcp {
35 
36 using google::internal::federated::plan::CheckpointOp;
37 using tensorflow::Tensor;
38 using tensorflow::TensorShape;
39 using tensorflow::ops::Add;
40 using tensorflow::ops::Assign;
41 using tensorflow::ops::Const;
42 using tensorflow::ops::Mul;
43 using tensorflow::ops::Placeholder;
44 using tensorflow::ops::Restore;
45 using tensorflow::ops::Save;
46 using tensorflow::ops::Variable;
47 using tensorflow::test::AsTensor;
48 using tensorflow::test::ExpectTensorEqual;
49 using testing::_;
50 using testing::Not;
51 
52 template <typename T>
CheckOutput(TfSession * sess,const std::string & output_op,Tensor expected)53 void CheckOutput(TfSession* sess, const std::string& output_op,
54                  Tensor expected) {
55   Result<std::unique_ptr<TfSession::NamedTensorMap>> outputs =
56       sess->GetOutputs(std::make_unique<std::vector<std::string>>(
57           std::initializer_list<std::string>{output_op}));
58   EXPECT_THAT(outputs, Not(IsError()));
59   ExpectTensorEqual<T>((*outputs.GetValueOrDie())[output_op], expected);
60 }
61 
TEST(TfSessionTest,InitializeWithEmptyGraph)62 TEST(TfSessionTest, InitializeWithEmptyGraph) {
63   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
64   TestTracingRecorder tracing_recorder;
65   TfSession sess("foo/bar", CreateGraph(&root));
66   ASSERT_THAT(sess.Ready(), Not(IsError()));
67   // Running an empty operation is a no-op.
68   EXPECT_THAT(sess.RunOp(""), Not(IsError()));
69   // Getting an empty list of outputs is a no-op.
70   EXPECT_THAT(sess.GetOutputs(std::make_unique<std::vector<std::string>>()),
71               Not(IsError()));
72   // There are no ops registered in the GraphDef, so trying to run an op won't
73   // work.
74   tracing_recorder.ExpectError<ResultExpectStatusError>();
75   EXPECT_THAT(sess.RunOp("sum"), IsError());
76   // Validate the expected hierarchy of tracing spans. There should be only one
77   // RunTfOp span, as we don't want to bother recording a noop if the op is
78   // empty.
79   EXPECT_THAT(tracing_recorder.root(),
80               ElementsAre(AllOf(
81                   IsSpan<RunTfOp>(),
82                   ElementsAre(IsEvent<ResultExpectStatusError>(
83                       static_cast<int>(fcp::OK),
84                       static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)))));
85 }
86 
TEST(TfSessionTest,InvalidGraphBytes)87 TEST(TfSessionTest, InvalidGraphBytes) {
88   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
89   TestTracingRecorder tracing_recorder;
90   tracing_recorder.ExpectError<ResultExpectStatusError>();
91   TfSession sess("foo/bar", "garbage");
92   ASSERT_THAT(sess.Ready(), IsError());
93   EXPECT_THAT(tracing_recorder.root(),
94               ElementsAre(IsEvent<ResultExpectStatusError>(
95                   static_cast<int>(fcp::OK),
96                   static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)));
97 }
98 
TEST(TfSessionTest,RunGraphOp)99 TEST(TfSessionTest, RunGraphOp) {
100   // Construct a TensorFlow graph with all desired operations.
101   // This graph just assigns the result of multiplying two constants "a" and "b"
102   // to a variable "c", and makes it possible to output "c".
103   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
104   auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
105   auto b = Const<int32_t>(root, {{2}});
106   auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
107   auto assign_c = Assign(root.WithOpName("assign_c"), c, Mul(root, a, b));
108 
109   // Run a session using the graph constructed above.
110   TestTracingRecorder tracing_recorder;
111   TfSession sess("foo/bar", CreateGraph(&root));
112   ASSERT_THAT(sess.Ready(), Not(IsError()));
113 
114   // Run an operation on the session and validate the result.
115   EXPECT_THAT(sess.RunOp("assign_c"), Not(IsError()));
116   CheckOutput<int32_t>(&sess, "c",
117                        AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
118 }
119 
TEST(TfSessionTest,RestoreFromTensor)120 TEST(TfSessionTest, RestoreFromTensor) {
121   // Construct a TensorFlow graph with all desired operations.
122   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
123   auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
124   auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
125   auto restore = Assign(root.WithOpName("restore_a"), a, input);
126   auto double_a = Assign(root.WithOpName("double_a"), a,
127                          Mul(root, a, Const<int32_t>(root, {{2}})));
128 
129   // Run a session using the graph constructed above.
130   TestTracingRecorder tracing_recorder;
131   TfSession sess(testing::TempDir(), CreateGraph(&root));
132   ASSERT_THAT(sess.Ready(), Not(IsError()));
133 
134   CheckpointOp restore_checkpoint_op;
135   restore_checkpoint_op.set_before_restore_op("restore_a");
136   restore_checkpoint_op.set_after_restore_op("double_a");
137 
138   tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
139   EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
140               Not(IsError()));
141 
142   CheckOutput<int32_t>(&sess, "a",
143                        AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
144 }
145 
TEST(TfSessionTest,RestoreFromTensorNoSaverDefAllowed)146 TEST(TfSessionTest, RestoreFromTensorNoSaverDefAllowed) {
147   // Construct a TensorFlow graph with all desired operations.
148   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
149   auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
150   auto a = Variable(root, {2, 2}, tensorflow::DT_INT32);
151   auto restore = Assign(root.WithOpName("restore_a"), a, input);
152   auto double_a = Assign(root.WithOpName("double_a"), a,
153                          Mul(root, a, Const<int32_t>(root, {{2}})));
154 
155   // Run a session using the graph constructed above.
156   TestTracingRecorder tracing_recorder;
157   tracing_recorder.ExpectError<InvalidCheckpointOp>();
158   TfSession sess(testing::TempDir(), CreateGraph(&root));
159   ASSERT_THAT(sess.Ready(), Not(IsError()));
160 
161   CheckpointOp restore_checkpoint_op;
162   restore_checkpoint_op.set_before_restore_op("restore_a");
163   restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
164   restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
165       "filename");
166   restore_checkpoint_op.set_after_restore_op("double_a");
167 
168   tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
169   EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
170               IsError());
171 }
172 
TEST(TfSessionTest,SaveAndRestoreCheckpointBytes)173 TEST(TfSessionTest, SaveAndRestoreCheckpointBytes) {
174   // Construct a TensorFlow graph with all desired operations.
175   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
176   auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
177   // Save the current value of constant "a" in a serialized checkpoint.
178   auto filename =
179       Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
180   auto save_a = Save(root.WithOpName("save"), filename, {"a"},
181                      std::initializer_list<tensorflow::Input>{a});
182   // Restore the value saved in the serialized checkpoint to variable "c".
183   auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
184   auto restore = Assign(root.WithOpName("restore"), c,
185                         Restore(root, filename, "a", tensorflow::DT_INT32));
186 
187   // Run a session using the graph constructed above.
188   TestTracingRecorder tracing_recorder;
189   TfSession sess(testing::TempDir(), CreateGraph(&root));
190   ASSERT_THAT(sess.Ready(), Not(IsError()));
191 
192   // Save to a checkpoint.
193   CheckpointOp save_checkpoint_op;
194   save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
195   save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
196   Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
197   EXPECT_THAT(save_res, Not(IsError()));
198 
199   // Restore from that checkpoint.
200   CheckpointOp restore_checkpoint_op;
201   restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
202   restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
203       "filename");
204   EXPECT_THAT(
205       sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
206       Not(IsError()));
207 
208   // Verify the value of variable "c" was loaded properly from the checkpoint.
209   CheckOutput<int32_t>(&sess, "c",
210                        AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
211 }
212 
TEST(TfSessionTest,SaveCheckpointBytesSaveOpInTensorFormat)213 TEST(TfSessionTest, SaveCheckpointBytesSaveOpInTensorFormat) {
214   // Construct a TensorFlow graph with all desired operations.
215   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
216   auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
217   // Save the current value of variable "a" in a serialized checkpoint.
218   auto filename =
219       Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
220   auto save_a = Save(root.WithOpName("save"), filename, {"a"},
221                      std::initializer_list<tensorflow::Input>{a});
222 
223   // Run a session using the graph constructed above.
224   TestTracingRecorder tracing_recorder;
225   TfSession sess(testing::TempDir(), CreateGraph(&root));
226   ASSERT_THAT(sess.Ready(), Not(IsError()));
227 
228   // Ensure that attempting to save doesn't return an error even if the save op
229   // is provided in tensor format (with a trailing ":0")
230   CheckpointOp save_checkpoint_op;
231   save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save:0");
232   save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
233   Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
234   EXPECT_THAT(save_res, Not(IsError()));
235 }
236 
TEST(TfSessionTest,SaveAndRestoreWithBeforeAndAfterOps)237 TEST(TfSessionTest, SaveAndRestoreWithBeforeAndAfterOps) {
238   // Construct a TensorFlow graph with all desired operations.
239   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
240   auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
241   auto b = Variable(root, {1, 1}, tensorflow::DT_INT32);
242   auto init_a = Assign(root.WithOpName("init_a"), a,
243                        Const<int32_t>(root, {{1, 2}, {3, 4}}));
244   auto init_b =
245       Assign(root.WithOpName("init_b"), b, Const<int32_t>(root, {{2}}));
246   auto mul_a = Assign(root.WithOpName("mul_a"), a, Mul(root, a, b));
247   auto inc_b = Assign(root.WithOpName("inc_b"), b,
248                       Add(root, b, Const<int32_t>(root, {{1}})));
249   // Save the current value of variable "a" in a serialized checkpoint.
250   auto filename =
251       Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
252   auto save_a = Save(root.WithOpName("save"), filename, {"a"},
253                      std::initializer_list<tensorflow::Input>{a});
254   // Restore the value saved in the serialized checkpoint to variable "a".
255   auto restore = Assign(root.WithOpName("restore"), a,
256                         Restore(root, filename, "a", tensorflow::DT_INT32));
257 
258   // Run a session using the graph constructed above.
259   TestTracingRecorder tracing_recorder;
260   TfSession sess(testing::TempDir(), CreateGraph(&root));
261   ASSERT_THAT(sess.Ready(), Not(IsError()));
262   EXPECT_THAT(sess.RunOp("init_a"), Not(IsError()));
263   EXPECT_THAT(sess.RunOp("init_b"), Not(IsError()));
264 
265   // Set "a = a * b" and save that value to a checkpoint, then reset "a" to its
266   // initial state.
267   CheckpointOp save_checkpoint_op;
268   save_checkpoint_op.set_before_save_op("mul_a");
269   save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
270   save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
271   save_checkpoint_op.set_after_save_op("init_a");
272   Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
273   EXPECT_THAT(save_res, Not(IsError()));
274   // Check that the value of variable "a" has been reset to the initial value by
275   // the after_save_op.
276   CheckOutput<int32_t>(&sess, "a",
277                        AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
278 
279   // Increment "b" to 3 in the before_restore_op, set "a" to the value from the
280   // checkpoint, then set "a = a * b".
281   CheckpointOp restore_checkpoint_op;
282   restore_checkpoint_op.set_before_restore_op("inc_b");
283   restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
284   restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
285       "filename");
286   restore_checkpoint_op.set_after_restore_op("mul_a");
287   EXPECT_THAT(
288       sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
289       Not(IsError()));
290   // The initial value of "a" should have been multiplied by 2 in the
291   // before_save_op and multiplied by 3 in the after_restore_op.
292   CheckOutput<int32_t>(&sess, "a",
293                        AsTensor<int32_t>({6, 12, 18, 24}, TensorShape({2, 2})));
294 }
295 
296 }  // namespace fcp
297