xref: /aosp_15_r20/external/grpc-grpc/test/cpp/end2end/client_callback_end2end_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <algorithm>
20 #include <condition_variable>
21 #include <functional>
22 #include <mutex>
23 #include <sstream>
24 #include <thread>
25 
26 #include <gtest/gtest.h>
27 
28 #include "absl/memory/memory.h"
29 
30 #include <grpcpp/channel.h>
31 #include <grpcpp/client_context.h>
32 #include <grpcpp/create_channel.h>
33 #include <grpcpp/generic/generic_stub.h>
34 #include <grpcpp/impl/proto_utils.h>
35 #include <grpcpp/server.h>
36 #include <grpcpp/server_builder.h>
37 #include <grpcpp/server_context.h>
38 #include <grpcpp/support/client_callback.h>
39 
40 #include "src/core/lib/gprpp/env.h"
41 #include "src/core/lib/iomgr/iomgr.h"
42 #include "src/proto/grpc/testing/echo.grpc.pb.h"
43 #include "test/core/util/port.h"
44 #include "test/core/util/test_config.h"
45 #include "test/cpp/end2end/interceptors_util.h"
46 #include "test/cpp/end2end/test_service_impl.h"
47 #include "test/cpp/util/byte_buffer_proto_helper.h"
48 #include "test/cpp/util/string_ref_helper.h"
49 #include "test/cpp/util/test_credentials_provider.h"
50 
51 namespace grpc {
52 namespace testing {
53 namespace {
54 
55 enum class Protocol { INPROC, TCP };
56 
57 class TestScenario {
58  public:
TestScenario(bool serve_callback,Protocol protocol,bool intercept,const std::string & creds_type)59   TestScenario(bool serve_callback, Protocol protocol, bool intercept,
60                const std::string& creds_type)
61       : callback_server(serve_callback),
62         protocol(protocol),
63         use_interceptors(intercept),
64         credentials_type(creds_type) {}
65   void Log() const;
66   bool callback_server;
67   Protocol protocol;
68   bool use_interceptors;
69   const std::string credentials_type;
70 };
71 
operator <<(std::ostream & out,const TestScenario & scenario)72 std::ostream& operator<<(std::ostream& out, const TestScenario& scenario) {
73   return out << "TestScenario{callback_server="
74              << (scenario.callback_server ? "true" : "false") << ",protocol="
75              << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP")
76              << ",intercept=" << (scenario.use_interceptors ? "true" : "false")
77              << ",creds=" << scenario.credentials_type << "}";
78 }
79 
Log() const80 void TestScenario::Log() const {
81   std::ostringstream out;
82   out << *this;
83   gpr_log(GPR_DEBUG, "%s", out.str().c_str());
84 }
85 
86 class ClientCallbackEnd2endTest
87     : public ::testing::TestWithParam<TestScenario> {
88  protected:
ClientCallbackEnd2endTest()89   ClientCallbackEnd2endTest() { GetParam().Log(); }
90 
SetUp()91   void SetUp() override {
92     ServerBuilder builder;
93 
94     auto server_creds = GetCredentialsProvider()->GetServerCredentials(
95         GetParam().credentials_type);
96     // TODO(vjpai): Support testing of AuthMetadataProcessor
97 
98     if (GetParam().protocol == Protocol::TCP) {
99       picked_port_ = grpc_pick_unused_port_or_die();
100       server_address_ << "localhost:" << picked_port_;
101       builder.AddListeningPort(server_address_.str(), server_creds);
102     }
103     if (!GetParam().callback_server) {
104       builder.RegisterService(&service_);
105     } else {
106       builder.RegisterService(&callback_service_);
107     }
108 
109     if (GetParam().use_interceptors) {
110       std::vector<
111           std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
112           creators;
113       // Add 20 phony server interceptors
114       creators.reserve(20);
115       for (auto i = 0; i < 20; i++) {
116         creators.push_back(std::make_unique<PhonyInterceptorFactory>());
117       }
118       builder.experimental().SetInterceptorCreators(std::move(creators));
119     }
120 
121     server_ = builder.BuildAndStart();
122     is_server_started_ = true;
123   }
124 
ResetStub(std::unique_ptr<experimental::ClientInterceptorFactoryInterface> interceptor=nullptr)125   void ResetStub(
126       std::unique_ptr<experimental::ClientInterceptorFactoryInterface>
127           interceptor = nullptr) {
128     ChannelArguments args;
129     auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
130         GetParam().credentials_type, &args);
131     auto interceptors = CreatePhonyClientInterceptors();
132     if (interceptor != nullptr) interceptors.push_back(std::move(interceptor));
133     switch (GetParam().protocol) {
134       case Protocol::TCP:
135         if (!GetParam().use_interceptors) {
136           channel_ = grpc::CreateCustomChannel(server_address_.str(),
137                                                channel_creds, args);
138         } else {
139           channel_ = CreateCustomChannelWithInterceptors(
140               server_address_.str(), channel_creds, args,
141               std::move(interceptors));
142         }
143         break;
144       case Protocol::INPROC:
145         if (!GetParam().use_interceptors) {
146           channel_ = server_->InProcessChannel(args);
147         } else {
148           channel_ = server_->experimental().InProcessChannelWithInterceptors(
149               args, std::move(interceptors));
150         }
151         break;
152       default:
153         assert(false);
154     }
155     stub_ = grpc::testing::EchoTestService::NewStub(channel_);
156     generic_stub_ = std::make_unique<GenericStub>(channel_);
157     PhonyInterceptor::Reset();
158   }
159 
TearDown()160   void TearDown() override {
161     if (is_server_started_) {
162       // Although we would normally do an explicit shutdown, the server
163       // should also work correctly with just a destructor call. The regular
164       // end2end test uses explicit shutdown, so let this one just do reset.
165       server_.reset();
166     }
167     if (picked_port_ > 0) {
168       grpc_recycle_unused_port(picked_port_);
169     }
170   }
171 
SendRpcs(int num_rpcs,bool with_binary_metadata)172   void SendRpcs(int num_rpcs, bool with_binary_metadata) {
173     std::string test_string;
174     for (int i = 0; i < num_rpcs; i++) {
175       EchoRequest request;
176       EchoResponse response;
177       ClientContext cli_ctx;
178 
179       test_string += "Hello world. ";
180       request.set_message(test_string);
181       std::string val;
182       if (with_binary_metadata) {
183         request.mutable_param()->set_echo_metadata(true);
184         char bytes[8] = {'\0', '\1', '\2', '\3',
185                          '\4', '\5', '\6', static_cast<char>(i)};
186         val = std::string(bytes, 8);
187         cli_ctx.AddMetadata("custom-bin", val);
188       }
189 
190       cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
191 
192       std::mutex mu;
193       std::condition_variable cv;
194       bool done = false;
195       stub_->async()->Echo(
196           &cli_ctx, &request, &response,
197           [&cli_ctx, &request, &response, &done, &mu, &cv, val,
198            with_binary_metadata](Status s) {
199             GPR_ASSERT(s.ok());
200 
201             EXPECT_EQ(request.message(), response.message());
202             if (with_binary_metadata) {
203               EXPECT_EQ(
204                   1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin"));
205               EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata()
206                                           .find("custom-bin")
207                                           ->second));
208             }
209             std::lock_guard<std::mutex> l(mu);
210             done = true;
211             cv.notify_one();
212           });
213       std::unique_lock<std::mutex> l(mu);
214       while (!done) {
215         cv.wait(l);
216       }
217     }
218   }
219 
SendRpcsGeneric(int num_rpcs,bool maybe_except,const char * suffix_for_stats)220   void SendRpcsGeneric(int num_rpcs, bool maybe_except,
221                        const char* suffix_for_stats) {
222     const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
223     std::string test_string;
224     for (int i = 0; i < num_rpcs; i++) {
225       EchoRequest request;
226       std::unique_ptr<ByteBuffer> send_buf;
227       ByteBuffer recv_buf;
228       ClientContext cli_ctx;
229 
230       test_string += "Hello world. ";
231       request.set_message(test_string);
232       send_buf = SerializeToByteBuffer(&request);
233 
234       std::mutex mu;
235       std::condition_variable cv;
236       bool done = false;
237       StubOptions options(suffix_for_stats);
238       generic_stub_->UnaryCall(
239           &cli_ctx, kMethodName, options, send_buf.get(), &recv_buf,
240           [&request, &recv_buf, &done, &mu, &cv, maybe_except](Status s) {
241             GPR_ASSERT(s.ok());
242 
243             EchoResponse response;
244             EXPECT_TRUE(ParseFromByteBuffer(&recv_buf, &response));
245             EXPECT_EQ(request.message(), response.message());
246             std::lock_guard<std::mutex> l(mu);
247             done = true;
248             cv.notify_one();
249 #if GRPC_ALLOW_EXCEPTIONS
250             if (maybe_except) {
251               throw -1;
252             }
253 #else
254             GPR_ASSERT(!maybe_except);
255 #endif
256           });
257       std::unique_lock<std::mutex> l(mu);
258       while (!done) {
259         cv.wait(l);
260       }
261     }
262   }
263 
SendGenericEchoAsBidi(int num_rpcs,int reuses,bool do_writes_done,const char * suffix_for_stats)264   void SendGenericEchoAsBidi(int num_rpcs, int reuses, bool do_writes_done,
265                              const char* suffix_for_stats) {
266     const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
267     std::string test_string;
268     for (int i = 0; i < num_rpcs; i++) {
269       test_string += "Hello world. ";
270       class Client : public grpc::ClientBidiReactor<ByteBuffer, ByteBuffer> {
271        public:
272         Client(ClientCallbackEnd2endTest* test, const std::string& method_name,
273                const char* suffix_for_stats, const std::string& test_str,
274                int reuses, bool do_writes_done)
275             : reuses_remaining_(reuses), do_writes_done_(do_writes_done) {
276           activate_ = [this, test, method_name, suffix_for_stats, test_str] {
277             if (reuses_remaining_ > 0) {
278               cli_ctx_ = std::make_unique<ClientContext>();
279               reuses_remaining_--;
280               StubOptions options(suffix_for_stats);
281               test->generic_stub_->PrepareBidiStreamingCall(
282                   cli_ctx_.get(), method_name, options, this);
283               request_.set_message(test_str);
284               send_buf_ = SerializeToByteBuffer(&request_);
285               StartWrite(send_buf_.get());
286               StartRead(&recv_buf_);
287               StartCall();
288             } else {
289               std::unique_lock<std::mutex> l(mu_);
290               done_ = true;
291               cv_.notify_one();
292             }
293           };
294           activate_();
295         }
296         void OnWriteDone(bool /*ok*/) override {
297           if (do_writes_done_) {
298             StartWritesDone();
299           }
300         }
301         void OnReadDone(bool /*ok*/) override {
302           EchoResponse response;
303           EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
304           EXPECT_EQ(request_.message(), response.message());
305         };
306         void OnDone(const Status& s) override {
307           EXPECT_TRUE(s.ok());
308           activate_();
309         }
310         void Await() {
311           std::unique_lock<std::mutex> l(mu_);
312           while (!done_) {
313             cv_.wait(l);
314           }
315         }
316 
317         EchoRequest request_;
318         std::unique_ptr<ByteBuffer> send_buf_;
319         ByteBuffer recv_buf_;
320         std::unique_ptr<ClientContext> cli_ctx_;
321         int reuses_remaining_;
322         std::function<void()> activate_;
323         std::mutex mu_;
324         std::condition_variable cv_;
325         bool done_ = false;
326         const bool do_writes_done_;
327       };
328 
329       Client rpc(this, kMethodName, suffix_for_stats, test_string, reuses,
330                  do_writes_done);
331 
332       rpc.Await();
333     }
334   }
335   bool is_server_started_{false};
336   int picked_port_{0};
337   std::shared_ptr<Channel> channel_;
338   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
339   std::unique_ptr<grpc::GenericStub> generic_stub_;
340   TestServiceImpl service_;
341   CallbackTestServiceImpl callback_service_;
342   std::unique_ptr<Server> server_;
343   std::ostringstream server_address_;
344 };
345 
TEST_P(ClientCallbackEnd2endTest,SimpleRpc)346 TEST_P(ClientCallbackEnd2endTest, SimpleRpc) {
347   ResetStub();
348   SendRpcs(1, false);
349 }
350 
TEST_P(ClientCallbackEnd2endTest,SimpleRpcExpectedError)351 TEST_P(ClientCallbackEnd2endTest, SimpleRpcExpectedError) {
352   ResetStub();
353 
354   EchoRequest request;
355   EchoResponse response;
356   ClientContext cli_ctx;
357   ErrorStatus error_status;
358 
359   request.set_message("Hello failure");
360   error_status.set_code(1);  // CANCELLED
361   error_status.set_error_message("cancel error message");
362   *request.mutable_param()->mutable_expected_error() = error_status;
363 
364   std::mutex mu;
365   std::condition_variable cv;
366   bool done = false;
367 
368   stub_->async()->Echo(&cli_ctx, &request, &response,
369                        [&response, &done, &mu, &cv, &error_status](Status s) {
370                          EXPECT_EQ("", response.message());
371                          EXPECT_EQ(error_status.code(), s.error_code());
372                          EXPECT_EQ(error_status.error_message(),
373                                    s.error_message());
374                          std::lock_guard<std::mutex> l(mu);
375                          done = true;
376                          cv.notify_one();
377                        });
378 
379   std::unique_lock<std::mutex> l(mu);
380   while (!done) {
381     cv.wait(l);
382   }
383 }
384 
TEST_P(ClientCallbackEnd2endTest,SimpleRpcUnderLockNested)385 TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLockNested) {
386   ResetStub();
387 
388   // The request/response state associated with an RPC and the synchronization
389   // variables needed to notify its completion.
390   struct RpcState {
391     std::mutex mu;
392     std::condition_variable cv;
393     bool done = false;
394     EchoRequest request;
395     EchoResponse response;
396     ClientContext cli_ctx;
397 
398     RpcState() = default;
399     ~RpcState() {
400       // Grab the lock to prevent destruction while another is still holding
401       // lock
402       std::lock_guard<std::mutex> lock(mu);
403     }
404   };
405   std::vector<RpcState> rpc_state(3);
406   for (size_t i = 0; i < rpc_state.size(); i++) {
407     std::string message = "Hello locked world";
408     message += std::to_string(i);
409     rpc_state[i].request.set_message(message);
410   }
411 
412   // Grab a lock and then start an RPC whose callback grabs the same lock and
413   // then calls this function to start the next RPC under lock (up to a limit of
414   // the size of the rpc_state vector).
415   std::function<void(int)> nested_call = [this, &nested_call,
416                                           &rpc_state](int index) {
417     std::lock_guard<std::mutex> l(rpc_state[index].mu);
418     stub_->async()->Echo(&rpc_state[index].cli_ctx, &rpc_state[index].request,
419                          &rpc_state[index].response,
420                          [index, &nested_call, &rpc_state](Status s) {
421                            std::lock_guard<std::mutex> l1(rpc_state[index].mu);
422                            EXPECT_TRUE(s.ok());
423                            rpc_state[index].done = true;
424                            rpc_state[index].cv.notify_all();
425                            // Call the next level of nesting if possible
426                            if (index + 1 < static_cast<int>(rpc_state.size())) {
427                              nested_call(index + 1);
428                            }
429                          });
430   };
431 
432   nested_call(0);
433 
434   // Wait for completion notifications from all RPCs. Order doesn't matter.
435   for (RpcState& state : rpc_state) {
436     std::unique_lock<std::mutex> l(state.mu);
437     while (!state.done) {
438       state.cv.wait(l);
439     }
440     EXPECT_EQ(state.request.message(), state.response.message());
441   }
442 }
443 
TEST_P(ClientCallbackEnd2endTest,SimpleRpcUnderLock)444 TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLock) {
445   ResetStub();
446   std::mutex mu;
447   std::condition_variable cv;
448   bool done = false;
449   EchoRequest request;
450   request.set_message("Hello locked world.");
451   EchoResponse response;
452   ClientContext cli_ctx;
453   {
454     std::lock_guard<std::mutex> l(mu);
455     stub_->async()->Echo(&cli_ctx, &request, &response,
456                          [&mu, &cv, &done, &request, &response](Status s) {
457                            std::lock_guard<std::mutex> l(mu);
458                            EXPECT_TRUE(s.ok());
459                            EXPECT_EQ(request.message(), response.message());
460                            done = true;
461                            cv.notify_one();
462                          });
463   }
464   std::unique_lock<std::mutex> l(mu);
465   while (!done) {
466     cv.wait(l);
467   }
468 }
469 
TEST_P(ClientCallbackEnd2endTest,SequentialRpcs)470 TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) {
471   ResetStub();
472   SendRpcs(10, false);
473 }
474 
TEST_P(ClientCallbackEnd2endTest,SendClientInitialMetadata)475 TEST_P(ClientCallbackEnd2endTest, SendClientInitialMetadata) {
476   ResetStub();
477   SimpleRequest request;
478   SimpleResponse response;
479   ClientContext cli_ctx;
480 
481   cli_ctx.AddMetadata(kCheckClientInitialMetadataKey,
482                       kCheckClientInitialMetadataVal);
483 
484   std::mutex mu;
485   std::condition_variable cv;
486   bool done = false;
487   stub_->async()->CheckClientInitialMetadata(
488       &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
489         GPR_ASSERT(s.ok());
490 
491         std::lock_guard<std::mutex> l(mu);
492         done = true;
493         cv.notify_one();
494       });
495   std::unique_lock<std::mutex> l(mu);
496   while (!done) {
497     cv.wait(l);
498   }
499 }
500 
TEST_P(ClientCallbackEnd2endTest,SimpleRpcWithBinaryMetadata)501 TEST_P(ClientCallbackEnd2endTest, SimpleRpcWithBinaryMetadata) {
502   ResetStub();
503   SendRpcs(1, true);
504 }
505 
TEST_P(ClientCallbackEnd2endTest,SequentialRpcsWithVariedBinaryMetadataValue)506 TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) {
507   ResetStub();
508   SendRpcs(10, true);
509 }
510 
TEST_P(ClientCallbackEnd2endTest,SequentialGenericRpcs)511 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) {
512   ResetStub(std::make_unique<TestInterceptorFactory>(
513       "/grpc.testing.EchoTestService/Echo", nullptr));
514   SendRpcsGeneric(10, false, /*suffix_for_stats=*/nullptr);
515 }
516 
TEST_P(ClientCallbackEnd2endTest,SequentialGenericRpcsWithSuffix)517 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsWithSuffix) {
518   ResetStub(std::make_unique<TestInterceptorFactory>(
519       "/grpc.testing.EchoTestService/Echo", "TestSuffix"));
520   SendRpcsGeneric(10, false, "TestSuffix");
521 }
522 
TEST_P(ClientCallbackEnd2endTest,SequentialGenericRpcsAsBidi)523 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) {
524   ResetStub(std::make_unique<TestInterceptorFactory>(
525       "/grpc.testing.EchoTestService/Echo", nullptr));
526   SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true,
527                         /*suffix_for_stats=*/nullptr);
528 }
529 
TEST_P(ClientCallbackEnd2endTest,SequentialGenericRpcsAsBidiWithSuffix)530 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithSuffix) {
531   ResetStub(std::make_unique<TestInterceptorFactory>(
532       "/grpc.testing.EchoTestService/Echo", "TestSuffix"));
533   SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true, "TestSuffix");
534 }
535 
TEST_P(ClientCallbackEnd2endTest,SequentialGenericRpcsAsBidiWithReactorReuse)536 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) {
537   ResetStub();
538   SendGenericEchoAsBidi(10, 10, /*do_writes_done=*/true,
539                         /*suffix_for_stats=*/nullptr);
540 }
541 
TEST_P(ClientCallbackEnd2endTest,GenericRpcNoWritesDone)542 TEST_P(ClientCallbackEnd2endTest, GenericRpcNoWritesDone) {
543   ResetStub();
544   SendGenericEchoAsBidi(1, 1, /*do_writes_done=*/false,
545                         /*suffix_for_stats=*/nullptr);
546 }
547 
548 #if GRPC_ALLOW_EXCEPTIONS
TEST_P(ClientCallbackEnd2endTest,ExceptingRpc)549 TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) {
550   ResetStub();
551   SendRpcsGeneric(10, true, nullptr);
552 }
553 #endif
554 
TEST_P(ClientCallbackEnd2endTest,MultipleRpcsWithVariedBinaryMetadataValue)555 TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) {
556   ResetStub();
557   std::vector<std::thread> threads;
558   threads.reserve(10);
559   for (int i = 0; i < 10; ++i) {
560     threads.emplace_back([this] { SendRpcs(10, true); });
561   }
562   for (int i = 0; i < 10; ++i) {
563     threads[i].join();
564   }
565 }
566 
TEST_P(ClientCallbackEnd2endTest,MultipleRpcs)567 TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) {
568   ResetStub();
569   std::vector<std::thread> threads;
570   threads.reserve(10);
571   for (int i = 0; i < 10; ++i) {
572     threads.emplace_back([this] { SendRpcs(10, false); });
573   }
574   for (int i = 0; i < 10; ++i) {
575     threads[i].join();
576   }
577 }
578 
TEST_P(ClientCallbackEnd2endTest,CancelRpcBeforeStart)579 TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) {
580   ResetStub();
581   EchoRequest request;
582   EchoResponse response;
583   ClientContext context;
584   request.set_message("hello");
585   context.TryCancel();
586 
587   std::mutex mu;
588   std::condition_variable cv;
589   bool done = false;
590   stub_->async()->Echo(&context, &request, &response,
591                        [&response, &done, &mu, &cv](Status s) {
592                          EXPECT_EQ("", response.message());
593                          EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
594                          std::lock_guard<std::mutex> l(mu);
595                          done = true;
596                          cv.notify_one();
597                        });
598   std::unique_lock<std::mutex> l(mu);
599   while (!done) {
600     cv.wait(l);
601   }
602   if (GetParam().use_interceptors) {
603     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
604   }
605 }
606 
TEST_P(ClientCallbackEnd2endTest,RequestEchoServerCancel)607 TEST_P(ClientCallbackEnd2endTest, RequestEchoServerCancel) {
608   ResetStub();
609   EchoRequest request;
610   EchoResponse response;
611   ClientContext context;
612   request.set_message("hello");
613   context.AddMetadata(kServerTryCancelRequest,
614                       std::to_string(CANCEL_BEFORE_PROCESSING));
615 
616   std::mutex mu;
617   std::condition_variable cv;
618   bool done = false;
619   stub_->async()->Echo(&context, &request, &response,
620                        [&done, &mu, &cv](Status s) {
621                          EXPECT_FALSE(s.ok());
622                          EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
623                          std::lock_guard<std::mutex> l(mu);
624                          done = true;
625                          cv.notify_one();
626                        });
627   std::unique_lock<std::mutex> l(mu);
628   while (!done) {
629     cv.wait(l);
630   }
631 }
632 
633 struct ClientCancelInfo {
634   bool cancel{false};
635   int ops_before_cancel;
636 
ClientCancelInfogrpc::testing::__anon04fca8480111::ClientCancelInfo637   ClientCancelInfo() : cancel{false} {}
ClientCancelInfogrpc::testing::__anon04fca8480111::ClientCancelInfo638   explicit ClientCancelInfo(int ops) : cancel{true}, ops_before_cancel{ops} {}
639 };
640 
641 class WriteClient : public grpc::ClientWriteReactor<EchoRequest> {
642  public:
WriteClient(grpc::testing::EchoTestService::Stub * stub,ServerTryCancelRequestPhase server_try_cancel,int num_msgs_to_send,ClientCancelInfo client_cancel={})643   WriteClient(grpc::testing::EchoTestService::Stub* stub,
644               ServerTryCancelRequestPhase server_try_cancel,
645               int num_msgs_to_send, ClientCancelInfo client_cancel = {})
646       : server_try_cancel_(server_try_cancel),
647         num_msgs_to_send_(num_msgs_to_send),
648         client_cancel_{client_cancel} {
649     std::string msg{"Hello server."};
650     for (int i = 0; i < num_msgs_to_send; i++) {
651       desired_ += msg;
652     }
653     if (server_try_cancel != DO_NOT_CANCEL) {
654       // Send server_try_cancel value in the client metadata
655       context_.AddMetadata(kServerTryCancelRequest,
656                            std::to_string(server_try_cancel));
657     }
658     context_.set_initial_metadata_corked(true);
659     stub->async()->RequestStream(&context_, &response_, this);
660     StartCall();
661     request_.set_message(msg);
662     MaybeWrite();
663   }
OnWriteDone(bool ok)664   void OnWriteDone(bool ok) override {
665     if (ok) {
666       num_msgs_sent_++;
667       MaybeWrite();
668     }
669   }
OnDone(const Status & s)670   void OnDone(const Status& s) override {
671     gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent_);
672     int num_to_send =
673         (client_cancel_.cancel)
674             ? std::min(num_msgs_to_send_, client_cancel_.ops_before_cancel)
675             : num_msgs_to_send_;
676     switch (server_try_cancel_) {
677       case CANCEL_BEFORE_PROCESSING:
678       case CANCEL_DURING_PROCESSING:
679         // If the RPC is canceled by server before / during messages from the
680         // client, it means that the client most likely did not get a chance to
681         // send all the messages it wanted to send. i.e num_msgs_sent <=
682         // num_msgs_to_send
683         EXPECT_LE(num_msgs_sent_, num_to_send);
684         break;
685       case DO_NOT_CANCEL:
686       case CANCEL_AFTER_PROCESSING:
687         // If the RPC was not canceled or canceled after all messages were read
688         // by the server, the client did get a chance to send all its messages
689         EXPECT_EQ(num_msgs_sent_, num_to_send);
690         break;
691       default:
692         assert(false);
693         break;
694     }
695     if ((server_try_cancel_ == DO_NOT_CANCEL) && !client_cancel_.cancel) {
696       EXPECT_TRUE(s.ok());
697       EXPECT_EQ(response_.message(), desired_);
698     } else {
699       EXPECT_FALSE(s.ok());
700       EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
701     }
702     std::unique_lock<std::mutex> l(mu_);
703     done_ = true;
704     cv_.notify_one();
705   }
Await()706   void Await() {
707     std::unique_lock<std::mutex> l(mu_);
708     while (!done_) {
709       cv_.wait(l);
710     }
711   }
712 
713  private:
MaybeWrite()714   void MaybeWrite() {
715     if (client_cancel_.cancel &&
716         num_msgs_sent_ == client_cancel_.ops_before_cancel) {
717       context_.TryCancel();
718     } else if (num_msgs_to_send_ > num_msgs_sent_ + 1) {
719       StartWrite(&request_);
720     } else if (num_msgs_to_send_ == num_msgs_sent_ + 1) {
721       StartWriteLast(&request_, WriteOptions());
722     }
723   }
724   EchoRequest request_;
725   EchoResponse response_;
726   ClientContext context_;
727   const ServerTryCancelRequestPhase server_try_cancel_;
728   int num_msgs_sent_{0};
729   const int num_msgs_to_send_;
730   std::string desired_;
731   const ClientCancelInfo client_cancel_;
732   std::mutex mu_;
733   std::condition_variable cv_;
734   bool done_ = false;
735 };
736 
TEST_P(ClientCallbackEnd2endTest,RequestStream)737 TEST_P(ClientCallbackEnd2endTest, RequestStream) {
738   ResetStub();
739   WriteClient test{stub_.get(), DO_NOT_CANCEL, 3};
740   test.Await();
741   // Make sure that the server interceptors were not notified to cancel
742   if (GetParam().use_interceptors) {
743     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
744   }
745 }
746 
TEST_P(ClientCallbackEnd2endTest,ClientCancelsRequestStream)747 TEST_P(ClientCallbackEnd2endTest, ClientCancelsRequestStream) {
748   ResetStub();
749   WriteClient test{stub_.get(), DO_NOT_CANCEL, 3, ClientCancelInfo{2}};
750   test.Await();
751   // Make sure that the server interceptors got the cancel
752   if (GetParam().use_interceptors) {
753     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
754   }
755 }
756 
757 // Server to cancel before doing reading the request
TEST_P(ClientCallbackEnd2endTest,RequestStreamServerCancelBeforeReads)758 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelBeforeReads) {
759   ResetStub();
760   WriteClient test{stub_.get(), CANCEL_BEFORE_PROCESSING, 1};
761   test.Await();
762   // Make sure that the server interceptors were notified
763   if (GetParam().use_interceptors) {
764     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
765   }
766 }
767 
768 // Server to cancel while reading a request from the stream in parallel
TEST_P(ClientCallbackEnd2endTest,RequestStreamServerCancelDuringRead)769 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelDuringRead) {
770   ResetStub();
771   WriteClient test{stub_.get(), CANCEL_DURING_PROCESSING, 10};
772   test.Await();
773   // Make sure that the server interceptors were notified
774   if (GetParam().use_interceptors) {
775     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
776   }
777 }
778 
779 // Server to cancel after reading all the requests but before returning to the
780 // client
TEST_P(ClientCallbackEnd2endTest,RequestStreamServerCancelAfterReads)781 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelAfterReads) {
782   ResetStub();
783   WriteClient test{stub_.get(), CANCEL_AFTER_PROCESSING, 4};
784   test.Await();
785   // Make sure that the server interceptors were notified
786   if (GetParam().use_interceptors) {
787     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
788   }
789 }
790 
TEST_P(ClientCallbackEnd2endTest,UnaryReactor)791 TEST_P(ClientCallbackEnd2endTest, UnaryReactor) {
792   ResetStub();
793   class UnaryClient : public grpc::ClientUnaryReactor {
794    public:
795     explicit UnaryClient(grpc::testing::EchoTestService::Stub* stub) {
796       cli_ctx_.AddMetadata("key1", "val1");
797       cli_ctx_.AddMetadata("key2", "val2");
798       request_.mutable_param()->set_echo_metadata_initially(true);
799       request_.set_message("Hello metadata");
800       stub->async()->Echo(&cli_ctx_, &request_, &response_, this);
801       StartCall();
802     }
803     void OnReadInitialMetadataDone(bool ok) override {
804       EXPECT_TRUE(ok);
805       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
806       EXPECT_EQ(
807           "val1",
808           ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
809       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
810       EXPECT_EQ(
811           "val2",
812           ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
813       initial_metadata_done_ = true;
814     }
815     void OnDone(const Status& s) override {
816       EXPECT_TRUE(initial_metadata_done_);
817       EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
818       EXPECT_TRUE(s.ok());
819       EXPECT_EQ(request_.message(), response_.message());
820       std::unique_lock<std::mutex> l(mu_);
821       done_ = true;
822       cv_.notify_one();
823     }
824     void Await() {
825       std::unique_lock<std::mutex> l(mu_);
826       while (!done_) {
827         cv_.wait(l);
828       }
829     }
830 
831    private:
832     EchoRequest request_;
833     EchoResponse response_;
834     ClientContext cli_ctx_;
835     std::mutex mu_;
836     std::condition_variable cv_;
837     bool done_{false};
838     bool initial_metadata_done_{false};
839   };
840 
841   UnaryClient test{stub_.get()};
842   test.Await();
843   // Make sure that the server interceptors were not notified of a cancel
844   if (GetParam().use_interceptors) {
845     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
846   }
847 }
848 
TEST_P(ClientCallbackEnd2endTest,GenericUnaryReactor)849 TEST_P(ClientCallbackEnd2endTest, GenericUnaryReactor) {
850   const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
851   constexpr char kSuffixForStats[] = "TestSuffixForStats";
852   ResetStub(
853       std::make_unique<TestInterceptorFactory>(kMethodName, kSuffixForStats));
854   class UnaryClient : public grpc::ClientUnaryReactor {
855    public:
856     UnaryClient(grpc::GenericStub* stub, const std::string& method_name,
857                 const char* suffix_for_stats) {
858       cli_ctx_.AddMetadata("key1", "val1");
859       cli_ctx_.AddMetadata("key2", "val2");
860       request_.mutable_param()->set_echo_metadata_initially(true);
861       request_.set_message("Hello metadata");
862       send_buf_ = SerializeToByteBuffer(&request_);
863 
864       StubOptions options(suffix_for_stats);
865       stub->PrepareUnaryCall(&cli_ctx_, method_name, options, send_buf_.get(),
866                              &recv_buf_, this);
867       StartCall();
868     }
869     void OnReadInitialMetadataDone(bool ok) override {
870       EXPECT_TRUE(ok);
871       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
872       EXPECT_EQ(
873           "val1",
874           ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
875       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
876       EXPECT_EQ(
877           "val2",
878           ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
879       initial_metadata_done_ = true;
880     }
881     void OnDone(const Status& s) override {
882       EXPECT_TRUE(initial_metadata_done_);
883       EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
884       EXPECT_TRUE(s.ok());
885       EchoResponse response;
886       EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
887       EXPECT_EQ(request_.message(), response.message());
888       std::unique_lock<std::mutex> l(mu_);
889       done_ = true;
890       cv_.notify_one();
891     }
892     void Await() {
893       std::unique_lock<std::mutex> l(mu_);
894       while (!done_) {
895         cv_.wait(l);
896       }
897     }
898 
899    private:
900     EchoRequest request_;
901     std::unique_ptr<ByteBuffer> send_buf_;
902     ByteBuffer recv_buf_;
903     ClientContext cli_ctx_;
904     std::mutex mu_;
905     std::condition_variable cv_;
906     bool done_{false};
907     bool initial_metadata_done_{false};
908   };
909 
910   UnaryClient test{generic_stub_.get(), kMethodName, kSuffixForStats};
911   test.Await();
912   // Make sure that the server interceptors were not notified of a cancel
913   if (GetParam().use_interceptors) {
914     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
915   }
916 }
917 
918 class ReadClient : public grpc::ClientReadReactor<EchoResponse> {
919  public:
ReadClient(grpc::testing::EchoTestService::Stub * stub,ServerTryCancelRequestPhase server_try_cancel,ClientCancelInfo client_cancel={})920   ReadClient(grpc::testing::EchoTestService::Stub* stub,
921              ServerTryCancelRequestPhase server_try_cancel,
922              ClientCancelInfo client_cancel = {})
923       : server_try_cancel_(server_try_cancel), client_cancel_{client_cancel} {
924     if (server_try_cancel_ != DO_NOT_CANCEL) {
925       // Send server_try_cancel value in the client metadata
926       context_.AddMetadata(kServerTryCancelRequest,
927                            std::to_string(server_try_cancel));
928     }
929     request_.set_message("Hello client ");
930     stub->async()->ResponseStream(&context_, &request_, this);
931     if (client_cancel_.cancel &&
932         reads_complete_ == client_cancel_.ops_before_cancel) {
933       context_.TryCancel();
934     }
935     // Even if we cancel, read until failure because there might be responses
936     // pending
937     StartRead(&response_);
938     StartCall();
939   }
OnReadDone(bool ok)940   void OnReadDone(bool ok) override {
941     if (!ok) {
942       if (server_try_cancel_ == DO_NOT_CANCEL && !client_cancel_.cancel) {
943         EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
944       }
945     } else {
946       EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
947       EXPECT_EQ(response_.message(),
948                 request_.message() + std::to_string(reads_complete_));
949       reads_complete_++;
950       if (client_cancel_.cancel &&
951           reads_complete_ == client_cancel_.ops_before_cancel) {
952         context_.TryCancel();
953       }
954       // Even if we cancel, read until failure because there might be responses
955       // pending
956       StartRead(&response_);
957     }
958   }
OnDone(const Status & s)959   void OnDone(const Status& s) override {
960     gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
961     switch (server_try_cancel_) {
962       case DO_NOT_CANCEL:
963         if (!client_cancel_.cancel || client_cancel_.ops_before_cancel >
964                                           kServerDefaultResponseStreamsToSend) {
965           EXPECT_TRUE(s.ok());
966           EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
967         } else {
968           EXPECT_GE(reads_complete_, client_cancel_.ops_before_cancel);
969           EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
970           // Status might be ok or cancelled depending on whether server
971           // sent status before client cancel went through
972           if (!s.ok()) {
973             EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
974           }
975         }
976         break;
977       case CANCEL_BEFORE_PROCESSING:
978         EXPECT_FALSE(s.ok());
979         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
980         EXPECT_EQ(reads_complete_, 0);
981         break;
982       case CANCEL_DURING_PROCESSING:
983       case CANCEL_AFTER_PROCESSING:
984         // If server canceled while writing messages, client must have read
985         // less than or equal to the expected number of messages. Even if the
986         // server canceled after writing all messages, the RPC may be canceled
987         // before the Client got a chance to read all the messages.
988         EXPECT_FALSE(s.ok());
989         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
990         EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
991         break;
992       default:
993         assert(false);
994     }
995     std::unique_lock<std::mutex> l(mu_);
996     done_ = true;
997     cv_.notify_one();
998   }
Await()999   void Await() {
1000     std::unique_lock<std::mutex> l(mu_);
1001     while (!done_) {
1002       cv_.wait(l);
1003     }
1004   }
1005 
1006  private:
1007   EchoRequest request_;
1008   EchoResponse response_;
1009   ClientContext context_;
1010   const ServerTryCancelRequestPhase server_try_cancel_;
1011   int reads_complete_{0};
1012   const ClientCancelInfo client_cancel_;
1013   std::mutex mu_;
1014   std::condition_variable cv_;
1015   bool done_ = false;
1016 };
1017 
TEST_P(ClientCallbackEnd2endTest,ResponseStream)1018 TEST_P(ClientCallbackEnd2endTest, ResponseStream) {
1019   ResetStub();
1020   ReadClient test{stub_.get(), DO_NOT_CANCEL};
1021   test.Await();
1022   // Make sure that the server interceptors were not notified of a cancel
1023   if (GetParam().use_interceptors) {
1024     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
1025   }
1026 }
1027 
TEST_P(ClientCallbackEnd2endTest,ClientCancelsResponseStream)1028 TEST_P(ClientCallbackEnd2endTest, ClientCancelsResponseStream) {
1029   ResetStub();
1030   ReadClient test{stub_.get(), DO_NOT_CANCEL, ClientCancelInfo{2}};
1031   test.Await();
1032   // Because cancel in this case races with server finish, we can't be sure that
1033   // server interceptors even see cancellation
1034 }
1035 
1036 // Server to cancel before sending any response messages
TEST_P(ClientCallbackEnd2endTest,ResponseStreamServerCancelBefore)1037 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelBefore) {
1038   ResetStub();
1039   ReadClient test{stub_.get(), CANCEL_BEFORE_PROCESSING};
1040   test.Await();
1041   // Make sure that the server interceptors were notified
1042   if (GetParam().use_interceptors) {
1043     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1044   }
1045 }
1046 
1047 // Server to cancel while writing a response to the stream in parallel
TEST_P(ClientCallbackEnd2endTest,ResponseStreamServerCancelDuring)1048 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelDuring) {
1049   ResetStub();
1050   ReadClient test{stub_.get(), CANCEL_DURING_PROCESSING};
1051   test.Await();
1052   // Make sure that the server interceptors were notified
1053   if (GetParam().use_interceptors) {
1054     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1055   }
1056 }
1057 
1058 // Server to cancel after writing all the respones to the stream but before
1059 // returning to the client
TEST_P(ClientCallbackEnd2endTest,ResponseStreamServerCancelAfter)1060 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelAfter) {
1061   ResetStub();
1062   ReadClient test{stub_.get(), CANCEL_AFTER_PROCESSING};
1063   test.Await();
1064   // Make sure that the server interceptors were notified
1065   if (GetParam().use_interceptors) {
1066     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1067   }
1068 }
1069 
1070 class BidiClient : public grpc::ClientBidiReactor<EchoRequest, EchoResponse> {
1071  public:
BidiClient(grpc::testing::EchoTestService::Stub * stub,ServerTryCancelRequestPhase server_try_cancel,int num_msgs_to_send,bool cork_metadata,bool first_write_async,ClientCancelInfo client_cancel={})1072   BidiClient(grpc::testing::EchoTestService::Stub* stub,
1073              ServerTryCancelRequestPhase server_try_cancel,
1074              int num_msgs_to_send, bool cork_metadata, bool first_write_async,
1075              ClientCancelInfo client_cancel = {})
1076       : server_try_cancel_(server_try_cancel),
1077         msgs_to_send_{num_msgs_to_send},
1078         client_cancel_{client_cancel} {
1079     if (server_try_cancel_ != DO_NOT_CANCEL) {
1080       // Send server_try_cancel value in the client metadata
1081       context_.AddMetadata(kServerTryCancelRequest,
1082                            std::to_string(server_try_cancel));
1083     }
1084     request_.set_message("Hello fren ");
1085     context_.set_initial_metadata_corked(cork_metadata);
1086     stub->async()->BidiStream(&context_, this);
1087     MaybeAsyncWrite(first_write_async);
1088     StartRead(&response_);
1089     StartCall();
1090   }
OnReadDone(bool ok)1091   void OnReadDone(bool ok) override {
1092     if (!ok) {
1093       if (server_try_cancel_ == DO_NOT_CANCEL) {
1094         if (!client_cancel_.cancel) {
1095           EXPECT_EQ(reads_complete_, msgs_to_send_);
1096         } else {
1097           EXPECT_LE(reads_complete_, writes_complete_);
1098         }
1099       }
1100     } else {
1101       EXPECT_LE(reads_complete_, msgs_to_send_);
1102       EXPECT_EQ(response_.message(), request_.message());
1103       reads_complete_++;
1104       StartRead(&response_);
1105     }
1106   }
OnWriteDone(bool ok)1107   void OnWriteDone(bool ok) override {
1108     if (async_write_thread_.joinable()) {
1109       async_write_thread_.join();
1110       RemoveHold();
1111     }
1112     if (server_try_cancel_ == DO_NOT_CANCEL) {
1113       EXPECT_TRUE(ok);
1114     } else if (!ok) {
1115       return;
1116     }
1117     writes_complete_++;
1118     MaybeWrite();
1119   }
OnDone(const Status & s)1120   void OnDone(const Status& s) override {
1121     gpr_log(GPR_INFO, "Sent %d messages", writes_complete_);
1122     gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
1123     switch (server_try_cancel_) {
1124       case DO_NOT_CANCEL:
1125         if (!client_cancel_.cancel ||
1126             client_cancel_.ops_before_cancel > msgs_to_send_) {
1127           EXPECT_TRUE(s.ok());
1128           EXPECT_EQ(writes_complete_, msgs_to_send_);
1129           EXPECT_EQ(reads_complete_, writes_complete_);
1130         } else {
1131           EXPECT_FALSE(s.ok());
1132           EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1133           EXPECT_EQ(writes_complete_, client_cancel_.ops_before_cancel);
1134           EXPECT_LE(reads_complete_, writes_complete_);
1135         }
1136         break;
1137       case CANCEL_BEFORE_PROCESSING:
1138         EXPECT_FALSE(s.ok());
1139         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1140         // The RPC is canceled before the server did any work or returned any
1141         // reads, but it's possible that some writes took place first from the
1142         // client
1143         EXPECT_LE(writes_complete_, msgs_to_send_);
1144         EXPECT_EQ(reads_complete_, 0);
1145         break;
1146       case CANCEL_DURING_PROCESSING:
1147         EXPECT_FALSE(s.ok());
1148         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1149         EXPECT_LE(writes_complete_, msgs_to_send_);
1150         EXPECT_LE(reads_complete_, writes_complete_);
1151         break;
1152       case CANCEL_AFTER_PROCESSING:
1153         EXPECT_FALSE(s.ok());
1154         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1155         EXPECT_EQ(writes_complete_, msgs_to_send_);
1156         // The Server canceled after reading the last message and after writing
1157         // the message to the client. However, the RPC cancellation might have
1158         // taken effect before the client actually read the response.
1159         EXPECT_LE(reads_complete_, writes_complete_);
1160         break;
1161       default:
1162         assert(false);
1163     }
1164     std::unique_lock<std::mutex> l(mu_);
1165     done_ = true;
1166     cv_.notify_one();
1167   }
Await()1168   void Await() {
1169     std::unique_lock<std::mutex> l(mu_);
1170     while (!done_) {
1171       cv_.wait(l);
1172     }
1173   }
1174 
1175  private:
MaybeAsyncWrite(bool first_write_async)1176   void MaybeAsyncWrite(bool first_write_async) {
1177     if (first_write_async) {
1178       // Make sure that we have a write to issue.
1179       // TODO(vjpai): Make this work with 0 writes case as well.
1180       assert(msgs_to_send_ >= 1);
1181 
1182       AddHold();
1183       async_write_thread_ = std::thread([this] {
1184         std::unique_lock<std::mutex> lock(async_write_thread_mu_);
1185         async_write_thread_cv_.wait(
1186             lock, [this] { return async_write_thread_start_; });
1187         MaybeWrite();
1188       });
1189       std::lock_guard<std::mutex> lock(async_write_thread_mu_);
1190       async_write_thread_start_ = true;
1191       async_write_thread_cv_.notify_one();
1192       return;
1193     }
1194     MaybeWrite();
1195   }
MaybeWrite()1196   void MaybeWrite() {
1197     if (client_cancel_.cancel &&
1198         writes_complete_ == client_cancel_.ops_before_cancel) {
1199       context_.TryCancel();
1200     } else if (writes_complete_ == msgs_to_send_) {
1201       StartWritesDone();
1202     } else {
1203       StartWrite(&request_);
1204     }
1205   }
1206   EchoRequest request_;
1207   EchoResponse response_;
1208   ClientContext context_;
1209   const ServerTryCancelRequestPhase server_try_cancel_;
1210   int reads_complete_{0};
1211   int writes_complete_{0};
1212   const int msgs_to_send_;
1213   const ClientCancelInfo client_cancel_;
1214   std::mutex mu_;
1215   std::condition_variable cv_;
1216   bool done_ = false;
1217   std::thread async_write_thread_;
1218   bool async_write_thread_start_ = false;
1219   std::mutex async_write_thread_mu_;
1220   std::condition_variable async_write_thread_cv_;
1221 };
1222 
TEST_P(ClientCallbackEnd2endTest,BidiStream)1223 TEST_P(ClientCallbackEnd2endTest, BidiStream) {
1224   ResetStub();
1225   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1226                   kServerDefaultResponseStreamsToSend,
1227                   /*cork_metadata=*/false, /*first_write_async=*/false);
1228   test.Await();
1229   // Make sure that the server interceptors were not notified of a cancel
1230   if (GetParam().use_interceptors) {
1231     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
1232   }
1233 }
1234 
TEST_P(ClientCallbackEnd2endTest,BidiStreamFirstWriteAsync)1235 TEST_P(ClientCallbackEnd2endTest, BidiStreamFirstWriteAsync) {
1236   ResetStub();
1237   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1238                   kServerDefaultResponseStreamsToSend,
1239                   /*cork_metadata=*/false, /*first_write_async=*/true);
1240   test.Await();
1241   // Make sure that the server interceptors were not notified of a cancel
1242   if (GetParam().use_interceptors) {
1243     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
1244   }
1245 }
1246 
TEST_P(ClientCallbackEnd2endTest,BidiStreamCorked)1247 TEST_P(ClientCallbackEnd2endTest, BidiStreamCorked) {
1248   ResetStub();
1249   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1250                   kServerDefaultResponseStreamsToSend,
1251                   /*cork_metadata=*/true, /*first_write_async=*/false);
1252   test.Await();
1253   // Make sure that the server interceptors were not notified of a cancel
1254   if (GetParam().use_interceptors) {
1255     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
1256   }
1257 }
1258 
TEST_P(ClientCallbackEnd2endTest,BidiStreamCorkedFirstWriteAsync)1259 TEST_P(ClientCallbackEnd2endTest, BidiStreamCorkedFirstWriteAsync) {
1260   ResetStub();
1261   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1262                   kServerDefaultResponseStreamsToSend,
1263                   /*cork_metadata=*/true, /*first_write_async=*/true);
1264   test.Await();
1265   // Make sure that the server interceptors were not notified of a cancel
1266   if (GetParam().use_interceptors) {
1267     EXPECT_EQ(0, PhonyInterceptor::GetNumTimesCancel());
1268   }
1269 }
1270 
TEST_P(ClientCallbackEnd2endTest,ClientCancelsBidiStream)1271 TEST_P(ClientCallbackEnd2endTest, ClientCancelsBidiStream) {
1272   ResetStub();
1273   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1274                   kServerDefaultResponseStreamsToSend,
1275                   /*cork_metadata=*/false, /*first_write_async=*/false,
1276                   ClientCancelInfo(2));
1277   test.Await();
1278   // Make sure that the server interceptors were notified of a cancel
1279   if (GetParam().use_interceptors) {
1280     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1281   }
1282 }
1283 
1284 // Server to cancel before reading/writing any requests/responses on the stream
TEST_P(ClientCallbackEnd2endTest,BidiStreamServerCancelBefore)1285 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelBefore) {
1286   ResetStub();
1287   BidiClient test(stub_.get(), CANCEL_BEFORE_PROCESSING, /*num_msgs_to_send=*/2,
1288                   /*cork_metadata=*/false, /*first_write_async=*/false);
1289   test.Await();
1290   // Make sure that the server interceptors were notified
1291   if (GetParam().use_interceptors) {
1292     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1293   }
1294 }
1295 
1296 // Server to cancel while reading/writing requests/responses on the stream in
1297 // parallel
TEST_P(ClientCallbackEnd2endTest,BidiStreamServerCancelDuring)1298 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelDuring) {
1299   ResetStub();
1300   BidiClient test(stub_.get(), CANCEL_DURING_PROCESSING,
1301                   /*num_msgs_to_send=*/10, /*cork_metadata=*/false,
1302                   /*first_write_async=*/false);
1303   test.Await();
1304   // Make sure that the server interceptors were notified
1305   if (GetParam().use_interceptors) {
1306     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1307   }
1308 }
1309 
1310 // Server to cancel after reading/writing all requests/responses on the stream
1311 // but before returning to the client
TEST_P(ClientCallbackEnd2endTest,BidiStreamServerCancelAfter)1312 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelAfter) {
1313   ResetStub();
1314   BidiClient test(stub_.get(), CANCEL_AFTER_PROCESSING, /*num_msgs_to_send=*/5,
1315                   /*cork_metadata=*/false, /*first_write_async=*/false);
1316   test.Await();
1317   // Make sure that the server interceptors were notified
1318   if (GetParam().use_interceptors) {
1319     EXPECT_EQ(20, PhonyInterceptor::GetNumTimesCancel());
1320   }
1321 }
1322 
TEST_P(ClientCallbackEnd2endTest,SimultaneousReadAndWritesDone)1323 TEST_P(ClientCallbackEnd2endTest, SimultaneousReadAndWritesDone) {
1324   ResetStub();
1325   class Client : public grpc::ClientBidiReactor<EchoRequest, EchoResponse> {
1326    public:
1327     explicit Client(grpc::testing::EchoTestService::Stub* stub) {
1328       request_.set_message("Hello bidi ");
1329       stub->async()->BidiStream(&context_, this);
1330       StartWrite(&request_);
1331       StartCall();
1332     }
1333     void OnReadDone(bool ok) override {
1334       EXPECT_TRUE(ok);
1335       EXPECT_EQ(response_.message(), request_.message());
1336     }
1337     void OnWriteDone(bool ok) override {
1338       EXPECT_TRUE(ok);
1339       // Now send out the simultaneous Read and WritesDone
1340       StartWritesDone();
1341       StartRead(&response_);
1342     }
1343     void OnDone(const Status& s) override {
1344       EXPECT_TRUE(s.ok());
1345       EXPECT_EQ(response_.message(), request_.message());
1346       std::unique_lock<std::mutex> l(mu_);
1347       done_ = true;
1348       cv_.notify_one();
1349     }
1350     void Await() {
1351       std::unique_lock<std::mutex> l(mu_);
1352       while (!done_) {
1353         cv_.wait(l);
1354       }
1355     }
1356 
1357    private:
1358     EchoRequest request_;
1359     EchoResponse response_;
1360     ClientContext context_;
1361     std::mutex mu_;
1362     std::condition_variable cv_;
1363     bool done_ = false;
1364   } test{stub_.get()};
1365 
1366   test.Await();
1367 }
1368 
TEST_P(ClientCallbackEnd2endTest,UnimplementedRpc)1369 TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) {
1370   ChannelArguments args;
1371   const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials(
1372       GetParam().credentials_type, &args);
1373   std::shared_ptr<Channel> channel =
1374       (GetParam().protocol == Protocol::TCP)
1375           ? grpc::CreateCustomChannel(server_address_.str(), channel_creds,
1376                                       args)
1377           : server_->InProcessChannel(args);
1378   std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
1379   stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
1380   EchoRequest request;
1381   EchoResponse response;
1382   ClientContext cli_ctx;
1383   request.set_message("Hello world.");
1384   std::mutex mu;
1385   std::condition_variable cv;
1386   bool done = false;
1387   stub->async()->Unimplemented(
1388       &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
1389         EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code());
1390         EXPECT_EQ("", s.error_message());
1391 
1392         std::lock_guard<std::mutex> l(mu);
1393         done = true;
1394         cv.notify_one();
1395       });
1396   std::unique_lock<std::mutex> l(mu);
1397   while (!done) {
1398     cv.wait(l);
1399   }
1400 }
1401 
TEST_P(ClientCallbackEnd2endTest,TestTrailersOnlyOnError)1402 TEST_P(ClientCallbackEnd2endTest, TestTrailersOnlyOnError) {
1403   // Note that trailers-only is an HTTP/2 concept so we shouldn't do this test
1404   // for any other transport such as inproc.
1405   if (GetParam().protocol != Protocol::TCP) {
1406     return;
1407   }
1408 
1409   ResetStub();
1410   class Reactor : public grpc::ClientBidiReactor<EchoRequest, EchoResponse> {
1411    public:
1412     explicit Reactor(grpc::testing::EchoTestService::Stub* stub) {
1413       stub->async()->UnimplementedBidi(&context_, this);
1414       StartCall();
1415     }
1416     void Await() {
1417       std::unique_lock<std::mutex> l(mu_);
1418       while (!done_) {
1419         done_cv_.wait(l);
1420       }
1421     }
1422 
1423    private:
1424     void OnReadInitialMetadataDone(bool ok) override { EXPECT_FALSE(ok); }
1425     void OnDone(const Status& s) override {
1426       EXPECT_EQ(s.error_code(), grpc::StatusCode::UNIMPLEMENTED);
1427       EXPECT_EQ(s.error_message(), "");
1428       std::unique_lock<std::mutex> l(mu_);
1429       done_ = true;
1430       done_cv_.notify_one();
1431     }
1432 
1433     ClientContext context_;
1434     std::mutex mu_;
1435     std::condition_variable done_cv_;
1436     bool done_ = false;
1437   } client(stub_.get());
1438 
1439   client.Await();
1440 }
1441 
TEST_P(ClientCallbackEnd2endTest,ResponseStreamExtraReactionFlowReadsUntilDone)1442 TEST_P(ClientCallbackEnd2endTest,
1443        ResponseStreamExtraReactionFlowReadsUntilDone) {
1444   ResetStub();
1445   class ReadAllIncomingDataClient
1446       : public grpc::ClientReadReactor<EchoResponse> {
1447    public:
1448     explicit ReadAllIncomingDataClient(
1449         grpc::testing::EchoTestService::Stub* stub) {
1450       request_.set_message("Hello client ");
1451       stub->async()->ResponseStream(&context_, &request_, this);
1452     }
1453     bool WaitForReadDone() {
1454       std::unique_lock<std::mutex> l(mu_);
1455       while (!read_done_) {
1456         read_cv_.wait(l);
1457       }
1458       read_done_ = false;
1459       return read_ok_;
1460     }
1461     void Await() {
1462       std::unique_lock<std::mutex> l(mu_);
1463       while (!done_) {
1464         done_cv_.wait(l);
1465       }
1466     }
1467     // RemoveHold under the same lock used for OnDone to make sure that we don't
1468     // call OnDone directly or indirectly from the RemoveHold function.
1469     void RemoveHoldUnderLock() {
1470       std::unique_lock<std::mutex> l(mu_);
1471       RemoveHold();
1472     }
1473     const Status& status() {
1474       std::unique_lock<std::mutex> l(mu_);
1475       return status_;
1476     }
1477 
1478    private:
1479     void OnReadDone(bool ok) override {
1480       std::unique_lock<std::mutex> l(mu_);
1481       read_ok_ = ok;
1482       read_done_ = true;
1483       read_cv_.notify_one();
1484     }
1485     void OnDone(const Status& s) override {
1486       std::unique_lock<std::mutex> l(mu_);
1487       done_ = true;
1488       status_ = s;
1489       done_cv_.notify_one();
1490     }
1491 
1492     EchoRequest request_;
1493     EchoResponse response_;
1494     ClientContext context_;
1495     bool read_ok_ = false;
1496     bool read_done_ = false;
1497     std::mutex mu_;
1498     std::condition_variable read_cv_;
1499     std::condition_variable done_cv_;
1500     bool done_ = false;
1501     Status status_;
1502   } client{stub_.get()};
1503 
1504   int reads_complete = 0;
1505   client.AddHold();
1506   client.StartCall();
1507 
1508   EchoResponse response;
1509   bool read_ok = true;
1510   while (read_ok) {
1511     client.StartRead(&response);
1512     read_ok = client.WaitForReadDone();
1513     if (read_ok) {
1514       ++reads_complete;
1515     }
1516   }
1517   client.RemoveHoldUnderLock();
1518   client.Await();
1519 
1520   EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete);
1521   EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK);
1522 }
1523 
CreateTestScenarios(bool test_insecure)1524 std::vector<TestScenario> CreateTestScenarios(bool test_insecure) {
1525 #if TARGET_OS_IPHONE
1526   // Workaround Apple CFStream bug
1527   grpc_core::SetEnv("grpc_cfstream", "0");
1528 #endif
1529 
1530   std::vector<TestScenario> scenarios;
1531   std::vector<std::string> credentials_types{
1532       GetCredentialsProvider()->GetSecureCredentialsTypeList()};
1533   auto insec_ok = [] {
1534     // Only allow insecure credentials type when it is registered with the
1535     // provider. User may create providers that do not have insecure.
1536     return GetCredentialsProvider()->GetChannelCredentials(
1537                kInsecureCredentialsType, nullptr) != nullptr;
1538   };
1539   if (test_insecure && insec_ok()) {
1540     credentials_types.push_back(kInsecureCredentialsType);
1541   }
1542   GPR_ASSERT(!credentials_types.empty());
1543 
1544   bool barr[]{false, true};
1545   Protocol parr[]{Protocol::INPROC, Protocol::TCP};
1546   for (Protocol p : parr) {
1547     for (const auto& cred : credentials_types) {
1548       // TODO(vjpai): Test inproc with secure credentials when feasible
1549       if (p == Protocol::INPROC &&
1550           (cred != kInsecureCredentialsType || !insec_ok())) {
1551         continue;
1552       }
1553       for (bool callback_server : barr) {
1554         for (bool use_interceptors : barr) {
1555           scenarios.emplace_back(callback_server, p, use_interceptors, cred);
1556         }
1557       }
1558     }
1559   }
1560   return scenarios;
1561 }
1562 
1563 INSTANTIATE_TEST_SUITE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest,
1564                          ::testing::ValuesIn(CreateTestScenarios(true)));
1565 
1566 }  // namespace
1567 }  // namespace testing
1568 }  // namespace grpc
1569 
main(int argc,char ** argv)1570 int main(int argc, char** argv) {
1571   ::testing::InitGoogleTest(&argc, argv);
1572   grpc::testing::TestEnvironment env(&argc, argv);
1573   grpc_init();
1574   int ret = RUN_ALL_TESTS();
1575   grpc_shutdown();
1576   return ret;
1577 }
1578