1 /*
2 * Copyright 2019 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "private_join_and_compute/server_impl.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "private_join_and_compute/crypto/ec_commutative_cipher.h"
27 #include "private_join_and_compute/crypto/paillier.h"
28 #include "private_join_and_compute/util/status.inc"
29
30 using ::private_join_and_compute::BigNum;
31 using ::private_join_and_compute::ECCommutativeCipher;
32
33 namespace private_join_and_compute {
34
35 StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne>
EncryptSet()36 PrivateIntersectionSumProtocolServerImpl::EncryptSet() {
37 if (ec_cipher_ != nullptr) {
38 return InvalidArgumentError("Attempted to call EncryptSet twice.");
39 }
40 StatusOr<std::unique_ptr<ECCommutativeCipher>> ec_cipher =
41 ECCommutativeCipher::CreateWithNewKey(
42 NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256);
43 if (!ec_cipher.ok()) {
44 return ec_cipher.status();
45 }
46 ec_cipher_ = std::move(ec_cipher.value());
47
48 PrivateIntersectionSumServerMessage::ServerRoundOne result;
49 for (const std::string& input : inputs_) {
50 EncryptedElement* encrypted =
51 result.mutable_encrypted_set()->add_elements();
52 StatusOr<std::string> encrypted_element = ec_cipher_->Encrypt(input);
53 if (!encrypted_element.ok()) {
54 return encrypted_element.status();
55 }
56 *encrypted->mutable_element() = encrypted_element.value();
57 }
58
59 return result;
60 }
61
62 StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo>
ComputeIntersection(const PrivateIntersectionSumClientMessage::ClientRoundOne & client_message)63 PrivateIntersectionSumProtocolServerImpl::ComputeIntersection(
64 const PrivateIntersectionSumClientMessage::ClientRoundOne& client_message) {
65 if (ec_cipher_ == nullptr) {
66 return InvalidArgumentError(
67 "Called ComputeIntersection before EncryptSet.");
68 }
69 PrivateIntersectionSumServerMessage::ServerRoundTwo result;
70 BigNum N = ctx_->CreateBigNum(client_message.public_key());
71 PublicPaillier public_paillier(ctx_, N, 2);
72
73 std::vector<EncryptedElement> server_set, client_set, intersection;
74
75 // First, we re-encrypt the client party's set, so that we can compare with
76 // the re-encrypted set received from the client.
77 for (const EncryptedElement& element :
78 client_message.encrypted_set().elements()) {
79 EncryptedElement reencrypted;
80 *reencrypted.mutable_associated_data() = element.associated_data();
81 StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
82 if (!reenc.ok()) {
83 return reenc.status();
84 }
85 *reencrypted.mutable_element() = reenc.value();
86 client_set.push_back(reencrypted);
87 }
88 for (const EncryptedElement& element :
89 client_message.reencrypted_set().elements()) {
90 server_set.push_back(element);
91 }
92
93 // std::set_intersection requires sorted inputs.
94 std::sort(client_set.begin(), client_set.end(),
95 [](const EncryptedElement& a, const EncryptedElement& b) {
96 return a.element() < b.element();
97 });
98 std::sort(server_set.begin(), server_set.end(),
99 [](const EncryptedElement& a, const EncryptedElement& b) {
100 return a.element() < b.element();
101 });
102 std::set_intersection(
103 client_set.begin(), client_set.end(), server_set.begin(),
104 server_set.end(), std::back_inserter(intersection),
105 [](const EncryptedElement& a, const EncryptedElement& b) {
106 return a.element() < b.element();
107 });
108
109 // From the intersection we compute the sum of the associated values, which is
110 // the result we return to the client.
111 StatusOr<BigNum> encrypted_zero =
112 public_paillier.Encrypt(ctx_->CreateBigNum(0));
113 if (!encrypted_zero.ok()) {
114 return encrypted_zero.status();
115 }
116 BigNum sum = encrypted_zero.value();
117 for (const EncryptedElement& element : intersection) {
118 sum =
119 public_paillier.Add(sum, ctx_->CreateBigNum(element.associated_data()));
120 }
121
122 *result.mutable_encrypted_sum() = sum.ToBytes();
123 result.set_intersection_size(intersection.size());
124 return result;
125 }
126
Handle(const ClientMessage & request,MessageSink<ServerMessage> * server_message_sink)127 Status PrivateIntersectionSumProtocolServerImpl::Handle(
128 const ClientMessage& request,
129 MessageSink<ServerMessage>* server_message_sink) {
130 if (protocol_finished()) {
131 return InvalidArgumentError(
132 "PrivateIntersectionSumProtocolServerImpl: Protocol is already "
133 "complete.");
134 }
135
136 // Check that the message is a PrivateIntersectionSum protocol message.
137 if (!request.has_private_intersection_sum_client_message()) {
138 return InvalidArgumentError(
139 "PrivateIntersectionSumProtocolServerImpl: Received a message for the "
140 "wrong protocol type");
141 }
142 const PrivateIntersectionSumClientMessage& client_message =
143 request.private_intersection_sum_client_message();
144
145 ServerMessage server_message;
146
147 if (client_message.has_start_protocol_request()) {
148 // Handle a protocol start message.
149 auto maybe_server_round_one = EncryptSet();
150 if (!maybe_server_round_one.ok()) {
151 return maybe_server_round_one.status();
152 }
153 *(server_message.mutable_private_intersection_sum_server_message()
154 ->mutable_server_round_one()) =
155 std::move(maybe_server_round_one.value());
156 } else if (client_message.has_client_round_one()) {
157 // Handle the client round 1 message.
158 auto maybe_server_round_two =
159 ComputeIntersection(client_message.client_round_one());
160 if (!maybe_server_round_two.ok()) {
161 return maybe_server_round_two.status();
162 }
163 *(server_message.mutable_private_intersection_sum_server_message()
164 ->mutable_server_round_two()) =
165 std::move(maybe_server_round_two.value());
166 // Mark the protocol as finished here.
167 protocol_finished_ = true;
168 } else {
169 return InvalidArgumentError(
170 "PrivateIntersectionSumProtocolServerImpl: Received a client message "
171 "of an unknown type.");
172 }
173
174 return server_message_sink->Send(server_message);
175 }
176
177 } // namespace private_join_and_compute
178