xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_plan_engine_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 // fcp:google3-internal-file
17 #include "fcp/client/engine/tflite_plan_engine.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/status/statusor.h"
26 #include "fcp/client/client_runner.h"
27 #include "fcp/client/diag_codes.pb.h"
28 #include "fcp/client/opstats/opstats_example_store.h"
29 #include "fcp/client/test_helpers.h"
30 #include "fcp/testing/testing.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/protobuf/struct.pb.h"
36 
37 namespace fcp {
38 namespace client {
39 namespace engine {
40 namespace {
41 using ::fcp::client::opstats::OpStatsSequence;
42 using ::google::internal::federated::plan::ClientOnlyPlan;
43 using ::google::internal::federated::plan::Dataset;
44 using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
45 using ::google::internal::federated::plan::FederatedComputeIORouter;
46 using ::google::internal::federated::plan::LocalComputeIORouter;
47 using ::testing::Gt;
48 using ::testing::InSequence;
49 using ::testing::Invoke;
50 using ::testing::IsEmpty;
51 using ::testing::NiceMock;
52 using ::testing::Return;
53 using ::testing::StrictMock;
54 
55 // We turn formatting off to prevent line breaks, which ensures that these paths
56 // are more easily code searchable.
57 // clang-format off
58 constexpr absl::string_view kArtifactPrefix =
59     "intelligence/brella/testing/tasks/mnist/simpleagg_mnist_training_tflite_task_artifacts"; // NOLINT
60 constexpr absl::string_view kEligibilityPlanArtifactPrefix =
61     "intelligence/brella/testing/tasks/eligibility_eval/eligibility_eval_tflite_task_artifacts"; // NOLINT
62 constexpr absl::string_view kSecaggArtifactPrefix =
63     "intelligence/brella/testing/tasks/secagg_only_tflite_task_artifacts";
64 constexpr absl::string_view kLcArtifactPrefix =
65     "intelligence/brella/testing/local_computation/mnist_tflite_personalization_artifacts"; // NOLINT
66 constexpr absl::string_view kLcInitialCheckpoint =
67     "intelligence/brella/testing/local_computation/initial.ckpt";
68 constexpr absl::string_view kConstantInputsArtifactPrefix =
69     "intelligence/brella/testing/tasks/mnist/simpleagg_constant_tflite_inputs_task_artifacts"; // NOLINT
70 // clang-format on
71 
72 const char* const kCollectionUri = "app:/test_collection";
73 const char* const kEligibilityEvalCollectionUri =
74     "app:/test_eligibility_eval_collection";
75 const char* const kLcTrainCollectionUri = "app:/p13n_train_collection";
76 const char* const kLcTestCollectionUri = "app:/p13n_test_collection";
77 
78 // Parameterized with whether per_phase_logs should be used.
79 class TfLitePlanEngineTest : public testing::Test {
80  protected:
SetUp()81   void SetUp() override {
82     EXPECT_CALL(mock_opstats_logger_, IsOpStatsEnabled())
83         .WillRepeatedly(Return(true));
84     EXPECT_CALL(mock_opstats_logger_, GetOpStatsDb())
85         .WillRepeatedly(Return(&mock_opstats_db_));
86     EXPECT_CALL(mock_opstats_db_, Read())
87         .WillRepeatedly(Return(OpStatsSequence::default_instance()));
88     EXPECT_CALL(mock_flags_, ensure_dynamic_tensors_are_released())
89         .WillRepeatedly(Return(true));
90     EXPECT_CALL(mock_flags_, large_tensor_threshold_for_dynamic_allocation())
91         .WillRepeatedly(Return(1000));
92     EXPECT_CALL(mock_flags_, num_threads_for_tflite())
93         .WillRepeatedly(Return(4));
94     EXPECT_CALL(mock_flags_, disable_tflite_delegate_clustering())
95         .WillRepeatedly(Return(false));
96     EXPECT_CALL(mock_flags_, support_constant_tf_inputs())
97         .WillRepeatedly(Return(false));
98   }
99 
InitializeFlTask(absl::string_view prefix)100   void InitializeFlTask(absl::string_view prefix) {
101     LoadArtifacts();
102 
103     example_iterator_factory_ =
104         std::make_unique<FunctionalExampleIteratorFactory>(
105             [&dataset = dataset_](
106                 const google::internal::federated::plan::ExampleSelector&
107                     selector) {
108               return std::make_unique<::fcp::client::SimpleExampleIterator>(
109                   dataset);
110             });
111 
112     // Compute dataset stats.
113     for (const Dataset::ClientDataset& client_dataset :
114          dataset_.client_data()) {
115       num_examples_ += client_dataset.example_size();
116       for (const std::string& example : client_dataset.example()) {
117         example_bytes_ += example.size();
118       }
119     }
120     // The single session FL plan specifies both input and output filepaths in
121     // its FederatedComputeIORouter.
122     FederatedComputeIORouter io_router =
123         client_only_plan_.phase().federated_compute();
124     if (!io_router.input_filepath_tensor_name().empty()) {
125       (*inputs_)[io_router.input_filepath_tensor_name()] = checkpoint_input_fd_;
126     }
127     checkpoint_output_filename_ =
128         files_impl_.CreateTempFile("output", ".ckp").value();
129     ASSERT_EQ(std::filesystem::file_size(checkpoint_output_filename_), 0);
130     int fd = open(checkpoint_output_filename_.c_str(), O_WRONLY);
131     ASSERT_NE(-1, fd);
132     checkpoint_output_fd_ = absl::StrCat("fd:///", fd);
133     if (!io_router.output_filepath_tensor_name().empty()) {
134       (*inputs_)[io_router.output_filepath_tensor_name()] =
135           checkpoint_output_fd_;
136     }
137 
138     for (const auto& tensor_spec :
139          client_only_plan_.phase().tensorflow_spec().output_tensor_specs()) {
140       output_names_.push_back(tensor_spec.name());
141     }
142   }
143 
LoadArtifacts()144   void LoadArtifacts() {
145     absl::StatusOr<::fcp::client::ComputationArtifacts> artifacts =
146         ::fcp::client::LoadFlArtifacts();
147     EXPECT_TRUE(artifacts.ok());
148     client_only_plan_ = std::move(artifacts->plan);
149     dataset_ = std::move(artifacts->dataset);
150     checkpoint_input_filename_ = artifacts->checkpoint_filepath;
151     int fd = open(checkpoint_input_filename_.c_str(), O_RDONLY);
152     ASSERT_NE(-1, fd);
153     checkpoint_input_fd_ = absl::StrCat("fd:///", fd);
154   }
155 
ComputeDatasetStats(const std::string & collection_uri)156   void ComputeDatasetStats(const std::string& collection_uri) {
157     for (const Dataset::ClientDataset& client_dataset :
158          dataset_.client_data()) {
159       for (const Dataset::ClientDataset::SelectedExample& selected_example :
160            client_dataset.selected_example()) {
161         if (selected_example.selector().collection_uri() != collection_uri) {
162           continue;
163         }
164         num_examples_ += selected_example.example_size();
165         for (const auto& example : selected_example.example()) {
166           example_bytes_ += example.size();
167         }
168       }
169     }
170   }
171 
172   fcp::client::FilesImpl files_impl_;
173   StrictMock<MockLogManager> mock_log_manager_;
174   StrictMock<MockOpStatsLogger> mock_opstats_logger_;
175   StrictMock<MockOpStatsDb> mock_opstats_db_;
176   StrictMock<MockFlags> mock_flags_;
177   std::unique_ptr<ExampleIteratorFactory> example_iterator_factory_;
178   // Never abort, by default.
__anon1c077f190302() 179   std::function<bool()> should_abort_ = []() { return false; };
180 
181   ClientOnlyPlan client_only_plan_;
182   Dataset dataset_;
183   std::string checkpoint_input_filename_;
184   std::string checkpoint_output_filename_;
185   std::string checkpoint_input_fd_;
186   std::string checkpoint_output_fd_;
187 
188   int num_examples_ = 0;
189   int example_bytes_ = 0;
190   std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs_ =
191       std::make_unique<absl::flat_hash_map<std::string, std::string>>();
192   std::vector<std::string> output_names_;
193 
194   fcp::client::InterruptibleRunner::TimingConfig timing_config_ = {
195       // Use 10 ms to make the polling faster, otherwise the Abort test might
196       // fail because the plan finishes before interruption.
197       .polling_period = absl::Milliseconds(10),
198       .graceful_shutdown_period = absl::Milliseconds(1000),
199       .extended_shutdown_period = absl::Milliseconds(2000),
200   };
201 };
202 
TEST_F(TfLitePlanEngineTest,SimpleAggPlanSucceeds)203 TEST_F(TfLitePlanEngineTest, SimpleAggPlanSucceeds) {
204   InitializeFlTask(kArtifactPrefix);
205 
206   EXPECT_CALL(mock_log_manager_,
207               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED));
208 
209   EXPECT_CALL(
210       mock_opstats_logger_,
211       UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
212 
213   TfLitePlanEngine plan_engine({example_iterator_factory_.get()}, should_abort_,
214                                &mock_log_manager_, &mock_opstats_logger_,
215                                &mock_flags_, &timing_config_);
216   engine::PlanResult result = plan_engine.RunPlan(
217       client_only_plan_.phase().tensorflow_spec(),
218       client_only_plan_.tflite_graph(), std::move(inputs_), output_names_);
219 
220   EXPECT_THAT(result.outcome, PlanOutcome::kSuccess);
221   EXPECT_THAT(result.output_tensors.size(), 0);
222   EXPECT_THAT(result.output_names.size(), 0);
223   EXPECT_EQ(result.example_stats.example_count, num_examples_);
224   EXPECT_EQ(result.example_stats.example_size_bytes, example_bytes_);
225   EXPECT_GT(std::filesystem::file_size(checkpoint_output_filename_), 0);
226 }
227 }  // namespace
228 }  // namespace engine
229 }  // namespace client
230 }  // namespace fcp
231