1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <grpc/support/port_platform.h>
16 
17 #include "src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h"
18 
19 #ifndef GRPC_NO_BINDER
20 
21 #include <functional>
22 #include <string>
23 #include <utility>
24 
25 #include <grpc/support/log.h>
26 
27 #include "src/core/lib/gprpp/crash.h"
28 
29 namespace grpc_binder {
30 
31 const absl::string_view
32     TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully =
33         "grpc-binder-transport: cancelled gracefully";
34 
RegisterRecvInitialMetadata(StreamIdentifier id,InitialMetadataCallbackType cb)35 void TransportStreamReceiverImpl::RegisterRecvInitialMetadata(
36     StreamIdentifier id, InitialMetadataCallbackType cb) {
37   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
38   absl::StatusOr<Metadata> initial_metadata{};
39   {
40     grpc_core::MutexLock l(&m_);
41     GPR_ASSERT(initial_metadata_cbs_.count(id) == 0);
42     auto iter = pending_initial_metadata_.find(id);
43     if (iter == pending_initial_metadata_.end()) {
44       if (trailing_metadata_recvd_.count(id)) {
45         cb(absl::CancelledError(""));
46       } else {
47         initial_metadata_cbs_[id] = std::move(cb);
48       }
49       cb = nullptr;
50     } else {
51       initial_metadata = std::move(iter->second.front());
52       iter->second.pop();
53       if (iter->second.empty()) {
54         pending_initial_metadata_.erase(iter);
55       }
56     }
57   }
58   if (cb != nullptr) {
59     cb(std::move(initial_metadata));
60   }
61 }
62 
RegisterRecvMessage(StreamIdentifier id,MessageDataCallbackType cb)63 void TransportStreamReceiverImpl::RegisterRecvMessage(
64     StreamIdentifier id, MessageDataCallbackType cb) {
65   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
66   absl::StatusOr<std::string> message{};
67   {
68     grpc_core::MutexLock l(&m_);
69     GPR_ASSERT(message_cbs_.count(id) == 0);
70     auto iter = pending_message_.find(id);
71     if (iter == pending_message_.end()) {
72       // If we'd already received trailing-metadata and there's no pending
73       // messages, cancel the callback.
74       if (trailing_metadata_recvd_.count(id)) {
75         cb(absl::CancelledError(
76             TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully));
77       } else {
78         message_cbs_[id] = std::move(cb);
79       }
80       cb = nullptr;
81     } else {
82       // We'll still keep all pending messages received before the trailing
83       // metadata since they're issued before the end of stream, as promised by
84       // WireReader which keeps transactions commit in-order.
85       message = std::move(iter->second.front());
86       iter->second.pop();
87       if (iter->second.empty()) {
88         pending_message_.erase(iter);
89       }
90     }
91   }
92   if (cb != nullptr) {
93     cb(std::move(message));
94   }
95 }
96 
RegisterRecvTrailingMetadata(StreamIdentifier id,TrailingMetadataCallbackType cb)97 void TransportStreamReceiverImpl::RegisterRecvTrailingMetadata(
98     StreamIdentifier id, TrailingMetadataCallbackType cb) {
99   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
100   std::pair<absl::StatusOr<Metadata>, int> trailing_metadata{};
101   {
102     grpc_core::MutexLock l(&m_);
103     GPR_ASSERT(trailing_metadata_cbs_.count(id) == 0);
104     auto iter = pending_trailing_metadata_.find(id);
105     if (iter == pending_trailing_metadata_.end()) {
106       trailing_metadata_cbs_[id] = std::move(cb);
107       cb = nullptr;
108     } else {
109       trailing_metadata = std::move(iter->second.front());
110       iter->second.pop();
111       if (iter->second.empty()) {
112         pending_trailing_metadata_.erase(iter);
113       }
114     }
115   }
116   if (cb != nullptr) {
117     cb(std::move(trailing_metadata.first), trailing_metadata.second);
118   }
119 }
120 
NotifyRecvInitialMetadata(StreamIdentifier id,absl::StatusOr<Metadata> initial_metadata)121 void TransportStreamReceiverImpl::NotifyRecvInitialMetadata(
122     StreamIdentifier id, absl::StatusOr<Metadata> initial_metadata) {
123   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
124   if (!is_client_ && accept_stream_callback_ && initial_metadata.ok()) {
125     accept_stream_callback_();
126   }
127   InitialMetadataCallbackType cb;
128   {
129     grpc_core::MutexLock l(&m_);
130     auto iter = initial_metadata_cbs_.find(id);
131     if (iter != initial_metadata_cbs_.end()) {
132       cb = iter->second;
133       initial_metadata_cbs_.erase(iter);
134     } else {
135       pending_initial_metadata_[id].push(std::move(initial_metadata));
136       return;
137     }
138   }
139   cb(std::move(initial_metadata));
140 }
141 
NotifyRecvMessage(StreamIdentifier id,absl::StatusOr<std::string> message)142 void TransportStreamReceiverImpl::NotifyRecvMessage(
143     StreamIdentifier id, absl::StatusOr<std::string> message) {
144   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
145   MessageDataCallbackType cb;
146   {
147     grpc_core::MutexLock l(&m_);
148     auto iter = message_cbs_.find(id);
149     if (iter != message_cbs_.end()) {
150       cb = iter->second;
151       message_cbs_.erase(iter);
152     } else {
153       pending_message_[id].push(std::move(message));
154       return;
155     }
156   }
157   cb(std::move(message));
158 }
159 
NotifyRecvTrailingMetadata(StreamIdentifier id,absl::StatusOr<Metadata> trailing_metadata,int status)160 void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata(
161     StreamIdentifier id, absl::StatusOr<Metadata> trailing_metadata,
162     int status) {
163   // Trailing metadata mark the end of the stream. Since TransportStreamReceiver
164   // assumes in-order commitments of transactions and that trailing metadata is
165   // parsed after message data, we can safely cancel all upcoming callbacks of
166   // recv_message.
167   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
168   OnRecvTrailingMetadata(id);
169   TrailingMetadataCallbackType cb;
170   {
171     grpc_core::MutexLock l(&m_);
172     auto iter = trailing_metadata_cbs_.find(id);
173     if (iter != trailing_metadata_cbs_.end()) {
174       cb = iter->second;
175       trailing_metadata_cbs_.erase(iter);
176     } else {
177       pending_trailing_metadata_[id].emplace(std::move(trailing_metadata),
178                                              status);
179       return;
180     }
181   }
182   cb(std::move(trailing_metadata), status);
183 }
184 
CancelInitialMetadataCallback(StreamIdentifier id,absl::Status error)185 void TransportStreamReceiverImpl::CancelInitialMetadataCallback(
186     StreamIdentifier id, absl::Status error) {
187   InitialMetadataCallbackType callback = nullptr;
188   {
189     grpc_core::MutexLock l(&m_);
190     auto iter = initial_metadata_cbs_.find(id);
191     if (iter != initial_metadata_cbs_.end()) {
192       callback = std::move(iter->second);
193       initial_metadata_cbs_.erase(iter);
194     }
195   }
196   if (callback != nullptr) {
197     std::move(callback)(error);
198   }
199 }
200 
CancelMessageCallback(StreamIdentifier id,absl::Status error)201 void TransportStreamReceiverImpl::CancelMessageCallback(StreamIdentifier id,
202                                                         absl::Status error) {
203   MessageDataCallbackType callback = nullptr;
204   {
205     grpc_core::MutexLock l(&m_);
206     auto iter = message_cbs_.find(id);
207     if (iter != message_cbs_.end()) {
208       callback = std::move(iter->second);
209       message_cbs_.erase(iter);
210     }
211   }
212   if (callback != nullptr) {
213     std::move(callback)(error);
214   }
215 }
216 
CancelTrailingMetadataCallback(StreamIdentifier id,absl::Status error)217 void TransportStreamReceiverImpl::CancelTrailingMetadataCallback(
218     StreamIdentifier id, absl::Status error) {
219   TrailingMetadataCallbackType callback = nullptr;
220   {
221     grpc_core::MutexLock l(&m_);
222     auto iter = trailing_metadata_cbs_.find(id);
223     if (iter != trailing_metadata_cbs_.end()) {
224       callback = std::move(iter->second);
225       trailing_metadata_cbs_.erase(iter);
226     }
227   }
228   if (callback != nullptr) {
229     std::move(callback)(error, 0);
230   }
231 }
232 
OnRecvTrailingMetadata(StreamIdentifier id)233 void TransportStreamReceiverImpl::OnRecvTrailingMetadata(StreamIdentifier id) {
234   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
235   m_.Lock();
236   trailing_metadata_recvd_.insert(id);
237   m_.Unlock();
238   CancelInitialMetadataCallback(id, absl::CancelledError(""));
239   CancelMessageCallback(
240       id,
241       absl::CancelledError(
242           TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully));
243 }
244 
CancelStream(StreamIdentifier id)245 void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) {
246   gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
247   CancelInitialMetadataCallback(id, absl::CancelledError("Stream cancelled"));
248   CancelMessageCallback(id, absl::CancelledError("Stream cancelled"));
249   CancelTrailingMetadataCallback(id, absl::CancelledError("Stream cancelled"));
250   grpc_core::MutexLock l(&m_);
251   trailing_metadata_recvd_.erase(id);
252   pending_initial_metadata_.erase(id);
253   pending_message_.erase(id);
254   pending_trailing_metadata_.erase(id);
255 }
256 }  // namespace grpc_binder
257 #endif
258