1*3f982cf4SFabien Sanglard // Copyright 2018 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard
5*3f982cf4SFabien Sanglard #include "osp/public/message_demuxer.h"
6*3f982cf4SFabien Sanglard
7*3f982cf4SFabien Sanglard #include <memory>
8*3f982cf4SFabien Sanglard #include <utility>
9*3f982cf4SFabien Sanglard
10*3f982cf4SFabien Sanglard #include "osp/impl/quic/quic_connection.h"
11*3f982cf4SFabien Sanglard #include "platform/base/error.h"
12*3f982cf4SFabien Sanglard #include "util/big_endian.h"
13*3f982cf4SFabien Sanglard #include "util/osp_logging.h"
14*3f982cf4SFabien Sanglard
15*3f982cf4SFabien Sanglard namespace openscreen {
16*3f982cf4SFabien Sanglard namespace osp {
17*3f982cf4SFabien Sanglard
18*3f982cf4SFabien Sanglard // static
19*3f982cf4SFabien Sanglard // Decodes a varUint, expecting it to follow the encoding format described here:
20*3f982cf4SFabien Sanglard // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeVarUint(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)21*3f982cf4SFabien Sanglard ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
22*3f982cf4SFabien Sanglard const std::vector<uint8_t>& buffer,
23*3f982cf4SFabien Sanglard size_t* num_bytes_decoded) {
24*3f982cf4SFabien Sanglard if (buffer.size() == 0) {
25*3f982cf4SFabien Sanglard return Error::Code::kCborIncompleteMessage;
26*3f982cf4SFabien Sanglard }
27*3f982cf4SFabien Sanglard
28*3f982cf4SFabien Sanglard uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
29*3f982cf4SFabien Sanglard *num_bytes_decoded = 0x1 << num_type_bytes;
30*3f982cf4SFabien Sanglard
31*3f982cf4SFabien Sanglard // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
32*3f982cf4SFabien Sanglard // since we expect the id to be followed by the message, equality is not valid
33*3f982cf4SFabien Sanglard if (buffer.size() <= *num_bytes_decoded) {
34*3f982cf4SFabien Sanglard return Error::Code::kCborIncompleteMessage;
35*3f982cf4SFabien Sanglard }
36*3f982cf4SFabien Sanglard
37*3f982cf4SFabien Sanglard switch (num_type_bytes) {
38*3f982cf4SFabien Sanglard case 0:
39*3f982cf4SFabien Sanglard return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
40*3f982cf4SFabien Sanglard case 1:
41*3f982cf4SFabien Sanglard return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
42*3f982cf4SFabien Sanglard case 2:
43*3f982cf4SFabien Sanglard return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
44*3f982cf4SFabien Sanglard case 3:
45*3f982cf4SFabien Sanglard return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
46*3f982cf4SFabien Sanglard default:
47*3f982cf4SFabien Sanglard OSP_NOTREACHED();
48*3f982cf4SFabien Sanglard }
49*3f982cf4SFabien Sanglard }
50*3f982cf4SFabien Sanglard
51*3f982cf4SFabien Sanglard // static
52*3f982cf4SFabien Sanglard // Decodes the Type of message, expecting it to follow the encoding format
53*3f982cf4SFabien Sanglard // described here:
54*3f982cf4SFabien Sanglard // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeType(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)55*3f982cf4SFabien Sanglard ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
56*3f982cf4SFabien Sanglard const std::vector<uint8_t>& buffer,
57*3f982cf4SFabien Sanglard size_t* num_bytes_decoded) {
58*3f982cf4SFabien Sanglard ErrorOr<uint64_t> message_type =
59*3f982cf4SFabien Sanglard MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
60*3f982cf4SFabien Sanglard if (message_type.is_error()) {
61*3f982cf4SFabien Sanglard return message_type.error();
62*3f982cf4SFabien Sanglard }
63*3f982cf4SFabien Sanglard
64*3f982cf4SFabien Sanglard msgs::Type parsed_type =
65*3f982cf4SFabien Sanglard msgs::TypeEnumValidator::SafeCast(message_type.value());
66*3f982cf4SFabien Sanglard if (parsed_type == msgs::Type::kUnknown) {
67*3f982cf4SFabien Sanglard return Error::Code::kCborInvalidMessage;
68*3f982cf4SFabien Sanglard }
69*3f982cf4SFabien Sanglard
70*3f982cf4SFabien Sanglard return parsed_type;
71*3f982cf4SFabien Sanglard }
72*3f982cf4SFabien Sanglard
73*3f982cf4SFabien Sanglard // static
74*3f982cf4SFabien Sanglard constexpr size_t MessageDemuxer::kDefaultBufferLimit;
75*3f982cf4SFabien Sanglard
76*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch() = default;
77*3f982cf4SFabien Sanglard
MessageWatch(MessageDemuxer * parent,bool is_default,uint64_t endpoint_id,msgs::Type message_type)78*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
79*3f982cf4SFabien Sanglard bool is_default,
80*3f982cf4SFabien Sanglard uint64_t endpoint_id,
81*3f982cf4SFabien Sanglard msgs::Type message_type)
82*3f982cf4SFabien Sanglard : parent_(parent),
83*3f982cf4SFabien Sanglard is_default_(is_default),
84*3f982cf4SFabien Sanglard endpoint_id_(endpoint_id),
85*3f982cf4SFabien Sanglard message_type_(message_type) {}
86*3f982cf4SFabien Sanglard
MessageWatch(MessageDemuxer::MessageWatch && other)87*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch(
88*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch&& other) noexcept
89*3f982cf4SFabien Sanglard : parent_(other.parent_),
90*3f982cf4SFabien Sanglard is_default_(other.is_default_),
91*3f982cf4SFabien Sanglard endpoint_id_(other.endpoint_id_),
92*3f982cf4SFabien Sanglard message_type_(other.message_type_) {
93*3f982cf4SFabien Sanglard other.parent_ = nullptr;
94*3f982cf4SFabien Sanglard }
95*3f982cf4SFabien Sanglard
~MessageWatch()96*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::~MessageWatch() {
97*3f982cf4SFabien Sanglard if (parent_) {
98*3f982cf4SFabien Sanglard if (is_default_) {
99*3f982cf4SFabien Sanglard OSP_VLOG << "dropping default handler for type: "
100*3f982cf4SFabien Sanglard << static_cast<int>(message_type_);
101*3f982cf4SFabien Sanglard parent_->StopDefaultMessageTypeWatch(message_type_);
102*3f982cf4SFabien Sanglard } else {
103*3f982cf4SFabien Sanglard OSP_VLOG << "dropping handler for type: "
104*3f982cf4SFabien Sanglard << static_cast<int>(message_type_);
105*3f982cf4SFabien Sanglard parent_->StopWatchingMessageType(endpoint_id_, message_type_);
106*3f982cf4SFabien Sanglard }
107*3f982cf4SFabien Sanglard }
108*3f982cf4SFabien Sanglard }
109*3f982cf4SFabien Sanglard
operator =(MessageWatch && other)110*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
111*3f982cf4SFabien Sanglard MessageWatch&& other) noexcept {
112*3f982cf4SFabien Sanglard using std::swap;
113*3f982cf4SFabien Sanglard swap(parent_, other.parent_);
114*3f982cf4SFabien Sanglard swap(is_default_, other.is_default_);
115*3f982cf4SFabien Sanglard swap(endpoint_id_, other.endpoint_id_);
116*3f982cf4SFabien Sanglard swap(message_type_, other.message_type_);
117*3f982cf4SFabien Sanglard return *this;
118*3f982cf4SFabien Sanglard }
119*3f982cf4SFabien Sanglard
MessageDemuxer(ClockNowFunctionPtr now_function,size_t buffer_limit=kDefaultBufferLimit)120*3f982cf4SFabien Sanglard MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
121*3f982cf4SFabien Sanglard size_t buffer_limit = kDefaultBufferLimit)
122*3f982cf4SFabien Sanglard : now_function_(now_function), buffer_limit_(buffer_limit) {
123*3f982cf4SFabien Sanglard OSP_DCHECK(now_function_);
124*3f982cf4SFabien Sanglard }
125*3f982cf4SFabien Sanglard
126*3f982cf4SFabien Sanglard MessageDemuxer::~MessageDemuxer() = default;
127*3f982cf4SFabien Sanglard
WatchMessageType(uint64_t endpoint_id,msgs::Type message_type,MessageCallback * callback)128*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
129*3f982cf4SFabien Sanglard uint64_t endpoint_id,
130*3f982cf4SFabien Sanglard msgs::Type message_type,
131*3f982cf4SFabien Sanglard MessageCallback* callback) {
132*3f982cf4SFabien Sanglard auto callbacks_entry = message_callbacks_.find(endpoint_id);
133*3f982cf4SFabien Sanglard if (callbacks_entry == message_callbacks_.end()) {
134*3f982cf4SFabien Sanglard callbacks_entry =
135*3f982cf4SFabien Sanglard message_callbacks_
136*3f982cf4SFabien Sanglard .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
137*3f982cf4SFabien Sanglard .first;
138*3f982cf4SFabien Sanglard }
139*3f982cf4SFabien Sanglard auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
140*3f982cf4SFabien Sanglard if (!emplace_result.second)
141*3f982cf4SFabien Sanglard return MessageWatch();
142*3f982cf4SFabien Sanglard auto endpoint_entry = buffers_.find(endpoint_id);
143*3f982cf4SFabien Sanglard if (endpoint_entry != buffers_.end()) {
144*3f982cf4SFabien Sanglard for (auto& buffer : endpoint_entry->second) {
145*3f982cf4SFabien Sanglard if (buffer.second.empty())
146*3f982cf4SFabien Sanglard continue;
147*3f982cf4SFabien Sanglard auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
148*3f982cf4SFabien Sanglard if (message_type == buffered_type) {
149*3f982cf4SFabien Sanglard HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
150*3f982cf4SFabien Sanglard &buffer.second);
151*3f982cf4SFabien Sanglard }
152*3f982cf4SFabien Sanglard }
153*3f982cf4SFabien Sanglard }
154*3f982cf4SFabien Sanglard return MessageWatch(this, false, endpoint_id, message_type);
155*3f982cf4SFabien Sanglard }
156*3f982cf4SFabien Sanglard
SetDefaultMessageTypeWatch(msgs::Type message_type,MessageCallback * callback)157*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
158*3f982cf4SFabien Sanglard msgs::Type message_type,
159*3f982cf4SFabien Sanglard MessageCallback* callback) {
160*3f982cf4SFabien Sanglard auto emplace_result = default_callbacks_.emplace(message_type, callback);
161*3f982cf4SFabien Sanglard if (!emplace_result.second)
162*3f982cf4SFabien Sanglard return MessageWatch();
163*3f982cf4SFabien Sanglard for (auto& endpoint_buffers : buffers_) {
164*3f982cf4SFabien Sanglard auto endpoint_id = endpoint_buffers.first;
165*3f982cf4SFabien Sanglard for (auto& stream_map : endpoint_buffers.second) {
166*3f982cf4SFabien Sanglard if (stream_map.second.empty())
167*3f982cf4SFabien Sanglard continue;
168*3f982cf4SFabien Sanglard auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
169*3f982cf4SFabien Sanglard if (message_type == buffered_type) {
170*3f982cf4SFabien Sanglard auto connection_id = stream_map.first;
171*3f982cf4SFabien Sanglard auto callbacks_entry = message_callbacks_.find(endpoint_id);
172*3f982cf4SFabien Sanglard HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
173*3f982cf4SFabien Sanglard &stream_map.second);
174*3f982cf4SFabien Sanglard }
175*3f982cf4SFabien Sanglard }
176*3f982cf4SFabien Sanglard }
177*3f982cf4SFabien Sanglard return MessageWatch(this, true, 0, message_type);
178*3f982cf4SFabien Sanglard }
179*3f982cf4SFabien Sanglard
OnStreamData(uint64_t endpoint_id,uint64_t connection_id,const uint8_t * data,size_t data_size)180*3f982cf4SFabien Sanglard void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
181*3f982cf4SFabien Sanglard uint64_t connection_id,
182*3f982cf4SFabien Sanglard const uint8_t* data,
183*3f982cf4SFabien Sanglard size_t data_size) {
184*3f982cf4SFabien Sanglard OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
185*3f982cf4SFabien Sanglard << "] - (" << data_size << ")";
186*3f982cf4SFabien Sanglard auto& stream_map = buffers_[endpoint_id];
187*3f982cf4SFabien Sanglard if (!data_size) {
188*3f982cf4SFabien Sanglard stream_map.erase(connection_id);
189*3f982cf4SFabien Sanglard if (stream_map.empty())
190*3f982cf4SFabien Sanglard buffers_.erase(endpoint_id);
191*3f982cf4SFabien Sanglard return;
192*3f982cf4SFabien Sanglard }
193*3f982cf4SFabien Sanglard std::vector<uint8_t>& buffer = stream_map[connection_id];
194*3f982cf4SFabien Sanglard buffer.insert(buffer.end(), data, data + data_size);
195*3f982cf4SFabien Sanglard
196*3f982cf4SFabien Sanglard auto callbacks_entry = message_callbacks_.find(endpoint_id);
197*3f982cf4SFabien Sanglard HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
198*3f982cf4SFabien Sanglard
199*3f982cf4SFabien Sanglard if (buffer.size() > buffer_limit_)
200*3f982cf4SFabien Sanglard stream_map.erase(connection_id);
201*3f982cf4SFabien Sanglard }
202*3f982cf4SFabien Sanglard
StopWatchingMessageType(uint64_t endpoint_id,msgs::Type message_type)203*3f982cf4SFabien Sanglard void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
204*3f982cf4SFabien Sanglard msgs::Type message_type) {
205*3f982cf4SFabien Sanglard auto& message_map = message_callbacks_[endpoint_id];
206*3f982cf4SFabien Sanglard auto it = message_map.find(message_type);
207*3f982cf4SFabien Sanglard message_map.erase(it);
208*3f982cf4SFabien Sanglard }
209*3f982cf4SFabien Sanglard
StopDefaultMessageTypeWatch(msgs::Type message_type)210*3f982cf4SFabien Sanglard void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
211*3f982cf4SFabien Sanglard default_callbacks_.erase(message_type);
212*3f982cf4SFabien Sanglard }
213*3f982cf4SFabien Sanglard
HandleStreamBufferLoop(uint64_t endpoint_id,uint64_t connection_id,std::map<uint64_t,std::map<msgs::Type,MessageCallback * >>::iterator callbacks_entry,std::vector<uint8_t> * buffer)214*3f982cf4SFabien Sanglard MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
215*3f982cf4SFabien Sanglard uint64_t endpoint_id,
216*3f982cf4SFabien Sanglard uint64_t connection_id,
217*3f982cf4SFabien Sanglard std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
218*3f982cf4SFabien Sanglard callbacks_entry,
219*3f982cf4SFabien Sanglard std::vector<uint8_t>* buffer) {
220*3f982cf4SFabien Sanglard HandleStreamBufferResult result;
221*3f982cf4SFabien Sanglard do {
222*3f982cf4SFabien Sanglard result = {false, 0};
223*3f982cf4SFabien Sanglard if (callbacks_entry != message_callbacks_.end()) {
224*3f982cf4SFabien Sanglard OSP_VLOG << "attempting endpoint-specific handling";
225*3f982cf4SFabien Sanglard result = HandleStreamBuffer(endpoint_id, connection_id,
226*3f982cf4SFabien Sanglard &callbacks_entry->second, buffer);
227*3f982cf4SFabien Sanglard }
228*3f982cf4SFabien Sanglard if (!result.handled) {
229*3f982cf4SFabien Sanglard if (!default_callbacks_.empty()) {
230*3f982cf4SFabien Sanglard OSP_VLOG << "attempting generic message handling";
231*3f982cf4SFabien Sanglard result = HandleStreamBuffer(endpoint_id, connection_id,
232*3f982cf4SFabien Sanglard &default_callbacks_, buffer);
233*3f982cf4SFabien Sanglard }
234*3f982cf4SFabien Sanglard }
235*3f982cf4SFabien Sanglard OSP_VLOG_IF(!result.handled) << "no message handler matched";
236*3f982cf4SFabien Sanglard } while (result.consumed && !buffer->empty());
237*3f982cf4SFabien Sanglard return result;
238*3f982cf4SFabien Sanglard }
239*3f982cf4SFabien Sanglard
240*3f982cf4SFabien Sanglard // TODO(rwkeane) Use absl::Span for the buffer
HandleStreamBuffer(uint64_t endpoint_id,uint64_t connection_id,std::map<msgs::Type,MessageCallback * > * message_callbacks,std::vector<uint8_t> * buffer)241*3f982cf4SFabien Sanglard MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
242*3f982cf4SFabien Sanglard uint64_t endpoint_id,
243*3f982cf4SFabien Sanglard uint64_t connection_id,
244*3f982cf4SFabien Sanglard std::map<msgs::Type, MessageCallback*>* message_callbacks,
245*3f982cf4SFabien Sanglard std::vector<uint8_t>* buffer) {
246*3f982cf4SFabien Sanglard size_t consumed = 0;
247*3f982cf4SFabien Sanglard size_t total_consumed = 0;
248*3f982cf4SFabien Sanglard bool handled = false;
249*3f982cf4SFabien Sanglard do {
250*3f982cf4SFabien Sanglard consumed = 0;
251*3f982cf4SFabien Sanglard size_t msg_type_byte_length;
252*3f982cf4SFabien Sanglard ErrorOr<msgs::Type> message_type =
253*3f982cf4SFabien Sanglard MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
254*3f982cf4SFabien Sanglard if (message_type.is_error()) {
255*3f982cf4SFabien Sanglard buffer->clear();
256*3f982cf4SFabien Sanglard break;
257*3f982cf4SFabien Sanglard }
258*3f982cf4SFabien Sanglard auto callback_entry = message_callbacks->find(message_type.value());
259*3f982cf4SFabien Sanglard if (callback_entry == message_callbacks->end())
260*3f982cf4SFabien Sanglard break;
261*3f982cf4SFabien Sanglard handled = true;
262*3f982cf4SFabien Sanglard OSP_VLOG << "handling message type "
263*3f982cf4SFabien Sanglard << static_cast<int>(message_type.value());
264*3f982cf4SFabien Sanglard auto consumed_or_error = callback_entry->second->OnStreamMessage(
265*3f982cf4SFabien Sanglard endpoint_id, connection_id, message_type.value(),
266*3f982cf4SFabien Sanglard buffer->data() + msg_type_byte_length,
267*3f982cf4SFabien Sanglard buffer->size() - msg_type_byte_length, now_function_());
268*3f982cf4SFabien Sanglard if (!consumed_or_error) {
269*3f982cf4SFabien Sanglard if (consumed_or_error.error().code() !=
270*3f982cf4SFabien Sanglard Error::Code::kCborIncompleteMessage) {
271*3f982cf4SFabien Sanglard buffer->clear();
272*3f982cf4SFabien Sanglard break;
273*3f982cf4SFabien Sanglard }
274*3f982cf4SFabien Sanglard } else {
275*3f982cf4SFabien Sanglard consumed = consumed_or_error.value();
276*3f982cf4SFabien Sanglard buffer->erase(buffer->begin(),
277*3f982cf4SFabien Sanglard buffer->begin() + consumed + msg_type_byte_length);
278*3f982cf4SFabien Sanglard }
279*3f982cf4SFabien Sanglard total_consumed += consumed;
280*3f982cf4SFabien Sanglard } while (consumed && !buffer->empty());
281*3f982cf4SFabien Sanglard return HandleStreamBufferResult{handled, total_consumed};
282*3f982cf4SFabien Sanglard }
283*3f982cf4SFabien Sanglard
StopWatching(MessageDemuxer::MessageWatch * watch)284*3f982cf4SFabien Sanglard void StopWatching(MessageDemuxer::MessageWatch* watch) {
285*3f982cf4SFabien Sanglard *watch = MessageDemuxer::MessageWatch();
286*3f982cf4SFabien Sanglard }
287*3f982cf4SFabien Sanglard
288*3f982cf4SFabien Sanglard } // namespace osp
289*3f982cf4SFabien Sanglard } // namespace openscreen
290