xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/server_impl.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/server_impl.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "private_join_and_compute/crypto/ec_commutative_cipher.h"
27 #include "private_join_and_compute/crypto/paillier.h"
28 #include "private_join_and_compute/util/status.inc"
29 
30 using ::private_join_and_compute::BigNum;
31 using ::private_join_and_compute::ECCommutativeCipher;
32 
33 namespace private_join_and_compute {
34 
35 StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne>
EncryptSet()36 PrivateIntersectionSumProtocolServerImpl::EncryptSet() {
37   if (ec_cipher_ != nullptr) {
38     return InvalidArgumentError("Attempted to call EncryptSet twice.");
39   }
40   StatusOr<std::unique_ptr<ECCommutativeCipher>> ec_cipher =
41       ECCommutativeCipher::CreateWithNewKey(
42           NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256);
43   if (!ec_cipher.ok()) {
44     return ec_cipher.status();
45   }
46   ec_cipher_ = std::move(ec_cipher.value());
47 
48   PrivateIntersectionSumServerMessage::ServerRoundOne result;
49   for (const std::string& input : inputs_) {
50     EncryptedElement* encrypted =
51         result.mutable_encrypted_set()->add_elements();
52     StatusOr<std::string> encrypted_element = ec_cipher_->Encrypt(input);
53     if (!encrypted_element.ok()) {
54       return encrypted_element.status();
55     }
56     *encrypted->mutable_element() = encrypted_element.value();
57   }
58 
59   return result;
60 }
61 
62 StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo>
ComputeIntersection(const PrivateIntersectionSumClientMessage::ClientRoundOne & client_message)63 PrivateIntersectionSumProtocolServerImpl::ComputeIntersection(
64     const PrivateIntersectionSumClientMessage::ClientRoundOne& client_message) {
65   if (ec_cipher_ == nullptr) {
66     return InvalidArgumentError(
67         "Called ComputeIntersection before EncryptSet.");
68   }
69   PrivateIntersectionSumServerMessage::ServerRoundTwo result;
70   BigNum N = ctx_->CreateBigNum(client_message.public_key());
71   PublicPaillier public_paillier(ctx_, N, 2);
72 
73   std::vector<EncryptedElement> server_set, client_set, intersection;
74 
75   // First, we re-encrypt the client party's set, so that we can compare with
76   // the re-encrypted set received from the client.
77   for (const EncryptedElement& element :
78        client_message.encrypted_set().elements()) {
79     EncryptedElement reencrypted;
80     *reencrypted.mutable_associated_data() = element.associated_data();
81     StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
82     if (!reenc.ok()) {
83       return reenc.status();
84     }
85     *reencrypted.mutable_element() = reenc.value();
86     client_set.push_back(reencrypted);
87   }
88   for (const EncryptedElement& element :
89        client_message.reencrypted_set().elements()) {
90     server_set.push_back(element);
91   }
92 
93   // std::set_intersection requires sorted inputs.
94   std::sort(client_set.begin(), client_set.end(),
95             [](const EncryptedElement& a, const EncryptedElement& b) {
96               return a.element() < b.element();
97             });
98   std::sort(server_set.begin(), server_set.end(),
99             [](const EncryptedElement& a, const EncryptedElement& b) {
100               return a.element() < b.element();
101             });
102   std::set_intersection(
103       client_set.begin(), client_set.end(), server_set.begin(),
104       server_set.end(), std::back_inserter(intersection),
105       [](const EncryptedElement& a, const EncryptedElement& b) {
106         return a.element() < b.element();
107       });
108 
109   // From the intersection we compute the sum of the associated values, which is
110   // the result we return to the client.
111   StatusOr<BigNum> encrypted_zero =
112       public_paillier.Encrypt(ctx_->CreateBigNum(0));
113   if (!encrypted_zero.ok()) {
114     return encrypted_zero.status();
115   }
116   BigNum sum = encrypted_zero.value();
117   for (const EncryptedElement& element : intersection) {
118     sum =
119         public_paillier.Add(sum, ctx_->CreateBigNum(element.associated_data()));
120   }
121 
122   *result.mutable_encrypted_sum() = sum.ToBytes();
123   result.set_intersection_size(intersection.size());
124   return result;
125 }
126 
Handle(const ClientMessage & request,MessageSink<ServerMessage> * server_message_sink)127 Status PrivateIntersectionSumProtocolServerImpl::Handle(
128     const ClientMessage& request,
129     MessageSink<ServerMessage>* server_message_sink) {
130   if (protocol_finished()) {
131     return InvalidArgumentError(
132         "PrivateIntersectionSumProtocolServerImpl: Protocol is already "
133         "complete.");
134   }
135 
136   // Check that the message is a PrivateIntersectionSum protocol message.
137   if (!request.has_private_intersection_sum_client_message()) {
138     return InvalidArgumentError(
139         "PrivateIntersectionSumProtocolServerImpl: Received a message for the "
140         "wrong protocol type");
141   }
142   const PrivateIntersectionSumClientMessage& client_message =
143       request.private_intersection_sum_client_message();
144 
145   ServerMessage server_message;
146 
147   if (client_message.has_start_protocol_request()) {
148     // Handle a protocol start message.
149     auto maybe_server_round_one = EncryptSet();
150     if (!maybe_server_round_one.ok()) {
151       return maybe_server_round_one.status();
152     }
153     *(server_message.mutable_private_intersection_sum_server_message()
154           ->mutable_server_round_one()) =
155         std::move(maybe_server_round_one.value());
156   } else if (client_message.has_client_round_one()) {
157     // Handle the client round 1 message.
158     auto maybe_server_round_two =
159         ComputeIntersection(client_message.client_round_one());
160     if (!maybe_server_round_two.ok()) {
161       return maybe_server_round_two.status();
162     }
163     *(server_message.mutable_private_intersection_sum_server_message()
164           ->mutable_server_round_two()) =
165         std::move(maybe_server_round_two.value());
166     // Mark the protocol as finished here.
167     protocol_finished_ = true;
168   } else {
169     return InvalidArgumentError(
170         "PrivateIntersectionSumProtocolServerImpl: Received a client message "
171         "of an unknown type.");
172   }
173 
174   return server_message_sink->Send(server_message);
175 }
176 
177 }  // namespace private_join_and_compute
178