1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 // fcp:google3-internal-file
17 #include "fcp/client/engine/tflite_plan_engine.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "absl/status/statusor.h"
26 #include "fcp/client/client_runner.h"
27 #include "fcp/client/diag_codes.pb.h"
28 #include "fcp/client/opstats/opstats_example_store.h"
29 #include "fcp/client/test_helpers.h"
30 #include "fcp/testing/testing.h"
31 #include "gmock/gmock.h"
32 #include "gtest/gtest.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/protobuf/struct.pb.h"
36
37 namespace fcp {
38 namespace client {
39 namespace engine {
40 namespace {
41 using ::fcp::client::opstats::OpStatsSequence;
42 using ::google::internal::federated::plan::ClientOnlyPlan;
43 using ::google::internal::federated::plan::Dataset;
44 using ::google::internal::federated::plan::FederatedComputeEligibilityIORouter;
45 using ::google::internal::federated::plan::FederatedComputeIORouter;
46 using ::google::internal::federated::plan::LocalComputeIORouter;
47 using ::testing::Gt;
48 using ::testing::InSequence;
49 using ::testing::Invoke;
50 using ::testing::IsEmpty;
51 using ::testing::NiceMock;
52 using ::testing::Return;
53 using ::testing::StrictMock;
54
55 // We turn formatting off to prevent line breaks, which ensures that these paths
56 // are more easily code searchable.
57 // clang-format off
58 constexpr absl::string_view kArtifactPrefix =
59 "intelligence/brella/testing/tasks/mnist/simpleagg_mnist_training_tflite_task_artifacts"; // NOLINT
60 constexpr absl::string_view kEligibilityPlanArtifactPrefix =
61 "intelligence/brella/testing/tasks/eligibility_eval/eligibility_eval_tflite_task_artifacts"; // NOLINT
62 constexpr absl::string_view kSecaggArtifactPrefix =
63 "intelligence/brella/testing/tasks/secagg_only_tflite_task_artifacts";
64 constexpr absl::string_view kLcArtifactPrefix =
65 "intelligence/brella/testing/local_computation/mnist_tflite_personalization_artifacts"; // NOLINT
66 constexpr absl::string_view kLcInitialCheckpoint =
67 "intelligence/brella/testing/local_computation/initial.ckpt";
68 constexpr absl::string_view kConstantInputsArtifactPrefix =
69 "intelligence/brella/testing/tasks/mnist/simpleagg_constant_tflite_inputs_task_artifacts"; // NOLINT
70 // clang-format on
71
72 const char* const kCollectionUri = "app:/test_collection";
73 const char* const kEligibilityEvalCollectionUri =
74 "app:/test_eligibility_eval_collection";
75 const char* const kLcTrainCollectionUri = "app:/p13n_train_collection";
76 const char* const kLcTestCollectionUri = "app:/p13n_test_collection";
77
78 // Parameterized with whether per_phase_logs should be used.
79 class TfLitePlanEngineTest : public testing::Test {
80 protected:
SetUp()81 void SetUp() override {
82 EXPECT_CALL(mock_opstats_logger_, IsOpStatsEnabled())
83 .WillRepeatedly(Return(true));
84 EXPECT_CALL(mock_opstats_logger_, GetOpStatsDb())
85 .WillRepeatedly(Return(&mock_opstats_db_));
86 EXPECT_CALL(mock_opstats_db_, Read())
87 .WillRepeatedly(Return(OpStatsSequence::default_instance()));
88 EXPECT_CALL(mock_flags_, ensure_dynamic_tensors_are_released())
89 .WillRepeatedly(Return(true));
90 EXPECT_CALL(mock_flags_, large_tensor_threshold_for_dynamic_allocation())
91 .WillRepeatedly(Return(1000));
92 EXPECT_CALL(mock_flags_, num_threads_for_tflite())
93 .WillRepeatedly(Return(4));
94 EXPECT_CALL(mock_flags_, disable_tflite_delegate_clustering())
95 .WillRepeatedly(Return(false));
96 EXPECT_CALL(mock_flags_, support_constant_tf_inputs())
97 .WillRepeatedly(Return(false));
98 }
99
InitializeFlTask(absl::string_view prefix)100 void InitializeFlTask(absl::string_view prefix) {
101 LoadArtifacts();
102
103 example_iterator_factory_ =
104 std::make_unique<FunctionalExampleIteratorFactory>(
105 [&dataset = dataset_](
106 const google::internal::federated::plan::ExampleSelector&
107 selector) {
108 return std::make_unique<::fcp::client::SimpleExampleIterator>(
109 dataset);
110 });
111
112 // Compute dataset stats.
113 for (const Dataset::ClientDataset& client_dataset :
114 dataset_.client_data()) {
115 num_examples_ += client_dataset.example_size();
116 for (const std::string& example : client_dataset.example()) {
117 example_bytes_ += example.size();
118 }
119 }
120 // The single session FL plan specifies both input and output filepaths in
121 // its FederatedComputeIORouter.
122 FederatedComputeIORouter io_router =
123 client_only_plan_.phase().federated_compute();
124 if (!io_router.input_filepath_tensor_name().empty()) {
125 (*inputs_)[io_router.input_filepath_tensor_name()] = checkpoint_input_fd_;
126 }
127 checkpoint_output_filename_ =
128 files_impl_.CreateTempFile("output", ".ckp").value();
129 ASSERT_EQ(std::filesystem::file_size(checkpoint_output_filename_), 0);
130 int fd = open(checkpoint_output_filename_.c_str(), O_WRONLY);
131 ASSERT_NE(-1, fd);
132 checkpoint_output_fd_ = absl::StrCat("fd:///", fd);
133 if (!io_router.output_filepath_tensor_name().empty()) {
134 (*inputs_)[io_router.output_filepath_tensor_name()] =
135 checkpoint_output_fd_;
136 }
137
138 for (const auto& tensor_spec :
139 client_only_plan_.phase().tensorflow_spec().output_tensor_specs()) {
140 output_names_.push_back(tensor_spec.name());
141 }
142 }
143
LoadArtifacts()144 void LoadArtifacts() {
145 absl::StatusOr<::fcp::client::ComputationArtifacts> artifacts =
146 ::fcp::client::LoadFlArtifacts();
147 EXPECT_TRUE(artifacts.ok());
148 client_only_plan_ = std::move(artifacts->plan);
149 dataset_ = std::move(artifacts->dataset);
150 checkpoint_input_filename_ = artifacts->checkpoint_filepath;
151 int fd = open(checkpoint_input_filename_.c_str(), O_RDONLY);
152 ASSERT_NE(-1, fd);
153 checkpoint_input_fd_ = absl::StrCat("fd:///", fd);
154 }
155
ComputeDatasetStats(const std::string & collection_uri)156 void ComputeDatasetStats(const std::string& collection_uri) {
157 for (const Dataset::ClientDataset& client_dataset :
158 dataset_.client_data()) {
159 for (const Dataset::ClientDataset::SelectedExample& selected_example :
160 client_dataset.selected_example()) {
161 if (selected_example.selector().collection_uri() != collection_uri) {
162 continue;
163 }
164 num_examples_ += selected_example.example_size();
165 for (const auto& example : selected_example.example()) {
166 example_bytes_ += example.size();
167 }
168 }
169 }
170 }
171
172 fcp::client::FilesImpl files_impl_;
173 StrictMock<MockLogManager> mock_log_manager_;
174 StrictMock<MockOpStatsLogger> mock_opstats_logger_;
175 StrictMock<MockOpStatsDb> mock_opstats_db_;
176 StrictMock<MockFlags> mock_flags_;
177 std::unique_ptr<ExampleIteratorFactory> example_iterator_factory_;
178 // Never abort, by default.
__anon1c077f190302() 179 std::function<bool()> should_abort_ = []() { return false; };
180
181 ClientOnlyPlan client_only_plan_;
182 Dataset dataset_;
183 std::string checkpoint_input_filename_;
184 std::string checkpoint_output_filename_;
185 std::string checkpoint_input_fd_;
186 std::string checkpoint_output_fd_;
187
188 int num_examples_ = 0;
189 int example_bytes_ = 0;
190 std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs_ =
191 std::make_unique<absl::flat_hash_map<std::string, std::string>>();
192 std::vector<std::string> output_names_;
193
194 fcp::client::InterruptibleRunner::TimingConfig timing_config_ = {
195 // Use 10 ms to make the polling faster, otherwise the Abort test might
196 // fail because the plan finishes before interruption.
197 .polling_period = absl::Milliseconds(10),
198 .graceful_shutdown_period = absl::Milliseconds(1000),
199 .extended_shutdown_period = absl::Milliseconds(2000),
200 };
201 };
202
TEST_F(TfLitePlanEngineTest,SimpleAggPlanSucceeds)203 TEST_F(TfLitePlanEngineTest, SimpleAggPlanSucceeds) {
204 InitializeFlTask(kArtifactPrefix);
205
206 EXPECT_CALL(mock_log_manager_,
207 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_TFLITE_ENGINE_USED));
208
209 EXPECT_CALL(
210 mock_opstats_logger_,
211 UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_));
212
213 TfLitePlanEngine plan_engine({example_iterator_factory_.get()}, should_abort_,
214 &mock_log_manager_, &mock_opstats_logger_,
215 &mock_flags_, &timing_config_);
216 engine::PlanResult result = plan_engine.RunPlan(
217 client_only_plan_.phase().tensorflow_spec(),
218 client_only_plan_.tflite_graph(), std::move(inputs_), output_names_);
219
220 EXPECT_THAT(result.outcome, PlanOutcome::kSuccess);
221 EXPECT_THAT(result.output_tensors.size(), 0);
222 EXPECT_THAT(result.output_names.size(), 0);
223 EXPECT_EQ(result.example_stats.example_count, num_examples_);
224 EXPECT_EQ(result.example_stats.example_size_bytes, example_bytes_);
225 EXPECT_GT(std::filesystem::file_size(checkpoint_output_filename_), 0);
226 }
227 } // namespace
228 } // namespace engine
229 } // namespace client
230 } // namespace fcp
231