xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_plan_engine.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
17 #define FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
18 
19 #include <functional>
20 #include <string>
21 #include <vector>
22 
23 #include "fcp/client/engine/common.h"
24 #include "fcp/client/engine/example_iterator_factory.h"
25 #include "fcp/client/event_publisher.h"
26 #include "fcp/client/flags.h"
27 #include "fcp/client/interruptible_runner.h"
28 #include "fcp/client/log_manager.h"
29 #include "fcp/client/opstats/opstats_logger.h"
30 #include "fcp/client/simple_task_environment.h"
31 
32 namespace fcp {
33 namespace client {
34 namespace engine {
35 
36 // A class used to "run" (interpret) a TensorflowSpec-based plan with TfLite.
37 // Each instance should generally only be used once to run a plan.
38 class TfLitePlanEngine {
39  public:
40   // For each example query issued by the plan at runtime, the given
41   // `example_iterator_factories` parameter will be iterated and the first
42   // iterator factory that can handle the given query will be used to create the
43   // example iterator for that query.
TfLitePlanEngine(std::vector<ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,::fcp::client::opstats::OpStatsLogger * opstats_logger,const Flags * flags,const InterruptibleRunner::TimingConfig * timing_config)44   TfLitePlanEngine(
45       std::vector<ExampleIteratorFactory*> example_iterator_factories,
46       std::function<bool()> should_abort, LogManager* log_manager,
47       ::fcp::client::opstats::OpStatsLogger* opstats_logger, const Flags* flags,
48       const InterruptibleRunner::TimingConfig* timing_config)
49       : example_iterator_factories_(example_iterator_factories),
50         should_abort_(should_abort),
51         log_manager_(log_manager),
52         opstats_logger_(opstats_logger),
53         flags_(*flags),
54         timing_config_(timing_config) {}
55 
56   // Runs the plan, and takes care of logging TfLite errors and external
57   // interruptions via event_publisher. If the TfLite call fails because it got
58   // aborted externally, returns CANCELLED. If the TfLite call fails because of
59   // other reasons, publishes an event, then returns INVALID_ARGUMENT. If the
60   // TfLite call is successful, returns OK, and the output tensors.
61   PlanResult RunPlan(
62       const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
63       const std::string& model,
64       std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
65       const std::vector<std::string>& output_names);
66 
67  private:
68   std::vector<ExampleIteratorFactory*> example_iterator_factories_;
69   std::function<bool()> should_abort_;
70   LogManager* log_manager_;
71   ::fcp::client::opstats::OpStatsLogger* opstats_logger_;
72   const Flags& flags_;
73   const InterruptibleRunner::TimingConfig* timing_config_;
74 };
75 
76 }  // namespace engine
77 }  // namespace client
78 }  // namespace fcp
79 
80 #endif  // FCP_CLIENT_ENGINE_TFLITE_PLAN_ENGINE_H_
81