1 // Copyright 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/moqt/moqt_session.h"
6
7 #include <array>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <utility>
13 #include <vector>
14
15
16 #include "absl/algorithm/container.h"
17 #include "absl/status/status.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "quiche/quic/core/quic_types.h"
22 #include "quiche/quic/moqt/moqt_messages.h"
23 #include "quiche/quic/moqt/moqt_parser.h"
24 #include "quiche/quic/moqt/moqt_subscribe_windows.h"
25 #include "quiche/quic/moqt/moqt_track.h"
26 #include "quiche/quic/platform/api/quic_bug_tracker.h"
27 #include "quiche/common/platform/api/quiche_logging.h"
28 #include "quiche/common/quiche_buffer_allocator.h"
29 #include "quiche/common/quiche_stream.h"
30 #include "quiche/web_transport/web_transport.h"
31
32 #define ENDPOINT \
33 (perspective() == Perspective::IS_SERVER ? "MoQT Server: " : "MoQT Client: ")
34
35 namespace moqt {
36
37 using ::quic::Perspective;
38
GetControlStream()39 MoqtSession::Stream* MoqtSession::GetControlStream() {
40 if (!control_stream_.has_value()) {
41 return nullptr;
42 }
43 webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_);
44 if (raw_stream == nullptr) {
45 return nullptr;
46 }
47 return static_cast<Stream*>(raw_stream->visitor());
48 }
49
SendControlMessage(quiche::QuicheBuffer message)50 void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) {
51 Stream* control_stream = GetControlStream();
52 if (control_stream == nullptr) {
53 QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream "
54 "while it does not exist";
55 return;
56 }
57 control_stream->SendOrBufferMessage(std::move(message));
58 }
59
OnSessionReady()60 void MoqtSession::OnSessionReady() {
61 QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
62 if (parameters_.perspective == Perspective::IS_SERVER) {
63 return;
64 }
65
66 webtransport::Stream* control_stream =
67 session_->OpenOutgoingBidirectionalStream();
68 if (control_stream == nullptr) {
69 Error(MoqtError::kInternalError, "Unable to open a control stream");
70 return;
71 }
72 control_stream->SetVisitor(std::make_unique<Stream>(
73 this, control_stream, /*is_control_stream=*/true));
74 control_stream_ = control_stream->GetStreamId();
75 MoqtClientSetup setup = MoqtClientSetup{
76 .supported_versions = std::vector<MoqtVersion>{parameters_.version},
77 .role = MoqtRole::kPubSub,
78 };
79 if (!parameters_.using_webtrans) {
80 setup.path = parameters_.path;
81 }
82 SendControlMessage(framer_.SerializeClientSetup(setup));
83 QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
84 }
85
OnSessionClosed(webtransport::SessionErrorCode,const std::string & error_message)86 void MoqtSession::OnSessionClosed(webtransport::SessionErrorCode,
87 const std::string& error_message) {
88 if (!error_.empty()) {
89 // Avoid erroring out twice.
90 return;
91 }
92 QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session closed with message: "
93 << error_message;
94 error_ = error_message;
95 std::move(callbacks_.session_terminated_callback)(error_message);
96 }
97
OnIncomingBidirectionalStreamAvailable()98 void MoqtSession::OnIncomingBidirectionalStreamAvailable() {
99 while (webtransport::Stream* stream =
100 session_->AcceptIncomingBidirectionalStream()) {
101 if (control_stream_.has_value()) {
102 Error(MoqtError::kProtocolViolation, "Bidirectional stream already open");
103 return;
104 }
105 stream->SetVisitor(std::make_unique<Stream>(this, stream));
106 stream->visitor()->OnCanRead();
107 }
108 }
OnIncomingUnidirectionalStreamAvailable()109 void MoqtSession::OnIncomingUnidirectionalStreamAvailable() {
110 while (webtransport::Stream* stream =
111 session_->AcceptIncomingUnidirectionalStream()) {
112 stream->SetVisitor(std::make_unique<Stream>(this, stream));
113 stream->visitor()->OnCanRead();
114 }
115 }
116
OnDatagramReceived(absl::string_view datagram)117 void MoqtSession::OnDatagramReceived(absl::string_view datagram) {
118 MoqtObject message;
119 absl::string_view payload = MoqtParser::ProcessDatagram(datagram, message);
120 if (payload.empty()) {
121 Error(MoqtError::kProtocolViolation, "Malformed datagram");
122 return;
123 }
124 QUICHE_DLOG(INFO) << ENDPOINT
125 << "Received OBJECT message in datagram for subscribe_id "
126 << message.subscribe_id << " for track alias "
127 << message.track_alias << " with sequence "
128 << message.group_id << ":" << message.object_id
129 << " send_order " << message.object_send_order << " length "
130 << payload.size();
131 auto [full_track_name, visitor] = TrackPropertiesFromAlias(message);
132 if (visitor != nullptr) {
133 visitor->OnObjectFragment(full_track_name, message.group_id,
134 message.object_id, message.object_send_order,
135 message.forwarding_preference, payload, true);
136 }
137 }
138
Error(MoqtError code,absl::string_view error)139 void MoqtSession::Error(MoqtError code, absl::string_view error) {
140 if (!error_.empty()) {
141 // Avoid erroring out twice.
142 return;
143 }
144 QUICHE_DLOG(INFO) << ENDPOINT << "MOQT session closed with code: "
145 << static_cast<int>(code) << " and message: " << error;
146 error_ = std::string(error);
147 session_->CloseSession(static_cast<uint64_t>(code), error);
148 std::move(callbacks_.session_terminated_callback)(error);
149 }
150
AddLocalTrack(const FullTrackName & full_track_name,MoqtForwardingPreference forwarding_preference,LocalTrack::Visitor * visitor)151 void MoqtSession::AddLocalTrack(const FullTrackName& full_track_name,
152 MoqtForwardingPreference forwarding_preference,
153 LocalTrack::Visitor* visitor) {
154 local_tracks_.try_emplace(full_track_name, full_track_name,
155 forwarding_preference, visitor);
156 }
157
158 // TODO: Create state that allows ANNOUNCE_OK/ERROR on spurious namespaces to
159 // trigger session errors.
Announce(absl::string_view track_namespace,MoqtOutgoingAnnounceCallback announce_callback)160 void MoqtSession::Announce(absl::string_view track_namespace,
161 MoqtOutgoingAnnounceCallback announce_callback) {
162 if (peer_role_ == MoqtRole::kPublisher) {
163 std::move(announce_callback)(
164 track_namespace,
165 MoqtAnnounceErrorReason{MoqtAnnounceErrorCode::kInternalError,
166 "ANNOUNCE cannot be sent to Publisher"});
167 return;
168 }
169 if (pending_outgoing_announces_.contains(track_namespace)) {
170 std::move(announce_callback)(
171 track_namespace,
172 MoqtAnnounceErrorReason{
173 MoqtAnnounceErrorCode::kInternalError,
174 "ANNOUNCE message already outstanding for namespace"});
175 return;
176 }
177 MoqtAnnounce message;
178 message.track_namespace = track_namespace;
179 SendControlMessage(framer_.SerializeAnnounce(message));
180 QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for "
181 << message.track_namespace;
182 pending_outgoing_announces_[track_namespace] = std::move(announce_callback);
183 }
184
HasSubscribers(const FullTrackName & full_track_name) const185 bool MoqtSession::HasSubscribers(const FullTrackName& full_track_name) const {
186 auto it = local_tracks_.find(full_track_name);
187 return (it != local_tracks_.end() && it->second.HasSubscriber());
188 }
189
SubscribeAbsolute(absl::string_view track_namespace,absl::string_view name,uint64_t start_group,uint64_t start_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)190 bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
191 absl::string_view name,
192 uint64_t start_group, uint64_t start_object,
193 RemoteTrack::Visitor* visitor,
194 absl::string_view auth_info) {
195 MoqtSubscribe message;
196 message.track_namespace = track_namespace;
197 message.track_name = name;
198 message.start_group = MoqtSubscribeLocation(true, start_group);
199 message.start_object = MoqtSubscribeLocation(true, start_object);
200 message.end_group = std::nullopt;
201 message.end_object = std::nullopt;
202 if (!auth_info.empty()) {
203 message.authorization_info = std::move(auth_info);
204 }
205 return Subscribe(message, visitor);
206 }
207
SubscribeAbsolute(absl::string_view track_namespace,absl::string_view name,uint64_t start_group,uint64_t start_object,uint64_t end_group,uint64_t end_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)208 bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
209 absl::string_view name,
210 uint64_t start_group, uint64_t start_object,
211 uint64_t end_group, uint64_t end_object,
212 RemoteTrack::Visitor* visitor,
213 absl::string_view auth_info) {
214 if (end_group < start_group) {
215 QUIC_DLOG(ERROR) << "Subscription end is before beginning";
216 return false;
217 }
218 if (end_group == start_group && end_object < start_object) {
219 QUIC_DLOG(ERROR) << "Subscription end is before beginning";
220 return false;
221 }
222 MoqtSubscribe message;
223 message.track_namespace = track_namespace;
224 message.track_name = name;
225 message.start_group = MoqtSubscribeLocation(true, start_group);
226 message.start_object = MoqtSubscribeLocation(true, start_object);
227 message.end_group = MoqtSubscribeLocation(true, end_group);
228 message.end_object = MoqtSubscribeLocation(true, end_object);
229 if (!auth_info.empty()) {
230 message.authorization_info = std::move(auth_info);
231 }
232 return Subscribe(message, visitor);
233 }
234
SubscribeRelative(absl::string_view track_namespace,absl::string_view name,int64_t start_group,int64_t start_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)235 bool MoqtSession::SubscribeRelative(absl::string_view track_namespace,
236 absl::string_view name, int64_t start_group,
237 int64_t start_object,
238 RemoteTrack::Visitor* visitor,
239 absl::string_view auth_info) {
240 MoqtSubscribe message;
241 message.track_namespace = track_namespace;
242 message.track_name = name;
243 message.start_group = MoqtSubscribeLocation(false, start_group);
244 message.start_object = MoqtSubscribeLocation(false, start_object);
245 message.end_group = std::nullopt;
246 message.end_object = std::nullopt;
247 if (!auth_info.empty()) {
248 message.authorization_info = std::move(auth_info);
249 }
250 return Subscribe(message, visitor);
251 }
252
SubscribeCurrentGroup(absl::string_view track_namespace,absl::string_view name,RemoteTrack::Visitor * visitor,absl::string_view auth_info)253 bool MoqtSession::SubscribeCurrentGroup(absl::string_view track_namespace,
254 absl::string_view name,
255 RemoteTrack::Visitor* visitor,
256 absl::string_view auth_info) {
257 MoqtSubscribe message;
258 message.track_namespace = track_namespace;
259 message.track_name = name;
260 // First object of current group.
261 message.start_group = MoqtSubscribeLocation(false, (uint64_t)0);
262 message.start_object = MoqtSubscribeLocation(true, (int64_t)0);
263 message.end_group = std::nullopt;
264 message.end_object = std::nullopt;
265 if (!auth_info.empty()) {
266 message.authorization_info = std::move(auth_info);
267 }
268 return Subscribe(message, visitor);
269 }
270
Subscribe(MoqtSubscribe & message,RemoteTrack::Visitor * visitor)271 bool MoqtSession::Subscribe(MoqtSubscribe& message,
272 RemoteTrack::Visitor* visitor) {
273 if (peer_role_ == MoqtRole::kSubscriber) {
274 QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE to subscriber peer";
275 return false;
276 }
277 // TODO(martinduke): support authorization info
278 message.subscribe_id = next_subscribe_id_++;
279 FullTrackName ftn(std::string(message.track_namespace),
280 std::string(message.track_name));
281 auto it = remote_track_aliases_.find(ftn);
282 if (it != remote_track_aliases_.end()) {
283 message.track_alias = it->second;
284 if (message.track_alias >= next_remote_track_alias_) {
285 next_remote_track_alias_ = message.track_alias + 1;
286 }
287 } else {
288 message.track_alias = next_remote_track_alias_++;
289 }
290 SendControlMessage(framer_.SerializeSubscribe(message));
291 QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
292 << message.track_namespace << ":" << message.track_name;
293 active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
294 return true;
295 }
296
OpenUnidirectionalStream()297 std::optional<webtransport::StreamId> MoqtSession::OpenUnidirectionalStream() {
298 if (!session_->CanOpenNextOutgoingUnidirectionalStream()) {
299 return std::nullopt;
300 }
301 webtransport::Stream* new_stream =
302 session_->OpenOutgoingUnidirectionalStream();
303 if (new_stream == nullptr) {
304 return std::nullopt;
305 }
306 new_stream->SetVisitor(std::make_unique<Stream>(this, new_stream, false));
307 return new_stream->GetStreamId();
308 }
309
310 std::pair<FullTrackName, RemoteTrack::Visitor*>
TrackPropertiesFromAlias(const MoqtObject & message)311 MoqtSession::TrackPropertiesFromAlias(const MoqtObject& message) {
312 auto it = remote_tracks_.find(message.track_alias);
313 RemoteTrack::Visitor* visitor = nullptr;
314 if (it == remote_tracks_.end()) {
315 // SUBSCRIBE_OK has not arrived yet, but deliver it.
316 auto subscribe_it = active_subscribes_.find(message.subscribe_id);
317 if (subscribe_it == active_subscribes_.end()) {
318 return std::pair<FullTrackName, RemoteTrack::Visitor*>(
319 {{"", ""}, nullptr});
320 }
321 ActiveSubscribe& subscribe = subscribe_it->second;
322 visitor = subscribe.visitor;
323 subscribe.received_object = true;
324 if (subscribe.forwarding_preference.has_value()) {
325 if (message.forwarding_preference != *subscribe.forwarding_preference) {
326 Error(MoqtError::kProtocolViolation,
327 "Forwarding preference changes mid-track");
328 return std::pair<FullTrackName, RemoteTrack::Visitor*>(
329 {{"", ""}, nullptr});
330 }
331 } else {
332 subscribe.forwarding_preference = message.forwarding_preference;
333 }
334 return std::pair<FullTrackName, RemoteTrack::Visitor*>(
335 {{subscribe.message.track_namespace, subscribe.message.track_name},
336 subscribe.visitor});
337 }
338 RemoteTrack& track = it->second;
339 if (!track.CheckForwardingPreference(message.forwarding_preference)) {
340 // Incorrect forwarding preference.
341 Error(MoqtError::kProtocolViolation,
342 "Forwarding preference changes mid-track");
343 return std::pair<FullTrackName, RemoteTrack::Visitor*>({{"", ""}, nullptr});
344 }
345 return std::pair<FullTrackName, RemoteTrack::Visitor*>(
346 {{track.full_track_name().track_namespace,
347 track.full_track_name().track_name},
348 track.visitor()});
349 }
350
PublishObject(const FullTrackName & full_track_name,uint64_t group_id,uint64_t object_id,uint64_t object_send_order,absl::string_view payload,bool end_of_stream)351 bool MoqtSession::PublishObject(const FullTrackName& full_track_name,
352 uint64_t group_id, uint64_t object_id,
353 uint64_t object_send_order,
354 absl::string_view payload, bool end_of_stream) {
355 auto track_it = local_tracks_.find(full_track_name);
356 if (track_it == local_tracks_.end()) {
357 QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track";
358 return false;
359 }
360 LocalTrack& track = track_it->second;
361 MoqtForwardingPreference forwarding_preference =
362 track.forwarding_preference();
363 if ((forwarding_preference == MoqtForwardingPreference::kObject ||
364 forwarding_preference == MoqtForwardingPreference::kDatagram) &&
365 !end_of_stream) {
366 QUIC_BUG(MoqtSession_PublishObject_end_of_stream_required)
367 << "Forwarding preferences of Object or Datagram require stream to be "
368 "immediately closed";
369 return false;
370 }
371 track.SentSequence(FullSequence(group_id, object_id));
372 std::vector<SubscribeWindow*> subscriptions =
373 track.ShouldSend({group_id, object_id});
374 if (subscriptions.empty()) {
375 return true;
376 }
377 MoqtObject object;
378 QUICHE_DCHECK(track.track_alias().has_value());
379 object.track_alias = *track.track_alias();
380 object.group_id = group_id;
381 object.object_id = object_id;
382 object.object_send_order = object_send_order;
383 object.forwarding_preference = forwarding_preference;
384 object.payload_length = payload.size();
385 int failures = 0;
386 quiche::StreamWriteOptions write_options;
387 write_options.set_send_fin(end_of_stream);
388 for (auto subscription : subscriptions) {
389 if (forwarding_preference == MoqtForwardingPreference::kDatagram) {
390 object.subscribe_id = subscription->subscribe_id();
391 quiche::QuicheBuffer datagram =
392 framer_.SerializeObjectDatagram(object, payload);
393 // TODO(martinduke): It's OK to just silently fail, but better to notify
394 // the app on errors.
395 session_->SendOrQueueDatagram(datagram.AsStringView());
396 continue;
397 }
398 bool new_stream = false;
399 std::optional<webtransport::StreamId> stream_id =
400 subscription->GetStreamForSequence(FullSequence(group_id, object_id));
401 if (!stream_id.has_value()) {
402 new_stream = true;
403 stream_id = OpenUnidirectionalStream();
404 if (!stream_id.has_value()) {
405 QUICHE_DLOG(ERROR) << ENDPOINT
406 << "Sending OBJECT to nonexistent stream";
407 ++failures;
408 continue;
409 }
410 if (!end_of_stream) {
411 subscription->AddStream(group_id, object_id, *stream_id);
412 }
413 }
414 webtransport::Stream* stream = session_->GetStreamById(*stream_id);
415 if (stream == nullptr) {
416 QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT to nonexistent stream "
417 << *stream_id;
418 ++failures;
419 continue;
420 }
421 object.subscribe_id = subscription->subscribe_id();
422 quiche::QuicheBuffer header =
423 framer_.SerializeObjectHeader(object, new_stream);
424 std::array<absl::string_view, 2> views = {header.AsStringView(), payload};
425 if (!stream->Writev(views, write_options).ok()) {
426 QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message";
427 ++failures;
428 continue;
429 }
430 QUICHE_LOG(INFO) << ENDPOINT << "Sending object length " << payload.length()
431 << " for " << full_track_name.track_namespace << ":"
432 << full_track_name.track_name << " with sequence "
433 << object.group_id << ":" << object.object_id
434 << " on stream " << *stream_id;
435 if (end_of_stream && !new_stream) {
436 subscription->RemoveStream(group_id, object_id);
437 }
438 }
439 return (failures == 0);
440 }
441
OnCanRead()442 void MoqtSession::Stream::OnCanRead() {
443 bool fin =
444 quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
445 parser_.ProcessData(chunk, /*end_of_stream=*/false);
446 });
447 if (fin) {
448 parser_.ProcessData("", /*end_of_stream=*/true);
449 }
450 }
OnCanWrite()451 void MoqtSession::Stream::OnCanWrite() {}
OnResetStreamReceived(webtransport::StreamErrorCode error)452 void MoqtSession::Stream::OnResetStreamReceived(
453 webtransport::StreamErrorCode error) {
454 if (is_control_stream_.has_value() && *is_control_stream_) {
455 session_->Error(
456 MoqtError::kProtocolViolation,
457 absl::StrCat("Control stream reset with error code ", error));
458 }
459 }
OnStopSendingReceived(webtransport::StreamErrorCode error)460 void MoqtSession::Stream::OnStopSendingReceived(
461 webtransport::StreamErrorCode error) {
462 if (is_control_stream_.has_value() && *is_control_stream_) {
463 session_->Error(
464 MoqtError::kProtocolViolation,
465 absl::StrCat("Control stream reset with error code ", error));
466 }
467 }
468
OnObjectMessage(const MoqtObject & message,absl::string_view payload,bool end_of_message)469 void MoqtSession::Stream::OnObjectMessage(const MoqtObject& message,
470 absl::string_view payload,
471 bool end_of_message) {
472 if (is_control_stream_ == true) {
473 session_->Error(MoqtError::kProtocolViolation,
474 "Received OBJECT message on control stream");
475 return;
476 }
477 QUICHE_DLOG(INFO)
478 << ENDPOINT << "Received OBJECT message on stream "
479 << stream_->GetStreamId() << " for subscribe_id " << message.subscribe_id
480 << " for track alias " << message.track_alias << " with sequence "
481 << message.group_id << ":" << message.object_id << " send_order "
482 << message.object_send_order << " forwarding_preference "
483 << MoqtForwardingPreferenceToString(message.forwarding_preference)
484 << " length " << payload.size() << " explicit length "
485 << (message.payload_length.has_value() ? (int)*message.payload_length
486 : -1)
487 << (end_of_message ? "F" : "");
488 if (!session_->parameters_.deliver_partial_objects) {
489 if (!end_of_message) { // Buffer partial object.
490 absl::StrAppend(&partial_object_, payload);
491 return;
492 }
493 if (!partial_object_.empty()) { // Completes the object
494 absl::StrAppend(&partial_object_, payload);
495 payload = absl::string_view(partial_object_);
496 }
497 }
498 auto [full_track_name, visitor] = session_->TrackPropertiesFromAlias(message);
499 if (visitor != nullptr) {
500 visitor->OnObjectFragment(full_track_name, message.group_id,
501 message.object_id, message.object_send_order,
502 message.forwarding_preference, payload,
503 end_of_message);
504 }
505 partial_object_.clear();
506 }
507
OnClientSetupMessage(const MoqtClientSetup & message)508 void MoqtSession::Stream::OnClientSetupMessage(const MoqtClientSetup& message) {
509 if (is_control_stream_.has_value()) {
510 if (!*is_control_stream_) {
511 session_->Error(MoqtError::kProtocolViolation,
512 "Received SETUP on non-control stream");
513 return;
514 }
515 } else {
516 is_control_stream_ = true;
517 }
518 session_->control_stream_ = stream_->GetStreamId();
519 if (perspective() == Perspective::IS_CLIENT) {
520 session_->Error(MoqtError::kProtocolViolation,
521 "Received CLIENT_SETUP from server");
522 return;
523 }
524 if (absl::c_find(message.supported_versions, session_->parameters_.version) ==
525 message.supported_versions.end()) {
526 // TODO(martinduke): Is this the right error code? See issue #346.
527 session_->Error(MoqtError::kProtocolViolation,
528 absl::StrCat("Version mismatch: expected 0x",
529 absl::Hex(session_->parameters_.version)));
530 return;
531 }
532 QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
533 if (session_->parameters_.perspective == Perspective::IS_SERVER) {
534 MoqtServerSetup response;
535 response.selected_version = session_->parameters_.version;
536 response.role = MoqtRole::kPubSub;
537 SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
538 QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
539 }
540 // TODO: handle role and path.
541 std::move(session_->callbacks_.session_established_callback)();
542 session_->peer_role_ = *message.role;
543 }
544
OnServerSetupMessage(const MoqtServerSetup & message)545 void MoqtSession::Stream::OnServerSetupMessage(const MoqtServerSetup& message) {
546 if (is_control_stream_.has_value()) {
547 if (!*is_control_stream_) {
548 session_->Error(MoqtError::kProtocolViolation,
549 "Received SETUP on non-control stream");
550 return;
551 }
552 } else {
553 is_control_stream_ = true;
554 }
555 if (perspective() == Perspective::IS_SERVER) {
556 session_->Error(MoqtError::kProtocolViolation,
557 "Received SERVER_SETUP from client");
558 return;
559 }
560 if (message.selected_version != session_->parameters_.version) {
561 // TODO(martinduke): Is this the right error code? See issue #346.
562 session_->Error(MoqtError::kProtocolViolation,
563 absl::StrCat("Version mismatch: expected 0x",
564 absl::Hex(session_->parameters_.version)));
565 return;
566 }
567 QUIC_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
568 // TODO: handle role and path.
569 std::move(session_->callbacks_.session_established_callback)();
570 session_->peer_role_ = *message.role;
571 }
572
SendSubscribeError(const MoqtSubscribe & message,SubscribeErrorCode error_code,absl::string_view reason_phrase,uint64_t track_alias)573 void MoqtSession::Stream::SendSubscribeError(const MoqtSubscribe& message,
574 SubscribeErrorCode error_code,
575 absl::string_view reason_phrase,
576 uint64_t track_alias) {
577 MoqtSubscribeError subscribe_error;
578 subscribe_error.subscribe_id = message.subscribe_id;
579 subscribe_error.error_code = error_code;
580 subscribe_error.reason_phrase = reason_phrase;
581 subscribe_error.track_alias = track_alias;
582 SendOrBufferMessage(
583 session_->framer_.SerializeSubscribeError(subscribe_error));
584 }
585
OnSubscribeMessage(const MoqtSubscribe & message)586 void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) {
587 std::string reason_phrase = "";
588 if (!CheckIfIsControlStream()) {
589 return;
590 }
591 if (session_->peer_role_ == MoqtRole::kPublisher) {
592 QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
593 session_->Error(MoqtError::kProtocolViolation,
594 "Received SUBSCRIBE from publisher");
595 return;
596 }
597 QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for "
598 << message.track_namespace << ":" << message.track_name;
599 auto it = session_->local_tracks_.find(FullTrackName(
600 std::string(message.track_namespace), std::string(message.track_name)));
601 if (it == session_->local_tracks_.end()) {
602 QUIC_DLOG(INFO) << ENDPOINT << "Rejected because "
603 << message.track_namespace << ":" << message.track_name
604 << " does not exist";
605 SendSubscribeError(message, SubscribeErrorCode::kInternalError,
606 "Track does not exist", message.track_alias);
607 return;
608 }
609 LocalTrack& track = it->second;
610 if ((track.track_alias().has_value() &&
611 message.track_alias != *track.track_alias()) ||
612 session_->used_track_aliases_.contains(message.track_alias)) {
613 // Propose a different track_alias.
614 SendSubscribeError(message, SubscribeErrorCode::kRetryTrackAlias,
615 "Track alias already exists",
616 session_->next_local_track_alias_++);
617 return;
618 } else { // Use client-provided alias.
619 track.set_track_alias(message.track_alias);
620 if (message.track_alias >= session_->next_local_track_alias_) {
621 session_->next_local_track_alias_ = message.track_alias + 1;
622 }
623 session_->used_track_aliases_.insert(message.track_alias);
624 }
625 std::optional<FullSequence> start = session_->LocationToAbsoluteNumber(
626 track, message.start_group, message.start_object);
627 QUICHE_DCHECK(start.has_value()); // Parser enforces this.
628 std::optional<FullSequence> end = session_->LocationToAbsoluteNumber(
629 track, message.end_group, message.end_object);
630 if (start < track.next_sequence() && track.visitor() != nullptr) {
631 // TODO: Rework this. It's not good that the session notifies the
632 // application -- presumably triggering the send of a bunch of objects --
633 // and only then sends the Subscribe OK.
634 SubscribeWindow window =
635 end.has_value()
636 ? SubscribeWindow(message.subscribe_id,
637 track.forwarding_preference(), start->group,
638 start->object, end->group, end->object)
639 : SubscribeWindow(message.subscribe_id,
640 track.forwarding_preference(), start->group,
641 start->object);
642 std::optional<absl::string_view> past_objects_available =
643 track.visitor()->OnSubscribeForPast(window);
644 if (past_objects_available.has_value()) {
645 SendSubscribeError(message, SubscribeErrorCode::kInternalError,
646 "Object does not exist", message.track_alias);
647 return;
648 }
649 }
650 MoqtSubscribeOk subscribe_ok;
651 subscribe_ok.subscribe_id = message.subscribe_id;
652 SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok));
653 QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
654 << message.track_namespace << ":" << message.track_name;
655 if (!end.has_value()) {
656 track.AddWindow(message.subscribe_id, start->group, start->object);
657 return;
658 }
659 track.AddWindow(message.subscribe_id, start->group, start->object, end->group,
660 end->object);
661 }
662
OnSubscribeOkMessage(const MoqtSubscribeOk & message)663 void MoqtSession::Stream::OnSubscribeOkMessage(const MoqtSubscribeOk& message) {
664 if (!CheckIfIsControlStream()) {
665 return;
666 }
667 auto it = session_->active_subscribes_.find(message.subscribe_id);
668 if (it == session_->active_subscribes_.end()) {
669 session_->Error(MoqtError::kProtocolViolation,
670 "Received SUBSCRIBE_OK for nonexistent subscribe");
671 return;
672 }
673 MoqtSubscribe& subscribe = it->second.message;
674 QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for "
675 << "subscribe_id = " << message.subscribe_id << " "
676 << subscribe.track_namespace << ":" << subscribe.track_name;
677 // Copy the Remote Track from session_->active_subscribes_ to
678 // session_->remote_tracks_.
679 FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
680 RemoteTrack::Visitor* visitor = it->second.visitor;
681 auto [track_iter, new_entry] = session_->remote_tracks_.try_emplace(
682 subscribe.track_alias, ftn, subscribe.track_alias, visitor);
683 if (it->second.forwarding_preference.has_value()) {
684 if (!track_iter->second.CheckForwardingPreference(
685 *it->second.forwarding_preference)) {
686 session_->Error(MoqtError::kProtocolViolation,
687 "Forwarding preference different in early objects");
688 return;
689 }
690 }
691 // TODO: handle expires.
692 if (visitor != nullptr) {
693 visitor->OnReply(ftn, std::nullopt);
694 }
695 session_->active_subscribes_.erase(it);
696 }
697
OnSubscribeErrorMessage(const MoqtSubscribeError & message)698 void MoqtSession::Stream::OnSubscribeErrorMessage(
699 const MoqtSubscribeError& message) {
700 if (!CheckIfIsControlStream()) {
701 return;
702 }
703 auto it = session_->active_subscribes_.find(message.subscribe_id);
704 if (it == session_->active_subscribes_.end()) {
705 session_->Error(MoqtError::kProtocolViolation,
706 "Received SUBSCRIBE_ERROR for nonexistent subscribe");
707 return;
708 }
709 if (it->second.received_object) {
710 session_->Error(MoqtError::kProtocolViolation,
711 "Received SUBSCRIBE_ERROR after object");
712 return;
713 }
714 MoqtSubscribe& subscribe = it->second.message;
715 QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for "
716 << "subscribe_id = " << message.subscribe_id << " ("
717 << subscribe.track_namespace << ":" << subscribe.track_name
718 << ")" << ", error = " << static_cast<int>(message.error_code)
719 << " (" << message.reason_phrase << ")";
720 RemoteTrack::Visitor* visitor = it->second.visitor;
721 FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
722 if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) {
723 // Automatically resubscribe with new alias.
724 session_->remote_track_aliases_[ftn] = message.track_alias;
725 session_->Subscribe(subscribe, visitor);
726 } else if (visitor != nullptr) {
727 visitor->OnReply(ftn, message.reason_phrase);
728 }
729 session_->active_subscribes_.erase(it);
730 }
731
OnUnsubscribeMessage(const MoqtUnsubscribe & message)732 void MoqtSession::Stream::OnUnsubscribeMessage(const MoqtUnsubscribe& message) {
733 // Search all the tracks to find the subscribe ID.
734 for (auto& [name, track] : session_->local_tracks_) {
735 track.DeleteWindow(message.subscribe_id);
736 }
737 // TODO(martinduke): Send SUBSCRIBE_DONE in response.
738 }
739
OnAnnounceMessage(const MoqtAnnounce & message)740 void MoqtSession::Stream::OnAnnounceMessage(const MoqtAnnounce& message) {
741 if (session_->peer_role_ == MoqtRole::kSubscriber) {
742 QUIC_DLOG(INFO) << ENDPOINT << "Subscriber peer sent SUBSCRIBE";
743 session_->Error(MoqtError::kProtocolViolation,
744 "Received ANNOUNCE from Subscriber");
745 return;
746 }
747 if (!CheckIfIsControlStream()) {
748 return;
749 }
750 std::optional<MoqtAnnounceErrorReason> error =
751 session_->callbacks_.incoming_announce_callback(message.track_namespace);
752 if (error.has_value()) {
753 MoqtAnnounceError reply;
754 reply.track_namespace = message.track_namespace;
755 reply.error_code = error->error_code;
756 reply.reason_phrase = error->reason_phrase;
757 SendOrBufferMessage(session_->framer_.SerializeAnnounceError(reply));
758 return;
759 }
760 MoqtAnnounceOk ok;
761 ok.track_namespace = message.track_namespace;
762 SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok));
763 }
764
OnAnnounceOkMessage(const MoqtAnnounceOk & message)765 void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
766 if (!CheckIfIsControlStream()) {
767 return;
768 }
769 auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
770 if (it == session_->pending_outgoing_announces_.end()) {
771 session_->Error(MoqtError::kProtocolViolation,
772 "Received ANNOUNCE_OK for nonexistent announce");
773 return;
774 }
775 std::move(it->second)(message.track_namespace, std::nullopt);
776 session_->pending_outgoing_announces_.erase(it);
777 }
778
OnAnnounceErrorMessage(const MoqtAnnounceError & message)779 void MoqtSession::Stream::OnAnnounceErrorMessage(
780 const MoqtAnnounceError& message) {
781 if (!CheckIfIsControlStream()) {
782 return;
783 }
784 auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
785 if (it == session_->pending_outgoing_announces_.end()) {
786 session_->Error(MoqtError::kProtocolViolation,
787 "Received ANNOUNCE_ERROR for nonexistent announce");
788 return;
789 }
790 std::move(it->second)(
791 message.track_namespace,
792 MoqtAnnounceErrorReason{message.error_code,
793 std::string(message.reason_phrase)});
794 session_->pending_outgoing_announces_.erase(it);
795 }
796
OnParsingError(MoqtError error_code,absl::string_view reason)797 void MoqtSession::Stream::OnParsingError(MoqtError error_code,
798 absl::string_view reason) {
799 session_->Error(error_code, absl::StrCat("Parse error: ", reason));
800 }
801
CheckIfIsControlStream()802 bool MoqtSession::Stream::CheckIfIsControlStream() {
803 if (!is_control_stream_.has_value()) {
804 session_->Error(MoqtError::kProtocolViolation,
805 "Received SUBSCRIBE_REQUEST as first message");
806 return false;
807 }
808 if (!*is_control_stream_) {
809 session_->Error(MoqtError::kProtocolViolation,
810 "Received SUBSCRIBE_REQUEST on non-control stream");
811 return false;
812 }
813 return true;
814 }
815
LocationToAbsoluteNumber(const LocalTrack & track,const std::optional<MoqtSubscribeLocation> & group,const std::optional<MoqtSubscribeLocation> & object)816 std::optional<FullSequence> MoqtSession::LocationToAbsoluteNumber(
817 const LocalTrack& track, const std::optional<MoqtSubscribeLocation>& group,
818 const std::optional<MoqtSubscribeLocation>& object) {
819 FullSequence sequence;
820 if (!group.has_value() || !object.has_value()) {
821 return std::nullopt;
822 }
823 if (group->absolute) {
824 sequence.group = group->absolute_value;
825 } else {
826 sequence.group = track.next_sequence().group + group->relative_value;
827 }
828 if (object->absolute) {
829 sequence.object = object->absolute_value;
830 } else {
831 // Subtract 1 because the relative value is computed from the largest sent
832 // sequence number, not the next one.
833 sequence.object = track.next_sequence().object + object->relative_value - 1;
834 }
835 return sequence;
836 }
837
SendOrBufferMessage(quiche::QuicheBuffer message,bool fin)838 void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message,
839 bool fin) {
840 quiche::StreamWriteOptions options;
841 options.set_send_fin(fin);
842 options.set_buffer_unconditionally(true);
843 std::array<absl::string_view, 1> write_vector = {message.AsStringView()};
844 absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options);
845 if (!success.ok()) {
846 session_->Error(MoqtError::kInternalError,
847 "Failed to write a control message");
848 }
849 }
850
851 } // namespace moqt
852