1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
18 
19 #include <memory>
20 #include <string>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/cord.h"
26 #include "absl/strings/str_cat.h"
27 #include "fcp/aggregation/core/mutable_vector_data.h"
28 #include "fcp/aggregation/core/tensor.h"
29 #include "fcp/aggregation/core/tensor_shape.h"
30 #include "fcp/aggregation/testing/test_data.h"
31 #include "fcp/aggregation/testing/testing.h"
32 #include "fcp/testing/testing.h"
33 
34 namespace fcp::aggregation::tensorflow {
35 namespace {
36 
37 using ::testing::AllOf;
38 using ::testing::Each;
39 using ::testing::Pair;
40 using ::testing::SizeIs;
41 using ::testing::StartsWith;
42 using ::testing::UnorderedElementsAre;
43 
TEST(TensorflowCheckpointBuilderFactoryTest,BuildCheckpoint)44 TEST(TensorflowCheckpointBuilderFactoryTest, BuildCheckpoint) {
45   TensorflowCheckpointBuilderFactory factory;
46   std::unique_ptr<CheckpointBuilder> builder = factory.Create();
47 
48   absl::StatusOr<Tensor> t1 = Tensor::Create(
49       DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
50   ASSERT_OK(t1.status());
51   absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
52                                              CreateTestData<float>({5.0, 6.0}));
53   ASSERT_OK(t2.status());
54 
55   EXPECT_OK(builder->Add("t1", *t1));
56   EXPECT_OK(builder->Add("t2", *t2));
57   absl::StatusOr<absl::Cord> checkpoint = builder->Build();
58   ASSERT_OK(checkpoint.status());
59   auto summary = SummarizeCheckpoint(*checkpoint);
60   ASSERT_OK(summary.status());
61   EXPECT_THAT(*summary,
62               UnorderedElementsAre(Pair("t1", "1 2 3 4"), Pair("t2", "5 6")));
63 }
64 
65 // Check that multiple checkpoints can be built simultanously.
TEST(TensorflowCheckpointBuilderFactoryTest,SimultaneousWrites)66 TEST(TensorflowCheckpointBuilderFactoryTest, SimultaneousWrites) {
67   TensorflowCheckpointBuilderFactory factory;
68 
69   absl::StatusOr<Tensor> t1 = Tensor::Create(
70       DT_FLOAT, TensorShape({4}), CreateTestData<float>({1.0, 2.0, 3.0, 4.0}));
71   ASSERT_OK(t1.status());
72   absl::StatusOr<Tensor> t2 = Tensor::Create(DT_FLOAT, TensorShape({2}),
73                                              CreateTestData<float>({5.0, 6.0}));
74   ASSERT_OK(t2.status());
75 
76   std::unique_ptr<CheckpointBuilder> builder1 = factory.Create();
77   std::unique_ptr<CheckpointBuilder> builder2 = factory.Create();
78   EXPECT_OK(builder1->Add("t1", *t1));
79   EXPECT_OK(builder2->Add("t2", *t2));
80   absl::StatusOr<absl::Cord> checkpoint1 = builder1->Build();
81   ASSERT_OK(checkpoint1.status());
82   absl::StatusOr<absl::Cord> checkpoint2 = builder2->Build();
83   ASSERT_OK(checkpoint2.status());
84   auto summary1 = SummarizeCheckpoint(*checkpoint1);
85   ASSERT_OK(summary1.status());
86   EXPECT_THAT(*summary1, UnorderedElementsAre(Pair("t1", "1 2 3 4")));
87   auto summary2 = SummarizeCheckpoint(*checkpoint2);
88   ASSERT_OK(summary2.status());
89   EXPECT_THAT(*summary2, UnorderedElementsAre(Pair("t2", "5 6")));
90 }
91 
TEST(TensorflowCheckpointBuilderFactoryTest,LargeCheckpoint)92 TEST(TensorflowCheckpointBuilderFactoryTest, LargeCheckpoint) {
93   TensorflowCheckpointBuilderFactory factory;
94   std::unique_ptr<CheckpointBuilder> builder = factory.Create();
95 
96   // Add 10 tensors that each require at least 8kB to exercise reading and
97   // writing in multiple chunks.
98   static constexpr int kTensorSize = 1024;
99   absl::StatusOr<Tensor> t =
100       Tensor::Create(DT_INT64, TensorShape({kTensorSize}),
101                      std::make_unique<MutableVectorData<int64_t>>(kTensorSize));
102   ASSERT_OK(t.status());
103   for (int i = 0; i < 10; ++i) {
104     EXPECT_OK(builder->Add(absl::StrCat("t", i), *t));
105   }
106   absl::StatusOr<absl::Cord> checkpoint = builder->Build();
107   ASSERT_OK(checkpoint.status());
108   auto summary = SummarizeCheckpoint(*checkpoint);
109   ASSERT_OK(summary.status());
110   EXPECT_THAT(*summary,
111               AllOf(SizeIs(10), Each(Pair(StartsWith("t"),
112                                           StartsWith("0 0 0 0 0 0 0 0 0")))));
113 }
114 
115 }  // namespace
116 }  // namespace fcp::aggregation::tensorflow
117