xref: /aosp_15_r20/external/grpc-grpc/test/cpp/qps/client_async.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <forward_list>
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <mutex>
24 #include <sstream>
25 #include <string>
26 #include <thread>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/memory/memory.h"
31 
32 #include <grpc/grpc.h>
33 #include <grpc/support/cpu.h>
34 #include <grpc/support/log.h>
35 #include <grpcpp/alarm.h>
36 #include <grpcpp/channel.h>
37 #include <grpcpp/client_context.h>
38 #include <grpcpp/generic/generic_stub.h>
39 
40 #include "src/core/lib/gprpp/crash.h"
41 #include "src/core/lib/surface/completion_queue.h"
42 #include "src/proto/grpc/testing/benchmark_service.grpc.pb.h"
43 #include "test/cpp/qps/client.h"
44 #include "test/cpp/qps/usage_timer.h"
45 #include "test/cpp/util/create_test_channel.h"
46 
47 namespace grpc {
48 namespace testing {
49 
50 class ClientRpcContext {
51  public:
ClientRpcContext()52   ClientRpcContext() {}
~ClientRpcContext()53   virtual ~ClientRpcContext() {}
54   // next state, return false if done. Collect stats when appropriate
55   virtual bool RunNextState(bool, HistogramEntry* entry) = 0;
56   virtual void StartNewClone(CompletionQueue* cq) = 0;
tag(ClientRpcContext * c)57   static void* tag(ClientRpcContext* c) { return static_cast<void*>(c); }
detag(void * t)58   static ClientRpcContext* detag(void* t) {
59     return static_cast<ClientRpcContext*>(t);
60   }
61 
62   virtual void Start(CompletionQueue* cq, const ClientConfig& config) = 0;
63   virtual void TryCancel() = 0;
64 };
65 
66 template <class RequestType, class ResponseType>
67 class ClientRpcContextUnaryImpl : public ClientRpcContext {
68  public:
ClientRpcContextUnaryImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,const RequestType &,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *,HistogramEntry *)> on_done)69   ClientRpcContextUnaryImpl(
70       BenchmarkService::Stub* stub, const RequestType& req,
71       std::function<gpr_timespec()> next_issue,
72       std::function<
73           std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
74               BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
75               CompletionQueue*)>
76           prepare_req,
77       std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> on_done)
78       : context_(),
79         stub_(stub),
80         cq_(nullptr),
81         req_(req),
82         response_(),
83         next_state_(State::READY),
84         callback_(on_done),
85         next_issue_(std::move(next_issue)),
86         prepare_req_(prepare_req) {}
~ClientRpcContextUnaryImpl()87   ~ClientRpcContextUnaryImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)88   void Start(CompletionQueue* cq, const ClientConfig& config) override {
89     GPR_ASSERT(!config.use_coalesce_api());  // not supported.
90     StartInternal(cq);
91   }
RunNextState(bool,HistogramEntry * entry)92   bool RunNextState(bool /*ok*/, HistogramEntry* entry) override {
93     switch (next_state_) {
94       case State::READY:
95         start_ = UsageTimer::Now();
96         response_reader_ = prepare_req_(stub_, &context_, req_, cq_);
97         response_reader_->StartCall();
98         next_state_ = State::RESP_DONE;
99         response_reader_->Finish(&response_, &status_,
100                                  ClientRpcContext::tag(this));
101         return true;
102       case State::RESP_DONE:
103         if (status_.ok()) {
104           entry->set_value((UsageTimer::Now() - start_) * 1e9);
105         }
106         callback_(status_, &response_, entry);
107         next_state_ = State::INVALID;
108         return false;
109       default:
110         grpc_core::Crash("unreachable");
111         return false;
112     }
113   }
StartNewClone(CompletionQueue * cq)114   void StartNewClone(CompletionQueue* cq) override {
115     auto* clone = new ClientRpcContextUnaryImpl(stub_, req_, next_issue_,
116                                                 prepare_req_, callback_);
117     clone->StartInternal(cq);
118   }
TryCancel()119   void TryCancel() override { context_.TryCancel(); }
120 
121  private:
122   grpc::ClientContext context_;
123   BenchmarkService::Stub* stub_;
124   CompletionQueue* cq_;
125   std::unique_ptr<Alarm> alarm_;
126   const RequestType& req_;
127   ResponseType response_;
128   enum State { INVALID, READY, RESP_DONE };
129   State next_state_;
130   std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> callback_;
131   std::function<gpr_timespec()> next_issue_;
132   std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
133       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
134       CompletionQueue*)>
135       prepare_req_;
136   grpc::Status status_;
137   double start_;
138   std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>
139       response_reader_;
140 
StartInternal(CompletionQueue * cq)141   void StartInternal(CompletionQueue* cq) {
142     cq_ = cq;
143     if (!next_issue_) {  // ready to issue
144       RunNextState(true, nullptr);
145     } else {  // wait for the issue time
146       alarm_ = std::make_unique<Alarm>();
147       alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
148     }
149   }
150 };
151 
152 template <class StubType, class RequestType>
153 class AsyncClient : public ClientImpl<StubType, RequestType> {
154   // Specify which protected members we are using since there is no
155   // member name resolution until the template types are fully resolved
156  public:
157   using Client::closed_loop_;
158   using Client::NextIssuer;
159   using Client::SetupLoadTest;
160   using ClientImpl<StubType, RequestType>::cores_;
161   using ClientImpl<StubType, RequestType>::channels_;
162   using ClientImpl<StubType, RequestType>::request_;
AsyncClient(const ClientConfig & config,std::function<ClientRpcContext * (StubType *,std::function<gpr_timespec ()> next_issue,const RequestType &)> setup_ctx,std::function<std::unique_ptr<StubType> (std::shared_ptr<Channel>)> create_stub)163   AsyncClient(const ClientConfig& config,
164               std::function<ClientRpcContext*(
165                   StubType*, std::function<gpr_timespec()> next_issue,
166                   const RequestType&)>
167                   setup_ctx,
168               std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)>
169                   create_stub)
170       : ClientImpl<StubType, RequestType>(config, create_stub),
171         num_async_threads_(NumThreads(config)) {
172     SetupLoadTest(config, num_async_threads_);
173 
174     int tpc = std::max(1, config.threads_per_cq());      // 1 if unspecified
175     int num_cqs = (num_async_threads_ + tpc - 1) / tpc;  // ceiling operator
176     for (int i = 0; i < num_cqs; i++) {
177       cli_cqs_.emplace_back(new CompletionQueue);
178     }
179 
180     for (int i = 0; i < num_async_threads_; i++) {
181       cq_.emplace_back(i % cli_cqs_.size());
182       next_issuers_.emplace_back(NextIssuer(i));
183       shutdown_state_.emplace_back(new PerThreadShutdownState());
184     }
185 
186     int t = 0;
187     for (int ch = 0; ch < config.client_channels(); ch++) {
188       for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) {
189         auto* cq = cli_cqs_[t].get();
190         auto ctx =
191             setup_ctx(channels_[ch].get_stub(), next_issuers_[t], request_);
192         ctx->Start(cq, config);
193       }
194       t = (t + 1) % cli_cqs_.size();
195     }
196   }
~AsyncClient()197   ~AsyncClient() override {
198     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
199       void* got_tag;
200       bool ok;
201       while ((*cq)->Next(&got_tag, &ok)) {
202         delete ClientRpcContext::detag(got_tag);
203       }
204     }
205   }
206 
GetPollCount()207   int GetPollCount() override {
208     int count = 0;
209     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
210       count += grpc_get_cq_poll_num((*cq)->cq());
211     }
212     return count;
213   }
214 
215  protected:
216   const int num_async_threads_;
217 
218  private:
219   struct PerThreadShutdownState {
220     mutable std::mutex mutex;
221     bool shutdown;
PerThreadShutdownStategrpc::testing::AsyncClient::PerThreadShutdownState222     PerThreadShutdownState() : shutdown(false) {}
223   };
224 
NumThreads(const ClientConfig & config)225   int NumThreads(const ClientConfig& config) {
226     int num_threads = config.async_client_threads();
227     if (num_threads <= 0) {  // Use dynamic sizing
228       num_threads = cores_;
229       gpr_log(GPR_INFO, "Sizing async client to %d threads", num_threads);
230     }
231     return num_threads;
232   }
DestroyMultithreading()233   void DestroyMultithreading() final {
234     for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) {
235       std::lock_guard<std::mutex> lock((*ss)->mutex);
236       (*ss)->shutdown = true;
237     }
238     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
239       (*cq)->Shutdown();
240     }
241     this->EndThreads();  // this needed for resolution
242   }
243 
ProcessTag(size_t thread_idx,void * tag)244   ClientRpcContext* ProcessTag(size_t thread_idx, void* tag) {
245     ClientRpcContext* ctx = ClientRpcContext::detag(tag);
246     if (shutdown_state_[thread_idx]->shutdown) {
247       ctx->TryCancel();
248       delete ctx;
249       bool ok;
250       while (cli_cqs_[cq_[thread_idx]]->Next(&tag, &ok)) {
251         ctx = ClientRpcContext::detag(tag);
252         ctx->TryCancel();
253         delete ctx;
254       }
255       return nullptr;
256     }
257     return ctx;
258   }
259 
ThreadFunc(size_t thread_idx,Client::Thread * t)260   void ThreadFunc(size_t thread_idx, Client::Thread* t) final {
261     void* got_tag;
262     bool ok;
263 
264     HistogramEntry entry;
265     HistogramEntry* entry_ptr = &entry;
266     if (!cli_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) {
267       return;
268     }
269     std::mutex* shutdown_mu = &shutdown_state_[thread_idx]->mutex;
270     shutdown_mu->lock();
271     ClientRpcContext* ctx = ProcessTag(thread_idx, got_tag);
272     if (ctx == nullptr) {
273       shutdown_mu->unlock();
274       return;
275     }
276     while (cli_cqs_[cq_[thread_idx]]->DoThenAsyncNext(
277         [&, ctx, ok, entry_ptr, shutdown_mu]() {
278           if (!ctx->RunNextState(ok, entry_ptr)) {
279             // The RPC and callback are done, so clone the ctx
280             // and kickstart the new one
281             ctx->StartNewClone(cli_cqs_[cq_[thread_idx]].get());
282             delete ctx;
283           }
284           shutdown_mu->unlock();
285         },
286         &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))) {
287       t->UpdateHistogram(entry_ptr);
288       entry = HistogramEntry();
289       shutdown_mu->lock();
290       ctx = ProcessTag(thread_idx, got_tag);
291       if (ctx == nullptr) {
292         shutdown_mu->unlock();
293         return;
294       }
295     }
296   }
297 
298   std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_;
299   std::vector<int> cq_;
300   std::vector<std::function<gpr_timespec()>> next_issuers_;
301   std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_;
302 };
303 
BenchmarkStubCreator(const std::shared_ptr<Channel> & ch)304 static std::unique_ptr<BenchmarkService::Stub> BenchmarkStubCreator(
305     const std::shared_ptr<Channel>& ch) {
306   return BenchmarkService::NewStub(ch);
307 }
308 
309 class AsyncUnaryClient final
310     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
311  public:
AsyncUnaryClient(const ClientConfig & config)312   explicit AsyncUnaryClient(const ClientConfig& config)
313       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
314             config, SetupCtx, BenchmarkStubCreator) {
315     StartThreads(num_async_threads_);
316   }
~AsyncUnaryClient()317   ~AsyncUnaryClient() override {}
318 
319  private:
CheckDone(const grpc::Status & s,SimpleResponse *,HistogramEntry * entry)320   static void CheckDone(const grpc::Status& s, SimpleResponse* /*response*/,
321                         HistogramEntry* entry) {
322     entry->set_status(s.error_code());
323   }
324   static std::unique_ptr<grpc::ClientAsyncResponseReader<SimpleResponse>>
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,const SimpleRequest & request,CompletionQueue * cq)325   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
326              const SimpleRequest& request, CompletionQueue* cq) {
327     return stub->PrepareAsyncUnaryCall(ctx, request, cq);
328   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)329   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
330                                     std::function<gpr_timespec()> next_issue,
331                                     const SimpleRequest& req) {
332     return new ClientRpcContextUnaryImpl<SimpleRequest, SimpleResponse>(
333         stub, req, std::move(next_issue), AsyncUnaryClient::PrepareReq,
334         AsyncUnaryClient::CheckDone);
335   }
336 };
337 
338 template <class RequestType, class ResponseType>
339 class ClientRpcContextStreamingPingPongImpl : public ClientRpcContext {
340  public:
ClientRpcContextStreamingPingPongImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType,ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)341   ClientRpcContextStreamingPingPongImpl(
342       BenchmarkService::Stub* stub, const RequestType& req,
343       std::function<gpr_timespec()> next_issue,
344       std::function<std::unique_ptr<
345           grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
346           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
347           prepare_req,
348       std::function<void(grpc::Status, ResponseType*)> on_done)
349       : context_(),
350         stub_(stub),
351         cq_(nullptr),
352         req_(req),
353         response_(),
354         next_state_(State::INVALID),
355         callback_(on_done),
356         next_issue_(std::move(next_issue)),
357         prepare_req_(prepare_req),
358         coalesce_(false) {}
~ClientRpcContextStreamingPingPongImpl()359   ~ClientRpcContextStreamingPingPongImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)360   void Start(CompletionQueue* cq, const ClientConfig& config) override {
361     StartInternal(cq, config.messages_per_stream(), config.use_coalesce_api());
362   }
RunNextState(bool ok,HistogramEntry * entry)363   bool RunNextState(bool ok, HistogramEntry* entry) override {
364     while (true) {
365       switch (next_state_) {
366         case State::STREAM_IDLE:
367           if (!next_issue_) {  // ready to issue
368             next_state_ = State::READY_TO_WRITE;
369           } else {
370             next_state_ = State::WAIT;
371           }
372           break;  // loop around, don't return
373         case State::WAIT:
374           next_state_ = State::READY_TO_WRITE;
375           alarm_ = std::make_unique<Alarm>();
376           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
377           return true;
378         case State::READY_TO_WRITE:
379           if (!ok) {
380             return false;
381           }
382           start_ = UsageTimer::Now();
383           next_state_ = State::WRITE_DONE;
384           if (coalesce_ && messages_issued_ == messages_per_stream_ - 1) {
385             stream_->WriteLast(req_, WriteOptions(),
386                                ClientRpcContext::tag(this));
387           } else {
388             stream_->Write(req_, ClientRpcContext::tag(this));
389           }
390           return true;
391         case State::WRITE_DONE:
392           if (!ok) {
393             return false;
394           }
395           next_state_ = State::READ_DONE;
396           stream_->Read(&response_, ClientRpcContext::tag(this));
397           return true;
398           break;
399         case State::READ_DONE:
400           entry->set_value((UsageTimer::Now() - start_) * 1e9);
401           callback_(status_, &response_);
402           if ((messages_per_stream_ != 0) &&
403               (++messages_issued_ >= messages_per_stream_)) {
404             next_state_ = State::WRITES_DONE_DONE;
405             if (coalesce_) {
406               // WritesDone should have been called on the last Write.
407               // loop around to call Finish.
408               break;
409             }
410             stream_->WritesDone(ClientRpcContext::tag(this));
411             return true;
412           }
413           next_state_ = State::STREAM_IDLE;
414           break;  // loop around
415         case State::WRITES_DONE_DONE:
416           next_state_ = State::FINISH_DONE;
417           stream_->Finish(&status_, ClientRpcContext::tag(this));
418           return true;
419         case State::FINISH_DONE:
420           next_state_ = State::INVALID;
421           return false;
422           break;
423         default:
424           grpc_core::Crash("unreachable");
425           return false;
426       }
427     }
428   }
StartNewClone(CompletionQueue * cq)429   void StartNewClone(CompletionQueue* cq) override {
430     auto* clone = new ClientRpcContextStreamingPingPongImpl(
431         stub_, req_, next_issue_, prepare_req_, callback_);
432     clone->StartInternal(cq, messages_per_stream_, coalesce_);
433   }
TryCancel()434   void TryCancel() override { context_.TryCancel(); }
435 
436  private:
437   grpc::ClientContext context_;
438   BenchmarkService::Stub* stub_;
439   CompletionQueue* cq_;
440   std::unique_ptr<Alarm> alarm_;
441   const RequestType& req_;
442   ResponseType response_;
443   enum State {
444     INVALID,
445     STREAM_IDLE,
446     WAIT,
447     READY_TO_WRITE,
448     WRITE_DONE,
449     READ_DONE,
450     WRITES_DONE_DONE,
451     FINISH_DONE
452   };
453   State next_state_;
454   std::function<void(grpc::Status, ResponseType*)> callback_;
455   std::function<gpr_timespec()> next_issue_;
456   std::function<
457       std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
458           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
459       prepare_req_;
460   grpc::Status status_;
461   double start_;
462   std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>
463       stream_;
464 
465   // Allow a limit on number of messages in a stream
466   int messages_per_stream_;
467   int messages_issued_;
468   // Whether to use coalescing API.
469   bool coalesce_;
470 
StartInternal(CompletionQueue * cq,int messages_per_stream,bool coalesce)471   void StartInternal(CompletionQueue* cq, int messages_per_stream,
472                      bool coalesce) {
473     cq_ = cq;
474     messages_per_stream_ = messages_per_stream;
475     messages_issued_ = 0;
476     coalesce_ = coalesce;
477     if (coalesce_) {
478       GPR_ASSERT(messages_per_stream_ != 0);
479       context_.set_initial_metadata_corked(true);
480     }
481     stream_ = prepare_req_(stub_, &context_, cq);
482     next_state_ = State::STREAM_IDLE;
483     stream_->StartCall(ClientRpcContext::tag(this));
484     if (coalesce_) {
485       // When the initial metadata is corked, the tag will not come back and we
486       // need to manually drive the state machine.
487       RunNextState(true, nullptr);
488     }
489   }
490 };
491 
492 class AsyncStreamingPingPongClient final
493     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
494  public:
AsyncStreamingPingPongClient(const ClientConfig & config)495   explicit AsyncStreamingPingPongClient(const ClientConfig& config)
496       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
497             config, SetupCtx, BenchmarkStubCreator) {
498     StartThreads(num_async_threads_);
499   }
500 
~AsyncStreamingPingPongClient()501   ~AsyncStreamingPingPongClient() override {}
502 
503  private:
CheckDone(const grpc::Status &,SimpleResponse *)504   static void CheckDone(const grpc::Status& /*s*/,
505                         SimpleResponse* /*response*/) {}
506   static std::unique_ptr<
507       grpc::ClientAsyncReaderWriter<SimpleRequest, SimpleResponse>>
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,CompletionQueue * cq)508   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
509              CompletionQueue* cq) {
510     auto stream = stub->PrepareAsyncStreamingCall(ctx, cq);
511     return stream;
512   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)513   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
514                                     std::function<gpr_timespec()> next_issue,
515                                     const SimpleRequest& req) {
516     return new ClientRpcContextStreamingPingPongImpl<SimpleRequest,
517                                                      SimpleResponse>(
518         stub, req, std::move(next_issue),
519         AsyncStreamingPingPongClient::PrepareReq,
520         AsyncStreamingPingPongClient::CheckDone);
521   }
522 };
523 
524 template <class RequestType, class ResponseType>
525 class ClientRpcContextStreamingFromClientImpl : public ClientRpcContext {
526  public:
ClientRpcContextStreamingFromClientImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> (BenchmarkService::Stub *,grpc::ClientContext *,ResponseType *,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)527   ClientRpcContextStreamingFromClientImpl(
528       BenchmarkService::Stub* stub, const RequestType& req,
529       std::function<gpr_timespec()> next_issue,
530       std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
531           BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
532           CompletionQueue*)>
533           prepare_req,
534       std::function<void(grpc::Status, ResponseType*)> on_done)
535       : context_(),
536         stub_(stub),
537         cq_(nullptr),
538         req_(req),
539         response_(),
540         next_state_(State::INVALID),
541         callback_(on_done),
542         next_issue_(std::move(next_issue)),
543         prepare_req_(prepare_req) {}
~ClientRpcContextStreamingFromClientImpl()544   ~ClientRpcContextStreamingFromClientImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)545   void Start(CompletionQueue* cq, const ClientConfig& config) override {
546     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
547     StartInternal(cq);
548   }
RunNextState(bool ok,HistogramEntry * entry)549   bool RunNextState(bool ok, HistogramEntry* entry) override {
550     while (true) {
551       switch (next_state_) {
552         case State::STREAM_IDLE:
553           if (!next_issue_) {  // ready to issue
554             next_state_ = State::READY_TO_WRITE;
555           } else {
556             next_state_ = State::WAIT;
557           }
558           break;  // loop around, don't return
559         case State::WAIT:
560           alarm_ = std::make_unique<Alarm>();
561           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
562           next_state_ = State::READY_TO_WRITE;
563           return true;
564         case State::READY_TO_WRITE:
565           if (!ok) {
566             return false;
567           }
568           start_ = UsageTimer::Now();
569           next_state_ = State::WRITE_DONE;
570           stream_->Write(req_, ClientRpcContext::tag(this));
571           return true;
572         case State::WRITE_DONE:
573           if (!ok) {
574             return false;
575           }
576           entry->set_value((UsageTimer::Now() - start_) * 1e9);
577           next_state_ = State::STREAM_IDLE;
578           break;  // loop around
579         default:
580           grpc_core::Crash("unreachable");
581           return false;
582       }
583     }
584   }
StartNewClone(CompletionQueue * cq)585   void StartNewClone(CompletionQueue* cq) override {
586     auto* clone = new ClientRpcContextStreamingFromClientImpl(
587         stub_, req_, next_issue_, prepare_req_, callback_);
588     clone->StartInternal(cq);
589   }
TryCancel()590   void TryCancel() override { context_.TryCancel(); }
591 
592  private:
593   grpc::ClientContext context_;
594   BenchmarkService::Stub* stub_;
595   CompletionQueue* cq_;
596   std::unique_ptr<Alarm> alarm_;
597   const RequestType& req_;
598   ResponseType response_;
599   enum State {
600     INVALID,
601     STREAM_IDLE,
602     WAIT,
603     READY_TO_WRITE,
604     WRITE_DONE,
605   };
606   State next_state_;
607   std::function<void(grpc::Status, ResponseType*)> callback_;
608   std::function<gpr_timespec()> next_issue_;
609   std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
610       BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
611       CompletionQueue*)>
612       prepare_req_;
613   grpc::Status status_;
614   double start_;
615   std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> stream_;
616 
StartInternal(CompletionQueue * cq)617   void StartInternal(CompletionQueue* cq) {
618     cq_ = cq;
619     stream_ = prepare_req_(stub_, &context_, &response_, cq);
620     next_state_ = State::STREAM_IDLE;
621     stream_->StartCall(ClientRpcContext::tag(this));
622   }
623 };
624 
625 class AsyncStreamingFromClientClient final
626     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
627  public:
AsyncStreamingFromClientClient(const ClientConfig & config)628   explicit AsyncStreamingFromClientClient(const ClientConfig& config)
629       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
630             config, SetupCtx, BenchmarkStubCreator) {
631     StartThreads(num_async_threads_);
632   }
633 
~AsyncStreamingFromClientClient()634   ~AsyncStreamingFromClientClient() override {}
635 
636  private:
CheckDone(const grpc::Status &,SimpleResponse *)637   static void CheckDone(const grpc::Status& /*s*/,
638                         SimpleResponse* /*response*/) {}
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,SimpleResponse * resp,CompletionQueue * cq)639   static std::unique_ptr<grpc::ClientAsyncWriter<SimpleRequest>> PrepareReq(
640       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
641       SimpleResponse* resp, CompletionQueue* cq) {
642     auto stream = stub->PrepareAsyncStreamingFromClient(ctx, resp, cq);
643     return stream;
644   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)645   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
646                                     std::function<gpr_timespec()> next_issue,
647                                     const SimpleRequest& req) {
648     return new ClientRpcContextStreamingFromClientImpl<SimpleRequest,
649                                                        SimpleResponse>(
650         stub, req, std::move(next_issue),
651         AsyncStreamingFromClientClient::PrepareReq,
652         AsyncStreamingFromClientClient::CheckDone);
653   }
654 };
655 
656 template <class RequestType, class ResponseType>
657 class ClientRpcContextStreamingFromServerImpl : public ClientRpcContext {
658  public:
ClientRpcContextStreamingFromServerImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,const RequestType &,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)659   ClientRpcContextStreamingFromServerImpl(
660       BenchmarkService::Stub* stub, const RequestType& req,
661       std::function<gpr_timespec()> next_issue,
662       std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
663           BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
664           CompletionQueue*)>
665           prepare_req,
666       std::function<void(grpc::Status, ResponseType*)> on_done)
667       : context_(),
668         stub_(stub),
669         cq_(nullptr),
670         req_(req),
671         response_(),
672         next_state_(State::INVALID),
673         callback_(on_done),
674         next_issue_(std::move(next_issue)),
675         prepare_req_(prepare_req) {}
~ClientRpcContextStreamingFromServerImpl()676   ~ClientRpcContextStreamingFromServerImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)677   void Start(CompletionQueue* cq, const ClientConfig& config) override {
678     GPR_ASSERT(!config.use_coalesce_api());  // not supported
679     StartInternal(cq);
680   }
RunNextState(bool ok,HistogramEntry * entry)681   bool RunNextState(bool ok, HistogramEntry* entry) override {
682     while (true) {
683       switch (next_state_) {
684         case State::STREAM_IDLE:
685           if (!ok) {
686             return false;
687           }
688           start_ = UsageTimer::Now();
689           next_state_ = State::READ_DONE;
690           stream_->Read(&response_, ClientRpcContext::tag(this));
691           return true;
692         case State::READ_DONE:
693           if (!ok) {
694             return false;
695           }
696           entry->set_value((UsageTimer::Now() - start_) * 1e9);
697           callback_(status_, &response_);
698           next_state_ = State::STREAM_IDLE;
699           break;  // loop around
700         default:
701           grpc_core::Crash("unreachable");
702           return false;
703       }
704     }
705   }
StartNewClone(CompletionQueue * cq)706   void StartNewClone(CompletionQueue* cq) override {
707     auto* clone = new ClientRpcContextStreamingFromServerImpl(
708         stub_, req_, next_issue_, prepare_req_, callback_);
709     clone->StartInternal(cq);
710   }
TryCancel()711   void TryCancel() override { context_.TryCancel(); }
712 
713  private:
714   grpc::ClientContext context_;
715   BenchmarkService::Stub* stub_;
716   CompletionQueue* cq_;
717   std::unique_ptr<Alarm> alarm_;
718   const RequestType& req_;
719   ResponseType response_;
720   enum State { INVALID, STREAM_IDLE, READ_DONE };
721   State next_state_;
722   std::function<void(grpc::Status, ResponseType*)> callback_;
723   std::function<gpr_timespec()> next_issue_;
724   std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
725       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
726       CompletionQueue*)>
727       prepare_req_;
728   grpc::Status status_;
729   double start_;
730   std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> stream_;
731 
StartInternal(CompletionQueue * cq)732   void StartInternal(CompletionQueue* cq) {
733     // TODO(vjpai): Add support to rate-pace this
734     cq_ = cq;
735     stream_ = prepare_req_(stub_, &context_, req_, cq);
736     next_state_ = State::STREAM_IDLE;
737     stream_->StartCall(ClientRpcContext::tag(this));
738   }
739 };
740 
741 class AsyncStreamingFromServerClient final
742     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
743  public:
AsyncStreamingFromServerClient(const ClientConfig & config)744   explicit AsyncStreamingFromServerClient(const ClientConfig& config)
745       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
746             config, SetupCtx, BenchmarkStubCreator) {
747     StartThreads(num_async_threads_);
748   }
749 
~AsyncStreamingFromServerClient()750   ~AsyncStreamingFromServerClient() override {}
751 
752  private:
CheckDone(const grpc::Status &,SimpleResponse *)753   static void CheckDone(const grpc::Status& /*s*/,
754                         SimpleResponse* /*response*/) {}
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,const SimpleRequest & req,CompletionQueue * cq)755   static std::unique_ptr<grpc::ClientAsyncReader<SimpleResponse>> PrepareReq(
756       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
757       const SimpleRequest& req, CompletionQueue* cq) {
758     auto stream = stub->PrepareAsyncStreamingFromServer(ctx, req, cq);
759     return stream;
760   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)761   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
762                                     std::function<gpr_timespec()> next_issue,
763                                     const SimpleRequest& req) {
764     return new ClientRpcContextStreamingFromServerImpl<SimpleRequest,
765                                                        SimpleResponse>(
766         stub, req, std::move(next_issue),
767         AsyncStreamingFromServerClient::PrepareReq,
768         AsyncStreamingFromServerClient::CheckDone);
769   }
770 };
771 
772 class ClientRpcContextGenericStreamingImpl : public ClientRpcContext {
773  public:
ClientRpcContextGenericStreamingImpl(grpc::GenericStub * stub,const ByteBuffer & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter> (grpc::GenericStub *,grpc::ClientContext *,const std::string & method_name,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ByteBuffer *)> on_done)774   ClientRpcContextGenericStreamingImpl(
775       grpc::GenericStub* stub, const ByteBuffer& req,
776       std::function<gpr_timespec()> next_issue,
777       std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
778           grpc::GenericStub*, grpc::ClientContext*,
779           const std::string& method_name, CompletionQueue*)>
780           prepare_req,
781       std::function<void(grpc::Status, ByteBuffer*)> on_done)
782       : context_(),
783         stub_(stub),
784         cq_(nullptr),
785         req_(req),
786         response_(),
787         next_state_(State::INVALID),
788         callback_(std::move(on_done)),
789         next_issue_(std::move(next_issue)),
790         prepare_req_(std::move(prepare_req)) {}
~ClientRpcContextGenericStreamingImpl()791   ~ClientRpcContextGenericStreamingImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)792   void Start(CompletionQueue* cq, const ClientConfig& config) override {
793     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
794     StartInternal(cq, config.messages_per_stream());
795   }
RunNextState(bool ok,HistogramEntry * entry)796   bool RunNextState(bool ok, HistogramEntry* entry) override {
797     while (true) {
798       switch (next_state_) {
799         case State::STREAM_IDLE:
800           if (!next_issue_) {  // ready to issue
801             next_state_ = State::READY_TO_WRITE;
802           } else {
803             next_state_ = State::WAIT;
804           }
805           break;  // loop around, don't return
806         case State::WAIT:
807           next_state_ = State::READY_TO_WRITE;
808           alarm_ = std::make_unique<Alarm>();
809           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
810           return true;
811         case State::READY_TO_WRITE:
812           if (!ok) {
813             return false;
814           }
815           start_ = UsageTimer::Now();
816           next_state_ = State::WRITE_DONE;
817           stream_->Write(req_, ClientRpcContext::tag(this));
818           return true;
819         case State::WRITE_DONE:
820           if (!ok) {
821             return false;
822           }
823           next_state_ = State::READ_DONE;
824           stream_->Read(&response_, ClientRpcContext::tag(this));
825           return true;
826         case State::READ_DONE:
827           entry->set_value((UsageTimer::Now() - start_) * 1e9);
828           callback_(status_, &response_);
829           if ((messages_per_stream_ != 0) &&
830               (++messages_issued_ >= messages_per_stream_)) {
831             next_state_ = State::WRITES_DONE_DONE;
832             stream_->WritesDone(ClientRpcContext::tag(this));
833             return true;
834           }
835           next_state_ = State::STREAM_IDLE;
836           break;  // loop around
837         case State::WRITES_DONE_DONE:
838           next_state_ = State::FINISH_DONE;
839           stream_->Finish(&status_, ClientRpcContext::tag(this));
840           return true;
841         case State::FINISH_DONE:
842           next_state_ = State::INVALID;
843           return false;
844         default:
845           grpc_core::Crash("unreachable");
846           return false;
847       }
848     }
849   }
StartNewClone(CompletionQueue * cq)850   void StartNewClone(CompletionQueue* cq) override {
851     auto* clone = new ClientRpcContextGenericStreamingImpl(
852         stub_, req_, next_issue_, prepare_req_, callback_);
853     clone->StartInternal(cq, messages_per_stream_);
854   }
TryCancel()855   void TryCancel() override { context_.TryCancel(); }
856 
857  private:
858   grpc::ClientContext context_;
859   grpc::GenericStub* stub_;
860   CompletionQueue* cq_;
861   std::unique_ptr<Alarm> alarm_;
862   ByteBuffer req_;
863   ByteBuffer response_;
864   enum State {
865     INVALID,
866     STREAM_IDLE,
867     WAIT,
868     READY_TO_WRITE,
869     WRITE_DONE,
870     READ_DONE,
871     WRITES_DONE_DONE,
872     FINISH_DONE
873   };
874   State next_state_;
875   std::function<void(grpc::Status, ByteBuffer*)> callback_;
876   std::function<gpr_timespec()> next_issue_;
877   std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
878       grpc::GenericStub*, grpc::ClientContext*, const std::string&,
879       CompletionQueue*)>
880       prepare_req_;
881   grpc::Status status_;
882   double start_;
883   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> stream_;
884 
885   // Allow a limit on number of messages in a stream
886   int messages_per_stream_;
887   int messages_issued_;
888 
StartInternal(CompletionQueue * cq,int messages_per_stream)889   void StartInternal(CompletionQueue* cq, int messages_per_stream) {
890     cq_ = cq;
891     const std::string kMethodName(
892         "/grpc.testing.BenchmarkService/StreamingCall");
893     messages_per_stream_ = messages_per_stream;
894     messages_issued_ = 0;
895     stream_ = prepare_req_(stub_, &context_, kMethodName, cq);
896     next_state_ = State::STREAM_IDLE;
897     stream_->StartCall(ClientRpcContext::tag(this));
898   }
899 };
900 
GenericStubCreator(const std::shared_ptr<Channel> & ch)901 static std::unique_ptr<grpc::GenericStub> GenericStubCreator(
902     const std::shared_ptr<Channel>& ch) {
903   return std::make_unique<grpc::GenericStub>(ch);
904 }
905 
906 class GenericAsyncStreamingClient final
907     : public AsyncClient<grpc::GenericStub, ByteBuffer> {
908  public:
GenericAsyncStreamingClient(const ClientConfig & config)909   explicit GenericAsyncStreamingClient(const ClientConfig& config)
910       : AsyncClient<grpc::GenericStub, ByteBuffer>(config, SetupCtx,
911                                                    GenericStubCreator) {
912     StartThreads(num_async_threads_);
913   }
914 
~GenericAsyncStreamingClient()915   ~GenericAsyncStreamingClient() override {}
916 
917  private:
CheckDone(const grpc::Status &,ByteBuffer *)918   static void CheckDone(const grpc::Status& /*s*/, ByteBuffer* /*response*/) {}
PrepareReq(grpc::GenericStub * stub,grpc::ClientContext * ctx,const std::string & method_name,CompletionQueue * cq)919   static std::unique_ptr<grpc::GenericClientAsyncReaderWriter> PrepareReq(
920       grpc::GenericStub* stub, grpc::ClientContext* ctx,
921       const std::string& method_name, CompletionQueue* cq) {
922     auto stream = stub->PrepareCall(ctx, method_name, cq);
923     return stream;
924   };
SetupCtx(grpc::GenericStub * stub,std::function<gpr_timespec ()> next_issue,const ByteBuffer & req)925   static ClientRpcContext* SetupCtx(grpc::GenericStub* stub,
926                                     std::function<gpr_timespec()> next_issue,
927                                     const ByteBuffer& req) {
928     return new ClientRpcContextGenericStreamingImpl(
929         stub, req, std::move(next_issue),
930         GenericAsyncStreamingClient::PrepareReq,
931         GenericAsyncStreamingClient::CheckDone);
932   }
933 };
934 
CreateAsyncClient(const ClientConfig & config)935 std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& config) {
936   switch (config.rpc_type()) {
937     case UNARY:
938       return std::unique_ptr<Client>(new AsyncUnaryClient(config));
939     case STREAMING:
940       return std::unique_ptr<Client>(new AsyncStreamingPingPongClient(config));
941     case STREAMING_FROM_CLIENT:
942       return std::unique_ptr<Client>(
943           new AsyncStreamingFromClientClient(config));
944     case STREAMING_FROM_SERVER:
945       return std::unique_ptr<Client>(
946           new AsyncStreamingFromServerClient(config));
947     case STREAMING_BOTH_WAYS:
948       // TODO(vjpai): Implement this
949       assert(false);
950       return nullptr;
951     default:
952       assert(false);
953       return nullptr;
954   }
955 }
CreateGenericAsyncStreamingClient(const ClientConfig & config)956 std::unique_ptr<Client> CreateGenericAsyncStreamingClient(
957     const ClientConfig& config) {
958   return std::unique_ptr<Client>(new GenericAsyncStreamingClient(config));
959 }
960 
961 }  // namespace testing
962 }  // namespace grpc
963