xref: /aosp_15_r20/external/openscreen/osp/public/request_response_handler.h (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1*3f982cf4SFabien Sanglard // Copyright 2019 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard 
5*3f982cf4SFabien Sanglard #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
6*3f982cf4SFabien Sanglard #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
7*3f982cf4SFabien Sanglard 
8*3f982cf4SFabien Sanglard #include <cstddef>
9*3f982cf4SFabien Sanglard #include <cstdint>
10*3f982cf4SFabien Sanglard #include <type_traits>
11*3f982cf4SFabien Sanglard #include <utility>
12*3f982cf4SFabien Sanglard #include <vector>
13*3f982cf4SFabien Sanglard 
14*3f982cf4SFabien Sanglard #include "absl/types/optional.h"
15*3f982cf4SFabien Sanglard #include "osp/public/message_demuxer.h"
16*3f982cf4SFabien Sanglard #include "osp/public/network_service_manager.h"
17*3f982cf4SFabien Sanglard #include "osp/public/protocol_connection.h"
18*3f982cf4SFabien Sanglard #include "platform/base/error.h"
19*3f982cf4SFabien Sanglard #include "platform/base/macros.h"
20*3f982cf4SFabien Sanglard #include "util/osp_logging.h"
21*3f982cf4SFabien Sanglard 
22*3f982cf4SFabien Sanglard namespace openscreen {
23*3f982cf4SFabien Sanglard namespace osp {
24*3f982cf4SFabien Sanglard 
25*3f982cf4SFabien Sanglard template <typename T>
26*3f982cf4SFabien Sanglard using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*);
27*3f982cf4SFabien Sanglard 
28*3f982cf4SFabien Sanglard // Provides a uniform way of accessing import properties of a request/response
29*3f982cf4SFabien Sanglard // message pair from a template: request encode function, response decode
30*3f982cf4SFabien Sanglard // function, request serializable data member.
31*3f982cf4SFabien Sanglard template <typename T>
32*3f982cf4SFabien Sanglard struct DefaultRequestCoderTraits {
33*3f982cf4SFabien Sanglard  public:
34*3f982cf4SFabien Sanglard   using RequestMsgType = typename T::RequestMsgType;
35*3f982cf4SFabien Sanglard   static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =
36*3f982cf4SFabien Sanglard       T::kEncoder;
37*3f982cf4SFabien Sanglard   static constexpr MessageDecodingFunction<typename T::ResponseMsgType>
38*3f982cf4SFabien Sanglard       kDecoder = T::kDecoder;
39*3f982cf4SFabien Sanglard 
serial_requestDefaultRequestCoderTraits40*3f982cf4SFabien Sanglard   static const RequestMsgType* serial_request(const T& data) {
41*3f982cf4SFabien Sanglard     return &data.request;
42*3f982cf4SFabien Sanglard   }
serial_requestDefaultRequestCoderTraits43*3f982cf4SFabien Sanglard   static RequestMsgType* serial_request(T& data) { return &data.request; }
44*3f982cf4SFabien Sanglard };
45*3f982cf4SFabien Sanglard 
46*3f982cf4SFabien Sanglard // Provides a wrapper for the common pattern of sending a request message and
47*3f982cf4SFabien Sanglard // waiting for a response message with a matching |request_id| field.  It also
48*3f982cf4SFabien Sanglard // handles the business of queueing messages to be sent until a protocol
49*3f982cf4SFabien Sanglard // connection is available.
50*3f982cf4SFabien Sanglard //
51*3f982cf4SFabien Sanglard // Messages are written using WriteMessage.  This will queue messages if there
52*3f982cf4SFabien Sanglard // is no protocol connection or write them immediately if there is.  When a
53*3f982cf4SFabien Sanglard // matching response is received via the MessageDemuxer (taken from the global
54*3f982cf4SFabien Sanglard // ProtocolConnectionClient), OnMatchedResponse is called on the provided
55*3f982cf4SFabien Sanglard // Delegate object along with the original request that it matches.
56*3f982cf4SFabien Sanglard template <typename RequestT,
57*3f982cf4SFabien Sanglard           typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>>
58*3f982cf4SFabien Sanglard class RequestResponseHandler : public MessageDemuxer::MessageCallback {
59*3f982cf4SFabien Sanglard  public:
60*3f982cf4SFabien Sanglard   class Delegate {
61*3f982cf4SFabien Sanglard    public:
62*3f982cf4SFabien Sanglard 
63*3f982cf4SFabien Sanglard     virtual void OnMatchedResponse(RequestT* request,
64*3f982cf4SFabien Sanglard                                    typename RequestT::ResponseMsgType* response,
65*3f982cf4SFabien Sanglard                                    uint64_t endpoint_id) = 0;
66*3f982cf4SFabien Sanglard     virtual void OnError(RequestT* request, Error error) = 0;
67*3f982cf4SFabien Sanglard 
68*3f982cf4SFabien Sanglard    protected:
69*3f982cf4SFabien Sanglard     virtual ~Delegate() = default;
70*3f982cf4SFabien Sanglard   };
71*3f982cf4SFabien Sanglard 
RequestResponseHandler(Delegate * delegate)72*3f982cf4SFabien Sanglard   explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {}
~RequestResponseHandler()73*3f982cf4SFabien Sanglard   ~RequestResponseHandler() { Reset(); }
74*3f982cf4SFabien Sanglard 
Reset()75*3f982cf4SFabien Sanglard   void Reset() {
76*3f982cf4SFabien Sanglard     connection_ = nullptr;
77*3f982cf4SFabien Sanglard     for (auto& message : to_send_) {
78*3f982cf4SFabien Sanglard       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
79*3f982cf4SFabien Sanglard     }
80*3f982cf4SFabien Sanglard     to_send_.clear();
81*3f982cf4SFabien Sanglard     for (auto& message : sent_) {
82*3f982cf4SFabien Sanglard       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
83*3f982cf4SFabien Sanglard     }
84*3f982cf4SFabien Sanglard     sent_.clear();
85*3f982cf4SFabien Sanglard     response_watch_ = MessageDemuxer::MessageWatch();
86*3f982cf4SFabien Sanglard   }
87*3f982cf4SFabien Sanglard 
88*3f982cf4SFabien Sanglard   // Write a message to the underlying protocol connection, or queue it until
89*3f982cf4SFabien Sanglard   // one is provided via SetConnection.  If |id| is provided, it can be used to
90*3f982cf4SFabien Sanglard   // cancel the message via CancelMessage.
91*3f982cf4SFabien Sanglard   template <typename RequestTRval>
92*3f982cf4SFabien Sanglard   typename std::enable_if<
93*3f982cf4SFabien Sanglard       !std::is_lvalue_reference<RequestTRval>::value &&
94*3f982cf4SFabien Sanglard           std::is_same<typename std::decay<RequestTRval>::type,
95*3f982cf4SFabien Sanglard                        RequestT>::value,
96*3f982cf4SFabien Sanglard       Error>::type
WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)97*3f982cf4SFabien Sanglard   WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) {
98*3f982cf4SFabien Sanglard     auto* request_msg = RequestCoderTraits::serial_request(message);
99*3f982cf4SFabien Sanglard     if (connection_) {
100*3f982cf4SFabien Sanglard       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
101*3f982cf4SFabien Sanglard       Error result =
102*3f982cf4SFabien Sanglard           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
103*3f982cf4SFabien Sanglard       if (!result.ok()) {
104*3f982cf4SFabien Sanglard         return result;
105*3f982cf4SFabien Sanglard       }
106*3f982cf4SFabien Sanglard       sent_.emplace_back(RequestWithId{id, std::move(message)});
107*3f982cf4SFabien Sanglard       EnsureResponseWatch();
108*3f982cf4SFabien Sanglard     } else {
109*3f982cf4SFabien Sanglard       to_send_.emplace_back(RequestWithId{id, std::move(message)});
110*3f982cf4SFabien Sanglard     }
111*3f982cf4SFabien Sanglard     return Error::None();
112*3f982cf4SFabien Sanglard   }
113*3f982cf4SFabien Sanglard 
114*3f982cf4SFabien Sanglard   template <typename RequestTRval>
115*3f982cf4SFabien Sanglard   typename std::enable_if<
116*3f982cf4SFabien Sanglard       !std::is_lvalue_reference<RequestTRval>::value &&
117*3f982cf4SFabien Sanglard           std::is_same<typename std::decay<RequestTRval>::type,
118*3f982cf4SFabien Sanglard                        RequestT>::value,
119*3f982cf4SFabien Sanglard       Error>::type
WriteMessage(RequestTRval && message)120*3f982cf4SFabien Sanglard   WriteMessage(RequestTRval&& message) {
121*3f982cf4SFabien Sanglard     return WriteMessage(absl::nullopt, std::move(message));
122*3f982cf4SFabien Sanglard   }
123*3f982cf4SFabien Sanglard 
124*3f982cf4SFabien Sanglard   // Remove the message that was originally written with |id| from the send and
125*3f982cf4SFabien Sanglard   // sent queues so that we are no longer looking for a response.
CancelMessage(uint64_t id)126*3f982cf4SFabien Sanglard   void CancelMessage(uint64_t id) {
127*3f982cf4SFabien Sanglard     to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(),
128*3f982cf4SFabien Sanglard                                   [&id](const RequestWithId& msg) {
129*3f982cf4SFabien Sanglard                                     return id == msg.id;
130*3f982cf4SFabien Sanglard                                   }),
131*3f982cf4SFabien Sanglard                    to_send_.end());
132*3f982cf4SFabien Sanglard     sent_.erase(std::remove_if(
133*3f982cf4SFabien Sanglard                     sent_.begin(), sent_.end(),
134*3f982cf4SFabien Sanglard                     [&id](const RequestWithId& msg) { return id == msg.id; }),
135*3f982cf4SFabien Sanglard                 sent_.end());
136*3f982cf4SFabien Sanglard     if (sent_.empty()) {
137*3f982cf4SFabien Sanglard       response_watch_ = MessageDemuxer::MessageWatch();
138*3f982cf4SFabien Sanglard     }
139*3f982cf4SFabien Sanglard   }
140*3f982cf4SFabien Sanglard 
141*3f982cf4SFabien Sanglard   // Assign a ProtocolConnection to this handler for writing messages.
SetConnection(ProtocolConnection * connection)142*3f982cf4SFabien Sanglard   void SetConnection(ProtocolConnection* connection) {
143*3f982cf4SFabien Sanglard     connection_ = connection;
144*3f982cf4SFabien Sanglard     for (auto& message : to_send_) {
145*3f982cf4SFabien Sanglard       auto* request_msg = RequestCoderTraits::serial_request(message.request);
146*3f982cf4SFabien Sanglard       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
147*3f982cf4SFabien Sanglard       Error result =
148*3f982cf4SFabien Sanglard           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
149*3f982cf4SFabien Sanglard       if (result.ok()) {
150*3f982cf4SFabien Sanglard         sent_.emplace_back(std::move(message));
151*3f982cf4SFabien Sanglard       } else {
152*3f982cf4SFabien Sanglard         delegate_->OnError(&message.request, result);
153*3f982cf4SFabien Sanglard       }
154*3f982cf4SFabien Sanglard     }
155*3f982cf4SFabien Sanglard     if (!to_send_.empty()) {
156*3f982cf4SFabien Sanglard       EnsureResponseWatch();
157*3f982cf4SFabien Sanglard     }
158*3f982cf4SFabien Sanglard     to_send_.clear();
159*3f982cf4SFabien Sanglard   }
160*3f982cf4SFabien Sanglard 
161*3f982cf4SFabien Sanglard   // MessageDemuxer::MessageCallback overrides.
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)162*3f982cf4SFabien Sanglard   ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
163*3f982cf4SFabien Sanglard                                   uint64_t connection_id,
164*3f982cf4SFabien Sanglard                                   msgs::Type message_type,
165*3f982cf4SFabien Sanglard                                   const uint8_t* buffer,
166*3f982cf4SFabien Sanglard                                   size_t buffer_size,
167*3f982cf4SFabien Sanglard                                   Clock::time_point now) override {
168*3f982cf4SFabien Sanglard     if (message_type != RequestT::kResponseType) {
169*3f982cf4SFabien Sanglard       return 0;
170*3f982cf4SFabien Sanglard     }
171*3f982cf4SFabien Sanglard     typename RequestT::ResponseMsgType response;
172*3f982cf4SFabien Sanglard     ssize_t result =
173*3f982cf4SFabien Sanglard         RequestCoderTraits::kDecoder(buffer, buffer_size, &response);
174*3f982cf4SFabien Sanglard     if (result < 0) {
175*3f982cf4SFabien Sanglard       return 0;
176*3f982cf4SFabien Sanglard     }
177*3f982cf4SFabien Sanglard     auto it = std::find_if(
178*3f982cf4SFabien Sanglard         sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) {
179*3f982cf4SFabien Sanglard           return RequestCoderTraits::serial_request(msg.request)->request_id ==
180*3f982cf4SFabien Sanglard                  response.request_id;
181*3f982cf4SFabien Sanglard         });
182*3f982cf4SFabien Sanglard     if (it != sent_.end()) {
183*3f982cf4SFabien Sanglard       delegate_->OnMatchedResponse(&it->request, &response,
184*3f982cf4SFabien Sanglard                                    connection_->endpoint_id());
185*3f982cf4SFabien Sanglard       sent_.erase(it);
186*3f982cf4SFabien Sanglard       if (sent_.empty()) {
187*3f982cf4SFabien Sanglard         response_watch_ = MessageDemuxer::MessageWatch();
188*3f982cf4SFabien Sanglard       }
189*3f982cf4SFabien Sanglard     } else {
190*3f982cf4SFabien Sanglard       OSP_LOG_WARN << "got response for unknown request id: "
191*3f982cf4SFabien Sanglard                    << response.request_id;
192*3f982cf4SFabien Sanglard     }
193*3f982cf4SFabien Sanglard     return result;
194*3f982cf4SFabien Sanglard   }
195*3f982cf4SFabien Sanglard 
196*3f982cf4SFabien Sanglard  private:
197*3f982cf4SFabien Sanglard   struct RequestWithId {
198*3f982cf4SFabien Sanglard     absl::optional<uint64_t> id;
199*3f982cf4SFabien Sanglard     RequestT request;
200*3f982cf4SFabien Sanglard   };
201*3f982cf4SFabien Sanglard 
EnsureResponseWatch()202*3f982cf4SFabien Sanglard   void EnsureResponseWatch() {
203*3f982cf4SFabien Sanglard     if (!response_watch_) {
204*3f982cf4SFabien Sanglard       response_watch_ = NetworkServiceManager::Get()
205*3f982cf4SFabien Sanglard                             ->GetProtocolConnectionClient()
206*3f982cf4SFabien Sanglard                             ->message_demuxer()
207*3f982cf4SFabien Sanglard                             ->WatchMessageType(connection_->endpoint_id(),
208*3f982cf4SFabien Sanglard                                                RequestT::kResponseType, this);
209*3f982cf4SFabien Sanglard     }
210*3f982cf4SFabien Sanglard   }
211*3f982cf4SFabien Sanglard 
GetNextRequestId(uint64_t endpoint_id)212*3f982cf4SFabien Sanglard   uint64_t GetNextRequestId(uint64_t endpoint_id) {
213*3f982cf4SFabien Sanglard     return NetworkServiceManager::Get()
214*3f982cf4SFabien Sanglard         ->GetProtocolConnectionClient()
215*3f982cf4SFabien Sanglard         ->endpoint_request_ids()
216*3f982cf4SFabien Sanglard         ->GetNextRequestId(endpoint_id);
217*3f982cf4SFabien Sanglard   }
218*3f982cf4SFabien Sanglard 
219*3f982cf4SFabien Sanglard   ProtocolConnection* connection_ = nullptr;
220*3f982cf4SFabien Sanglard   Delegate* const delegate_;
221*3f982cf4SFabien Sanglard   std::vector<RequestWithId> to_send_;
222*3f982cf4SFabien Sanglard   std::vector<RequestWithId> sent_;
223*3f982cf4SFabien Sanglard   MessageDemuxer::MessageWatch response_watch_;
224*3f982cf4SFabien Sanglard 
225*3f982cf4SFabien Sanglard   OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler);
226*3f982cf4SFabien Sanglard };
227*3f982cf4SFabien Sanglard 
228*3f982cf4SFabien Sanglard }  // namespace osp
229*3f982cf4SFabien Sanglard }  // namespace openscreen
230*3f982cf4SFabien Sanglard 
231*3f982cf4SFabien Sanglard #endif  // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
232