1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18
19 #include "fcp/aggregation/core/tensor_aggregator_factory.h"
20
21 #ifdef FCP_BAREMETAL
22 #include <unordered_map>
23 #else
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/synchronization/mutex.h"
26 #endif
27
28 namespace fcp {
29 namespace aggregation {
30
31 namespace internal {
32
33 class Registry final {
34 public:
RegisterAggregatorFactory(const std::string & intrinsic_uri,const TensorAggregatorFactory * factory)35 void RegisterAggregatorFactory(const std::string& intrinsic_uri,
36 const TensorAggregatorFactory* factory) {
37 FCP_CHECK(factory != nullptr);
38
39 #ifndef FCP_BAREMETAL
40 absl::MutexLock lock(&mutex_);
41 #endif
42 FCP_CHECK(map_.find(intrinsic_uri) == map_.end())
43 << "A factory for intrinsic_uri '" << intrinsic_uri
44 << "' is already registered.";
45 map_[intrinsic_uri] = factory;
46 FCP_LOG(INFO) << "TensorAggregatorFactory for intrinsic_uri '"
47 << intrinsic_uri << "' is registered.";
48 }
49
GetAggregatorFactory(const std::string & intrinsic_uri)50 StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
51 const std::string& intrinsic_uri) {
52 #ifndef FCP_BAREMETAL
53 absl::MutexLock lock(&mutex_);
54 #endif
55 auto it = map_.find(intrinsic_uri);
56 if (it == map_.end()) {
57 return FCP_STATUS(NOT_FOUND)
58 << "Unknown factory for intrinsic_uri '" << intrinsic_uri << "'.";
59 }
60 return it->second;
61 }
62
63 private:
64 #ifdef FCP_BAREMETAL
65 std::unordered_map<std::string, const TensorAggregatorFactory*> map_;
66 #else
67 // Synchronization of potentially concurrent registry calls is done only in
68 // the non-baremetal environment. In the baremetal environment, since there is
69 // no OS, a single thread execution environment is expected and the
70 // synchronization primitives aren't available.
71 absl::Mutex mutex_;
72 absl::flat_hash_map<std::string, const TensorAggregatorFactory*> map_
73 ABSL_GUARDED_BY(mutex_);
74 #endif
75 };
76
77 #ifdef FCP_BAREMETAL
78 // TODO(team): Revise the registration mechanism below.
79 // In a baremetal build the static initialization mechanism isn't available
80 // which means that all the aggregation intrinsics need to be explicitly
81 // registered below.
82 extern "C" void RegisterFederatedSum();
83
RegisterAll()84 void RegisterAll() { RegisterFederatedSum(); }
85 #endif // FCP_BAREMETAL
86
GetRegistry()87 Registry* GetRegistry() {
88 static Registry* global_registry = new Registry();
89 #ifdef FCP_BAREMETAL
90 // TODO(team): Revise the registration mechanism below.
91 static bool registration_done = false;
92 if (!registration_done) {
93 registration_done = true;
94 RegisterAll();
95 }
96 #endif
97 return global_registry;
98 }
99
100 } // namespace internal
101
102 // Registers a factory instance for the given intrinsic type.
RegisterAggregatorFactory(const std::string & intrinsic_uri,const TensorAggregatorFactory * factory)103 void RegisterAggregatorFactory(const std::string& intrinsic_uri,
104 const TensorAggregatorFactory* factory) {
105 internal::GetRegistry()->RegisterAggregatorFactory(intrinsic_uri, factory);
106 }
107
108 // Looks up a factory instance for the given intrinsic type.
GetAggregatorFactory(const std::string & intrinsic_uri)109 StatusOr<const TensorAggregatorFactory*> GetAggregatorFactory(
110 const std::string& intrinsic_uri) {
111 return internal::GetRegistry()->GetAggregatorFactory(intrinsic_uri);
112 }
113
114 } // namespace aggregation
115 } // namespace fcp
116