xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/tensor_aggregator_registry.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 
19 #include "fcp/aggregation/core/tensor_aggregator_factory.h"
20 
21 #ifdef FCP_BAREMETAL
22 #include <unordered_map>
23 #else
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/synchronization/mutex.h"
26 #endif
27 
28 namespace fcp {
29 namespace aggregation {
30 
31 namespace internal {
32 
33 class Registry final {
34  public:
RegisterAggregatorFactory(const std::string & intrinsic_uri,const TensorAggregatorFactory * factory)35   void RegisterAggregatorFactory(const std::string& intrinsic_uri,
36                                  const TensorAggregatorFactory* factory) {
37     FCP_CHECK(factory != nullptr);
38 
39 #ifndef FCP_BAREMETAL
40     absl::MutexLock lock(&mutex_);
41 #endif
42     FCP_CHECK(map_.find(intrinsic_uri) == map_.end())
43         << "A factory for intrinsic_uri '" << intrinsic_uri
44         << "' is already registered.";
45     map_[intrinsic_uri] = factory;
46     FCP_LOG(INFO) << "TensorAggregatorFactory for intrinsic_uri '"
47                   << intrinsic_uri << "' is registered.";
48   }
49 
GetAggregatorFactory(const std::string & intrinsic_uri)50   StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
51       const std::string& intrinsic_uri) {
52 #ifndef FCP_BAREMETAL
53     absl::MutexLock lock(&mutex_);
54 #endif
55     auto it = map_.find(intrinsic_uri);
56     if (it == map_.end()) {
57       return FCP_STATUS(NOT_FOUND)
58              << "Unknown factory for intrinsic_uri '" << intrinsic_uri << "'.";
59     }
60     return it->second;
61   }
62 
63  private:
64 #ifdef FCP_BAREMETAL
65   std::unordered_map<std::string, const TensorAggregatorFactory*> map_;
66 #else
67   // Synchronization of potentially concurrent registry calls is done only in
68   // the non-baremetal environment. In the baremetal environment, since there is
69   // no OS, a single thread execution environment is expected and the
70   // synchronization primitives aren't available.
71   absl::Mutex mutex_;
72   absl::flat_hash_map<std::string, const TensorAggregatorFactory*> map_
73       ABSL_GUARDED_BY(mutex_);
74 #endif
75 };
76 
77 #ifdef FCP_BAREMETAL
78 // TODO(team): Revise the registration mechanism below.
79 // In a baremetal build the static initialization mechanism isn't available
80 // which means that all the aggregation intrinsics need to be explicitly
81 // registered below.
82 extern "C" void RegisterFederatedSum();
83 
RegisterAll()84 void RegisterAll() { RegisterFederatedSum(); }
85 #endif  // FCP_BAREMETAL
86 
GetRegistry()87 Registry* GetRegistry() {
88   static Registry* global_registry = new Registry();
89 #ifdef FCP_BAREMETAL
90   // TODO(team): Revise the registration mechanism below.
91   static bool registration_done = false;
92   if (!registration_done) {
93     registration_done = true;
94     RegisterAll();
95   }
96 #endif
97   return global_registry;
98 }
99 
100 }  // namespace internal
101 
102 // Registers a factory instance for the given intrinsic type.
RegisterAggregatorFactory(const std::string & intrinsic_uri,const TensorAggregatorFactory * factory)103 void RegisterAggregatorFactory(const std::string& intrinsic_uri,
104                                const TensorAggregatorFactory* factory) {
105   internal::GetRegistry()->RegisterAggregatorFactory(intrinsic_uri, factory);
106 }
107 
108 // Looks up a factory instance for the given intrinsic type.
GetAggregatorFactory(const std::string & intrinsic_uri)109 StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
110     const std::string& intrinsic_uri) {
111   return internal::GetRegistry()->GetAggregatorFactory(intrinsic_uri);
112 }
113 
114 }  // namespace aggregation
115 }  // namespace fcp
116