1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/tensorflow/tf_session.h"
18
19 #include <string>
20
21 #include "gtest/gtest.h"
22 #include "fcp/base/tracing_schema.h"
23 #include "fcp/protos/plan.pb.h"
24 #include "fcp/tensorflow/testing/tf_helper.h"
25 #include "fcp/tensorflow/tracing_schema.h"
26 #include "fcp/testing/result_matchers.h"
27 #include "fcp/tracing/test_tracing_recorder.h"
28 #include "tensorflow/cc/framework/scope.h"
29 #include "tensorflow/cc/ops/math_ops.h"
30 #include "tensorflow/cc/ops/standard_ops.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/protobuf/saver.pb.h"
33
34 namespace fcp {
35
36 using google::internal::federated::plan::CheckpointOp;
37 using tensorflow::Tensor;
38 using tensorflow::TensorShape;
39 using tensorflow::ops::Add;
40 using tensorflow::ops::Assign;
41 using tensorflow::ops::Const;
42 using tensorflow::ops::Mul;
43 using tensorflow::ops::Placeholder;
44 using tensorflow::ops::Restore;
45 using tensorflow::ops::Save;
46 using tensorflow::ops::Variable;
47 using tensorflow::test::AsTensor;
48 using tensorflow::test::ExpectTensorEqual;
49 using testing::_;
50 using testing::Not;
51
52 template <typename T>
CheckOutput(TfSession * sess,const std::string & output_op,Tensor expected)53 void CheckOutput(TfSession* sess, const std::string& output_op,
54 Tensor expected) {
55 Result<std::unique_ptr<TfSession::NamedTensorMap>> outputs =
56 sess->GetOutputs(std::make_unique<std::vector<std::string>>(
57 std::initializer_list<std::string>{output_op}));
58 EXPECT_THAT(outputs, Not(IsError()));
59 ExpectTensorEqual<T>((*outputs.GetValueOrDie())[output_op], expected);
60 }
61
TEST(TfSessionTest,InitializeWithEmptyGraph)62 TEST(TfSessionTest, InitializeWithEmptyGraph) {
63 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
64 TestTracingRecorder tracing_recorder;
65 TfSession sess("foo/bar", CreateGraph(&root));
66 ASSERT_THAT(sess.Ready(), Not(IsError()));
67 // Running an empty operation is a no-op.
68 EXPECT_THAT(sess.RunOp(""), Not(IsError()));
69 // Getting an empty list of outputs is a no-op.
70 EXPECT_THAT(sess.GetOutputs(std::make_unique<std::vector<std::string>>()),
71 Not(IsError()));
72 // There are no ops registered in the GraphDef, so trying to run an op won't
73 // work.
74 tracing_recorder.ExpectError<ResultExpectStatusError>();
75 EXPECT_THAT(sess.RunOp("sum"), IsError());
76 // Validate the expected hierarchy of tracing spans. There should be only one
77 // RunTfOp span, as we don't want to bother recording a noop if the op is
78 // empty.
79 EXPECT_THAT(tracing_recorder.root(),
80 ElementsAre(AllOf(
81 IsSpan<RunTfOp>(),
82 ElementsAre(IsEvent<ResultExpectStatusError>(
83 static_cast<int>(fcp::OK),
84 static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)))));
85 }
86
TEST(TfSessionTest,InvalidGraphBytes)87 TEST(TfSessionTest, InvalidGraphBytes) {
88 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
89 TestTracingRecorder tracing_recorder;
90 tracing_recorder.ExpectError<ResultExpectStatusError>();
91 TfSession sess("foo/bar", "garbage");
92 ASSERT_THAT(sess.Ready(), IsError());
93 EXPECT_THAT(tracing_recorder.root(),
94 ElementsAre(IsEvent<ResultExpectStatusError>(
95 static_cast<int>(fcp::OK),
96 static_cast<int>(fcp::INVALID_ARGUMENT), _, _, _)));
97 }
98
TEST(TfSessionTest,RunGraphOp)99 TEST(TfSessionTest, RunGraphOp) {
100 // Construct a TensorFlow graph with all desired operations.
101 // This graph just assigns the result of multiplying two constants "a" and "b"
102 // to a variable "c", and makes it possible to output "c".
103 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
104 auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
105 auto b = Const<int32_t>(root, {{2}});
106 auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
107 auto assign_c = Assign(root.WithOpName("assign_c"), c, Mul(root, a, b));
108
109 // Run a session using the graph constructed above.
110 TestTracingRecorder tracing_recorder;
111 TfSession sess("foo/bar", CreateGraph(&root));
112 ASSERT_THAT(sess.Ready(), Not(IsError()));
113
114 // Run an operation on the session and validate the result.
115 EXPECT_THAT(sess.RunOp("assign_c"), Not(IsError()));
116 CheckOutput<int32_t>(&sess, "c",
117 AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
118 }
119
TEST(TfSessionTest,RestoreFromTensor)120 TEST(TfSessionTest, RestoreFromTensor) {
121 // Construct a TensorFlow graph with all desired operations.
122 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
123 auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
124 auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
125 auto restore = Assign(root.WithOpName("restore_a"), a, input);
126 auto double_a = Assign(root.WithOpName("double_a"), a,
127 Mul(root, a, Const<int32_t>(root, {{2}})));
128
129 // Run a session using the graph constructed above.
130 TestTracingRecorder tracing_recorder;
131 TfSession sess(testing::TempDir(), CreateGraph(&root));
132 ASSERT_THAT(sess.Ready(), Not(IsError()));
133
134 CheckpointOp restore_checkpoint_op;
135 restore_checkpoint_op.set_before_restore_op("restore_a");
136 restore_checkpoint_op.set_after_restore_op("double_a");
137
138 tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
139 EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
140 Not(IsError()));
141
142 CheckOutput<int32_t>(&sess, "a",
143 AsTensor<int32_t>({2, 4, 6, 8}, TensorShape({2, 2})));
144 }
145
TEST(TfSessionTest,RestoreFromTensorNoSaverDefAllowed)146 TEST(TfSessionTest, RestoreFromTensorNoSaverDefAllowed) {
147 // Construct a TensorFlow graph with all desired operations.
148 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
149 auto input = Placeholder(root.WithOpName("i"), tensorflow::DT_INT32);
150 auto a = Variable(root, {2, 2}, tensorflow::DT_INT32);
151 auto restore = Assign(root.WithOpName("restore_a"), a, input);
152 auto double_a = Assign(root.WithOpName("double_a"), a,
153 Mul(root, a, Const<int32_t>(root, {{2}})));
154
155 // Run a session using the graph constructed above.
156 TestTracingRecorder tracing_recorder;
157 tracing_recorder.ExpectError<InvalidCheckpointOp>();
158 TfSession sess(testing::TempDir(), CreateGraph(&root));
159 ASSERT_THAT(sess.Ready(), Not(IsError()));
160
161 CheckpointOp restore_checkpoint_op;
162 restore_checkpoint_op.set_before_restore_op("restore_a");
163 restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
164 restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
165 "filename");
166 restore_checkpoint_op.set_after_restore_op("double_a");
167
168 tensorflow::Input::Initializer i({{1, 2}, {3, 4}});
169 EXPECT_THAT(sess.RestoreState(restore_checkpoint_op, {{"i", i.tensor}}),
170 IsError());
171 }
172
TEST(TfSessionTest,SaveAndRestoreCheckpointBytes)173 TEST(TfSessionTest, SaveAndRestoreCheckpointBytes) {
174 // Construct a TensorFlow graph with all desired operations.
175 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
176 auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
177 // Save the current value of constant "a" in a serialized checkpoint.
178 auto filename =
179 Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
180 auto save_a = Save(root.WithOpName("save"), filename, {"a"},
181 std::initializer_list<tensorflow::Input>{a});
182 // Restore the value saved in the serialized checkpoint to variable "c".
183 auto c = Variable(root.WithOpName("c"), {2, 2}, tensorflow::DT_INT32);
184 auto restore = Assign(root.WithOpName("restore"), c,
185 Restore(root, filename, "a", tensorflow::DT_INT32));
186
187 // Run a session using the graph constructed above.
188 TestTracingRecorder tracing_recorder;
189 TfSession sess(testing::TempDir(), CreateGraph(&root));
190 ASSERT_THAT(sess.Ready(), Not(IsError()));
191
192 // Save to a checkpoint.
193 CheckpointOp save_checkpoint_op;
194 save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
195 save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
196 Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
197 EXPECT_THAT(save_res, Not(IsError()));
198
199 // Restore from that checkpoint.
200 CheckpointOp restore_checkpoint_op;
201 restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
202 restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
203 "filename");
204 EXPECT_THAT(
205 sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
206 Not(IsError()));
207
208 // Verify the value of variable "c" was loaded properly from the checkpoint.
209 CheckOutput<int32_t>(&sess, "c",
210 AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
211 }
212
TEST(TfSessionTest,SaveCheckpointBytesSaveOpInTensorFormat)213 TEST(TfSessionTest, SaveCheckpointBytesSaveOpInTensorFormat) {
214 // Construct a TensorFlow graph with all desired operations.
215 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
216 auto a = Const<int32_t>(root, {{1, 2}, {3, 4}});
217 // Save the current value of variable "a" in a serialized checkpoint.
218 auto filename =
219 Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
220 auto save_a = Save(root.WithOpName("save"), filename, {"a"},
221 std::initializer_list<tensorflow::Input>{a});
222
223 // Run a session using the graph constructed above.
224 TestTracingRecorder tracing_recorder;
225 TfSession sess(testing::TempDir(), CreateGraph(&root));
226 ASSERT_THAT(sess.Ready(), Not(IsError()));
227
228 // Ensure that attempting to save doesn't return an error even if the save op
229 // is provided in tensor format (with a trailing ":0")
230 CheckpointOp save_checkpoint_op;
231 save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save:0");
232 save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
233 Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
234 EXPECT_THAT(save_res, Not(IsError()));
235 }
236
TEST(TfSessionTest,SaveAndRestoreWithBeforeAndAfterOps)237 TEST(TfSessionTest, SaveAndRestoreWithBeforeAndAfterOps) {
238 // Construct a TensorFlow graph with all desired operations.
239 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
240 auto a = Variable(root.WithOpName("a"), {2, 2}, tensorflow::DT_INT32);
241 auto b = Variable(root, {1, 1}, tensorflow::DT_INT32);
242 auto init_a = Assign(root.WithOpName("init_a"), a,
243 Const<int32_t>(root, {{1, 2}, {3, 4}}));
244 auto init_b =
245 Assign(root.WithOpName("init_b"), b, Const<int32_t>(root, {{2}}));
246 auto mul_a = Assign(root.WithOpName("mul_a"), a, Mul(root, a, b));
247 auto inc_b = Assign(root.WithOpName("inc_b"), b,
248 Add(root, b, Const<int32_t>(root, {{1}})));
249 // Save the current value of variable "a" in a serialized checkpoint.
250 auto filename =
251 Placeholder(root.WithOpName("filename"), tensorflow::DT_STRING);
252 auto save_a = Save(root.WithOpName("save"), filename, {"a"},
253 std::initializer_list<tensorflow::Input>{a});
254 // Restore the value saved in the serialized checkpoint to variable "a".
255 auto restore = Assign(root.WithOpName("restore"), a,
256 Restore(root, filename, "a", tensorflow::DT_INT32));
257
258 // Run a session using the graph constructed above.
259 TestTracingRecorder tracing_recorder;
260 TfSession sess(testing::TempDir(), CreateGraph(&root));
261 ASSERT_THAT(sess.Ready(), Not(IsError()));
262 EXPECT_THAT(sess.RunOp("init_a"), Not(IsError()));
263 EXPECT_THAT(sess.RunOp("init_b"), Not(IsError()));
264
265 // Set "a = a * b" and save that value to a checkpoint, then reset "a" to its
266 // initial state.
267 CheckpointOp save_checkpoint_op;
268 save_checkpoint_op.set_before_save_op("mul_a");
269 save_checkpoint_op.mutable_saver_def()->set_save_tensor_name("save");
270 save_checkpoint_op.mutable_saver_def()->set_filename_tensor_name("filename");
271 save_checkpoint_op.set_after_save_op("init_a");
272 Result<absl::Cord> save_res = sess.SaveState(save_checkpoint_op);
273 EXPECT_THAT(save_res, Not(IsError()));
274 // Check that the value of variable "a" has been reset to the initial value by
275 // the after_save_op.
276 CheckOutput<int32_t>(&sess, "a",
277 AsTensor<int32_t>({1, 2, 3, 4}, TensorShape({2, 2})));
278
279 // Increment "b" to 3 in the before_restore_op, set "a" to the value from the
280 // checkpoint, then set "a = a * b".
281 CheckpointOp restore_checkpoint_op;
282 restore_checkpoint_op.set_before_restore_op("inc_b");
283 restore_checkpoint_op.mutable_saver_def()->set_restore_op_name("restore");
284 restore_checkpoint_op.mutable_saver_def()->set_filename_tensor_name(
285 "filename");
286 restore_checkpoint_op.set_after_restore_op("mul_a");
287 EXPECT_THAT(
288 sess.RestoreState(restore_checkpoint_op, save_res.GetValueOrDie()),
289 Not(IsError()));
290 // The initial value of "a" should have been multiplied by 2 in the
291 // before_save_op and multiplied by 3 in the after_restore_op.
292 CheckOutput<int32_t>(&sess, "a",
293 AsTensor<int32_t>({6, 12, 18, 24}, TensorShape({2, 2})));
294 }
295
296 } // namespace fcp
297