xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/client.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1*a6aa18fbSYabin Cui /*
2*a6aa18fbSYabin Cui  * Copyright 2019 Google LLC.
3*a6aa18fbSYabin Cui  * Licensed under the Apache License, Version 2.0 (the "License");
4*a6aa18fbSYabin Cui  * you may not use this file except in compliance with the License.
5*a6aa18fbSYabin Cui  * You may obtain a copy of the License at
6*a6aa18fbSYabin Cui  *
7*a6aa18fbSYabin Cui  *     https://www.apache.org/licenses/LICENSE-2.0
8*a6aa18fbSYabin Cui  *
9*a6aa18fbSYabin Cui  * Unless required by applicable law or agreed to in writing, software
10*a6aa18fbSYabin Cui  * distributed under the License is distributed on an "AS IS" BASIS,
11*a6aa18fbSYabin Cui  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*a6aa18fbSYabin Cui  * See the License for the specific language governing permissions and
13*a6aa18fbSYabin Cui  * limitations under the License.
14*a6aa18fbSYabin Cui  */
15*a6aa18fbSYabin Cui 
16*a6aa18fbSYabin Cui #include <iostream>
17*a6aa18fbSYabin Cui #include <memory>
18*a6aa18fbSYabin Cui #include <ostream>
19*a6aa18fbSYabin Cui #include <string>
20*a6aa18fbSYabin Cui #include <utility>
21*a6aa18fbSYabin Cui 
22*a6aa18fbSYabin Cui #include "absl/flags/flag.h"
23*a6aa18fbSYabin Cui #include "absl/flags/parse.h"
24*a6aa18fbSYabin Cui #include "absl/strings/str_cat.h"
25*a6aa18fbSYabin Cui #include "include/grpc/grpc_security_constants.h"
26*a6aa18fbSYabin Cui #include "include/grpcpp/channel.h"
27*a6aa18fbSYabin Cui #include "include/grpcpp/client_context.h"
28*a6aa18fbSYabin Cui #include "include/grpcpp/create_channel.h"
29*a6aa18fbSYabin Cui #include "include/grpcpp/grpcpp.h"
30*a6aa18fbSYabin Cui #include "include/grpcpp/security/credentials.h"
31*a6aa18fbSYabin Cui #include "include/grpcpp/support/status.h"
32*a6aa18fbSYabin Cui #include "private_join_and_compute/client_impl.h"
33*a6aa18fbSYabin Cui #include "private_join_and_compute/data_util.h"
34*a6aa18fbSYabin Cui #include "private_join_and_compute/private_join_and_compute.grpc.pb.h"
35*a6aa18fbSYabin Cui #include "private_join_and_compute/private_join_and_compute.pb.h"
36*a6aa18fbSYabin Cui #include "private_join_and_compute/protocol_client.h"
37*a6aa18fbSYabin Cui #include "private_join_and_compute/util/status.inc"
38*a6aa18fbSYabin Cui 
39*a6aa18fbSYabin Cui ABSL_FLAG(std::string, port, "0.0.0.0:10501",
40*a6aa18fbSYabin Cui           "Port on which to contact server");
41*a6aa18fbSYabin Cui ABSL_FLAG(std::string, client_data_file, "",
42*a6aa18fbSYabin Cui           "The file from which to read the client database.");
43*a6aa18fbSYabin Cui ABSL_FLAG(
44*a6aa18fbSYabin Cui     int32_t, paillier_modulus_size, 1536,
45*a6aa18fbSYabin Cui     "The bit-length of the modulus to use for Paillier encryption. The modulus "
46*a6aa18fbSYabin Cui     "will be the product of two safe primes, each of size "
47*a6aa18fbSYabin Cui     "paillier_modulus_size/2.");
48*a6aa18fbSYabin Cui 
49*a6aa18fbSYabin Cui namespace private_join_and_compute {
50*a6aa18fbSYabin Cui namespace {
51*a6aa18fbSYabin Cui 
52*a6aa18fbSYabin Cui class InvokeServerHandleClientMessageSink : public MessageSink<ClientMessage> {
53*a6aa18fbSYabin Cui  public:
InvokeServerHandleClientMessageSink(std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)54*a6aa18fbSYabin Cui   explicit InvokeServerHandleClientMessageSink(
55*a6aa18fbSYabin Cui       std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)
56*a6aa18fbSYabin Cui       : stub_(std::move(stub)) {}
57*a6aa18fbSYabin Cui 
58*a6aa18fbSYabin Cui   ~InvokeServerHandleClientMessageSink() override = default;
59*a6aa18fbSYabin Cui 
Send(const ClientMessage & message)60*a6aa18fbSYabin Cui   Status Send(const ClientMessage& message) override {
61*a6aa18fbSYabin Cui     ::grpc::ClientContext client_context;
62*a6aa18fbSYabin Cui     ::grpc::Status grpc_status =
63*a6aa18fbSYabin Cui         stub_->Handle(&client_context, message, &last_server_response_);
64*a6aa18fbSYabin Cui     if (grpc_status.ok()) {
65*a6aa18fbSYabin Cui       return OkStatus();
66*a6aa18fbSYabin Cui     } else {
67*a6aa18fbSYabin Cui       return InternalError(absl::StrCat(
68*a6aa18fbSYabin Cui           "GrpcClientMessageSink: Failed to send message, error code: ",
69*a6aa18fbSYabin Cui           grpc_status.error_code(),
70*a6aa18fbSYabin Cui           ", error_message: ", grpc_status.error_message()));
71*a6aa18fbSYabin Cui     }
72*a6aa18fbSYabin Cui   }
73*a6aa18fbSYabin Cui 
last_server_response()74*a6aa18fbSYabin Cui   const ServerMessage& last_server_response() { return last_server_response_; }
75*a6aa18fbSYabin Cui 
76*a6aa18fbSYabin Cui  private:
77*a6aa18fbSYabin Cui   std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub_;
78*a6aa18fbSYabin Cui   ServerMessage last_server_response_;
79*a6aa18fbSYabin Cui };
80*a6aa18fbSYabin Cui 
ExecuteProtocol()81*a6aa18fbSYabin Cui int ExecuteProtocol() {
82*a6aa18fbSYabin Cui   ::private_join_and_compute::Context context;
83*a6aa18fbSYabin Cui 
84*a6aa18fbSYabin Cui   std::cout << "Client: Loading data..." << std::endl;
85*a6aa18fbSYabin Cui   auto maybe_client_identifiers_and_associated_values =
86*a6aa18fbSYabin Cui       ::private_join_and_compute::ReadClientDatasetFromFile(
87*a6aa18fbSYabin Cui           absl::GetFlag(FLAGS_client_data_file), &context);
88*a6aa18fbSYabin Cui   if (!maybe_client_identifiers_and_associated_values.ok()) {
89*a6aa18fbSYabin Cui     std::cerr << "Client::ExecuteProtocol: failed "
90*a6aa18fbSYabin Cui               << maybe_client_identifiers_and_associated_values.status()
91*a6aa18fbSYabin Cui               << std::endl;
92*a6aa18fbSYabin Cui     return 1;
93*a6aa18fbSYabin Cui   }
94*a6aa18fbSYabin Cui   auto client_identifiers_and_associated_values =
95*a6aa18fbSYabin Cui       std::move(maybe_client_identifiers_and_associated_values.value());
96*a6aa18fbSYabin Cui 
97*a6aa18fbSYabin Cui   std::cout << "Client: Generating keys..." << std::endl;
98*a6aa18fbSYabin Cui   std::unique_ptr<::private_join_and_compute::ProtocolClient> client =
99*a6aa18fbSYabin Cui       std::make_unique<
100*a6aa18fbSYabin Cui           ::private_join_and_compute::PrivateIntersectionSumProtocolClientImpl>(
101*a6aa18fbSYabin Cui           &context, std::move(client_identifiers_and_associated_values.first),
102*a6aa18fbSYabin Cui           std::move(client_identifiers_and_associated_values.second),
103*a6aa18fbSYabin Cui           absl::GetFlag(FLAGS_paillier_modulus_size));
104*a6aa18fbSYabin Cui 
105*a6aa18fbSYabin Cui   // Consider grpc::SslServerCredentials if not running locally.
106*a6aa18fbSYabin Cui   std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub =
107*a6aa18fbSYabin Cui       PrivateJoinAndComputeRpc::NewStub(::grpc::CreateChannel(
108*a6aa18fbSYabin Cui           absl::GetFlag(FLAGS_port), ::grpc::experimental::LocalCredentials(
109*a6aa18fbSYabin Cui                                          grpc_local_connect_type::LOCAL_TCP)));
110*a6aa18fbSYabin Cui   InvokeServerHandleClientMessageSink invoke_server_handle_message_sink(
111*a6aa18fbSYabin Cui       std::move(stub));
112*a6aa18fbSYabin Cui 
113*a6aa18fbSYabin Cui   // Execute StartProtocol and wait for response from ServerRoundOne.
114*a6aa18fbSYabin Cui   std::cout
115*a6aa18fbSYabin Cui       << "Client: Starting the protocol." << std::endl
116*a6aa18fbSYabin Cui       << "Client: Waiting for response and encrypted set from the server..."
117*a6aa18fbSYabin Cui       << std::endl;
118*a6aa18fbSYabin Cui   auto start_protocol_status =
119*a6aa18fbSYabin Cui       client->StartProtocol(&invoke_server_handle_message_sink);
120*a6aa18fbSYabin Cui   if (!start_protocol_status.ok()) {
121*a6aa18fbSYabin Cui     std::cerr << "Client::ExecuteProtocol: failed to StartProtocol: "
122*a6aa18fbSYabin Cui               << start_protocol_status << std::endl;
123*a6aa18fbSYabin Cui     return 1;
124*a6aa18fbSYabin Cui   }
125*a6aa18fbSYabin Cui   ServerMessage server_round_one =
126*a6aa18fbSYabin Cui       invoke_server_handle_message_sink.last_server_response();
127*a6aa18fbSYabin Cui 
128*a6aa18fbSYabin Cui   // Execute ClientRoundOne, and wait for response from ServerRoundTwo.
129*a6aa18fbSYabin Cui   std::cout
130*a6aa18fbSYabin Cui       << "Client: Received encrypted set from the server, double encrypting..."
131*a6aa18fbSYabin Cui       << std::endl;
132*a6aa18fbSYabin Cui   std::cout << "Client: Sending double encrypted server data and "
133*a6aa18fbSYabin Cui                "single-encrypted client data to the server."
134*a6aa18fbSYabin Cui             << std::endl
135*a6aa18fbSYabin Cui             << "Client: Waiting for encrypted intersection sum..." << std::endl;
136*a6aa18fbSYabin Cui   auto client_round_one_status =
137*a6aa18fbSYabin Cui       client->Handle(server_round_one, &invoke_server_handle_message_sink);
138*a6aa18fbSYabin Cui   if (!client_round_one_status.ok()) {
139*a6aa18fbSYabin Cui     std::cerr << "Client::ExecuteProtocol: failed to ReEncryptSet: "
140*a6aa18fbSYabin Cui               << client_round_one_status << std::endl;
141*a6aa18fbSYabin Cui     return 1;
142*a6aa18fbSYabin Cui   }
143*a6aa18fbSYabin Cui 
144*a6aa18fbSYabin Cui   // Execute ServerRoundTwo.
145*a6aa18fbSYabin Cui   std::cout << "Client: Sending double encrypted server data and "
146*a6aa18fbSYabin Cui                "single-encrypted client data to the server."
147*a6aa18fbSYabin Cui             << std::endl
148*a6aa18fbSYabin Cui             << "Client: Waiting for encrypted intersection sum..." << std::endl;
149*a6aa18fbSYabin Cui   ServerMessage server_round_two =
150*a6aa18fbSYabin Cui       invoke_server_handle_message_sink.last_server_response();
151*a6aa18fbSYabin Cui 
152*a6aa18fbSYabin Cui   // Compute the intersection size and sum.
153*a6aa18fbSYabin Cui   std::cout << "Client: Received response from the server. Decrypting the "
154*a6aa18fbSYabin Cui                "intersection-sum."
155*a6aa18fbSYabin Cui             << std::endl;
156*a6aa18fbSYabin Cui   auto intersection_size_and_sum_status =
157*a6aa18fbSYabin Cui       client->Handle(server_round_two, &invoke_server_handle_message_sink);
158*a6aa18fbSYabin Cui   if (!intersection_size_and_sum_status.ok()) {
159*a6aa18fbSYabin Cui     std::cerr << "Client::ExecuteProtocol: failed to DecryptSum: "
160*a6aa18fbSYabin Cui               << intersection_size_and_sum_status << std::endl;
161*a6aa18fbSYabin Cui     return 1;
162*a6aa18fbSYabin Cui   }
163*a6aa18fbSYabin Cui 
164*a6aa18fbSYabin Cui   // Output the result.
165*a6aa18fbSYabin Cui   auto client_print_output_status = client->PrintOutput();
166*a6aa18fbSYabin Cui   if (!client_print_output_status.ok()) {
167*a6aa18fbSYabin Cui     std::cerr << "Client::ExecuteProtocol: failed to PrintOutput: "
168*a6aa18fbSYabin Cui               << client_print_output_status << std::endl;
169*a6aa18fbSYabin Cui     return 1;
170*a6aa18fbSYabin Cui   }
171*a6aa18fbSYabin Cui 
172*a6aa18fbSYabin Cui   return 0;
173*a6aa18fbSYabin Cui }
174*a6aa18fbSYabin Cui 
175*a6aa18fbSYabin Cui }  // namespace
176*a6aa18fbSYabin Cui }  // namespace private_join_and_compute
177*a6aa18fbSYabin Cui 
main(int argc,char ** argv)178*a6aa18fbSYabin Cui int main(int argc, char** argv) {
179*a6aa18fbSYabin Cui   absl::ParseCommandLine(argc, argv);
180*a6aa18fbSYabin Cui 
181*a6aa18fbSYabin Cui   return private_join_and_compute::ExecuteProtocol();
182*a6aa18fbSYabin Cui }
183