xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/checkpoint_writer_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/checkpoint_writer.h"
18 
19 #include <cstdint>
20 #include <utility>
21 #include <vector>
22 
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "fcp/aggregation/core/datatype.h"
26 #include "fcp/aggregation/core/tensor.h"
27 #include "fcp/aggregation/core/tensor_shape.h"
28 #include "fcp/aggregation/tensorflow/checkpoint_reader.h"
29 #include "fcp/aggregation/testing/test_data.h"
30 #include "fcp/aggregation/testing/testing.h"
31 #include "fcp/testing/testing.h"
32 
33 namespace fcp::aggregation::tensorflow {
34 namespace {
35 
36 using ::testing::Key;
37 using ::testing::UnorderedElementsAre;
38 
TEST(CheckpointWriterTest,WriteTensors)39 TEST(CheckpointWriterTest, WriteTensors) {
40   // Write the checkpoint using Aggregation Core checkpoint writer.
41   auto temp_filename = TemporaryTestFile(".ckpt");
42 
43   auto t1 = Tensor::Create(DT_FLOAT, TensorShape({4}),
44                            CreateTestData<float>({1.0, 2.0, 3.0, 4.0}))
45                 .value();
46   auto t2 = Tensor::Create(DT_INT32, TensorShape({2, 3}),
47                            CreateTestData<int32_t>({11, 12, 13, 14, 15, 16}))
48                 .value();
49   auto t3 =
50       Tensor::Create(
51           DT_STRING, TensorShape({3}),
52           CreateTestData<string_view>({"foo", "bar", "bazzzzzzzzzzzzzzzzzzz"}))
53           .value();
54 
55   CheckpointWriter checkpoint_writer(temp_filename);
56   EXPECT_OK(checkpoint_writer.Add("a", t1));
57   EXPECT_OK(checkpoint_writer.Add("b", t2));
58   EXPECT_OK(checkpoint_writer.Add("c", t3));
59   EXPECT_OK(checkpoint_writer.Finish());
60 
61   // Read the checkpoint using the Aggregation Core checkpoint reader.
62   auto checkpoint_reader_or_status = CheckpointReader::Create(temp_filename);
63   EXPECT_OK(checkpoint_reader_or_status.status());
64 
65   auto checkpoint_reader = std::move(checkpoint_reader_or_status).value();
66   EXPECT_THAT(checkpoint_reader->GetDataTypeMap(),
67               UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
68   EXPECT_THAT(checkpoint_reader->GetTensorShapeMap(),
69               UnorderedElementsAre(Key("a"), Key("b"), Key("c")));
70 
71   // Read and verify the tensors.
72   EXPECT_THAT(*checkpoint_reader->GetTensor("a"),
73               IsTensor<float>({4}, {1.0, 2.0, 3.0, 4.0}));
74   EXPECT_THAT(*checkpoint_reader->GetTensor("b"),
75               IsTensor<int32_t>({2, 3}, {11, 12, 13, 14, 15, 16}));
76   EXPECT_THAT(
77       *checkpoint_reader->GetTensor("c"),
78       IsTensor<string_view>({3}, {"foo", "bar", "bazzzzzzzzzzzzzzzzzzz"}));
79 }
80 
81 }  // namespace
82 }  // namespace fcp::aggregation::tensorflow
83