1*3f982cf4SFabien Sanglard // Copyright 2019 The Chromium Authors. All rights reserved. 2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be 3*3f982cf4SFabien Sanglard // found in the LICENSE file. 4*3f982cf4SFabien Sanglard 5*3f982cf4SFabien Sanglard #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 6*3f982cf4SFabien Sanglard #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 7*3f982cf4SFabien Sanglard 8*3f982cf4SFabien Sanglard #include <cstddef> 9*3f982cf4SFabien Sanglard #include <cstdint> 10*3f982cf4SFabien Sanglard #include <type_traits> 11*3f982cf4SFabien Sanglard #include <utility> 12*3f982cf4SFabien Sanglard #include <vector> 13*3f982cf4SFabien Sanglard 14*3f982cf4SFabien Sanglard #include "absl/types/optional.h" 15*3f982cf4SFabien Sanglard #include "osp/public/message_demuxer.h" 16*3f982cf4SFabien Sanglard #include "osp/public/network_service_manager.h" 17*3f982cf4SFabien Sanglard #include "osp/public/protocol_connection.h" 18*3f982cf4SFabien Sanglard #include "platform/base/error.h" 19*3f982cf4SFabien Sanglard #include "platform/base/macros.h" 20*3f982cf4SFabien Sanglard #include "util/osp_logging.h" 21*3f982cf4SFabien Sanglard 22*3f982cf4SFabien Sanglard namespace openscreen { 23*3f982cf4SFabien Sanglard namespace osp { 24*3f982cf4SFabien Sanglard 25*3f982cf4SFabien Sanglard template <typename T> 26*3f982cf4SFabien Sanglard using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); 27*3f982cf4SFabien Sanglard 28*3f982cf4SFabien Sanglard // Provides a uniform way of accessing import properties of a request/response 29*3f982cf4SFabien Sanglard // message pair from a template: request encode function, response decode 30*3f982cf4SFabien Sanglard // function, request serializable data member. 31*3f982cf4SFabien Sanglard template <typename T> 32*3f982cf4SFabien Sanglard struct DefaultRequestCoderTraits { 33*3f982cf4SFabien Sanglard public: 34*3f982cf4SFabien Sanglard using RequestMsgType = typename T::RequestMsgType; 35*3f982cf4SFabien Sanglard static constexpr MessageEncodingFunction<RequestMsgType> kEncoder = 36*3f982cf4SFabien Sanglard T::kEncoder; 37*3f982cf4SFabien Sanglard static constexpr MessageDecodingFunction<typename T::ResponseMsgType> 38*3f982cf4SFabien Sanglard kDecoder = T::kDecoder; 39*3f982cf4SFabien Sanglard serial_requestDefaultRequestCoderTraits40*3f982cf4SFabien Sanglard static const RequestMsgType* serial_request(const T& data) { 41*3f982cf4SFabien Sanglard return &data.request; 42*3f982cf4SFabien Sanglard } serial_requestDefaultRequestCoderTraits43*3f982cf4SFabien Sanglard static RequestMsgType* serial_request(T& data) { return &data.request; } 44*3f982cf4SFabien Sanglard }; 45*3f982cf4SFabien Sanglard 46*3f982cf4SFabien Sanglard // Provides a wrapper for the common pattern of sending a request message and 47*3f982cf4SFabien Sanglard // waiting for a response message with a matching |request_id| field. It also 48*3f982cf4SFabien Sanglard // handles the business of queueing messages to be sent until a protocol 49*3f982cf4SFabien Sanglard // connection is available. 50*3f982cf4SFabien Sanglard // 51*3f982cf4SFabien Sanglard // Messages are written using WriteMessage. This will queue messages if there 52*3f982cf4SFabien Sanglard // is no protocol connection or write them immediately if there is. When a 53*3f982cf4SFabien Sanglard // matching response is received via the MessageDemuxer (taken from the global 54*3f982cf4SFabien Sanglard // ProtocolConnectionClient), OnMatchedResponse is called on the provided 55*3f982cf4SFabien Sanglard // Delegate object along with the original request that it matches. 56*3f982cf4SFabien Sanglard template <typename RequestT, 57*3f982cf4SFabien Sanglard typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>> 58*3f982cf4SFabien Sanglard class RequestResponseHandler : public MessageDemuxer::MessageCallback { 59*3f982cf4SFabien Sanglard public: 60*3f982cf4SFabien Sanglard class Delegate { 61*3f982cf4SFabien Sanglard public: 62*3f982cf4SFabien Sanglard 63*3f982cf4SFabien Sanglard virtual void OnMatchedResponse(RequestT* request, 64*3f982cf4SFabien Sanglard typename RequestT::ResponseMsgType* response, 65*3f982cf4SFabien Sanglard uint64_t endpoint_id) = 0; 66*3f982cf4SFabien Sanglard virtual void OnError(RequestT* request, Error error) = 0; 67*3f982cf4SFabien Sanglard 68*3f982cf4SFabien Sanglard protected: 69*3f982cf4SFabien Sanglard virtual ~Delegate() = default; 70*3f982cf4SFabien Sanglard }; 71*3f982cf4SFabien Sanglard RequestResponseHandler(Delegate * delegate)72*3f982cf4SFabien Sanglard explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} ~RequestResponseHandler()73*3f982cf4SFabien Sanglard ~RequestResponseHandler() { Reset(); } 74*3f982cf4SFabien Sanglard Reset()75*3f982cf4SFabien Sanglard void Reset() { 76*3f982cf4SFabien Sanglard connection_ = nullptr; 77*3f982cf4SFabien Sanglard for (auto& message : to_send_) { 78*3f982cf4SFabien Sanglard delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 79*3f982cf4SFabien Sanglard } 80*3f982cf4SFabien Sanglard to_send_.clear(); 81*3f982cf4SFabien Sanglard for (auto& message : sent_) { 82*3f982cf4SFabien Sanglard delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 83*3f982cf4SFabien Sanglard } 84*3f982cf4SFabien Sanglard sent_.clear(); 85*3f982cf4SFabien Sanglard response_watch_ = MessageDemuxer::MessageWatch(); 86*3f982cf4SFabien Sanglard } 87*3f982cf4SFabien Sanglard 88*3f982cf4SFabien Sanglard // Write a message to the underlying protocol connection, or queue it until 89*3f982cf4SFabien Sanglard // one is provided via SetConnection. If |id| is provided, it can be used to 90*3f982cf4SFabien Sanglard // cancel the message via CancelMessage. 91*3f982cf4SFabien Sanglard template <typename RequestTRval> 92*3f982cf4SFabien Sanglard typename std::enable_if< 93*3f982cf4SFabien Sanglard !std::is_lvalue_reference<RequestTRval>::value && 94*3f982cf4SFabien Sanglard std::is_same<typename std::decay<RequestTRval>::type, 95*3f982cf4SFabien Sanglard RequestT>::value, 96*3f982cf4SFabien Sanglard Error>::type WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)97*3f982cf4SFabien Sanglard WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) { 98*3f982cf4SFabien Sanglard auto* request_msg = RequestCoderTraits::serial_request(message); 99*3f982cf4SFabien Sanglard if (connection_) { 100*3f982cf4SFabien Sanglard request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 101*3f982cf4SFabien Sanglard Error result = 102*3f982cf4SFabien Sanglard connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 103*3f982cf4SFabien Sanglard if (!result.ok()) { 104*3f982cf4SFabien Sanglard return result; 105*3f982cf4SFabien Sanglard } 106*3f982cf4SFabien Sanglard sent_.emplace_back(RequestWithId{id, std::move(message)}); 107*3f982cf4SFabien Sanglard EnsureResponseWatch(); 108*3f982cf4SFabien Sanglard } else { 109*3f982cf4SFabien Sanglard to_send_.emplace_back(RequestWithId{id, std::move(message)}); 110*3f982cf4SFabien Sanglard } 111*3f982cf4SFabien Sanglard return Error::None(); 112*3f982cf4SFabien Sanglard } 113*3f982cf4SFabien Sanglard 114*3f982cf4SFabien Sanglard template <typename RequestTRval> 115*3f982cf4SFabien Sanglard typename std::enable_if< 116*3f982cf4SFabien Sanglard !std::is_lvalue_reference<RequestTRval>::value && 117*3f982cf4SFabien Sanglard std::is_same<typename std::decay<RequestTRval>::type, 118*3f982cf4SFabien Sanglard RequestT>::value, 119*3f982cf4SFabien Sanglard Error>::type WriteMessage(RequestTRval && message)120*3f982cf4SFabien Sanglard WriteMessage(RequestTRval&& message) { 121*3f982cf4SFabien Sanglard return WriteMessage(absl::nullopt, std::move(message)); 122*3f982cf4SFabien Sanglard } 123*3f982cf4SFabien Sanglard 124*3f982cf4SFabien Sanglard // Remove the message that was originally written with |id| from the send and 125*3f982cf4SFabien Sanglard // sent queues so that we are no longer looking for a response. CancelMessage(uint64_t id)126*3f982cf4SFabien Sanglard void CancelMessage(uint64_t id) { 127*3f982cf4SFabien Sanglard to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), 128*3f982cf4SFabien Sanglard [&id](const RequestWithId& msg) { 129*3f982cf4SFabien Sanglard return id == msg.id; 130*3f982cf4SFabien Sanglard }), 131*3f982cf4SFabien Sanglard to_send_.end()); 132*3f982cf4SFabien Sanglard sent_.erase(std::remove_if( 133*3f982cf4SFabien Sanglard sent_.begin(), sent_.end(), 134*3f982cf4SFabien Sanglard [&id](const RequestWithId& msg) { return id == msg.id; }), 135*3f982cf4SFabien Sanglard sent_.end()); 136*3f982cf4SFabien Sanglard if (sent_.empty()) { 137*3f982cf4SFabien Sanglard response_watch_ = MessageDemuxer::MessageWatch(); 138*3f982cf4SFabien Sanglard } 139*3f982cf4SFabien Sanglard } 140*3f982cf4SFabien Sanglard 141*3f982cf4SFabien Sanglard // Assign a ProtocolConnection to this handler for writing messages. SetConnection(ProtocolConnection * connection)142*3f982cf4SFabien Sanglard void SetConnection(ProtocolConnection* connection) { 143*3f982cf4SFabien Sanglard connection_ = connection; 144*3f982cf4SFabien Sanglard for (auto& message : to_send_) { 145*3f982cf4SFabien Sanglard auto* request_msg = RequestCoderTraits::serial_request(message.request); 146*3f982cf4SFabien Sanglard request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 147*3f982cf4SFabien Sanglard Error result = 148*3f982cf4SFabien Sanglard connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 149*3f982cf4SFabien Sanglard if (result.ok()) { 150*3f982cf4SFabien Sanglard sent_.emplace_back(std::move(message)); 151*3f982cf4SFabien Sanglard } else { 152*3f982cf4SFabien Sanglard delegate_->OnError(&message.request, result); 153*3f982cf4SFabien Sanglard } 154*3f982cf4SFabien Sanglard } 155*3f982cf4SFabien Sanglard if (!to_send_.empty()) { 156*3f982cf4SFabien Sanglard EnsureResponseWatch(); 157*3f982cf4SFabien Sanglard } 158*3f982cf4SFabien Sanglard to_send_.clear(); 159*3f982cf4SFabien Sanglard } 160*3f982cf4SFabien Sanglard 161*3f982cf4SFabien Sanglard // MessageDemuxer::MessageCallback overrides. OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)162*3f982cf4SFabien Sanglard ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, 163*3f982cf4SFabien Sanglard uint64_t connection_id, 164*3f982cf4SFabien Sanglard msgs::Type message_type, 165*3f982cf4SFabien Sanglard const uint8_t* buffer, 166*3f982cf4SFabien Sanglard size_t buffer_size, 167*3f982cf4SFabien Sanglard Clock::time_point now) override { 168*3f982cf4SFabien Sanglard if (message_type != RequestT::kResponseType) { 169*3f982cf4SFabien Sanglard return 0; 170*3f982cf4SFabien Sanglard } 171*3f982cf4SFabien Sanglard typename RequestT::ResponseMsgType response; 172*3f982cf4SFabien Sanglard ssize_t result = 173*3f982cf4SFabien Sanglard RequestCoderTraits::kDecoder(buffer, buffer_size, &response); 174*3f982cf4SFabien Sanglard if (result < 0) { 175*3f982cf4SFabien Sanglard return 0; 176*3f982cf4SFabien Sanglard } 177*3f982cf4SFabien Sanglard auto it = std::find_if( 178*3f982cf4SFabien Sanglard sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { 179*3f982cf4SFabien Sanglard return RequestCoderTraits::serial_request(msg.request)->request_id == 180*3f982cf4SFabien Sanglard response.request_id; 181*3f982cf4SFabien Sanglard }); 182*3f982cf4SFabien Sanglard if (it != sent_.end()) { 183*3f982cf4SFabien Sanglard delegate_->OnMatchedResponse(&it->request, &response, 184*3f982cf4SFabien Sanglard connection_->endpoint_id()); 185*3f982cf4SFabien Sanglard sent_.erase(it); 186*3f982cf4SFabien Sanglard if (sent_.empty()) { 187*3f982cf4SFabien Sanglard response_watch_ = MessageDemuxer::MessageWatch(); 188*3f982cf4SFabien Sanglard } 189*3f982cf4SFabien Sanglard } else { 190*3f982cf4SFabien Sanglard OSP_LOG_WARN << "got response for unknown request id: " 191*3f982cf4SFabien Sanglard << response.request_id; 192*3f982cf4SFabien Sanglard } 193*3f982cf4SFabien Sanglard return result; 194*3f982cf4SFabien Sanglard } 195*3f982cf4SFabien Sanglard 196*3f982cf4SFabien Sanglard private: 197*3f982cf4SFabien Sanglard struct RequestWithId { 198*3f982cf4SFabien Sanglard absl::optional<uint64_t> id; 199*3f982cf4SFabien Sanglard RequestT request; 200*3f982cf4SFabien Sanglard }; 201*3f982cf4SFabien Sanglard EnsureResponseWatch()202*3f982cf4SFabien Sanglard void EnsureResponseWatch() { 203*3f982cf4SFabien Sanglard if (!response_watch_) { 204*3f982cf4SFabien Sanglard response_watch_ = NetworkServiceManager::Get() 205*3f982cf4SFabien Sanglard ->GetProtocolConnectionClient() 206*3f982cf4SFabien Sanglard ->message_demuxer() 207*3f982cf4SFabien Sanglard ->WatchMessageType(connection_->endpoint_id(), 208*3f982cf4SFabien Sanglard RequestT::kResponseType, this); 209*3f982cf4SFabien Sanglard } 210*3f982cf4SFabien Sanglard } 211*3f982cf4SFabien Sanglard GetNextRequestId(uint64_t endpoint_id)212*3f982cf4SFabien Sanglard uint64_t GetNextRequestId(uint64_t endpoint_id) { 213*3f982cf4SFabien Sanglard return NetworkServiceManager::Get() 214*3f982cf4SFabien Sanglard ->GetProtocolConnectionClient() 215*3f982cf4SFabien Sanglard ->endpoint_request_ids() 216*3f982cf4SFabien Sanglard ->GetNextRequestId(endpoint_id); 217*3f982cf4SFabien Sanglard } 218*3f982cf4SFabien Sanglard 219*3f982cf4SFabien Sanglard ProtocolConnection* connection_ = nullptr; 220*3f982cf4SFabien Sanglard Delegate* const delegate_; 221*3f982cf4SFabien Sanglard std::vector<RequestWithId> to_send_; 222*3f982cf4SFabien Sanglard std::vector<RequestWithId> sent_; 223*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch response_watch_; 224*3f982cf4SFabien Sanglard 225*3f982cf4SFabien Sanglard OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); 226*3f982cf4SFabien Sanglard }; 227*3f982cf4SFabien Sanglard 228*3f982cf4SFabien Sanglard } // namespace osp 229*3f982cf4SFabien Sanglard } // namespace openscreen 230*3f982cf4SFabien Sanglard 231*3f982cf4SFabien Sanglard #endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 232