1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
18
19 #include <memory>
20 #include <string>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/cord.h"
26 #include "absl/strings/str_cat.h"
27 #include "fcp/aggregation/core/mutable_vector_data.h"
28 #include "fcp/aggregation/core/tensor.h"
29 #include "fcp/aggregation/core/tensor_shape.h"
30 #include "fcp/aggregation/testing/test_data.h"
31 #include "fcp/aggregation/testing/testing.h"
32 #include "fcp/testing/testing.h"
33
34 namespace fcp::aggregation::tensorflow {
35 namespace {
36
37 using ::testing::AllOf;
38 using ::testing::Each;
39 using ::testing::Pair;
40 using ::testing::SizeIs;
41 using ::testing::StartsWith;
42 using ::testing::UnorderedElementsAre;
43
TEST(TensorflowCheckpointBuilderFactoryTest,BuildCheckpoint)44 TEST(TensorflowCheckpointBuilderFactoryTest, BuildCheckpoint) {
45 TensorflowCheckpointBuilderFactory factory;
46 std::unique_ptr<CheckpointBuilder> builder = factory.Create();
47
48 absl::StatusOr<Tensor> t1 = Tensor::Create(
49 DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
50 ASSERT_OK(t1.status());
51 absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
52 CreateTestData<float>({5.0, 6.0}));
53 ASSERT_OK(t2.status());
54
55 EXPECT_OK(builder->Add("t1", *t1));
56 EXPECT_OK(builder->Add("t2", *t2));
57 absl::StatusOr<absl::Cord> checkpoint = builder->Build();
58 ASSERT_OK(checkpoint.status());
59 auto summary = SummarizeCheckpoint(*checkpoint);
60 ASSERT_OK(summary.status());
61 EXPECT_THAT(*summary,
62 UnorderedElementsAre(Pair("t1", "1 2 3 4"), Pair("t2", "5 6")));
63 }
64
65 // Check that multiple checkpoints can be built simultanously.
TEST(TensorflowCheckpointBuilderFactoryTest,SimultaneousWrites)66 TEST(TensorflowCheckpointBuilderFactoryTest, SimultaneousWrites) {
67 TensorflowCheckpointBuilderFactory factory;
68
69 absl::StatusOr<Tensor> t1 = Tensor::Create(
70 DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
71 ASSERT_OK(t1.status());
72 absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
73 CreateTestData<float>({5.0, 6.0}));
74 ASSERT_OK(t2.status());
75
76 std::unique_ptr<CheckpointBuilder> builder1 = factory.Create();
77 std::unique_ptr<CheckpointBuilder> builder2 = factory.Create();
78 EXPECT_OK(builder1->Add("t1", *t1));
79 EXPECT_OK(builder2->Add("t2", *t2));
80 absl::StatusOr<absl::Cord> checkpoint1 = builder1->Build();
81 ASSERT_OK(checkpoint1.status());
82 absl::StatusOr<absl::Cord> checkpoint2 = builder2->Build();
83 ASSERT_OK(checkpoint2.status());
84 auto summary1 = SummarizeCheckpoint(*checkpoint1);
85 ASSERT_OK(summary1.status());
86 EXPECT_THAT(*summary1, UnorderedElementsAre(Pair("t1", "1 2 3 4")));
87 auto summary2 = SummarizeCheckpoint(*checkpoint2);
88 ASSERT_OK(summary2.status());
89 EXPECT_THAT(*summary2, UnorderedElementsAre(Pair("t2", "5 6")));
90 }
91
TEST(TensorflowCheckpointBuilderFactoryTest,LargeCheckpoint)92 TEST(TensorflowCheckpointBuilderFactoryTest, LargeCheckpoint) {
93 TensorflowCheckpointBuilderFactory factory;
94 std::unique_ptr<CheckpointBuilder> builder = factory.Create();
95
96 // Add 10 tensors that each require at least 8kB to exercise reading and
97 // writing in multiple chunks.
98 static constexpr int kTensorSize = 1024;
99 absl::StatusOr<Tensor> t =
100 Tensor::Create(DT_INT64, TensorShape({kTensorSize}),
101 std::make_unique<MutableVectorData<int64_t>>(kTensorSize));
102 ASSERT_OK(t.status());
103 for (int i = 0; i < 10; ++i) {
104 EXPECT_OK(builder->Add(absl::StrCat("t", i), *t));
105 }
106 absl::StatusOr<absl::Cord> checkpoint = builder->Build();
107 ASSERT_OK(checkpoint.status());
108 auto summary = SummarizeCheckpoint(*checkpoint);
109 ASSERT_OK(summary.status());
110 EXPECT_THAT(*summary,
111 AllOf(SizeIs(10), Each(Pair(StartsWith("t"),
112 StartsWith("0 0 0 0 0 0 0 0 0")))));
113 }
114
115 } // namespace
116 } // namespace fcp::aggregation::tensorflow
117