xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/moqt/moqt_session.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/moqt/moqt_session.h"
6 
7 #include <array>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/status/status.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "quiche/quic/core/quic_types.h"
22 #include "quiche/quic/moqt/moqt_messages.h"
23 #include "quiche/quic/moqt/moqt_parser.h"
24 #include "quiche/quic/moqt/moqt_subscribe_windows.h"
25 #include "quiche/quic/moqt/moqt_track.h"
26 #include "quiche/quic/platform/api/quic_bug_tracker.h"
27 #include "quiche/common/platform/api/quiche_logging.h"
28 #include "quiche/common/quiche_buffer_allocator.h"
29 #include "quiche/common/quiche_stream.h"
30 #include "quiche/web_transport/web_transport.h"
31 
32 #define ENDPOINT \
33   (perspective() == Perspective::IS_SERVER ? "MoQT Server: " : "MoQT Client: ")
34 
35 namespace moqt {
36 
37 using ::quic::Perspective;
38 
GetControlStream()39 MoqtSession::Stream* MoqtSession::GetControlStream() {
40   if (!control_stream_.has_value()) {
41     return nullptr;
42   }
43   webtransport::Stream* raw_stream = session_->GetStreamById(*control_stream_);
44   if (raw_stream == nullptr) {
45     return nullptr;
46   }
47   return static_cast<Stream*>(raw_stream->visitor());
48 }
49 
SendControlMessage(quiche::QuicheBuffer message)50 void MoqtSession::SendControlMessage(quiche::QuicheBuffer message) {
51   Stream* control_stream = GetControlStream();
52   if (control_stream == nullptr) {
53     QUICHE_LOG(DFATAL) << "Trying to send a message on the control stream "
54                           "while it does not exist";
55     return;
56   }
57   control_stream->SendOrBufferMessage(std::move(message));
58 }
59 
OnSessionReady()60 void MoqtSession::OnSessionReady() {
61   QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session ready";
62   if (parameters_.perspective == Perspective::IS_SERVER) {
63     return;
64   }
65 
66   webtransport::Stream* control_stream =
67       session_->OpenOutgoingBidirectionalStream();
68   if (control_stream == nullptr) {
69     Error(MoqtError::kInternalError, "Unable to open a control stream");
70     return;
71   }
72   control_stream->SetVisitor(std::make_unique<Stream>(
73       this, control_stream, /*is_control_stream=*/true));
74   control_stream_ = control_stream->GetStreamId();
75   MoqtClientSetup setup = MoqtClientSetup{
76       .supported_versions = std::vector<MoqtVersion>{parameters_.version},
77       .role = MoqtRole::kPubSub,
78   };
79   if (!parameters_.using_webtrans) {
80     setup.path = parameters_.path;
81   }
82   SendControlMessage(framer_.SerializeClientSetup(setup));
83   QUIC_DLOG(INFO) << ENDPOINT << "Send the SETUP message";
84 }
85 
OnSessionClosed(webtransport::SessionErrorCode,const std::string & error_message)86 void MoqtSession::OnSessionClosed(webtransport::SessionErrorCode,
87                                   const std::string& error_message) {
88   if (!error_.empty()) {
89     // Avoid erroring out twice.
90     return;
91   }
92   QUICHE_DLOG(INFO) << ENDPOINT << "Underlying session closed with message: "
93                     << error_message;
94   error_ = error_message;
95   std::move(callbacks_.session_terminated_callback)(error_message);
96 }
97 
OnIncomingBidirectionalStreamAvailable()98 void MoqtSession::OnIncomingBidirectionalStreamAvailable() {
99   while (webtransport::Stream* stream =
100              session_->AcceptIncomingBidirectionalStream()) {
101     if (control_stream_.has_value()) {
102       Error(MoqtError::kProtocolViolation, "Bidirectional stream already open");
103       return;
104     }
105     stream->SetVisitor(std::make_unique<Stream>(this, stream));
106     stream->visitor()->OnCanRead();
107   }
108 }
OnIncomingUnidirectionalStreamAvailable()109 void MoqtSession::OnIncomingUnidirectionalStreamAvailable() {
110   while (webtransport::Stream* stream =
111              session_->AcceptIncomingUnidirectionalStream()) {
112     stream->SetVisitor(std::make_unique<Stream>(this, stream));
113     stream->visitor()->OnCanRead();
114   }
115 }
116 
OnDatagramReceived(absl::string_view datagram)117 void MoqtSession::OnDatagramReceived(absl::string_view datagram) {
118   MoqtObject message;
119   absl::string_view payload = MoqtParser::ProcessDatagram(datagram, message);
120   if (payload.empty()) {
121     Error(MoqtError::kProtocolViolation, "Malformed datagram");
122     return;
123   }
124   QUICHE_DLOG(INFO) << ENDPOINT
125                     << "Received OBJECT message in datagram for subscribe_id "
126                     << message.subscribe_id << " for track alias "
127                     << message.track_alias << " with sequence "
128                     << message.group_id << ":" << message.object_id
129                     << " send_order " << message.object_send_order << " length "
130                     << payload.size();
131   auto [full_track_name, visitor] = TrackPropertiesFromAlias(message);
132   if (visitor != nullptr) {
133     visitor->OnObjectFragment(full_track_name, message.group_id,
134                               message.object_id, message.object_send_order,
135                               message.forwarding_preference, payload, true);
136   }
137 }
138 
Error(MoqtError code,absl::string_view error)139 void MoqtSession::Error(MoqtError code, absl::string_view error) {
140   if (!error_.empty()) {
141     // Avoid erroring out twice.
142     return;
143   }
144   QUICHE_DLOG(INFO) << ENDPOINT << "MOQT session closed with code: "
145                     << static_cast<int>(code) << " and message: " << error;
146   error_ = std::string(error);
147   session_->CloseSession(static_cast<uint64_t>(code), error);
148   std::move(callbacks_.session_terminated_callback)(error);
149 }
150 
AddLocalTrack(const FullTrackName & full_track_name,MoqtForwardingPreference forwarding_preference,LocalTrack::Visitor * visitor)151 void MoqtSession::AddLocalTrack(const FullTrackName& full_track_name,
152                                 MoqtForwardingPreference forwarding_preference,
153                                 LocalTrack::Visitor* visitor) {
154   local_tracks_.try_emplace(full_track_name, full_track_name,
155                             forwarding_preference, visitor);
156 }
157 
158 // TODO: Create state that allows ANNOUNCE_OK/ERROR on spurious namespaces to
159 // trigger session errors.
Announce(absl::string_view track_namespace,MoqtOutgoingAnnounceCallback announce_callback)160 void MoqtSession::Announce(absl::string_view track_namespace,
161                            MoqtOutgoingAnnounceCallback announce_callback) {
162   if (peer_role_ == MoqtRole::kPublisher) {
163     std::move(announce_callback)(
164         track_namespace,
165         MoqtAnnounceErrorReason{MoqtAnnounceErrorCode::kInternalError,
166                                 "ANNOUNCE cannot be sent to Publisher"});
167     return;
168   }
169   if (pending_outgoing_announces_.contains(track_namespace)) {
170     std::move(announce_callback)(
171         track_namespace,
172         MoqtAnnounceErrorReason{
173             MoqtAnnounceErrorCode::kInternalError,
174             "ANNOUNCE message already outstanding for namespace"});
175     return;
176   }
177   MoqtAnnounce message;
178   message.track_namespace = track_namespace;
179   SendControlMessage(framer_.SerializeAnnounce(message));
180   QUIC_DLOG(INFO) << ENDPOINT << "Sent ANNOUNCE message for "
181                   << message.track_namespace;
182   pending_outgoing_announces_[track_namespace] = std::move(announce_callback);
183 }
184 
HasSubscribers(const FullTrackName & full_track_name) const185 bool MoqtSession::HasSubscribers(const FullTrackName& full_track_name) const {
186   auto it = local_tracks_.find(full_track_name);
187   return (it != local_tracks_.end() && it->second.HasSubscriber());
188 }
189 
SubscribeAbsolute(absl::string_view track_namespace,absl::string_view name,uint64_t start_group,uint64_t start_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)190 bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
191                                     absl::string_view name,
192                                     uint64_t start_group, uint64_t start_object,
193                                     RemoteTrack::Visitor* visitor,
194                                     absl::string_view auth_info) {
195   MoqtSubscribe message;
196   message.track_namespace = track_namespace;
197   message.track_name = name;
198   message.start_group = MoqtSubscribeLocation(true, start_group);
199   message.start_object = MoqtSubscribeLocation(true, start_object);
200   message.end_group = std::nullopt;
201   message.end_object = std::nullopt;
202   if (!auth_info.empty()) {
203     message.authorization_info = std::move(auth_info);
204   }
205   return Subscribe(message, visitor);
206 }
207 
SubscribeAbsolute(absl::string_view track_namespace,absl::string_view name,uint64_t start_group,uint64_t start_object,uint64_t end_group,uint64_t end_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)208 bool MoqtSession::SubscribeAbsolute(absl::string_view track_namespace,
209                                     absl::string_view name,
210                                     uint64_t start_group, uint64_t start_object,
211                                     uint64_t end_group, uint64_t end_object,
212                                     RemoteTrack::Visitor* visitor,
213                                     absl::string_view auth_info) {
214   if (end_group < start_group) {
215     QUIC_DLOG(ERROR) << "Subscription end is before beginning";
216     return false;
217   }
218   if (end_group == start_group && end_object < start_object) {
219     QUIC_DLOG(ERROR) << "Subscription end is before beginning";
220     return false;
221   }
222   MoqtSubscribe message;
223   message.track_namespace = track_namespace;
224   message.track_name = name;
225   message.start_group = MoqtSubscribeLocation(true, start_group);
226   message.start_object = MoqtSubscribeLocation(true, start_object);
227   message.end_group = MoqtSubscribeLocation(true, end_group);
228   message.end_object = MoqtSubscribeLocation(true, end_object);
229   if (!auth_info.empty()) {
230     message.authorization_info = std::move(auth_info);
231   }
232   return Subscribe(message, visitor);
233 }
234 
SubscribeRelative(absl::string_view track_namespace,absl::string_view name,int64_t start_group,int64_t start_object,RemoteTrack::Visitor * visitor,absl::string_view auth_info)235 bool MoqtSession::SubscribeRelative(absl::string_view track_namespace,
236                                     absl::string_view name, int64_t start_group,
237                                     int64_t start_object,
238                                     RemoteTrack::Visitor* visitor,
239                                     absl::string_view auth_info) {
240   MoqtSubscribe message;
241   message.track_namespace = track_namespace;
242   message.track_name = name;
243   message.start_group = MoqtSubscribeLocation(false, start_group);
244   message.start_object = MoqtSubscribeLocation(false, start_object);
245   message.end_group = std::nullopt;
246   message.end_object = std::nullopt;
247   if (!auth_info.empty()) {
248     message.authorization_info = std::move(auth_info);
249   }
250   return Subscribe(message, visitor);
251 }
252 
SubscribeCurrentGroup(absl::string_view track_namespace,absl::string_view name,RemoteTrack::Visitor * visitor,absl::string_view auth_info)253 bool MoqtSession::SubscribeCurrentGroup(absl::string_view track_namespace,
254                                         absl::string_view name,
255                                         RemoteTrack::Visitor* visitor,
256                                         absl::string_view auth_info) {
257   MoqtSubscribe message;
258   message.track_namespace = track_namespace;
259   message.track_name = name;
260   // First object of current group.
261   message.start_group = MoqtSubscribeLocation(false, (uint64_t)0);
262   message.start_object = MoqtSubscribeLocation(true, (int64_t)0);
263   message.end_group = std::nullopt;
264   message.end_object = std::nullopt;
265   if (!auth_info.empty()) {
266     message.authorization_info = std::move(auth_info);
267   }
268   return Subscribe(message, visitor);
269 }
270 
Subscribe(MoqtSubscribe & message,RemoteTrack::Visitor * visitor)271 bool MoqtSession::Subscribe(MoqtSubscribe& message,
272                             RemoteTrack::Visitor* visitor) {
273   if (peer_role_ == MoqtRole::kSubscriber) {
274     QUIC_DLOG(INFO) << ENDPOINT << "Tried to send SUBSCRIBE to subscriber peer";
275     return false;
276   }
277   // TODO(martinduke): support authorization info
278   message.subscribe_id = next_subscribe_id_++;
279   FullTrackName ftn(std::string(message.track_namespace),
280                     std::string(message.track_name));
281   auto it = remote_track_aliases_.find(ftn);
282   if (it != remote_track_aliases_.end()) {
283     message.track_alias = it->second;
284     if (message.track_alias >= next_remote_track_alias_) {
285       next_remote_track_alias_ = message.track_alias + 1;
286     }
287   } else {
288     message.track_alias = next_remote_track_alias_++;
289   }
290   SendControlMessage(framer_.SerializeSubscribe(message));
291   QUIC_DLOG(INFO) << ENDPOINT << "Sent SUBSCRIBE message for "
292                   << message.track_namespace << ":" << message.track_name;
293   active_subscribes_.try_emplace(message.subscribe_id, message, visitor);
294   return true;
295 }
296 
OpenUnidirectionalStream()297 std::optional<webtransport::StreamId> MoqtSession::OpenUnidirectionalStream() {
298   if (!session_->CanOpenNextOutgoingUnidirectionalStream()) {
299     return std::nullopt;
300   }
301   webtransport::Stream* new_stream =
302       session_->OpenOutgoingUnidirectionalStream();
303   if (new_stream == nullptr) {
304     return std::nullopt;
305   }
306   new_stream->SetVisitor(std::make_unique<Stream>(this, new_stream, false));
307   return new_stream->GetStreamId();
308 }
309 
310 std::pair<FullTrackName, RemoteTrack::Visitor*>
TrackPropertiesFromAlias(const MoqtObject & message)311 MoqtSession::TrackPropertiesFromAlias(const MoqtObject& message) {
312   auto it = remote_tracks_.find(message.track_alias);
313   RemoteTrack::Visitor* visitor = nullptr;
314   if (it == remote_tracks_.end()) {
315     // SUBSCRIBE_OK has not arrived yet, but deliver it.
316     auto subscribe_it = active_subscribes_.find(message.subscribe_id);
317     if (subscribe_it == active_subscribes_.end()) {
318       return std::pair<FullTrackName, RemoteTrack::Visitor*>(
319           {{"", ""}, nullptr});
320     }
321     ActiveSubscribe& subscribe = subscribe_it->second;
322     visitor = subscribe.visitor;
323     subscribe.received_object = true;
324     if (subscribe.forwarding_preference.has_value()) {
325       if (message.forwarding_preference != *subscribe.forwarding_preference) {
326         Error(MoqtError::kProtocolViolation,
327               "Forwarding preference changes mid-track");
328         return std::pair<FullTrackName, RemoteTrack::Visitor*>(
329             {{"", ""}, nullptr});
330       }
331     } else {
332       subscribe.forwarding_preference = message.forwarding_preference;
333     }
334     return std::pair<FullTrackName, RemoteTrack::Visitor*>(
335         {{subscribe.message.track_namespace, subscribe.message.track_name},
336          subscribe.visitor});
337   }
338   RemoteTrack& track = it->second;
339   if (!track.CheckForwardingPreference(message.forwarding_preference)) {
340     // Incorrect forwarding preference.
341     Error(MoqtError::kProtocolViolation,
342           "Forwarding preference changes mid-track");
343     return std::pair<FullTrackName, RemoteTrack::Visitor*>({{"", ""}, nullptr});
344   }
345   return std::pair<FullTrackName, RemoteTrack::Visitor*>(
346       {{track.full_track_name().track_namespace,
347         track.full_track_name().track_name},
348        track.visitor()});
349 }
350 
PublishObject(const FullTrackName & full_track_name,uint64_t group_id,uint64_t object_id,uint64_t object_send_order,absl::string_view payload,bool end_of_stream)351 bool MoqtSession::PublishObject(const FullTrackName& full_track_name,
352                                 uint64_t group_id, uint64_t object_id,
353                                 uint64_t object_send_order,
354                                 absl::string_view payload, bool end_of_stream) {
355   auto track_it = local_tracks_.find(full_track_name);
356   if (track_it == local_tracks_.end()) {
357     QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT for nonexistent track";
358     return false;
359   }
360   LocalTrack& track = track_it->second;
361   MoqtForwardingPreference forwarding_preference =
362       track.forwarding_preference();
363   if ((forwarding_preference == MoqtForwardingPreference::kObject ||
364        forwarding_preference == MoqtForwardingPreference::kDatagram) &&
365       !end_of_stream) {
366     QUIC_BUG(MoqtSession_PublishObject_end_of_stream_required)
367         << "Forwarding preferences of Object or Datagram require stream to be "
368            "immediately closed";
369     return false;
370   }
371   track.SentSequence(FullSequence(group_id, object_id));
372   std::vector<SubscribeWindow*> subscriptions =
373       track.ShouldSend({group_id, object_id});
374   if (subscriptions.empty()) {
375     return true;
376   }
377   MoqtObject object;
378   QUICHE_DCHECK(track.track_alias().has_value());
379   object.track_alias = *track.track_alias();
380   object.group_id = group_id;
381   object.object_id = object_id;
382   object.object_send_order = object_send_order;
383   object.forwarding_preference = forwarding_preference;
384   object.payload_length = payload.size();
385   int failures = 0;
386   quiche::StreamWriteOptions write_options;
387   write_options.set_send_fin(end_of_stream);
388   for (auto subscription : subscriptions) {
389     if (forwarding_preference == MoqtForwardingPreference::kDatagram) {
390       object.subscribe_id = subscription->subscribe_id();
391       quiche::QuicheBuffer datagram =
392           framer_.SerializeObjectDatagram(object, payload);
393       // TODO(martinduke): It's OK to just silently fail, but better to notify
394       // the app on errors.
395       session_->SendOrQueueDatagram(datagram.AsStringView());
396       continue;
397     }
398     bool new_stream = false;
399     std::optional<webtransport::StreamId> stream_id =
400         subscription->GetStreamForSequence(FullSequence(group_id, object_id));
401     if (!stream_id.has_value()) {
402       new_stream = true;
403       stream_id = OpenUnidirectionalStream();
404       if (!stream_id.has_value()) {
405         QUICHE_DLOG(ERROR) << ENDPOINT
406                            << "Sending OBJECT to nonexistent stream";
407         ++failures;
408         continue;
409       }
410       if (!end_of_stream) {
411         subscription->AddStream(group_id, object_id, *stream_id);
412       }
413     }
414     webtransport::Stream* stream = session_->GetStreamById(*stream_id);
415     if (stream == nullptr) {
416       QUICHE_DLOG(ERROR) << ENDPOINT << "Sending OBJECT to nonexistent stream "
417                          << *stream_id;
418       ++failures;
419       continue;
420     }
421     object.subscribe_id = subscription->subscribe_id();
422     quiche::QuicheBuffer header =
423         framer_.SerializeObjectHeader(object, new_stream);
424     std::array<absl::string_view, 2> views = {header.AsStringView(), payload};
425     if (!stream->Writev(views, write_options).ok()) {
426       QUICHE_DLOG(ERROR) << ENDPOINT << "Failed to write OBJECT message";
427       ++failures;
428       continue;
429     }
430     QUICHE_LOG(INFO) << ENDPOINT << "Sending object length " << payload.length()
431                      << " for " << full_track_name.track_namespace << ":"
432                      << full_track_name.track_name << " with sequence "
433                      << object.group_id << ":" << object.object_id
434                      << " on stream " << *stream_id;
435     if (end_of_stream && !new_stream) {
436       subscription->RemoveStream(group_id, object_id);
437     }
438   }
439   return (failures == 0);
440 }
441 
OnCanRead()442 void MoqtSession::Stream::OnCanRead() {
443   bool fin =
444       quiche::ProcessAllReadableRegions(*stream_, [&](absl::string_view chunk) {
445         parser_.ProcessData(chunk, /*end_of_stream=*/false);
446       });
447   if (fin) {
448     parser_.ProcessData("", /*end_of_stream=*/true);
449   }
450 }
OnCanWrite()451 void MoqtSession::Stream::OnCanWrite() {}
OnResetStreamReceived(webtransport::StreamErrorCode error)452 void MoqtSession::Stream::OnResetStreamReceived(
453     webtransport::StreamErrorCode error) {
454   if (is_control_stream_.has_value() && *is_control_stream_) {
455     session_->Error(
456         MoqtError::kProtocolViolation,
457         absl::StrCat("Control stream reset with error code ", error));
458   }
459 }
OnStopSendingReceived(webtransport::StreamErrorCode error)460 void MoqtSession::Stream::OnStopSendingReceived(
461     webtransport::StreamErrorCode error) {
462   if (is_control_stream_.has_value() && *is_control_stream_) {
463     session_->Error(
464         MoqtError::kProtocolViolation,
465         absl::StrCat("Control stream reset with error code ", error));
466   }
467 }
468 
OnObjectMessage(const MoqtObject & message,absl::string_view payload,bool end_of_message)469 void MoqtSession::Stream::OnObjectMessage(const MoqtObject& message,
470                                           absl::string_view payload,
471                                           bool end_of_message) {
472   if (is_control_stream_ == true) {
473     session_->Error(MoqtError::kProtocolViolation,
474                     "Received OBJECT message on control stream");
475     return;
476   }
477   QUICHE_DLOG(INFO)
478       << ENDPOINT << "Received OBJECT message on stream "
479       << stream_->GetStreamId() << " for subscribe_id " << message.subscribe_id
480       << " for track alias " << message.track_alias << " with sequence "
481       << message.group_id << ":" << message.object_id << " send_order "
482       << message.object_send_order << " forwarding_preference "
483       << MoqtForwardingPreferenceToString(message.forwarding_preference)
484       << " length " << payload.size() << " explicit length "
485       << (message.payload_length.has_value() ? (int)*message.payload_length
486                                              : -1)
487       << (end_of_message ? "F" : "");
488   if (!session_->parameters_.deliver_partial_objects) {
489     if (!end_of_message) {  // Buffer partial object.
490       absl::StrAppend(&partial_object_, payload);
491       return;
492     }
493     if (!partial_object_.empty()) {  // Completes the object
494       absl::StrAppend(&partial_object_, payload);
495       payload = absl::string_view(partial_object_);
496     }
497   }
498   auto [full_track_name, visitor] = session_->TrackPropertiesFromAlias(message);
499   if (visitor != nullptr) {
500     visitor->OnObjectFragment(full_track_name, message.group_id,
501                               message.object_id, message.object_send_order,
502                               message.forwarding_preference, payload,
503                               end_of_message);
504   }
505   partial_object_.clear();
506 }
507 
OnClientSetupMessage(const MoqtClientSetup & message)508 void MoqtSession::Stream::OnClientSetupMessage(const MoqtClientSetup& message) {
509   if (is_control_stream_.has_value()) {
510     if (!*is_control_stream_) {
511       session_->Error(MoqtError::kProtocolViolation,
512                       "Received SETUP on non-control stream");
513       return;
514     }
515   } else {
516     is_control_stream_ = true;
517   }
518   session_->control_stream_ = stream_->GetStreamId();
519   if (perspective() == Perspective::IS_CLIENT) {
520     session_->Error(MoqtError::kProtocolViolation,
521                     "Received CLIENT_SETUP from server");
522     return;
523   }
524   if (absl::c_find(message.supported_versions, session_->parameters_.version) ==
525       message.supported_versions.end()) {
526     // TODO(martinduke): Is this the right error code? See issue #346.
527     session_->Error(MoqtError::kProtocolViolation,
528                     absl::StrCat("Version mismatch: expected 0x",
529                                  absl::Hex(session_->parameters_.version)));
530     return;
531   }
532   QUICHE_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
533   if (session_->parameters_.perspective == Perspective::IS_SERVER) {
534     MoqtServerSetup response;
535     response.selected_version = session_->parameters_.version;
536     response.role = MoqtRole::kPubSub;
537     SendOrBufferMessage(session_->framer_.SerializeServerSetup(response));
538     QUIC_DLOG(INFO) << ENDPOINT << "Sent the SETUP message";
539   }
540   // TODO: handle role and path.
541   std::move(session_->callbacks_.session_established_callback)();
542   session_->peer_role_ = *message.role;
543 }
544 
OnServerSetupMessage(const MoqtServerSetup & message)545 void MoqtSession::Stream::OnServerSetupMessage(const MoqtServerSetup& message) {
546   if (is_control_stream_.has_value()) {
547     if (!*is_control_stream_) {
548       session_->Error(MoqtError::kProtocolViolation,
549                       "Received SETUP on non-control stream");
550       return;
551     }
552   } else {
553     is_control_stream_ = true;
554   }
555   if (perspective() == Perspective::IS_SERVER) {
556     session_->Error(MoqtError::kProtocolViolation,
557                     "Received SERVER_SETUP from client");
558     return;
559   }
560   if (message.selected_version != session_->parameters_.version) {
561     // TODO(martinduke): Is this the right error code? See issue #346.
562     session_->Error(MoqtError::kProtocolViolation,
563                     absl::StrCat("Version mismatch: expected 0x",
564                                  absl::Hex(session_->parameters_.version)));
565     return;
566   }
567   QUIC_DLOG(INFO) << ENDPOINT << "Received the SETUP message";
568   // TODO: handle role and path.
569   std::move(session_->callbacks_.session_established_callback)();
570   session_->peer_role_ = *message.role;
571 }
572 
SendSubscribeError(const MoqtSubscribe & message,SubscribeErrorCode error_code,absl::string_view reason_phrase,uint64_t track_alias)573 void MoqtSession::Stream::SendSubscribeError(const MoqtSubscribe& message,
574                                              SubscribeErrorCode error_code,
575                                              absl::string_view reason_phrase,
576                                              uint64_t track_alias) {
577   MoqtSubscribeError subscribe_error;
578   subscribe_error.subscribe_id = message.subscribe_id;
579   subscribe_error.error_code = error_code;
580   subscribe_error.reason_phrase = reason_phrase;
581   subscribe_error.track_alias = track_alias;
582   SendOrBufferMessage(
583       session_->framer_.SerializeSubscribeError(subscribe_error));
584 }
585 
OnSubscribeMessage(const MoqtSubscribe & message)586 void MoqtSession::Stream::OnSubscribeMessage(const MoqtSubscribe& message) {
587   std::string reason_phrase = "";
588   if (!CheckIfIsControlStream()) {
589     return;
590   }
591   if (session_->peer_role_ == MoqtRole::kPublisher) {
592     QUIC_DLOG(INFO) << ENDPOINT << "Publisher peer sent SUBSCRIBE";
593     session_->Error(MoqtError::kProtocolViolation,
594                     "Received SUBSCRIBE from publisher");
595     return;
596   }
597   QUIC_DLOG(INFO) << ENDPOINT << "Received a SUBSCRIBE for "
598                   << message.track_namespace << ":" << message.track_name;
599   auto it = session_->local_tracks_.find(FullTrackName(
600       std::string(message.track_namespace), std::string(message.track_name)));
601   if (it == session_->local_tracks_.end()) {
602     QUIC_DLOG(INFO) << ENDPOINT << "Rejected because "
603                     << message.track_namespace << ":" << message.track_name
604                     << " does not exist";
605     SendSubscribeError(message, SubscribeErrorCode::kInternalError,
606                        "Track does not exist", message.track_alias);
607     return;
608   }
609   LocalTrack& track = it->second;
610   if ((track.track_alias().has_value() &&
611        message.track_alias != *track.track_alias()) ||
612       session_->used_track_aliases_.contains(message.track_alias)) {
613     // Propose a different track_alias.
614     SendSubscribeError(message, SubscribeErrorCode::kRetryTrackAlias,
615                        "Track alias already exists",
616                        session_->next_local_track_alias_++);
617     return;
618   } else {  // Use client-provided alias.
619     track.set_track_alias(message.track_alias);
620     if (message.track_alias >= session_->next_local_track_alias_) {
621       session_->next_local_track_alias_ = message.track_alias + 1;
622     }
623     session_->used_track_aliases_.insert(message.track_alias);
624   }
625   std::optional<FullSequence> start = session_->LocationToAbsoluteNumber(
626       track, message.start_group, message.start_object);
627   QUICHE_DCHECK(start.has_value());  // Parser enforces this.
628   std::optional<FullSequence> end = session_->LocationToAbsoluteNumber(
629       track, message.end_group, message.end_object);
630   if (start < track.next_sequence() && track.visitor() != nullptr) {
631     // TODO: Rework this. It's not good that the session notifies the
632     // application -- presumably triggering the send of a bunch of objects --
633     // and only then sends the Subscribe OK.
634     SubscribeWindow window =
635         end.has_value()
636             ? SubscribeWindow(message.subscribe_id,
637                               track.forwarding_preference(), start->group,
638                               start->object, end->group, end->object)
639             : SubscribeWindow(message.subscribe_id,
640                               track.forwarding_preference(), start->group,
641                               start->object);
642     std::optional<absl::string_view> past_objects_available =
643         track.visitor()->OnSubscribeForPast(window);
644     if (past_objects_available.has_value()) {
645       SendSubscribeError(message, SubscribeErrorCode::kInternalError,
646                          "Object does not exist", message.track_alias);
647       return;
648     }
649   }
650   MoqtSubscribeOk subscribe_ok;
651   subscribe_ok.subscribe_id = message.subscribe_id;
652   SendOrBufferMessage(session_->framer_.SerializeSubscribeOk(subscribe_ok));
653   QUIC_DLOG(INFO) << ENDPOINT << "Created subscription for "
654                   << message.track_namespace << ":" << message.track_name;
655   if (!end.has_value()) {
656     track.AddWindow(message.subscribe_id, start->group, start->object);
657     return;
658   }
659   track.AddWindow(message.subscribe_id, start->group, start->object, end->group,
660                   end->object);
661 }
662 
OnSubscribeOkMessage(const MoqtSubscribeOk & message)663 void MoqtSession::Stream::OnSubscribeOkMessage(const MoqtSubscribeOk& message) {
664   if (!CheckIfIsControlStream()) {
665     return;
666   }
667   auto it = session_->active_subscribes_.find(message.subscribe_id);
668   if (it == session_->active_subscribes_.end()) {
669     session_->Error(MoqtError::kProtocolViolation,
670                     "Received SUBSCRIBE_OK for nonexistent subscribe");
671     return;
672   }
673   MoqtSubscribe& subscribe = it->second.message;
674   QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_OK for "
675                   << "subscribe_id = " << message.subscribe_id << " "
676                   << subscribe.track_namespace << ":" << subscribe.track_name;
677   // Copy the Remote Track from session_->active_subscribes_ to
678   // session_->remote_tracks_.
679   FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
680   RemoteTrack::Visitor* visitor = it->second.visitor;
681   auto [track_iter, new_entry] = session_->remote_tracks_.try_emplace(
682       subscribe.track_alias, ftn, subscribe.track_alias, visitor);
683   if (it->second.forwarding_preference.has_value()) {
684     if (!track_iter->second.CheckForwardingPreference(
685             *it->second.forwarding_preference)) {
686       session_->Error(MoqtError::kProtocolViolation,
687                       "Forwarding preference different in early objects");
688       return;
689     }
690   }
691   // TODO: handle expires.
692   if (visitor != nullptr) {
693     visitor->OnReply(ftn, std::nullopt);
694   }
695   session_->active_subscribes_.erase(it);
696 }
697 
OnSubscribeErrorMessage(const MoqtSubscribeError & message)698 void MoqtSession::Stream::OnSubscribeErrorMessage(
699     const MoqtSubscribeError& message) {
700   if (!CheckIfIsControlStream()) {
701     return;
702   }
703   auto it = session_->active_subscribes_.find(message.subscribe_id);
704   if (it == session_->active_subscribes_.end()) {
705     session_->Error(MoqtError::kProtocolViolation,
706                     "Received SUBSCRIBE_ERROR for nonexistent subscribe");
707     return;
708   }
709   if (it->second.received_object) {
710     session_->Error(MoqtError::kProtocolViolation,
711                     "Received SUBSCRIBE_ERROR after object");
712     return;
713   }
714   MoqtSubscribe& subscribe = it->second.message;
715   QUIC_DLOG(INFO) << ENDPOINT << "Received the SUBSCRIBE_ERROR for "
716                   << "subscribe_id = " << message.subscribe_id << " ("
717                   << subscribe.track_namespace << ":" << subscribe.track_name
718                   << ")" << ", error = " << static_cast<int>(message.error_code)
719                   << " (" << message.reason_phrase << ")";
720   RemoteTrack::Visitor* visitor = it->second.visitor;
721   FullTrackName ftn(subscribe.track_namespace, subscribe.track_name);
722   if (message.error_code == SubscribeErrorCode::kRetryTrackAlias) {
723     // Automatically resubscribe with new alias.
724     session_->remote_track_aliases_[ftn] = message.track_alias;
725     session_->Subscribe(subscribe, visitor);
726   } else if (visitor != nullptr) {
727     visitor->OnReply(ftn, message.reason_phrase);
728   }
729   session_->active_subscribes_.erase(it);
730 }
731 
OnUnsubscribeMessage(const MoqtUnsubscribe & message)732 void MoqtSession::Stream::OnUnsubscribeMessage(const MoqtUnsubscribe& message) {
733   // Search all the tracks to find the subscribe ID.
734   for (auto& [name, track] : session_->local_tracks_) {
735     track.DeleteWindow(message.subscribe_id);
736   }
737   // TODO(martinduke): Send SUBSCRIBE_DONE in response.
738 }
739 
OnAnnounceMessage(const MoqtAnnounce & message)740 void MoqtSession::Stream::OnAnnounceMessage(const MoqtAnnounce& message) {
741   if (session_->peer_role_ == MoqtRole::kSubscriber) {
742     QUIC_DLOG(INFO) << ENDPOINT << "Subscriber peer sent SUBSCRIBE";
743     session_->Error(MoqtError::kProtocolViolation,
744                     "Received ANNOUNCE from Subscriber");
745     return;
746   }
747   if (!CheckIfIsControlStream()) {
748     return;
749   }
750   std::optional<MoqtAnnounceErrorReason> error =
751       session_->callbacks_.incoming_announce_callback(message.track_namespace);
752   if (error.has_value()) {
753     MoqtAnnounceError reply;
754     reply.track_namespace = message.track_namespace;
755     reply.error_code = error->error_code;
756     reply.reason_phrase = error->reason_phrase;
757     SendOrBufferMessage(session_->framer_.SerializeAnnounceError(reply));
758     return;
759   }
760   MoqtAnnounceOk ok;
761   ok.track_namespace = message.track_namespace;
762   SendOrBufferMessage(session_->framer_.SerializeAnnounceOk(ok));
763 }
764 
OnAnnounceOkMessage(const MoqtAnnounceOk & message)765 void MoqtSession::Stream::OnAnnounceOkMessage(const MoqtAnnounceOk& message) {
766   if (!CheckIfIsControlStream()) {
767     return;
768   }
769   auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
770   if (it == session_->pending_outgoing_announces_.end()) {
771     session_->Error(MoqtError::kProtocolViolation,
772                     "Received ANNOUNCE_OK for nonexistent announce");
773     return;
774   }
775   std::move(it->second)(message.track_namespace, std::nullopt);
776   session_->pending_outgoing_announces_.erase(it);
777 }
778 
OnAnnounceErrorMessage(const MoqtAnnounceError & message)779 void MoqtSession::Stream::OnAnnounceErrorMessage(
780     const MoqtAnnounceError& message) {
781   if (!CheckIfIsControlStream()) {
782     return;
783   }
784   auto it = session_->pending_outgoing_announces_.find(message.track_namespace);
785   if (it == session_->pending_outgoing_announces_.end()) {
786     session_->Error(MoqtError::kProtocolViolation,
787                     "Received ANNOUNCE_ERROR for nonexistent announce");
788     return;
789   }
790   std::move(it->second)(
791       message.track_namespace,
792       MoqtAnnounceErrorReason{message.error_code,
793                               std::string(message.reason_phrase)});
794   session_->pending_outgoing_announces_.erase(it);
795 }
796 
OnParsingError(MoqtError error_code,absl::string_view reason)797 void MoqtSession::Stream::OnParsingError(MoqtError error_code,
798                                          absl::string_view reason) {
799   session_->Error(error_code, absl::StrCat("Parse error: ", reason));
800 }
801 
CheckIfIsControlStream()802 bool MoqtSession::Stream::CheckIfIsControlStream() {
803   if (!is_control_stream_.has_value()) {
804     session_->Error(MoqtError::kProtocolViolation,
805                     "Received SUBSCRIBE_REQUEST as first message");
806     return false;
807   }
808   if (!*is_control_stream_) {
809     session_->Error(MoqtError::kProtocolViolation,
810                     "Received SUBSCRIBE_REQUEST on non-control stream");
811     return false;
812   }
813   return true;
814 }
815 
LocationToAbsoluteNumber(const LocalTrack & track,const std::optional<MoqtSubscribeLocation> & group,const std::optional<MoqtSubscribeLocation> & object)816 std::optional<FullSequence> MoqtSession::LocationToAbsoluteNumber(
817     const LocalTrack& track, const std::optional<MoqtSubscribeLocation>& group,
818     const std::optional<MoqtSubscribeLocation>& object) {
819   FullSequence sequence;
820   if (!group.has_value() || !object.has_value()) {
821     return std::nullopt;
822   }
823   if (group->absolute) {
824     sequence.group = group->absolute_value;
825   } else {
826     sequence.group = track.next_sequence().group + group->relative_value;
827   }
828   if (object->absolute) {
829     sequence.object = object->absolute_value;
830   } else {
831     // Subtract 1 because the relative value is computed from the largest sent
832     // sequence number, not the next one.
833     sequence.object = track.next_sequence().object + object->relative_value - 1;
834   }
835   return sequence;
836 }
837 
SendOrBufferMessage(quiche::QuicheBuffer message,bool fin)838 void MoqtSession::Stream::SendOrBufferMessage(quiche::QuicheBuffer message,
839                                               bool fin) {
840   quiche::StreamWriteOptions options;
841   options.set_send_fin(fin);
842   options.set_buffer_unconditionally(true);
843   std::array<absl::string_view, 1> write_vector = {message.AsStringView()};
844   absl::Status success = stream_->Writev(absl::MakeSpan(write_vector), options);
845   if (!success.ok()) {
846     session_->Error(MoqtError::kInternalError,
847                     "Failed to write a control message");
848   }
849 }
850 
851 }  // namespace moqt
852