xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/moqt/moqt_session_test.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/moqt/moqt_session.h"
6 
7 #include <cstdint>
8 #include <cstring>
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <utility>
13 
14 #include "absl/status/status.h"
15 #include "absl/strings/string_view.h"
16 #include "absl/types/span.h"
17 #include "quiche/quic/core/quic_data_reader.h"
18 #include "quiche/quic/core/quic_time.h"
19 #include "quiche/quic/core/quic_types.h"
20 #include "quiche/quic/moqt/moqt_messages.h"
21 #include "quiche/quic/moqt/moqt_parser.h"
22 #include "quiche/quic/moqt/moqt_track.h"
23 #include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
24 #include "quiche/quic/platform/api/quic_test.h"
25 #include "quiche/common/quiche_stream.h"
26 #include "quiche/web_transport/test_tools/mock_web_transport.h"
27 #include "quiche/web_transport/web_transport.h"
28 
29 namespace moqt {
30 
31 namespace test {
32 
33 namespace {
34 
35 using ::testing::_;
36 using ::testing::AnyNumber;
37 using ::testing::Return;
38 using ::testing::StrictMock;
39 
40 constexpr webtransport::StreamId kControlStreamId = 4;
41 constexpr webtransport::StreamId kIncomingUniStreamId = 15;
42 constexpr webtransport::StreamId kOutgoingUniStreamId = 14;
43 
44 constexpr MoqtSessionParameters default_parameters = {
45     /*version=*/MoqtVersion::kDraft03,
46     /*perspective=*/quic::Perspective::IS_CLIENT,
47     /*using_webtrans=*/true,
48     /*path=*/std::string(),
49     /*deliver_partial_objects=*/false,
50 };
51 
52 // Returns nullopt if there is not enough in |message| to extract a type
ExtractMessageType(const absl::string_view message)53 static std::optional<MoqtMessageType> ExtractMessageType(
54     const absl::string_view message) {
55   quic::QuicDataReader reader(message);
56   uint64_t value;
57   if (!reader.ReadVarInt62(&value)) {
58     return std::nullopt;
59   }
60   return static_cast<MoqtMessageType>(value);
61 }
62 
63 }  // namespace
64 
65 class MoqtSessionPeer {
66  public:
CreateControlStream(MoqtSession * session,webtransport::test::MockStream * stream)67   static std::unique_ptr<MoqtParserVisitor> CreateControlStream(
68       MoqtSession* session, webtransport::test::MockStream* stream) {
69     auto new_stream = std::make_unique<MoqtSession::Stream>(
70         session, stream, /*is_control_stream=*/true);
71     session->control_stream_ = kControlStreamId;
72     EXPECT_CALL(*stream, visitor())
73         .Times(AnyNumber())
74         .WillRepeatedly(Return(new_stream.get()));
75     return new_stream;
76   }
77 
CreateUniStream(MoqtSession * session,webtransport::Stream * stream)78   static std::unique_ptr<MoqtParserVisitor> CreateUniStream(
79       MoqtSession* session, webtransport::Stream* stream) {
80     auto new_stream = std::make_unique<MoqtSession::Stream>(
81         session, stream, /*is_control_stream=*/false);
82     return new_stream;
83   }
84 
85   // In the test OnSessionReady, the session creates a stream and then passes
86   // its unique_ptr to the mock webtransport stream. This function casts
87   // that unique_ptr into a MoqtSession::Stream*, which is a private class of
88   // MoqtSession, and then casts again into MoqtParserVisitor so that the test
89   // can inject packets into that stream.
90   // This function is useful for any test that wants to inject packets on a
91   // stream created by the MoqtSession.
FetchParserVisitorFromWebtransportStreamVisitor(MoqtSession * session,webtransport::StreamVisitor * visitor)92   static MoqtParserVisitor* FetchParserVisitorFromWebtransportStreamVisitor(
93       MoqtSession* session, webtransport::StreamVisitor* visitor) {
94     return (MoqtSession::Stream*)visitor;
95   }
96 
CreateRemoteTrack(MoqtSession * session,const FullTrackName & name,RemoteTrack::Visitor * visitor,uint64_t track_alias)97   static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
98                                 RemoteTrack::Visitor* visitor,
99                                 uint64_t track_alias) {
100     session->remote_tracks_.try_emplace(track_alias, name, track_alias,
101                                         visitor);
102     session->remote_track_aliases_.try_emplace(name, track_alias);
103   }
104 
AddActiveSubscribe(MoqtSession * session,uint64_t subscribe_id,MoqtSubscribe & subscribe,RemoteTrack::Visitor * visitor)105   static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
106                                  MoqtSubscribe& subscribe,
107                                  RemoteTrack::Visitor* visitor) {
108     session->active_subscribes_[subscribe_id] = {subscribe, visitor};
109   }
110 
AddSubscription(MoqtSession * session,FullTrackName & name,uint64_t subscribe_id,uint64_t track_alias,uint64_t start_group,uint64_t start_object)111   static void AddSubscription(MoqtSession* session, FullTrackName& name,
112                               uint64_t subscribe_id, uint64_t track_alias,
113                               uint64_t start_group, uint64_t start_object) {
114     auto it = session->local_tracks_.find(name);
115     ASSERT_NE(it, session->local_tracks_.end());
116     LocalTrack& track = it->second;
117     track.set_track_alias(track_alias);
118     track.AddWindow(subscribe_id, start_group, start_object);
119     session->used_track_aliases_.emplace(track_alias);
120   }
121 
next_sequence(MoqtSession * session,FullTrackName & name)122   static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) {
123     auto it = session->local_tracks_.find(name);
124     EXPECT_NE(it, session->local_tracks_.end());
125     LocalTrack& track = it->second;
126     return track.next_sequence();
127   }
128 
set_peer_role(MoqtSession * session,MoqtRole role)129   static void set_peer_role(MoqtSession* session, MoqtRole role) {
130     session->peer_role_ = role;
131   }
132 
remote_track(MoqtSession * session,uint64_t track_alias)133   static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
134     return session->remote_tracks_.find(track_alias)->second;
135   }
136 };
137 
138 class MoqtSessionTest : public quic::test::QuicTest {
139  public:
MoqtSessionTest()140   MoqtSessionTest()
141       : session_(&mock_session_, default_parameters,
142                  session_callbacks_.AsSessionCallbacks()) {}
~MoqtSessionTest()143   ~MoqtSessionTest() {
144     EXPECT_CALL(session_callbacks_.session_deleted_callback, Call());
145   }
146 
147   MockSessionCallbacks session_callbacks_;
148   StrictMock<webtransport::test::MockSession> mock_session_;
149   MoqtSession session_;
150 };
151 
TEST_F(MoqtSessionTest,Queries)152 TEST_F(MoqtSessionTest, Queries) {
153   EXPECT_EQ(session_.perspective(), quic::Perspective::IS_CLIENT);
154 }
155 
156 // Verify the session sends CLIENT_SETUP on the control stream.
TEST_F(MoqtSessionTest,OnSessionReady)157 TEST_F(MoqtSessionTest, OnSessionReady) {
158   StrictMock<webtransport::test::MockStream> mock_stream;
159   EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream())
160       .WillOnce(Return(&mock_stream));
161   std::unique_ptr<webtransport::StreamVisitor> visitor;
162   // Save a reference to MoqtSession::Stream
163   EXPECT_CALL(mock_stream, SetVisitor(_))
164       .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) {
165         visitor = std::move(new_visitor);
166       });
167   EXPECT_CALL(mock_stream, GetStreamId())
168       .WillOnce(Return(webtransport::StreamId(4)));
169   EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
170   bool correct_message = false;
171   EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
172   EXPECT_CALL(mock_stream, Writev(_, _))
173       .WillOnce([&](absl::Span<const absl::string_view> data,
174                     const quiche::StreamWriteOptions& options) {
175         correct_message = true;
176         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup);
177         return absl::OkStatus();
178       });
179   session_.OnSessionReady();
180   EXPECT_TRUE(correct_message);
181 
182   // Receive SERVER_SETUP
183   MoqtParserVisitor* stream_input =
184       MoqtSessionPeer::FetchParserVisitorFromWebtransportStreamVisitor(
185           &session_, visitor.get());
186   // Handle the server setup
187   MoqtServerSetup setup = {
188       MoqtVersion::kDraft03,
189       MoqtRole::kPubSub,
190   };
191   EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
192   stream_input->OnServerSetupMessage(setup);
193 }
194 
TEST_F(MoqtSessionTest,OnClientSetup)195 TEST_F(MoqtSessionTest, OnClientSetup) {
196   MoqtSessionParameters server_parameters = {
197       /*version=*/MoqtVersion::kDraft03,
198       /*perspective=*/quic::Perspective::IS_SERVER,
199       /*using_webtrans=*/true,
200       /*path=*/"",
201       /*deliver_partial_objects=*/false,
202   };
203   MoqtSession server_session(&mock_session_, server_parameters,
204                              session_callbacks_.AsSessionCallbacks());
205   StrictMock<webtransport::test::MockStream> mock_stream;
206   std::unique_ptr<MoqtParserVisitor> stream_input =
207       MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream);
208   MoqtClientSetup setup = {
209       /*supported_versions=*/{MoqtVersion::kDraft03},
210       /*role=*/MoqtRole::kPubSub,
211       /*path=*/std::nullopt,
212   };
213   bool correct_message = false;
214   EXPECT_CALL(mock_stream, Writev(_, _))
215       .WillOnce([&](absl::Span<const absl::string_view> data,
216                     const quiche::StreamWriteOptions& options) {
217         correct_message = true;
218         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
219         return absl::OkStatus();
220       });
221   EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
222   EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
223   stream_input->OnClientSetupMessage(setup);
224 }
225 
TEST_F(MoqtSessionTest,OnSessionClosed)226 TEST_F(MoqtSessionTest, OnSessionClosed) {
227   bool reported_error = false;
228   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
229       .WillOnce([&](absl::string_view error_message) {
230         reported_error = true;
231         EXPECT_EQ(error_message, "foo");
232       });
233   session_.OnSessionClosed(webtransport::SessionErrorCode(1), "foo");
234   EXPECT_TRUE(reported_error);
235 }
236 
TEST_F(MoqtSessionTest,OnIncomingBidirectionalStream)237 TEST_F(MoqtSessionTest, OnIncomingBidirectionalStream) {
238   ::testing::InSequence seq;
239   StrictMock<webtransport::test::MockStream> mock_stream;
240   StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor;
241   EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
242       .WillOnce(Return(&mock_stream));
243   EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
244   EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor));
245   EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1);
246   EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
247       .WillOnce(Return(nullptr));
248   session_.OnIncomingBidirectionalStreamAvailable();
249 }
250 
TEST_F(MoqtSessionTest,OnIncomingUnidirectionalStream)251 TEST_F(MoqtSessionTest, OnIncomingUnidirectionalStream) {
252   ::testing::InSequence seq;
253   StrictMock<webtransport::test::MockStream> mock_stream;
254   StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor;
255   EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream())
256       .WillOnce(Return(&mock_stream));
257   EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
258   EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor));
259   EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1);
260   EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream())
261       .WillOnce(Return(nullptr));
262   session_.OnIncomingUnidirectionalStreamAvailable();
263 }
264 
TEST_F(MoqtSessionTest,Error)265 TEST_F(MoqtSessionTest, Error) {
266   bool reported_error = false;
267   EXPECT_CALL(
268       mock_session_,
269       CloseSession(static_cast<uint64_t>(MoqtError::kParameterLengthMismatch),
270                    "foo"))
271       .Times(1);
272   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
273       .WillOnce([&](absl::string_view error_message) {
274         reported_error = (error_message == "foo");
275       });
276   session_.Error(MoqtError::kParameterLengthMismatch, "foo");
277   EXPECT_TRUE(reported_error);
278 }
279 
TEST_F(MoqtSessionTest,AddLocalTrack)280 TEST_F(MoqtSessionTest, AddLocalTrack) {
281   MoqtSubscribe request = {
282       /*subscribe_id=*/1,
283       /*track_alias=*/2,
284       /*track_namespace=*/"foo",
285       /*track_name=*/"bar",
286       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
287       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
288       /*end_group=*/std::nullopt,
289       /*end_object=*/std::nullopt,
290       /*authorization_info=*/std::nullopt,
291   };
292   StrictMock<webtransport::test::MockStream> mock_stream;
293   std::unique_ptr<MoqtParserVisitor> stream_input =
294       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
295   // Request for track returns SUBSCRIBE_ERROR.
296   bool correct_message = false;
297   EXPECT_CALL(mock_stream, Writev(_, _))
298       .WillOnce([&](absl::Span<const absl::string_view> data,
299                     const quiche::StreamWriteOptions& options) {
300         correct_message = true;
301         EXPECT_EQ(*ExtractMessageType(data[0]),
302                   MoqtMessageType::kSubscribeError);
303         return absl::OkStatus();
304       });
305   stream_input->OnSubscribeMessage(request);
306   EXPECT_TRUE(correct_message);
307 
308   // Add the track. Now Subscribe should succeed.
309   MockLocalTrackVisitor local_track_visitor;
310   session_.AddLocalTrack(FullTrackName("foo", "bar"),
311                          MoqtForwardingPreference::kObject,
312                          &local_track_visitor);
313   correct_message = true;
314   EXPECT_CALL(mock_stream, Writev(_, _))
315       .WillOnce([&](absl::Span<const absl::string_view> data,
316                     const quiche::StreamWriteOptions& options) {
317         correct_message = true;
318         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
319         return absl::OkStatus();
320       });
321   stream_input->OnSubscribeMessage(request);
322   EXPECT_TRUE(correct_message);
323 }
324 
TEST_F(MoqtSessionTest,AnnounceWithOk)325 TEST_F(MoqtSessionTest, AnnounceWithOk) {
326   testing::MockFunction<void(
327       absl::string_view track_namespace,
328       std::optional<MoqtAnnounceErrorReason> error_message)>
329       announce_resolved_callback;
330   StrictMock<webtransport::test::MockStream> mock_stream;
331   std::unique_ptr<MoqtParserVisitor> stream_input =
332       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
333   EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
334   bool correct_message = true;
335   EXPECT_CALL(mock_stream, Writev(_, _))
336       .WillOnce([&](absl::Span<const absl::string_view> data,
337                     const quiche::StreamWriteOptions& options) {
338         correct_message = true;
339         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce);
340         return absl::OkStatus();
341       });
342   session_.Announce("foo", announce_resolved_callback.AsStdFunction());
343   EXPECT_TRUE(correct_message);
344 
345   MoqtAnnounceOk ok = {
346       /*track_namespace=*/"foo",
347   };
348   correct_message = false;
349   EXPECT_CALL(announce_resolved_callback, Call(_, _))
350       .WillOnce([&](absl::string_view track_namespace,
351                     std::optional<MoqtAnnounceErrorReason> error) {
352         correct_message = true;
353         EXPECT_EQ(track_namespace, "foo");
354         EXPECT_FALSE(error.has_value());
355       });
356   stream_input->OnAnnounceOkMessage(ok);
357   EXPECT_TRUE(correct_message);
358 }
359 
TEST_F(MoqtSessionTest,AnnounceWithError)360 TEST_F(MoqtSessionTest, AnnounceWithError) {
361   testing::MockFunction<void(
362       absl::string_view track_namespace,
363       std::optional<MoqtAnnounceErrorReason> error_message)>
364       announce_resolved_callback;
365   StrictMock<webtransport::test::MockStream> mock_stream;
366   std::unique_ptr<MoqtParserVisitor> stream_input =
367       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
368   EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
369   bool correct_message = true;
370   EXPECT_CALL(mock_stream, Writev(_, _))
371       .WillOnce([&](absl::Span<const absl::string_view> data,
372                     const quiche::StreamWriteOptions& options) {
373         correct_message = true;
374         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce);
375         return absl::OkStatus();
376       });
377   session_.Announce("foo", announce_resolved_callback.AsStdFunction());
378   EXPECT_TRUE(correct_message);
379 
380   MoqtAnnounceError error = {
381       /*track_namespace=*/"foo",
382       /*error_code=*/MoqtAnnounceErrorCode::kInternalError,
383       /*reason_phrase=*/"Test error",
384   };
385   correct_message = false;
386   EXPECT_CALL(announce_resolved_callback, Call(_, _))
387       .WillOnce([&](absl::string_view track_namespace,
388                     std::optional<MoqtAnnounceErrorReason> error) {
389         correct_message = true;
390         EXPECT_EQ(track_namespace, "foo");
391         ASSERT_TRUE(error.has_value());
392         EXPECT_EQ(error->error_code, MoqtAnnounceErrorCode::kInternalError);
393         EXPECT_EQ(error->reason_phrase, "Test error");
394       });
395   stream_input->OnAnnounceErrorMessage(error);
396   EXPECT_TRUE(correct_message);
397 }
398 
TEST_F(MoqtSessionTest,HasSubscribers)399 TEST_F(MoqtSessionTest, HasSubscribers) {
400   MockLocalTrackVisitor local_track_visitor;
401   FullTrackName ftn("foo", "bar");
402   EXPECT_FALSE(session_.HasSubscribers(ftn));
403   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kGroup,
404                          &local_track_visitor);
405   EXPECT_FALSE(session_.HasSubscribers(ftn));
406 
407   // Peer subscribes.
408   MoqtSubscribe request = {
409       /*subscribe_id=*/1,
410       /*track_alias=*/2,
411       /*track_namespace=*/"foo",
412       /*track_name=*/"bar",
413       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
414       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
415       /*end_group=*/std::nullopt,
416       /*end_object=*/std::nullopt,
417       /*authorization_info=*/std::nullopt,
418   };
419   StrictMock<webtransport::test::MockStream> mock_stream;
420   std::unique_ptr<MoqtParserVisitor> stream_input =
421       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
422   bool correct_message = true;
423   EXPECT_CALL(mock_stream, Writev(_, _))
424       .WillOnce([&](absl::Span<const absl::string_view> data,
425                     const quiche::StreamWriteOptions& options) {
426         correct_message = true;
427         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
428         return absl::OkStatus();
429       });
430   stream_input->OnSubscribeMessage(request);
431   EXPECT_TRUE(correct_message);
432   EXPECT_TRUE(session_.HasSubscribers(ftn));
433 }
434 
TEST_F(MoqtSessionTest,SubscribeForPast)435 TEST_F(MoqtSessionTest, SubscribeForPast) {
436   MockLocalTrackVisitor local_track_visitor;
437   FullTrackName ftn("foo", "bar");
438   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
439                          &local_track_visitor);
440 
441   // Send Sequence (2, 0) so that next_sequence is set correctly.
442   session_.PublishObject(ftn, 2, 0, 0, "foo", true);
443   // Peer subscribes to (0, 0)
444   MoqtSubscribe request = {
445       /*subscribe_id=*/1,
446       /*track_alias=*/2,
447       /*track_namespace=*/"foo",
448       /*track_name=*/"bar",
449       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
450       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
451       /*end_group=*/std::nullopt,
452       /*end_object=*/std::nullopt,
453       /*authorization_info=*/std::nullopt,
454   };
455   StrictMock<webtransport::test::MockStream> mock_stream;
456   std::unique_ptr<MoqtParserVisitor> stream_input =
457       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
458   bool correct_message = true;
459   EXPECT_CALL(local_track_visitor, OnSubscribeForPast(_))
460       .WillOnce(Return(std::nullopt));
461   EXPECT_CALL(mock_stream, Writev(_, _))
462       .WillOnce([&](absl::Span<const absl::string_view> data,
463                     const quiche::StreamWriteOptions& options) {
464         correct_message = true;
465         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
466         return absl::OkStatus();
467       });
468   stream_input->OnSubscribeMessage(request);
469   EXPECT_TRUE(correct_message);
470 }
471 
TEST_F(MoqtSessionTest,SubscribeWithOk)472 TEST_F(MoqtSessionTest, SubscribeWithOk) {
473   StrictMock<webtransport::test::MockStream> mock_stream;
474   std::unique_ptr<MoqtParserVisitor> stream_input =
475       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
476   MockRemoteTrackVisitor remote_track_visitor;
477   EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
478   bool correct_message = true;
479   EXPECT_CALL(mock_stream, Writev(_, _))
480       .WillOnce([&](absl::Span<const absl::string_view> data,
481                     const quiche::StreamWriteOptions& options) {
482         correct_message = true;
483         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
484         return absl::OkStatus();
485       });
486   session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, "");
487 
488   MoqtSubscribeOk ok = {
489       /*subscribe_id=*/0,
490       /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
491   };
492   correct_message = false;
493   EXPECT_CALL(remote_track_visitor, OnReply(_, _))
494       .WillOnce([&](const FullTrackName& ftn,
495                     std::optional<absl::string_view> error_message) {
496         correct_message = true;
497         EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
498         EXPECT_FALSE(error_message.has_value());
499       });
500   stream_input->OnSubscribeOkMessage(ok);
501   EXPECT_TRUE(correct_message);
502 }
503 
TEST_F(MoqtSessionTest,SubscribeWithError)504 TEST_F(MoqtSessionTest, SubscribeWithError) {
505   StrictMock<webtransport::test::MockStream> mock_stream;
506   std::unique_ptr<MoqtParserVisitor> stream_input =
507       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
508   MockRemoteTrackVisitor remote_track_visitor;
509   EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
510   bool correct_message = true;
511   EXPECT_CALL(mock_stream, Writev(_, _))
512       .WillOnce([&](absl::Span<const absl::string_view> data,
513                     const quiche::StreamWriteOptions& options) {
514         correct_message = true;
515         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
516         return absl::OkStatus();
517       });
518   session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, "");
519 
520   MoqtSubscribeError error = {
521       /*subscribe_id=*/0,
522       /*error_code=*/SubscribeErrorCode::kInvalidRange,
523       /*reason_phrase=*/"deadbeef",
524       /*track_alias=*/2,
525   };
526   correct_message = false;
527   EXPECT_CALL(remote_track_visitor, OnReply(_, _))
528       .WillOnce([&](const FullTrackName& ftn,
529                     std::optional<absl::string_view> error_message) {
530         correct_message = true;
531         EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
532         EXPECT_EQ(*error_message, "deadbeef");
533       });
534   stream_input->OnSubscribeErrorMessage(error);
535   EXPECT_TRUE(correct_message);
536 }
537 
TEST_F(MoqtSessionTest,ReplyToAnnounce)538 TEST_F(MoqtSessionTest, ReplyToAnnounce) {
539   StrictMock<webtransport::test::MockStream> mock_stream;
540   std::unique_ptr<MoqtParserVisitor> stream_input =
541       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
542   MoqtAnnounce announce = {
543       /*track_namespace=*/"foo",
544   };
545   bool correct_message = false;
546   EXPECT_CALL(session_callbacks_.incoming_announce_callback, Call("foo"))
547       .WillOnce(Return(std::nullopt));
548   EXPECT_CALL(mock_stream, Writev(_, _))
549       .WillOnce([&](absl::Span<const absl::string_view> data,
550                     const quiche::StreamWriteOptions& options) {
551         correct_message = true;
552         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounceOk);
553         return absl::OkStatus();
554       });
555   stream_input->OnAnnounceMessage(announce);
556   EXPECT_TRUE(correct_message);
557 }
558 
TEST_F(MoqtSessionTest,IncomingObject)559 TEST_F(MoqtSessionTest, IncomingObject) {
560   MockRemoteTrackVisitor visitor_;
561   FullTrackName ftn("foo", "bar");
562   std::string payload = "deadbeef";
563   MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
564   MoqtObject object = {
565       /*subscribe_id=*/1,
566       /*track_alias=*/2,
567       /*group_sequence=*/0,
568       /*object_sequence=*/0,
569       /*object_send_order=*/0,
570       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
571       /*payload_length=*/8,
572   };
573   StrictMock<webtransport::test::MockStream> mock_stream;
574   std::unique_ptr<MoqtParserVisitor> object_stream =
575       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
576 
577   EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
578   EXPECT_CALL(mock_stream, GetStreamId())
579       .WillRepeatedly(Return(kIncomingUniStreamId));
580   object_stream->OnObjectMessage(object, payload, true);
581 }
582 
TEST_F(MoqtSessionTest,IncomingPartialObject)583 TEST_F(MoqtSessionTest, IncomingPartialObject) {
584   MockRemoteTrackVisitor visitor_;
585   FullTrackName ftn("foo", "bar");
586   std::string payload = "deadbeef";
587   MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
588   MoqtObject object = {
589       /*subscribe_id=*/1,
590       /*track_alias=*/2,
591       /*group_sequence=*/0,
592       /*object_sequence=*/0,
593       /*object_send_order=*/0,
594       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
595       /*payload_length=*/16,
596   };
597   StrictMock<webtransport::test::MockStream> mock_stream;
598   std::unique_ptr<MoqtParserVisitor> object_stream =
599       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
600 
601   EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
602   EXPECT_CALL(mock_stream, GetStreamId())
603       .WillRepeatedly(Return(kIncomingUniStreamId));
604   object_stream->OnObjectMessage(object, payload, false);
605   object_stream->OnObjectMessage(object, payload, true);  // complete the object
606 }
607 
TEST_F(MoqtSessionTest,IncomingPartialObjectNoBuffer)608 TEST_F(MoqtSessionTest, IncomingPartialObjectNoBuffer) {
609   MoqtSessionParameters parameters = {
610       /*version=*/MoqtVersion::kDraft03,
611       /*perspective=*/quic::Perspective::IS_CLIENT,
612       /*using_webtrans=*/true,
613       /*path=*/"",
614       /*deliver_partial_objects=*/true,
615   };
616   MoqtSession session(&mock_session_, parameters,
617                       session_callbacks_.AsSessionCallbacks());
618   MockRemoteTrackVisitor visitor_;
619   FullTrackName ftn("foo", "bar");
620   std::string payload = "deadbeef";
621   MoqtSessionPeer::CreateRemoteTrack(&session, ftn, &visitor_, 2);
622   MoqtObject object = {
623       /*subscribe_id=*/1,
624       /*track_alias=*/2,
625       /*group_sequence=*/0,
626       /*object_sequence=*/0,
627       /*object_send_order=*/0,
628       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
629       /*payload_length=*/16,
630   };
631   StrictMock<webtransport::test::MockStream> mock_stream;
632   std::unique_ptr<MoqtParserVisitor> object_stream =
633       MoqtSessionPeer::CreateUniStream(&session, &mock_stream);
634 
635   EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(2);
636   EXPECT_CALL(mock_stream, GetStreamId())
637       .WillRepeatedly(Return(kIncomingUniStreamId));
638   object_stream->OnObjectMessage(object, payload, false);
639   object_stream->OnObjectMessage(object, payload, true);  // complete the object
640 }
641 
TEST_F(MoqtSessionTest,ObjectBeforeSubscribeOk)642 TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) {
643   MockRemoteTrackVisitor visitor_;
644   FullTrackName ftn("foo", "bar");
645   std::string payload = "deadbeef";
646   MoqtSubscribe subscribe = {
647       /*subscribe_id=*/1,
648       /*track_alias=*/2,
649       /*track_namespace=*/ftn.track_namespace,
650       /*track_name=*/ftn.track_name,
651       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
652       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
653       /*end_group=*/std::nullopt,
654       /*end_object=*/std::nullopt,
655   };
656   MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor_);
657   MoqtObject object = {
658       /*subscribe_id=*/1,
659       /*track_alias=*/2,
660       /*group_sequence=*/0,
661       /*object_sequence=*/0,
662       /*object_send_order=*/0,
663       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
664       /*payload_length=*/8,
665   };
666   StrictMock<webtransport::test::MockStream> mock_stream;
667   std::unique_ptr<MoqtParserVisitor> object_stream =
668       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
669 
670   EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _))
671       .WillOnce([&](const FullTrackName& full_track_name,
672                     uint64_t group_sequence, uint64_t object_sequence,
673                     uint64_t object_send_order,
674                     MoqtForwardingPreference forwarding_preference,
675                     absl::string_view payload, bool end_of_message) {
676         EXPECT_EQ(full_track_name, ftn);
677         EXPECT_EQ(group_sequence, object.group_id);
678         EXPECT_EQ(object_sequence, object.object_id);
679       });
680   EXPECT_CALL(mock_stream, GetStreamId())
681       .WillRepeatedly(Return(kIncomingUniStreamId));
682   object_stream->OnObjectMessage(object, payload, true);
683 
684   // SUBSCRIBE_OK arrives
685   MoqtSubscribeOk ok = {
686       /*subscribe_id=*/1,
687       /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
688       /*largest_id=*/std::nullopt,
689   };
690   StrictMock<webtransport::test::MockStream> mock_control_stream;
691   std::unique_ptr<MoqtParserVisitor> control_stream =
692       MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
693   EXPECT_CALL(visitor_, OnReply(_, _)).Times(1);
694   control_stream->OnSubscribeOkMessage(ok);
695 }
696 
TEST_F(MoqtSessionTest,ObjectBeforeSubscribeError)697 TEST_F(MoqtSessionTest, ObjectBeforeSubscribeError) {
698   MockRemoteTrackVisitor visitor;
699   FullTrackName ftn("foo", "bar");
700   std::string payload = "deadbeef";
701   MoqtSubscribe subscribe = {
702       /*subscribe_id=*/1,
703       /*track_alias=*/2,
704       /*track_namespace=*/ftn.track_namespace,
705       /*track_name=*/ftn.track_name,
706       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
707       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
708       /*end_group=*/std::nullopt,
709       /*end_object=*/std::nullopt,
710   };
711   MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
712   MoqtObject object = {
713       /*subscribe_id=*/1,
714       /*track_alias=*/2,
715       /*group_sequence=*/0,
716       /*object_sequence=*/0,
717       /*object_send_order=*/0,
718       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
719       /*payload_length=*/8,
720   };
721   StrictMock<webtransport::test::MockStream> mock_stream;
722   std::unique_ptr<MoqtParserVisitor> object_stream =
723       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
724 
725   EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
726       .WillOnce([&](const FullTrackName& full_track_name,
727                     uint64_t group_sequence, uint64_t object_sequence,
728                     uint64_t object_send_order,
729                     MoqtForwardingPreference forwarding_preference,
730                     absl::string_view payload, bool end_of_message) {
731         EXPECT_EQ(full_track_name, ftn);
732         EXPECT_EQ(group_sequence, object.group_id);
733         EXPECT_EQ(object_sequence, object.object_id);
734       });
735   EXPECT_CALL(mock_stream, GetStreamId())
736       .WillRepeatedly(Return(kIncomingUniStreamId));
737   object_stream->OnObjectMessage(object, payload, true);
738 
739   // SUBSCRIBE_ERROR arrives
740   MoqtSubscribeError subscribe_error = {
741       /*subscribe_id=*/1,
742       /*error_code=*/SubscribeErrorCode::kRetryTrackAlias,
743       /*reason_phrase=*/"foo",
744       /*track_alias =*/3,
745   };
746   StrictMock<webtransport::test::MockStream> mock_control_stream;
747   std::unique_ptr<MoqtParserVisitor> control_stream =
748       MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
749   EXPECT_CALL(mock_session_,
750               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
751                            "Received SUBSCRIBE_ERROR after object"))
752       .Times(1);
753   control_stream->OnSubscribeErrorMessage(subscribe_error);
754 }
755 
TEST_F(MoqtSessionTest,TwoEarlyObjectsDifferentForwarding)756 TEST_F(MoqtSessionTest, TwoEarlyObjectsDifferentForwarding) {
757   MockRemoteTrackVisitor visitor;
758   FullTrackName ftn("foo", "bar");
759   std::string payload = "deadbeef";
760   MoqtSubscribe subscribe = {
761       /*subscribe_id=*/1,
762       /*track_alias=*/2,
763       /*track_namespace=*/ftn.track_namespace,
764       /*track_name=*/ftn.track_name,
765       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
766       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
767       /*end_group=*/std::nullopt,
768       /*end_object=*/std::nullopt,
769   };
770   MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
771   MoqtObject object = {
772       /*subscribe_id=*/1,
773       /*track_alias=*/2,
774       /*group_sequence=*/0,
775       /*object_sequence=*/0,
776       /*object_send_order=*/0,
777       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
778       /*payload_length=*/8,
779   };
780   StrictMock<webtransport::test::MockStream> mock_stream;
781   std::unique_ptr<MoqtParserVisitor> object_stream =
782       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
783 
784   EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
785       .WillOnce([&](const FullTrackName& full_track_name,
786                     uint64_t group_sequence, uint64_t object_sequence,
787                     uint64_t object_send_order,
788                     MoqtForwardingPreference forwarding_preference,
789                     absl::string_view payload, bool end_of_message) {
790         EXPECT_EQ(full_track_name, ftn);
791         EXPECT_EQ(group_sequence, object.group_id);
792         EXPECT_EQ(object_sequence, object.object_id);
793       });
794   EXPECT_CALL(mock_stream, GetStreamId())
795       .WillRepeatedly(Return(kIncomingUniStreamId));
796   object_stream->OnObjectMessage(object, payload, true);
797   object.forwarding_preference = MoqtForwardingPreference::kObject;
798   ++object.object_id;
799   EXPECT_CALL(mock_session_,
800               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
801                            "Forwarding preference changes mid-track"))
802       .Times(1);
803   object_stream->OnObjectMessage(object, payload, true);
804 }
805 
TEST_F(MoqtSessionTest,EarlyObjectForwardingDoesNotMatchTrack)806 TEST_F(MoqtSessionTest, EarlyObjectForwardingDoesNotMatchTrack) {
807   MockRemoteTrackVisitor visitor;
808   FullTrackName ftn("foo", "bar");
809   std::string payload = "deadbeef";
810   MoqtSubscribe subscribe = {
811       /*subscribe_id=*/1,
812       /*track_alias=*/2,
813       /*track_namespace=*/ftn.track_namespace,
814       /*track_name=*/ftn.track_name,
815       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
816       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
817       /*end_group=*/std::nullopt,
818       /*end_object=*/std::nullopt,
819   };
820   MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
821   MoqtObject object = {
822       /*subscribe_id=*/1,
823       /*track_alias=*/2,
824       /*group_sequence=*/0,
825       /*object_sequence=*/0,
826       /*object_send_order=*/0,
827       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
828       /*payload_length=*/8,
829   };
830   StrictMock<webtransport::test::MockStream> mock_stream;
831   std::unique_ptr<MoqtParserVisitor> object_stream =
832       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
833 
834   EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
835       .WillOnce([&](const FullTrackName& full_track_name,
836                     uint64_t group_sequence, uint64_t object_sequence,
837                     uint64_t object_send_order,
838                     MoqtForwardingPreference forwarding_preference,
839                     absl::string_view payload, bool end_of_message) {
840         EXPECT_EQ(full_track_name, ftn);
841         EXPECT_EQ(group_sequence, object.group_id);
842         EXPECT_EQ(object_sequence, object.object_id);
843       });
844   EXPECT_CALL(mock_stream, GetStreamId())
845       .WillRepeatedly(Return(kIncomingUniStreamId));
846   object_stream->OnObjectMessage(object, payload, true);
847 
848   MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor, 2);
849   // The track already exists, and has a different forwarding preference.
850   MoqtSessionPeer::remote_track(&session_, 2)
851       .CheckForwardingPreference(MoqtForwardingPreference::kObject);
852 
853   // SUBSCRIBE_OK arrives
854   MoqtSubscribeOk ok = {
855       /*subscribe_id=*/1,
856       /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
857       /*largest_id=*/std::nullopt,
858   };
859   StrictMock<webtransport::test::MockStream> mock_control_stream;
860   std::unique_ptr<MoqtParserVisitor> control_stream =
861       MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
862   EXPECT_CALL(mock_session_,
863               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
864                            "Forwarding preference different in early objects"))
865       .Times(1);
866   control_stream->OnSubscribeOkMessage(ok);
867 }
868 
TEST_F(MoqtSessionTest,CreateUniStreamAndSend)869 TEST_F(MoqtSessionTest, CreateUniStreamAndSend) {
870   StrictMock<webtransport::test::MockStream> mock_stream;
871   FullTrackName ftn("foo", "bar");
872   MockLocalTrackVisitor track_visitor;
873   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
874                          &track_visitor);
875   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
876 
877   // No subscription; this is a no-op except to update next_sequence.
878   EXPECT_CALL(mock_stream, Writev(_, _)).Times(0);
879   session_.PublishObject(ftn, 4, 1, 0, "deadbeef", true);
880   EXPECT_EQ(MoqtSessionPeer::next_sequence(&session_, ftn), FullSequence(4, 2));
881 
882   // Publish in window.
883   EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
884       .WillOnce(Return(true));
885   EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
886       .WillOnce(Return(&mock_stream));
887   EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
888   EXPECT_CALL(mock_stream, GetStreamId())
889       .WillRepeatedly(Return(kOutgoingUniStreamId));
890   // Send on the stream
891   EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
892       .WillOnce(Return(&mock_stream));
893   bool correct_message = false;
894   // Verify first six message fields are sent correctly
895   uint8_t kExpectedMessage[] = {0x00, 0x00, 0x02, 0x05, 0x00, 0x00};
896   EXPECT_CALL(mock_stream, Writev(_, _))
897       .WillOnce([&](absl::Span<const absl::string_view> data,
898                     const quiche::StreamWriteOptions& options) {
899         correct_message = (0 == memcmp(data.data()->data(), kExpectedMessage,
900                                        sizeof(kExpectedMessage)));
901         return absl::OkStatus();
902       });
903   session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true);
904   EXPECT_TRUE(correct_message);
905 }
906 
907 // TODO: Test operation with multiple streams.
908 
909 // Error cases
910 
TEST_F(MoqtSessionTest,CannotOpenUniStream)911 TEST_F(MoqtSessionTest, CannotOpenUniStream) {
912   StrictMock<webtransport::test::MockStream> mock_stream;
913   FullTrackName ftn("foo", "bar");
914   MockLocalTrackVisitor track_visitor;
915   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
916                          &track_visitor);
917   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
918   ;
919   EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
920       .WillOnce(Return(false));
921   EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true));
922 }
923 
TEST_F(MoqtSessionTest,GetStreamByIdFails)924 TEST_F(MoqtSessionTest, GetStreamByIdFails) {
925   StrictMock<webtransport::test::MockStream> mock_stream;
926   FullTrackName ftn("foo", "bar");
927   MockLocalTrackVisitor track_visitor;
928   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
929                          &track_visitor);
930   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
931   EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
932       .WillOnce(Return(true));
933   EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
934       .WillOnce(Return(&mock_stream));
935   EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
936   EXPECT_CALL(mock_stream, GetStreamId())
937       .WillRepeatedly(Return(kOutgoingUniStreamId));
938   EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
939       .WillOnce(Return(nullptr));
940   EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true));
941 }
942 
TEST_F(MoqtSessionTest,SubscribeProposesBadTrackAlias)943 TEST_F(MoqtSessionTest, SubscribeProposesBadTrackAlias) {
944   MockLocalTrackVisitor local_track_visitor;
945   FullTrackName ftn("foo", "bar");
946   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kGroup,
947                          &local_track_visitor);
948   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
949 
950   // Peer subscribes.
951   MoqtSubscribe request = {
952       /*subscribe_id=*/1,
953       /*track_alias=*/3,  // Doesn't match 2.
954       /*track_namespace=*/"foo",
955       /*track_name=*/"bar",
956       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
957       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
958       /*end_group=*/std::nullopt,
959       /*end_object=*/std::nullopt,
960       /*authorization_info=*/std::nullopt,
961   };
962   StrictMock<webtransport::test::MockStream> mock_stream;
963   std::unique_ptr<MoqtParserVisitor> stream_input =
964       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
965   bool correct_message = true;
966   EXPECT_CALL(mock_stream, Writev(_, _))
967       .WillOnce([&](absl::Span<const absl::string_view> data,
968                     const quiche::StreamWriteOptions& options) {
969         correct_message = true;
970         EXPECT_EQ(*ExtractMessageType(data[0]),
971                   MoqtMessageType::kSubscribeError);
972         return absl::OkStatus();
973       });
974   stream_input->OnSubscribeMessage(request);
975   EXPECT_TRUE(correct_message);
976 }
977 
TEST_F(MoqtSessionTest,OneBidirectionalStreamClient)978 TEST_F(MoqtSessionTest, OneBidirectionalStreamClient) {
979   StrictMock<webtransport::test::MockStream> mock_stream;
980   EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream())
981       .WillOnce(Return(&mock_stream));
982   std::unique_ptr<webtransport::StreamVisitor> visitor;
983   // Save a reference to MoqtSession::Stream
984   EXPECT_CALL(mock_stream, SetVisitor(_))
985       .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) {
986         visitor = std::move(new_visitor);
987       });
988   EXPECT_CALL(mock_stream, GetStreamId())
989       .WillOnce(Return(webtransport::StreamId(4)));
990   EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
991   bool correct_message = false;
992   EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
993   EXPECT_CALL(mock_stream, Writev(_, _))
994       .WillOnce([&](absl::Span<const absl::string_view> data,
995                     const quiche::StreamWriteOptions& options) {
996         correct_message = true;
997         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup);
998         return absl::OkStatus();
999       });
1000   session_.OnSessionReady();
1001   EXPECT_TRUE(correct_message);
1002 
1003   // Peer tries to open a bidi stream.
1004   bool reported_error = false;
1005   EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
1006       .WillOnce(Return(&mock_stream));
1007   EXPECT_CALL(mock_session_,
1008               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1009                            "Bidirectional stream already open"))
1010       .Times(1);
1011   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
1012       .WillOnce([&](absl::string_view error_message) {
1013         reported_error = (error_message == "Bidirectional stream already open");
1014       });
1015   session_.OnIncomingBidirectionalStreamAvailable();
1016   EXPECT_TRUE(reported_error);
1017 }
1018 
TEST_F(MoqtSessionTest,OneBidirectionalStreamServer)1019 TEST_F(MoqtSessionTest, OneBidirectionalStreamServer) {
1020   MoqtSessionParameters server_parameters = {
1021       /*version=*/MoqtVersion::kDraft03,
1022       /*perspective=*/quic::Perspective::IS_SERVER,
1023       /*using_webtrans=*/true,
1024       /*path=*/"",
1025       /*deliver_partial_objects=*/false,
1026   };
1027   MoqtSession server_session(&mock_session_, server_parameters,
1028                              session_callbacks_.AsSessionCallbacks());
1029   StrictMock<webtransport::test::MockStream> mock_stream;
1030   std::unique_ptr<MoqtParserVisitor> stream_input =
1031       MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream);
1032   MoqtClientSetup setup = {
1033       /*supported_versions*/ {MoqtVersion::kDraft03},
1034       /*role=*/MoqtRole::kPubSub,
1035       /*path=*/std::nullopt,
1036   };
1037   bool correct_message = false;
1038   EXPECT_CALL(mock_stream, Writev(_, _))
1039       .WillOnce([&](absl::Span<const absl::string_view> data,
1040                     const quiche::StreamWriteOptions& options) {
1041         correct_message = true;
1042         EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
1043         return absl::OkStatus();
1044       });
1045   EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
1046   EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
1047   stream_input->OnClientSetupMessage(setup);
1048 
1049   // Peer tries to open a bidi stream.
1050   bool reported_error = false;
1051   EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
1052       .WillOnce(Return(&mock_stream));
1053   EXPECT_CALL(mock_session_,
1054               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1055                            "Bidirectional stream already open"))
1056       .Times(1);
1057   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
1058       .WillOnce([&](absl::string_view error_message) {
1059         reported_error = (error_message == "Bidirectional stream already open");
1060       });
1061   server_session.OnIncomingBidirectionalStreamAvailable();
1062   EXPECT_TRUE(reported_error);
1063 }
1064 
TEST_F(MoqtSessionTest,ReceiveUnsubscribe)1065 TEST_F(MoqtSessionTest, ReceiveUnsubscribe) {
1066   FullTrackName ftn("foo", "bar");
1067   MockLocalTrackVisitor visitor;
1068   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &visitor);
1069   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 1, 3, 4);
1070   EXPECT_TRUE(session_.HasSubscribers(ftn));
1071   StrictMock<webtransport::test::MockStream> mock_stream;
1072   std::unique_ptr<MoqtParserVisitor> stream_input =
1073       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1074   MoqtUnsubscribe unsubscribe = {
1075       /*subscribe_id=*/0,
1076   };
1077   stream_input->OnUnsubscribeMessage(unsubscribe);
1078   EXPECT_FALSE(session_.HasSubscribers(ftn));
1079 }
1080 
TEST_F(MoqtSessionTest,SendDatagram)1081 TEST_F(MoqtSessionTest, SendDatagram) {
1082   FullTrackName ftn("foo", "bar");
1083   MockLocalTrackVisitor track_visitor;
1084   session_.AddLocalTrack(ftn, MoqtForwardingPreference::kDatagram,
1085                          &track_visitor);
1086   MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
1087 
1088   // Publish in window.
1089   bool correct_message = false;
1090   uint8_t kExpectedMessage[] = {
1091       0x01, 0x00, 0x02, 0x05, 0x00, 0x00, 0x64,
1092       0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66,
1093   };
1094   EXPECT_CALL(mock_session_, SendOrQueueDatagram(_))
1095       .WillOnce([&](absl::string_view datagram) {
1096         if (datagram.size() == sizeof(kExpectedMessage)) {
1097           correct_message = (0 == memcmp(datagram.data(), kExpectedMessage,
1098                                          sizeof(kExpectedMessage)));
1099         }
1100         return webtransport::DatagramStatus(
1101             webtransport::DatagramStatusCode::kSuccess, "");
1102       });
1103   session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true);
1104   EXPECT_TRUE(correct_message);
1105 }
1106 
TEST_F(MoqtSessionTest,ReceiveDatagram)1107 TEST_F(MoqtSessionTest, ReceiveDatagram) {
1108   MockRemoteTrackVisitor visitor_;
1109   FullTrackName ftn("foo", "bar");
1110   std::string payload = "deadbeef";
1111   MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
1112   MoqtObject object = {
1113       /*subscribe_id=*/1,
1114       /*track_alias=*/2,
1115       /*group_sequence=*/0,
1116       /*object_sequence=*/0,
1117       /*object_send_order=*/0,
1118       /*forwarding_preference=*/MoqtForwardingPreference::kDatagram,
1119       /*payload_length=*/8,
1120   };
1121   char datagram[] = {0x01, 0x01, 0x02, 0x00, 0x00, 0x00, 0x64,
1122                      0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66};
1123   EXPECT_CALL(visitor_,
1124               OnObjectFragment(ftn, object.group_id, object.object_id,
1125                                object.object_send_order,
1126                                object.forwarding_preference, payload, true))
1127       .Times(1);
1128   session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
1129 }
1130 
TEST_F(MoqtSessionTest,ForwardingPreferenceMismatch)1131 TEST_F(MoqtSessionTest, ForwardingPreferenceMismatch) {
1132   MockRemoteTrackVisitor visitor_;
1133   FullTrackName ftn("foo", "bar");
1134   std::string payload = "deadbeef";
1135   MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
1136   MoqtObject object = {
1137       /*subscribe_id=*/1,
1138       /*track_alias=*/2,
1139       /*group_sequence=*/0,
1140       /*object_sequence=*/0,
1141       /*object_send_order=*/0,
1142       /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
1143       /*payload_length=*/8,
1144   };
1145   StrictMock<webtransport::test::MockStream> mock_stream;
1146   std::unique_ptr<MoqtParserVisitor> object_stream =
1147       MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
1148 
1149   EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
1150   EXPECT_CALL(mock_stream, GetStreamId())
1151       .WillRepeatedly(Return(kIncomingUniStreamId));
1152   object_stream->OnObjectMessage(object, payload, true);
1153   ++object.object_id;
1154   object.forwarding_preference = MoqtForwardingPreference::kTrack;
1155   EXPECT_CALL(mock_session_,
1156               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1157                            "Forwarding preference changes mid-track"))
1158       .Times(1);
1159   object_stream->OnObjectMessage(object, payload, true);
1160 }
1161 
TEST_F(MoqtSessionTest,AnnounceToPublisher)1162 TEST_F(MoqtSessionTest, AnnounceToPublisher) {
1163   MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
1164   testing::MockFunction<void(
1165       absl::string_view track_namespace,
1166       std::optional<MoqtAnnounceErrorReason> error_message)>
1167       announce_resolved_callback;
1168   EXPECT_CALL(announce_resolved_callback, Call(_, _)).Times(1);
1169   session_.Announce("foo", announce_resolved_callback.AsStdFunction());
1170 }
1171 
TEST_F(MoqtSessionTest,SubscribeFromPublisher)1172 TEST_F(MoqtSessionTest, SubscribeFromPublisher) {
1173   MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
1174   MoqtSubscribe request = {
1175       /*subscribe_id=*/1,
1176       /*track_alias=*/2,
1177       /*track_namespace=*/"foo",
1178       /*track_name=*/"bar",
1179       /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
1180       /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
1181       /*end_group=*/std::nullopt,
1182       /*end_object=*/std::nullopt,
1183       /*authorization_info=*/std::nullopt,
1184   };
1185   StrictMock<webtransport::test::MockStream> mock_stream;
1186   std::unique_ptr<MoqtParserVisitor> stream_input =
1187       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1188   // Request for track returns Protocol Violation.
1189   EXPECT_CALL(mock_session_,
1190               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1191                            "Received SUBSCRIBE from publisher"))
1192       .Times(1);
1193   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1);
1194   stream_input->OnSubscribeMessage(request);
1195 }
1196 
TEST_F(MoqtSessionTest,AnnounceFromSubscriber)1197 TEST_F(MoqtSessionTest, AnnounceFromSubscriber) {
1198   MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber);
1199   StrictMock<webtransport::test::MockStream> mock_stream;
1200   std::unique_ptr<MoqtParserVisitor> stream_input =
1201       MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1202   MoqtAnnounce announce = {
1203       /*track_namespace=*/"foo",
1204   };
1205   EXPECT_CALL(mock_session_,
1206               CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1207                            "Received ANNOUNCE from Subscriber"))
1208       .Times(1);
1209   EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1);
1210   stream_input->OnAnnounceMessage(announce);
1211 }
1212 
1213 // TODO: Cover more error cases in the above
1214 
1215 }  // namespace test
1216 
1217 }  // namespace moqt
1218