1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/simultaneous_fixed_bases_exp.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "private_join_and_compute/crypto/big_num.h"
27 #include "private_join_and_compute/crypto/mont_mul.h"
28 #include "private_join_and_compute/util/status.inc"
29 
30 namespace private_join_and_compute {
31 
32 namespace internal {
33 
34 template <typename Element>
35 StatusOr<Element> Clone(const Element& element);
36 
37 template <typename Element, typename Context>
38 StatusOr<Element> Mul(const Element& e1, const Element& e2,
39                       const Context& context);
40 
41 template <typename Element>
42 bool IsZero(const Element& c);
43 
44 template <>
Clone(const private_join_and_compute::BigNum & element)45 StatusOr<private_join_and_compute::BigNum> Clone(
46     const private_join_and_compute::BigNum& element) {
47   return element;
48 }
49 
50 template <>
IsZero(const private_join_and_compute::BigNum & c)51 bool IsZero(const private_join_and_compute::BigNum& c) {
52   return c.IsOne();
53 }
54 
55 template <>
Mul(const ZnElement & e1,const ZnElement & e2,const ZnContext & context)56 StatusOr<ZnElement> Mul(const ZnElement& e1, const ZnElement& e2,
57                         const ZnContext& context) {
58   return e1.ModMul(e2, context.modulus);
59 }
60 
61 template <>
Clone(const private_join_and_compute::MontBigNum & element)62 StatusOr<private_join_and_compute::MontBigNum> Clone(
63     const private_join_and_compute::MontBigNum& element) {
64   return element;
65 }
66 
67 template <>
Mul(const private_join_and_compute::MontBigNum & e1,const private_join_and_compute::MontBigNum & e2,const private_join_and_compute::MontContext & context)68 StatusOr<private_join_and_compute::MontBigNum> Mul(
69     const private_join_and_compute::MontBigNum& e1,
70     const private_join_and_compute::MontBigNum& e2,
71     const private_join_and_compute::MontContext& context) {
72   return e1.Mul(e2);
73 }
74 
75 template <>
IsZero(const private_join_and_compute::MontBigNum & c)76 bool IsZero(const private_join_and_compute::MontBigNum& c) {
77   return c.ToBigNum().IsOne();
78 }
79 
80 }  // namespace internal
81 
82 template <typename Element, typename Context>
SimultaneousFixedBasesExp(size_t num_bases,size_t num_simultaneous,size_t num_batches,std::unique_ptr<Element> zero,std::unique_ptr<Context> context,std::vector<std::vector<std::unique_ptr<Element>>> table)83 SimultaneousFixedBasesExp<Element, Context>::SimultaneousFixedBasesExp(
84     size_t num_bases, size_t num_simultaneous, size_t num_batches,
85     std::unique_ptr<Element> zero, std::unique_ptr<Context> context,
86     std::vector<std::vector<std::unique_ptr<Element>>> table)
87     : num_bases_(num_bases),
88       num_simultaneous_(num_simultaneous),
89       num_batches_(num_batches),
90       zero_(std::move(zero)),
91       context_(std::move(context)),
92       precomputed_table_(std::move(table)) {}
93 
94 template <typename Element, typename Context>
95 StatusOr<std::unique_ptr<SimultaneousFixedBasesExp<Element, Context>>>
Create(const std::vector<Element> & bases,const Element & zero,size_t num_simultaneous,std::unique_ptr<Context> context)96 SimultaneousFixedBasesExp<Element, Context>::Create(
97     const std::vector<Element>& bases, const Element& zero,
98     size_t num_simultaneous, std::unique_ptr<Context> context) {
99   if (num_simultaneous == 0) {
100     return absl::InvalidArgumentError(
101         absl::StrCat("The num_simultaneous parameter, ", num_simultaneous,
102                      ", should be positive."));
103   }
104   if (num_simultaneous > bases.size()) {
105     return absl::InvalidArgumentError(absl::StrCat(
106         "The num_simultaneous parameter, ", num_simultaneous,
107         ", can be at most the number of bases", bases.size(), "."));
108   }
109   size_t num_batches = (bases.size() + num_simultaneous - 1) / num_simultaneous;
110   ASSIGN_OR_RETURN(auto zero_clone, internal::Clone(zero));
111   std::unique_ptr<Element> zero_ptr =
112       std::make_unique<Element>(std::move(zero_clone));
113   ASSIGN_OR_RETURN(std::vector<std::vector<std::unique_ptr<Element>>> table,
114                    SimultaneousFixedBasesExp::Precompute(
115                        bases, zero, *context, num_simultaneous, num_batches));
116   return absl::WrapUnique<SimultaneousFixedBasesExp>(
117       new SimultaneousFixedBasesExp(bases.size(), num_simultaneous, num_batches,
118                                     std::move(zero_ptr), std::move(context),
119                                     std::move(table)));
120 }
121 
122 template <typename Element, typename Context>
123 StatusOr<std::vector<std::vector<std::unique_ptr<Element>>>>
Precompute(const std::vector<Element> & bases,const Element & zero,const Context & context,size_t num_simultaneous,size_t num_batches)124 SimultaneousFixedBasesExp<Element, Context>::Precompute(
125     const std::vector<Element>& bases, const Element& zero,
126     const Context& context, size_t num_simultaneous, size_t num_batches) {
127   std::vector<std::vector<std::unique_ptr<Element>>> table;
128   for (size_t i = 0; i < num_batches; ++i) {
129     table.push_back({});
130     ASSIGN_OR_RETURN(Element zero_clone, internal::Clone(zero));
131     table[i].push_back(std::make_unique<Element>(std::move(zero_clone)));
132     const size_t start = i * num_simultaneous;
133     const size_t num_items_in_batch =
134         std::min(bases.size() - start, num_simultaneous);
135     int highest_one_bit = 0;
136     // Generate all values (c1, ..., ck) in {0, 1}^k using the binary
137     // representation of integers between [0, 2^k - 1].
138     for (int j = 1; j < (1 << num_items_in_batch); ++j) {
139       if (j & (1 << (highest_one_bit + 1))) {
140         ++highest_one_bit;
141       }
142       size_t prev = j - (1 << highest_one_bit);
143       if (prev == 0) {
144         ASSIGN_OR_RETURN(Element clone,
145                          internal::Clone(bases[start + highest_one_bit]));
146         table[i].push_back(std::make_unique<Element>(std::move(clone)));
147       } else {
148         ASSIGN_OR_RETURN(
149             Element add,
150             internal::Mul(*(table[i][prev]), bases[start + highest_one_bit],
151                           context));
152         table[i].push_back(std::make_unique<Element>(std::move(add)));
153       }
154     }
155   }
156   return std::move(table);
157 }
158 
159 template <typename Element, typename Context>
SimultaneousExp(const std::vector<private_join_and_compute::BigNum> & exponents) const160 StatusOr<Element> SimultaneousFixedBasesExp<Element, Context>::SimultaneousExp(
161     const std::vector<private_join_and_compute::BigNum>& exponents) const {
162   if (exponents.size() != num_bases_) {
163     return absl::InvalidArgumentError(
164         absl::StrCat("Number of exponents, ", exponents.size(), ", and bases,",
165                      num_bases_, ", are not equal."));
166   }
167   int max_bit_length = 0;
168   for (const auto& exponent : exponents) {
169     if (exponent.BitLength() > max_bit_length) {
170       max_bit_length = exponent.BitLength();
171     }
172   }
173   ASSIGN_OR_RETURN(Element result, internal::Clone(*zero_));
174   for (int i = max_bit_length - 1; i >= 0; --i) {
175     if (!internal::IsZero(result)) {
176       ASSIGN_OR_RETURN(result, internal::Mul(result, result, *context_));
177     }
178     for (size_t j = 0; j < num_batches_; ++j) {
179       size_t precompute_idx = 0;
180       size_t batch_size = num_simultaneous_;
181       if (batch_size > num_bases_ - (j * num_simultaneous_)) {
182         batch_size = num_bases_ - (j * num_simultaneous_);
183       }
184       for (size_t k = 0; k < batch_size; ++k) {
185         size_t data_idx = (j * num_simultaneous_) + k;
186         if (exponents[data_idx].IsBitSet(i)) {
187           precompute_idx += (1 << k);
188         }
189       }
190       if (precompute_idx) {
191         ASSIGN_OR_RETURN(
192             result,
193             internal::Mul(result, *(precomputed_table_[j][precompute_idx]),
194                           *context_));
195       }
196     }
197   }
198   return std::move(result);
199 }
200 
201 template class SimultaneousFixedBasesExp<private_join_and_compute::MontBigNum,
202                                          private_join_and_compute::MontContext>;
203 template class SimultaneousFixedBasesExp<ZnElement, ZnContext>;
204 
205 }  // namespace private_join_and_compute
206