xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/map_of_masks_bench.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cstdint>
18 #include <vector>
19 
20 #include "absl/numeric/bits.h"
21 #include "absl/strings/str_cat.h"
22 #include "benchmark/benchmark.h"
23 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
24 #include "fcp/secagg/shared/input_vector_specification.h"
25 #include "fcp/secagg/shared/map_of_masks.h"
26 #include "fcp/secagg/shared/secagg_vector.h"
27 
28 namespace fcp {
29 namespace secagg {
30 namespace {
31 
32 constexpr auto kVectorSize = 1024 * 1024;
33 constexpr auto kNumKeys = 128;
34 
BM_MapOfMasks_Impl(benchmark::State & state,uint64_t modulus)35 inline void BM_MapOfMasks_Impl(benchmark::State& state, uint64_t modulus) {
36   state.PauseTiming();
37   std::vector<AesKey> prng_keys_to_add;
38   uint8_t key[AesKey::kSize];
39   memset(key, 'A', AesKey::kSize);
40   prng_keys_to_add.reserve(kNumKeys);
41   for (int i = 0; i < kNumKeys; i++) {
42     prng_keys_to_add.emplace_back(key);
43   }
44   std::vector<AesKey> prng_keys_to_subtract;
45   SessionId session_id = {std::string(32, 'Z')};
46 
47   std::vector<InputVectorSpecification> vector_specs;
48   vector_specs.emplace_back("unused", kVectorSize, modulus);
49 
50   state.ResumeTiming();
51   benchmark::DoNotOptimize(MapOfMasks(
52       prng_keys_to_add, prng_keys_to_subtract, vector_specs, session_id,
53       static_cast<const AesPrngFactory&>(AesCtrPrngFactory())));
54 
55   state.SetItemsProcessed(kVectorSize);
56 }
57 
BM_MapOfMasksV3_Impl(benchmark::State & state,uint64_t modulus)58 inline void BM_MapOfMasksV3_Impl(benchmark::State& state, uint64_t modulus) {
59   state.PauseTiming();
60   std::vector<AesKey> prng_keys_to_add;
61   uint8_t key[AesKey::kSize];
62   memset(key, 'A', AesKey::kSize);
63   prng_keys_to_add.reserve(kNumKeys);
64   for (int i = 0; i < kNumKeys; i++) {
65     prng_keys_to_add.emplace_back(key);
66   }
67   std::vector<AesKey> prng_keys_to_subtract;
68   SessionId session_id = {std::string(32, 'Z')};
69 
70   std::vector<InputVectorSpecification> vector_specs;
71   vector_specs.emplace_back("unused", kVectorSize, modulus);
72 
73   state.ResumeTiming();
74   benchmark::DoNotOptimize(MapOfMasksV3(
75       prng_keys_to_add, prng_keys_to_subtract, vector_specs, session_id,
76       static_cast<const AesPrngFactory&>(AesCtrPrngFactory())));
77 
78   state.SetItemsProcessed(kVectorSize);
79 }
80 
BM_MapOfMasks_PowerOfTwo(benchmark::State & state)81 void BM_MapOfMasks_PowerOfTwo(benchmark::State& state) {
82   for (auto s : state) {
83     int bitwidth = static_cast<int>(state.range(0));
84     BM_MapOfMasks_Impl(state, 1ULL << bitwidth);
85   }
86 }
87 
BM_MapOfMasks_Arbitrary(benchmark::State & state)88 void BM_MapOfMasks_Arbitrary(benchmark::State& state) {
89   for (auto s : state) {
90     uint64_t modulus = static_cast<uint64_t>(state.range(0));
91     BM_MapOfMasks_Impl(state, modulus);
92   }
93 }
94 
BM_MapOfMasksV3_PowerOfTwo(benchmark::State & state)95 void BM_MapOfMasksV3_PowerOfTwo(benchmark::State& state) {
96   for (auto s : state) {
97     int bitwidth = static_cast<int>(state.range(0));
98     BM_MapOfMasksV3_Impl(state, 1ULL << bitwidth);
99   }
100 }
101 
BM_MapOfMasksV3_Arbitrary(benchmark::State & state)102 void BM_MapOfMasksV3_Arbitrary(benchmark::State& state) {
103   for (auto s : state) {
104     uint64_t modulus = static_cast<uint64_t>(state.range(0));
105     BM_MapOfMasksV3_Impl(state, modulus);
106   }
107 }
108 
109 BENCHMARK(BM_MapOfMasks_PowerOfTwo)
110     ->Arg(9)
111     ->Arg(25)
112     ->Arg(41)
113     ->Arg(53)
114     ->Arg(absl::bit_width(SecAggVector::kMaxModulus - 1));
115 
116 BENCHMARK(BM_MapOfMasks_Arbitrary)
117     ->Arg(5)
118     ->Arg(39)
119     ->Arg(485)
120     ->Arg(2400)
121     ->Arg(14901)
122     ->Arg(51813)
123     ->Arg(532021)
124     ->Arg(13916946)
125     ->Arg(39549497)
126     ->Arg(548811945)
127     ->Arg(590549014)
128     ->Arg(48296031686)
129     ->Arg(156712951284)
130     ->Arg(2636861836189)
131     ->Arg(14673852658160)
132     ->Arg(92971495438615)
133     ->Arg(304436005557271)
134     ->Arg(14046234330484262)
135     ->Arg(38067457113486645)
136     ->Arg(175631339105057682);
137 
138 BENCHMARK(BM_MapOfMasksV3_PowerOfTwo)
139     ->Arg(9)
140     ->Arg(25)
141     ->Arg(41)
142     ->Arg(53)
143     ->Arg(absl::bit_width(SecAggVector::kMaxModulus - 1));
144 
145 BENCHMARK(BM_MapOfMasksV3_Arbitrary)
146     ->Arg(5)
147     ->Arg(39)
148     ->Arg(485)
149     ->Arg(2400)
150     ->Arg(14901)
151     ->Arg(51813)
152     ->Arg(532021)
153     ->Arg(13916946)
154     ->Arg(39549497)
155     ->Arg(548811945)
156     ->Arg(590549014)
157     ->Arg(48296031686)
158     ->Arg(156712951284)
159     ->Arg(2636861836189)
160     ->Arg(14673852658160)
161     ->Arg(92971495438615)
162     ->Arg(304436005557271)
163     ->Arg(14046234330484262)
164     ->Arg(38067457113486645)
165     ->Arg(175631339105057682);
166 
167 }  // namespace
168 }  // namespace secagg
169 }  // namespace fcp
170