1 // Copyright 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/moqt/moqt_session.h"
6
7 #include <cstdint>
8 #include <cstring>
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <utility>
13
14 #include "absl/status/status.h"
15 #include "absl/strings/string_view.h"
16 #include "absl/types/span.h"
17 #include "quiche/quic/core/quic_data_reader.h"
18 #include "quiche/quic/core/quic_time.h"
19 #include "quiche/quic/core/quic_types.h"
20 #include "quiche/quic/moqt/moqt_messages.h"
21 #include "quiche/quic/moqt/moqt_parser.h"
22 #include "quiche/quic/moqt/moqt_track.h"
23 #include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
24 #include "quiche/quic/platform/api/quic_test.h"
25 #include "quiche/common/quiche_stream.h"
26 #include "quiche/web_transport/test_tools/mock_web_transport.h"
27 #include "quiche/web_transport/web_transport.h"
28
29 namespace moqt {
30
31 namespace test {
32
33 namespace {
34
35 using ::testing::_;
36 using ::testing::AnyNumber;
37 using ::testing::Return;
38 using ::testing::StrictMock;
39
40 constexpr webtransport::StreamId kControlStreamId = 4;
41 constexpr webtransport::StreamId kIncomingUniStreamId = 15;
42 constexpr webtransport::StreamId kOutgoingUniStreamId = 14;
43
44 constexpr MoqtSessionParameters default_parameters = {
45 /*version=*/MoqtVersion::kDraft03,
46 /*perspective=*/quic::Perspective::IS_CLIENT,
47 /*using_webtrans=*/true,
48 /*path=*/std::string(),
49 /*deliver_partial_objects=*/false,
50 };
51
52 // Returns nullopt if there is not enough in |message| to extract a type
ExtractMessageType(const absl::string_view message)53 static std::optional<MoqtMessageType> ExtractMessageType(
54 const absl::string_view message) {
55 quic::QuicDataReader reader(message);
56 uint64_t value;
57 if (!reader.ReadVarInt62(&value)) {
58 return std::nullopt;
59 }
60 return static_cast<MoqtMessageType>(value);
61 }
62
63 } // namespace
64
65 class MoqtSessionPeer {
66 public:
CreateControlStream(MoqtSession * session,webtransport::test::MockStream * stream)67 static std::unique_ptr<MoqtParserVisitor> CreateControlStream(
68 MoqtSession* session, webtransport::test::MockStream* stream) {
69 auto new_stream = std::make_unique<MoqtSession::Stream>(
70 session, stream, /*is_control_stream=*/true);
71 session->control_stream_ = kControlStreamId;
72 EXPECT_CALL(*stream, visitor())
73 .Times(AnyNumber())
74 .WillRepeatedly(Return(new_stream.get()));
75 return new_stream;
76 }
77
CreateUniStream(MoqtSession * session,webtransport::Stream * stream)78 static std::unique_ptr<MoqtParserVisitor> CreateUniStream(
79 MoqtSession* session, webtransport::Stream* stream) {
80 auto new_stream = std::make_unique<MoqtSession::Stream>(
81 session, stream, /*is_control_stream=*/false);
82 return new_stream;
83 }
84
85 // In the test OnSessionReady, the session creates a stream and then passes
86 // its unique_ptr to the mock webtransport stream. This function casts
87 // that unique_ptr into a MoqtSession::Stream*, which is a private class of
88 // MoqtSession, and then casts again into MoqtParserVisitor so that the test
89 // can inject packets into that stream.
90 // This function is useful for any test that wants to inject packets on a
91 // stream created by the MoqtSession.
FetchParserVisitorFromWebtransportStreamVisitor(MoqtSession * session,webtransport::StreamVisitor * visitor)92 static MoqtParserVisitor* FetchParserVisitorFromWebtransportStreamVisitor(
93 MoqtSession* session, webtransport::StreamVisitor* visitor) {
94 return (MoqtSession::Stream*)visitor;
95 }
96
CreateRemoteTrack(MoqtSession * session,const FullTrackName & name,RemoteTrack::Visitor * visitor,uint64_t track_alias)97 static void CreateRemoteTrack(MoqtSession* session, const FullTrackName& name,
98 RemoteTrack::Visitor* visitor,
99 uint64_t track_alias) {
100 session->remote_tracks_.try_emplace(track_alias, name, track_alias,
101 visitor);
102 session->remote_track_aliases_.try_emplace(name, track_alias);
103 }
104
AddActiveSubscribe(MoqtSession * session,uint64_t subscribe_id,MoqtSubscribe & subscribe,RemoteTrack::Visitor * visitor)105 static void AddActiveSubscribe(MoqtSession* session, uint64_t subscribe_id,
106 MoqtSubscribe& subscribe,
107 RemoteTrack::Visitor* visitor) {
108 session->active_subscribes_[subscribe_id] = {subscribe, visitor};
109 }
110
AddSubscription(MoqtSession * session,FullTrackName & name,uint64_t subscribe_id,uint64_t track_alias,uint64_t start_group,uint64_t start_object)111 static void AddSubscription(MoqtSession* session, FullTrackName& name,
112 uint64_t subscribe_id, uint64_t track_alias,
113 uint64_t start_group, uint64_t start_object) {
114 auto it = session->local_tracks_.find(name);
115 ASSERT_NE(it, session->local_tracks_.end());
116 LocalTrack& track = it->second;
117 track.set_track_alias(track_alias);
118 track.AddWindow(subscribe_id, start_group, start_object);
119 session->used_track_aliases_.emplace(track_alias);
120 }
121
next_sequence(MoqtSession * session,FullTrackName & name)122 static FullSequence next_sequence(MoqtSession* session, FullTrackName& name) {
123 auto it = session->local_tracks_.find(name);
124 EXPECT_NE(it, session->local_tracks_.end());
125 LocalTrack& track = it->second;
126 return track.next_sequence();
127 }
128
set_peer_role(MoqtSession * session,MoqtRole role)129 static void set_peer_role(MoqtSession* session, MoqtRole role) {
130 session->peer_role_ = role;
131 }
132
remote_track(MoqtSession * session,uint64_t track_alias)133 static RemoteTrack& remote_track(MoqtSession* session, uint64_t track_alias) {
134 return session->remote_tracks_.find(track_alias)->second;
135 }
136 };
137
138 class MoqtSessionTest : public quic::test::QuicTest {
139 public:
MoqtSessionTest()140 MoqtSessionTest()
141 : session_(&mock_session_, default_parameters,
142 session_callbacks_.AsSessionCallbacks()) {}
~MoqtSessionTest()143 ~MoqtSessionTest() {
144 EXPECT_CALL(session_callbacks_.session_deleted_callback, Call());
145 }
146
147 MockSessionCallbacks session_callbacks_;
148 StrictMock<webtransport::test::MockSession> mock_session_;
149 MoqtSession session_;
150 };
151
TEST_F(MoqtSessionTest,Queries)152 TEST_F(MoqtSessionTest, Queries) {
153 EXPECT_EQ(session_.perspective(), quic::Perspective::IS_CLIENT);
154 }
155
156 // Verify the session sends CLIENT_SETUP on the control stream.
TEST_F(MoqtSessionTest,OnSessionReady)157 TEST_F(MoqtSessionTest, OnSessionReady) {
158 StrictMock<webtransport::test::MockStream> mock_stream;
159 EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream())
160 .WillOnce(Return(&mock_stream));
161 std::unique_ptr<webtransport::StreamVisitor> visitor;
162 // Save a reference to MoqtSession::Stream
163 EXPECT_CALL(mock_stream, SetVisitor(_))
164 .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) {
165 visitor = std::move(new_visitor);
166 });
167 EXPECT_CALL(mock_stream, GetStreamId())
168 .WillOnce(Return(webtransport::StreamId(4)));
169 EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
170 bool correct_message = false;
171 EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
172 EXPECT_CALL(mock_stream, Writev(_, _))
173 .WillOnce([&](absl::Span<const absl::string_view> data,
174 const quiche::StreamWriteOptions& options) {
175 correct_message = true;
176 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup);
177 return absl::OkStatus();
178 });
179 session_.OnSessionReady();
180 EXPECT_TRUE(correct_message);
181
182 // Receive SERVER_SETUP
183 MoqtParserVisitor* stream_input =
184 MoqtSessionPeer::FetchParserVisitorFromWebtransportStreamVisitor(
185 &session_, visitor.get());
186 // Handle the server setup
187 MoqtServerSetup setup = {
188 MoqtVersion::kDraft03,
189 MoqtRole::kPubSub,
190 };
191 EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
192 stream_input->OnServerSetupMessage(setup);
193 }
194
TEST_F(MoqtSessionTest,OnClientSetup)195 TEST_F(MoqtSessionTest, OnClientSetup) {
196 MoqtSessionParameters server_parameters = {
197 /*version=*/MoqtVersion::kDraft03,
198 /*perspective=*/quic::Perspective::IS_SERVER,
199 /*using_webtrans=*/true,
200 /*path=*/"",
201 /*deliver_partial_objects=*/false,
202 };
203 MoqtSession server_session(&mock_session_, server_parameters,
204 session_callbacks_.AsSessionCallbacks());
205 StrictMock<webtransport::test::MockStream> mock_stream;
206 std::unique_ptr<MoqtParserVisitor> stream_input =
207 MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream);
208 MoqtClientSetup setup = {
209 /*supported_versions=*/{MoqtVersion::kDraft03},
210 /*role=*/MoqtRole::kPubSub,
211 /*path=*/std::nullopt,
212 };
213 bool correct_message = false;
214 EXPECT_CALL(mock_stream, Writev(_, _))
215 .WillOnce([&](absl::Span<const absl::string_view> data,
216 const quiche::StreamWriteOptions& options) {
217 correct_message = true;
218 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
219 return absl::OkStatus();
220 });
221 EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
222 EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
223 stream_input->OnClientSetupMessage(setup);
224 }
225
TEST_F(MoqtSessionTest,OnSessionClosed)226 TEST_F(MoqtSessionTest, OnSessionClosed) {
227 bool reported_error = false;
228 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
229 .WillOnce([&](absl::string_view error_message) {
230 reported_error = true;
231 EXPECT_EQ(error_message, "foo");
232 });
233 session_.OnSessionClosed(webtransport::SessionErrorCode(1), "foo");
234 EXPECT_TRUE(reported_error);
235 }
236
TEST_F(MoqtSessionTest,OnIncomingBidirectionalStream)237 TEST_F(MoqtSessionTest, OnIncomingBidirectionalStream) {
238 ::testing::InSequence seq;
239 StrictMock<webtransport::test::MockStream> mock_stream;
240 StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor;
241 EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
242 .WillOnce(Return(&mock_stream));
243 EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
244 EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor));
245 EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1);
246 EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
247 .WillOnce(Return(nullptr));
248 session_.OnIncomingBidirectionalStreamAvailable();
249 }
250
TEST_F(MoqtSessionTest,OnIncomingUnidirectionalStream)251 TEST_F(MoqtSessionTest, OnIncomingUnidirectionalStream) {
252 ::testing::InSequence seq;
253 StrictMock<webtransport::test::MockStream> mock_stream;
254 StrictMock<webtransport::test::MockStreamVisitor> mock_stream_visitor;
255 EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream())
256 .WillOnce(Return(&mock_stream));
257 EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
258 EXPECT_CALL(mock_stream, visitor()).WillOnce(Return(&mock_stream_visitor));
259 EXPECT_CALL(mock_stream_visitor, OnCanRead()).Times(1);
260 EXPECT_CALL(mock_session_, AcceptIncomingUnidirectionalStream())
261 .WillOnce(Return(nullptr));
262 session_.OnIncomingUnidirectionalStreamAvailable();
263 }
264
TEST_F(MoqtSessionTest,Error)265 TEST_F(MoqtSessionTest, Error) {
266 bool reported_error = false;
267 EXPECT_CALL(
268 mock_session_,
269 CloseSession(static_cast<uint64_t>(MoqtError::kParameterLengthMismatch),
270 "foo"))
271 .Times(1);
272 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
273 .WillOnce([&](absl::string_view error_message) {
274 reported_error = (error_message == "foo");
275 });
276 session_.Error(MoqtError::kParameterLengthMismatch, "foo");
277 EXPECT_TRUE(reported_error);
278 }
279
TEST_F(MoqtSessionTest,AddLocalTrack)280 TEST_F(MoqtSessionTest, AddLocalTrack) {
281 MoqtSubscribe request = {
282 /*subscribe_id=*/1,
283 /*track_alias=*/2,
284 /*track_namespace=*/"foo",
285 /*track_name=*/"bar",
286 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
287 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
288 /*end_group=*/std::nullopt,
289 /*end_object=*/std::nullopt,
290 /*authorization_info=*/std::nullopt,
291 };
292 StrictMock<webtransport::test::MockStream> mock_stream;
293 std::unique_ptr<MoqtParserVisitor> stream_input =
294 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
295 // Request for track returns SUBSCRIBE_ERROR.
296 bool correct_message = false;
297 EXPECT_CALL(mock_stream, Writev(_, _))
298 .WillOnce([&](absl::Span<const absl::string_view> data,
299 const quiche::StreamWriteOptions& options) {
300 correct_message = true;
301 EXPECT_EQ(*ExtractMessageType(data[0]),
302 MoqtMessageType::kSubscribeError);
303 return absl::OkStatus();
304 });
305 stream_input->OnSubscribeMessage(request);
306 EXPECT_TRUE(correct_message);
307
308 // Add the track. Now Subscribe should succeed.
309 MockLocalTrackVisitor local_track_visitor;
310 session_.AddLocalTrack(FullTrackName("foo", "bar"),
311 MoqtForwardingPreference::kObject,
312 &local_track_visitor);
313 correct_message = true;
314 EXPECT_CALL(mock_stream, Writev(_, _))
315 .WillOnce([&](absl::Span<const absl::string_view> data,
316 const quiche::StreamWriteOptions& options) {
317 correct_message = true;
318 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
319 return absl::OkStatus();
320 });
321 stream_input->OnSubscribeMessage(request);
322 EXPECT_TRUE(correct_message);
323 }
324
TEST_F(MoqtSessionTest,AnnounceWithOk)325 TEST_F(MoqtSessionTest, AnnounceWithOk) {
326 testing::MockFunction<void(
327 absl::string_view track_namespace,
328 std::optional<MoqtAnnounceErrorReason> error_message)>
329 announce_resolved_callback;
330 StrictMock<webtransport::test::MockStream> mock_stream;
331 std::unique_ptr<MoqtParserVisitor> stream_input =
332 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
333 EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
334 bool correct_message = true;
335 EXPECT_CALL(mock_stream, Writev(_, _))
336 .WillOnce([&](absl::Span<const absl::string_view> data,
337 const quiche::StreamWriteOptions& options) {
338 correct_message = true;
339 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce);
340 return absl::OkStatus();
341 });
342 session_.Announce("foo", announce_resolved_callback.AsStdFunction());
343 EXPECT_TRUE(correct_message);
344
345 MoqtAnnounceOk ok = {
346 /*track_namespace=*/"foo",
347 };
348 correct_message = false;
349 EXPECT_CALL(announce_resolved_callback, Call(_, _))
350 .WillOnce([&](absl::string_view track_namespace,
351 std::optional<MoqtAnnounceErrorReason> error) {
352 correct_message = true;
353 EXPECT_EQ(track_namespace, "foo");
354 EXPECT_FALSE(error.has_value());
355 });
356 stream_input->OnAnnounceOkMessage(ok);
357 EXPECT_TRUE(correct_message);
358 }
359
TEST_F(MoqtSessionTest,AnnounceWithError)360 TEST_F(MoqtSessionTest, AnnounceWithError) {
361 testing::MockFunction<void(
362 absl::string_view track_namespace,
363 std::optional<MoqtAnnounceErrorReason> error_message)>
364 announce_resolved_callback;
365 StrictMock<webtransport::test::MockStream> mock_stream;
366 std::unique_ptr<MoqtParserVisitor> stream_input =
367 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
368 EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
369 bool correct_message = true;
370 EXPECT_CALL(mock_stream, Writev(_, _))
371 .WillOnce([&](absl::Span<const absl::string_view> data,
372 const quiche::StreamWriteOptions& options) {
373 correct_message = true;
374 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounce);
375 return absl::OkStatus();
376 });
377 session_.Announce("foo", announce_resolved_callback.AsStdFunction());
378 EXPECT_TRUE(correct_message);
379
380 MoqtAnnounceError error = {
381 /*track_namespace=*/"foo",
382 /*error_code=*/MoqtAnnounceErrorCode::kInternalError,
383 /*reason_phrase=*/"Test error",
384 };
385 correct_message = false;
386 EXPECT_CALL(announce_resolved_callback, Call(_, _))
387 .WillOnce([&](absl::string_view track_namespace,
388 std::optional<MoqtAnnounceErrorReason> error) {
389 correct_message = true;
390 EXPECT_EQ(track_namespace, "foo");
391 ASSERT_TRUE(error.has_value());
392 EXPECT_EQ(error->error_code, MoqtAnnounceErrorCode::kInternalError);
393 EXPECT_EQ(error->reason_phrase, "Test error");
394 });
395 stream_input->OnAnnounceErrorMessage(error);
396 EXPECT_TRUE(correct_message);
397 }
398
TEST_F(MoqtSessionTest,HasSubscribers)399 TEST_F(MoqtSessionTest, HasSubscribers) {
400 MockLocalTrackVisitor local_track_visitor;
401 FullTrackName ftn("foo", "bar");
402 EXPECT_FALSE(session_.HasSubscribers(ftn));
403 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kGroup,
404 &local_track_visitor);
405 EXPECT_FALSE(session_.HasSubscribers(ftn));
406
407 // Peer subscribes.
408 MoqtSubscribe request = {
409 /*subscribe_id=*/1,
410 /*track_alias=*/2,
411 /*track_namespace=*/"foo",
412 /*track_name=*/"bar",
413 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
414 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
415 /*end_group=*/std::nullopt,
416 /*end_object=*/std::nullopt,
417 /*authorization_info=*/std::nullopt,
418 };
419 StrictMock<webtransport::test::MockStream> mock_stream;
420 std::unique_ptr<MoqtParserVisitor> stream_input =
421 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
422 bool correct_message = true;
423 EXPECT_CALL(mock_stream, Writev(_, _))
424 .WillOnce([&](absl::Span<const absl::string_view> data,
425 const quiche::StreamWriteOptions& options) {
426 correct_message = true;
427 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
428 return absl::OkStatus();
429 });
430 stream_input->OnSubscribeMessage(request);
431 EXPECT_TRUE(correct_message);
432 EXPECT_TRUE(session_.HasSubscribers(ftn));
433 }
434
TEST_F(MoqtSessionTest,SubscribeForPast)435 TEST_F(MoqtSessionTest, SubscribeForPast) {
436 MockLocalTrackVisitor local_track_visitor;
437 FullTrackName ftn("foo", "bar");
438 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
439 &local_track_visitor);
440
441 // Send Sequence (2, 0) so that next_sequence is set correctly.
442 session_.PublishObject(ftn, 2, 0, 0, "foo", true);
443 // Peer subscribes to (0, 0)
444 MoqtSubscribe request = {
445 /*subscribe_id=*/1,
446 /*track_alias=*/2,
447 /*track_namespace=*/"foo",
448 /*track_name=*/"bar",
449 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
450 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
451 /*end_group=*/std::nullopt,
452 /*end_object=*/std::nullopt,
453 /*authorization_info=*/std::nullopt,
454 };
455 StrictMock<webtransport::test::MockStream> mock_stream;
456 std::unique_ptr<MoqtParserVisitor> stream_input =
457 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
458 bool correct_message = true;
459 EXPECT_CALL(local_track_visitor, OnSubscribeForPast(_))
460 .WillOnce(Return(std::nullopt));
461 EXPECT_CALL(mock_stream, Writev(_, _))
462 .WillOnce([&](absl::Span<const absl::string_view> data,
463 const quiche::StreamWriteOptions& options) {
464 correct_message = true;
465 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribeOk);
466 return absl::OkStatus();
467 });
468 stream_input->OnSubscribeMessage(request);
469 EXPECT_TRUE(correct_message);
470 }
471
TEST_F(MoqtSessionTest,SubscribeWithOk)472 TEST_F(MoqtSessionTest, SubscribeWithOk) {
473 StrictMock<webtransport::test::MockStream> mock_stream;
474 std::unique_ptr<MoqtParserVisitor> stream_input =
475 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
476 MockRemoteTrackVisitor remote_track_visitor;
477 EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
478 bool correct_message = true;
479 EXPECT_CALL(mock_stream, Writev(_, _))
480 .WillOnce([&](absl::Span<const absl::string_view> data,
481 const quiche::StreamWriteOptions& options) {
482 correct_message = true;
483 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
484 return absl::OkStatus();
485 });
486 session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, "");
487
488 MoqtSubscribeOk ok = {
489 /*subscribe_id=*/0,
490 /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
491 };
492 correct_message = false;
493 EXPECT_CALL(remote_track_visitor, OnReply(_, _))
494 .WillOnce([&](const FullTrackName& ftn,
495 std::optional<absl::string_view> error_message) {
496 correct_message = true;
497 EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
498 EXPECT_FALSE(error_message.has_value());
499 });
500 stream_input->OnSubscribeOkMessage(ok);
501 EXPECT_TRUE(correct_message);
502 }
503
TEST_F(MoqtSessionTest,SubscribeWithError)504 TEST_F(MoqtSessionTest, SubscribeWithError) {
505 StrictMock<webtransport::test::MockStream> mock_stream;
506 std::unique_ptr<MoqtParserVisitor> stream_input =
507 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
508 MockRemoteTrackVisitor remote_track_visitor;
509 EXPECT_CALL(mock_session_, GetStreamById(_)).WillOnce(Return(&mock_stream));
510 bool correct_message = true;
511 EXPECT_CALL(mock_stream, Writev(_, _))
512 .WillOnce([&](absl::Span<const absl::string_view> data,
513 const quiche::StreamWriteOptions& options) {
514 correct_message = true;
515 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kSubscribe);
516 return absl::OkStatus();
517 });
518 session_.SubscribeCurrentGroup("foo", "bar", &remote_track_visitor, "");
519
520 MoqtSubscribeError error = {
521 /*subscribe_id=*/0,
522 /*error_code=*/SubscribeErrorCode::kInvalidRange,
523 /*reason_phrase=*/"deadbeef",
524 /*track_alias=*/2,
525 };
526 correct_message = false;
527 EXPECT_CALL(remote_track_visitor, OnReply(_, _))
528 .WillOnce([&](const FullTrackName& ftn,
529 std::optional<absl::string_view> error_message) {
530 correct_message = true;
531 EXPECT_EQ(ftn, FullTrackName("foo", "bar"));
532 EXPECT_EQ(*error_message, "deadbeef");
533 });
534 stream_input->OnSubscribeErrorMessage(error);
535 EXPECT_TRUE(correct_message);
536 }
537
TEST_F(MoqtSessionTest,ReplyToAnnounce)538 TEST_F(MoqtSessionTest, ReplyToAnnounce) {
539 StrictMock<webtransport::test::MockStream> mock_stream;
540 std::unique_ptr<MoqtParserVisitor> stream_input =
541 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
542 MoqtAnnounce announce = {
543 /*track_namespace=*/"foo",
544 };
545 bool correct_message = false;
546 EXPECT_CALL(session_callbacks_.incoming_announce_callback, Call("foo"))
547 .WillOnce(Return(std::nullopt));
548 EXPECT_CALL(mock_stream, Writev(_, _))
549 .WillOnce([&](absl::Span<const absl::string_view> data,
550 const quiche::StreamWriteOptions& options) {
551 correct_message = true;
552 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kAnnounceOk);
553 return absl::OkStatus();
554 });
555 stream_input->OnAnnounceMessage(announce);
556 EXPECT_TRUE(correct_message);
557 }
558
TEST_F(MoqtSessionTest,IncomingObject)559 TEST_F(MoqtSessionTest, IncomingObject) {
560 MockRemoteTrackVisitor visitor_;
561 FullTrackName ftn("foo", "bar");
562 std::string payload = "deadbeef";
563 MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
564 MoqtObject object = {
565 /*subscribe_id=*/1,
566 /*track_alias=*/2,
567 /*group_sequence=*/0,
568 /*object_sequence=*/0,
569 /*object_send_order=*/0,
570 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
571 /*payload_length=*/8,
572 };
573 StrictMock<webtransport::test::MockStream> mock_stream;
574 std::unique_ptr<MoqtParserVisitor> object_stream =
575 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
576
577 EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
578 EXPECT_CALL(mock_stream, GetStreamId())
579 .WillRepeatedly(Return(kIncomingUniStreamId));
580 object_stream->OnObjectMessage(object, payload, true);
581 }
582
TEST_F(MoqtSessionTest,IncomingPartialObject)583 TEST_F(MoqtSessionTest, IncomingPartialObject) {
584 MockRemoteTrackVisitor visitor_;
585 FullTrackName ftn("foo", "bar");
586 std::string payload = "deadbeef";
587 MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
588 MoqtObject object = {
589 /*subscribe_id=*/1,
590 /*track_alias=*/2,
591 /*group_sequence=*/0,
592 /*object_sequence=*/0,
593 /*object_send_order=*/0,
594 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
595 /*payload_length=*/16,
596 };
597 StrictMock<webtransport::test::MockStream> mock_stream;
598 std::unique_ptr<MoqtParserVisitor> object_stream =
599 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
600
601 EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
602 EXPECT_CALL(mock_stream, GetStreamId())
603 .WillRepeatedly(Return(kIncomingUniStreamId));
604 object_stream->OnObjectMessage(object, payload, false);
605 object_stream->OnObjectMessage(object, payload, true); // complete the object
606 }
607
TEST_F(MoqtSessionTest,IncomingPartialObjectNoBuffer)608 TEST_F(MoqtSessionTest, IncomingPartialObjectNoBuffer) {
609 MoqtSessionParameters parameters = {
610 /*version=*/MoqtVersion::kDraft03,
611 /*perspective=*/quic::Perspective::IS_CLIENT,
612 /*using_webtrans=*/true,
613 /*path=*/"",
614 /*deliver_partial_objects=*/true,
615 };
616 MoqtSession session(&mock_session_, parameters,
617 session_callbacks_.AsSessionCallbacks());
618 MockRemoteTrackVisitor visitor_;
619 FullTrackName ftn("foo", "bar");
620 std::string payload = "deadbeef";
621 MoqtSessionPeer::CreateRemoteTrack(&session, ftn, &visitor_, 2);
622 MoqtObject object = {
623 /*subscribe_id=*/1,
624 /*track_alias=*/2,
625 /*group_sequence=*/0,
626 /*object_sequence=*/0,
627 /*object_send_order=*/0,
628 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
629 /*payload_length=*/16,
630 };
631 StrictMock<webtransport::test::MockStream> mock_stream;
632 std::unique_ptr<MoqtParserVisitor> object_stream =
633 MoqtSessionPeer::CreateUniStream(&session, &mock_stream);
634
635 EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(2);
636 EXPECT_CALL(mock_stream, GetStreamId())
637 .WillRepeatedly(Return(kIncomingUniStreamId));
638 object_stream->OnObjectMessage(object, payload, false);
639 object_stream->OnObjectMessage(object, payload, true); // complete the object
640 }
641
TEST_F(MoqtSessionTest,ObjectBeforeSubscribeOk)642 TEST_F(MoqtSessionTest, ObjectBeforeSubscribeOk) {
643 MockRemoteTrackVisitor visitor_;
644 FullTrackName ftn("foo", "bar");
645 std::string payload = "deadbeef";
646 MoqtSubscribe subscribe = {
647 /*subscribe_id=*/1,
648 /*track_alias=*/2,
649 /*track_namespace=*/ftn.track_namespace,
650 /*track_name=*/ftn.track_name,
651 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
652 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
653 /*end_group=*/std::nullopt,
654 /*end_object=*/std::nullopt,
655 };
656 MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor_);
657 MoqtObject object = {
658 /*subscribe_id=*/1,
659 /*track_alias=*/2,
660 /*group_sequence=*/0,
661 /*object_sequence=*/0,
662 /*object_send_order=*/0,
663 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
664 /*payload_length=*/8,
665 };
666 StrictMock<webtransport::test::MockStream> mock_stream;
667 std::unique_ptr<MoqtParserVisitor> object_stream =
668 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
669
670 EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _))
671 .WillOnce([&](const FullTrackName& full_track_name,
672 uint64_t group_sequence, uint64_t object_sequence,
673 uint64_t object_send_order,
674 MoqtForwardingPreference forwarding_preference,
675 absl::string_view payload, bool end_of_message) {
676 EXPECT_EQ(full_track_name, ftn);
677 EXPECT_EQ(group_sequence, object.group_id);
678 EXPECT_EQ(object_sequence, object.object_id);
679 });
680 EXPECT_CALL(mock_stream, GetStreamId())
681 .WillRepeatedly(Return(kIncomingUniStreamId));
682 object_stream->OnObjectMessage(object, payload, true);
683
684 // SUBSCRIBE_OK arrives
685 MoqtSubscribeOk ok = {
686 /*subscribe_id=*/1,
687 /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
688 /*largest_id=*/std::nullopt,
689 };
690 StrictMock<webtransport::test::MockStream> mock_control_stream;
691 std::unique_ptr<MoqtParserVisitor> control_stream =
692 MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
693 EXPECT_CALL(visitor_, OnReply(_, _)).Times(1);
694 control_stream->OnSubscribeOkMessage(ok);
695 }
696
TEST_F(MoqtSessionTest,ObjectBeforeSubscribeError)697 TEST_F(MoqtSessionTest, ObjectBeforeSubscribeError) {
698 MockRemoteTrackVisitor visitor;
699 FullTrackName ftn("foo", "bar");
700 std::string payload = "deadbeef";
701 MoqtSubscribe subscribe = {
702 /*subscribe_id=*/1,
703 /*track_alias=*/2,
704 /*track_namespace=*/ftn.track_namespace,
705 /*track_name=*/ftn.track_name,
706 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
707 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
708 /*end_group=*/std::nullopt,
709 /*end_object=*/std::nullopt,
710 };
711 MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
712 MoqtObject object = {
713 /*subscribe_id=*/1,
714 /*track_alias=*/2,
715 /*group_sequence=*/0,
716 /*object_sequence=*/0,
717 /*object_send_order=*/0,
718 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
719 /*payload_length=*/8,
720 };
721 StrictMock<webtransport::test::MockStream> mock_stream;
722 std::unique_ptr<MoqtParserVisitor> object_stream =
723 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
724
725 EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
726 .WillOnce([&](const FullTrackName& full_track_name,
727 uint64_t group_sequence, uint64_t object_sequence,
728 uint64_t object_send_order,
729 MoqtForwardingPreference forwarding_preference,
730 absl::string_view payload, bool end_of_message) {
731 EXPECT_EQ(full_track_name, ftn);
732 EXPECT_EQ(group_sequence, object.group_id);
733 EXPECT_EQ(object_sequence, object.object_id);
734 });
735 EXPECT_CALL(mock_stream, GetStreamId())
736 .WillRepeatedly(Return(kIncomingUniStreamId));
737 object_stream->OnObjectMessage(object, payload, true);
738
739 // SUBSCRIBE_ERROR arrives
740 MoqtSubscribeError subscribe_error = {
741 /*subscribe_id=*/1,
742 /*error_code=*/SubscribeErrorCode::kRetryTrackAlias,
743 /*reason_phrase=*/"foo",
744 /*track_alias =*/3,
745 };
746 StrictMock<webtransport::test::MockStream> mock_control_stream;
747 std::unique_ptr<MoqtParserVisitor> control_stream =
748 MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
749 EXPECT_CALL(mock_session_,
750 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
751 "Received SUBSCRIBE_ERROR after object"))
752 .Times(1);
753 control_stream->OnSubscribeErrorMessage(subscribe_error);
754 }
755
TEST_F(MoqtSessionTest,TwoEarlyObjectsDifferentForwarding)756 TEST_F(MoqtSessionTest, TwoEarlyObjectsDifferentForwarding) {
757 MockRemoteTrackVisitor visitor;
758 FullTrackName ftn("foo", "bar");
759 std::string payload = "deadbeef";
760 MoqtSubscribe subscribe = {
761 /*subscribe_id=*/1,
762 /*track_alias=*/2,
763 /*track_namespace=*/ftn.track_namespace,
764 /*track_name=*/ftn.track_name,
765 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
766 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
767 /*end_group=*/std::nullopt,
768 /*end_object=*/std::nullopt,
769 };
770 MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
771 MoqtObject object = {
772 /*subscribe_id=*/1,
773 /*track_alias=*/2,
774 /*group_sequence=*/0,
775 /*object_sequence=*/0,
776 /*object_send_order=*/0,
777 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
778 /*payload_length=*/8,
779 };
780 StrictMock<webtransport::test::MockStream> mock_stream;
781 std::unique_ptr<MoqtParserVisitor> object_stream =
782 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
783
784 EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
785 .WillOnce([&](const FullTrackName& full_track_name,
786 uint64_t group_sequence, uint64_t object_sequence,
787 uint64_t object_send_order,
788 MoqtForwardingPreference forwarding_preference,
789 absl::string_view payload, bool end_of_message) {
790 EXPECT_EQ(full_track_name, ftn);
791 EXPECT_EQ(group_sequence, object.group_id);
792 EXPECT_EQ(object_sequence, object.object_id);
793 });
794 EXPECT_CALL(mock_stream, GetStreamId())
795 .WillRepeatedly(Return(kIncomingUniStreamId));
796 object_stream->OnObjectMessage(object, payload, true);
797 object.forwarding_preference = MoqtForwardingPreference::kObject;
798 ++object.object_id;
799 EXPECT_CALL(mock_session_,
800 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
801 "Forwarding preference changes mid-track"))
802 .Times(1);
803 object_stream->OnObjectMessage(object, payload, true);
804 }
805
TEST_F(MoqtSessionTest,EarlyObjectForwardingDoesNotMatchTrack)806 TEST_F(MoqtSessionTest, EarlyObjectForwardingDoesNotMatchTrack) {
807 MockRemoteTrackVisitor visitor;
808 FullTrackName ftn("foo", "bar");
809 std::string payload = "deadbeef";
810 MoqtSubscribe subscribe = {
811 /*subscribe_id=*/1,
812 /*track_alias=*/2,
813 /*track_namespace=*/ftn.track_namespace,
814 /*track_name=*/ftn.track_name,
815 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
816 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
817 /*end_group=*/std::nullopt,
818 /*end_object=*/std::nullopt,
819 };
820 MoqtSessionPeer::AddActiveSubscribe(&session_, 1, subscribe, &visitor);
821 MoqtObject object = {
822 /*subscribe_id=*/1,
823 /*track_alias=*/2,
824 /*group_sequence=*/0,
825 /*object_sequence=*/0,
826 /*object_send_order=*/0,
827 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
828 /*payload_length=*/8,
829 };
830 StrictMock<webtransport::test::MockStream> mock_stream;
831 std::unique_ptr<MoqtParserVisitor> object_stream =
832 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
833
834 EXPECT_CALL(visitor, OnObjectFragment(_, _, _, _, _, _, _))
835 .WillOnce([&](const FullTrackName& full_track_name,
836 uint64_t group_sequence, uint64_t object_sequence,
837 uint64_t object_send_order,
838 MoqtForwardingPreference forwarding_preference,
839 absl::string_view payload, bool end_of_message) {
840 EXPECT_EQ(full_track_name, ftn);
841 EXPECT_EQ(group_sequence, object.group_id);
842 EXPECT_EQ(object_sequence, object.object_id);
843 });
844 EXPECT_CALL(mock_stream, GetStreamId())
845 .WillRepeatedly(Return(kIncomingUniStreamId));
846 object_stream->OnObjectMessage(object, payload, true);
847
848 MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor, 2);
849 // The track already exists, and has a different forwarding preference.
850 MoqtSessionPeer::remote_track(&session_, 2)
851 .CheckForwardingPreference(MoqtForwardingPreference::kObject);
852
853 // SUBSCRIBE_OK arrives
854 MoqtSubscribeOk ok = {
855 /*subscribe_id=*/1,
856 /*expires=*/quic::QuicTimeDelta::FromMilliseconds(0),
857 /*largest_id=*/std::nullopt,
858 };
859 StrictMock<webtransport::test::MockStream> mock_control_stream;
860 std::unique_ptr<MoqtParserVisitor> control_stream =
861 MoqtSessionPeer::CreateControlStream(&session_, &mock_control_stream);
862 EXPECT_CALL(mock_session_,
863 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
864 "Forwarding preference different in early objects"))
865 .Times(1);
866 control_stream->OnSubscribeOkMessage(ok);
867 }
868
TEST_F(MoqtSessionTest,CreateUniStreamAndSend)869 TEST_F(MoqtSessionTest, CreateUniStreamAndSend) {
870 StrictMock<webtransport::test::MockStream> mock_stream;
871 FullTrackName ftn("foo", "bar");
872 MockLocalTrackVisitor track_visitor;
873 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
874 &track_visitor);
875 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
876
877 // No subscription; this is a no-op except to update next_sequence.
878 EXPECT_CALL(mock_stream, Writev(_, _)).Times(0);
879 session_.PublishObject(ftn, 4, 1, 0, "deadbeef", true);
880 EXPECT_EQ(MoqtSessionPeer::next_sequence(&session_, ftn), FullSequence(4, 2));
881
882 // Publish in window.
883 EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
884 .WillOnce(Return(true));
885 EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
886 .WillOnce(Return(&mock_stream));
887 EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
888 EXPECT_CALL(mock_stream, GetStreamId())
889 .WillRepeatedly(Return(kOutgoingUniStreamId));
890 // Send on the stream
891 EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
892 .WillOnce(Return(&mock_stream));
893 bool correct_message = false;
894 // Verify first six message fields are sent correctly
895 uint8_t kExpectedMessage[] = {0x00, 0x00, 0x02, 0x05, 0x00, 0x00};
896 EXPECT_CALL(mock_stream, Writev(_, _))
897 .WillOnce([&](absl::Span<const absl::string_view> data,
898 const quiche::StreamWriteOptions& options) {
899 correct_message = (0 == memcmp(data.data()->data(), kExpectedMessage,
900 sizeof(kExpectedMessage)));
901 return absl::OkStatus();
902 });
903 session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true);
904 EXPECT_TRUE(correct_message);
905 }
906
907 // TODO: Test operation with multiple streams.
908
909 // Error cases
910
TEST_F(MoqtSessionTest,CannotOpenUniStream)911 TEST_F(MoqtSessionTest, CannotOpenUniStream) {
912 StrictMock<webtransport::test::MockStream> mock_stream;
913 FullTrackName ftn("foo", "bar");
914 MockLocalTrackVisitor track_visitor;
915 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
916 &track_visitor);
917 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
918 ;
919 EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
920 .WillOnce(Return(false));
921 EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true));
922 }
923
TEST_F(MoqtSessionTest,GetStreamByIdFails)924 TEST_F(MoqtSessionTest, GetStreamByIdFails) {
925 StrictMock<webtransport::test::MockStream> mock_stream;
926 FullTrackName ftn("foo", "bar");
927 MockLocalTrackVisitor track_visitor;
928 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kObject,
929 &track_visitor);
930 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
931 EXPECT_CALL(mock_session_, CanOpenNextOutgoingUnidirectionalStream())
932 .WillOnce(Return(true));
933 EXPECT_CALL(mock_session_, OpenOutgoingUnidirectionalStream())
934 .WillOnce(Return(&mock_stream));
935 EXPECT_CALL(mock_stream, SetVisitor(_)).Times(1);
936 EXPECT_CALL(mock_stream, GetStreamId())
937 .WillRepeatedly(Return(kOutgoingUniStreamId));
938 EXPECT_CALL(mock_session_, GetStreamById(kOutgoingUniStreamId))
939 .WillOnce(Return(nullptr));
940 EXPECT_FALSE(session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true));
941 }
942
TEST_F(MoqtSessionTest,SubscribeProposesBadTrackAlias)943 TEST_F(MoqtSessionTest, SubscribeProposesBadTrackAlias) {
944 MockLocalTrackVisitor local_track_visitor;
945 FullTrackName ftn("foo", "bar");
946 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kGroup,
947 &local_track_visitor);
948 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
949
950 // Peer subscribes.
951 MoqtSubscribe request = {
952 /*subscribe_id=*/1,
953 /*track_alias=*/3, // Doesn't match 2.
954 /*track_namespace=*/"foo",
955 /*track_name=*/"bar",
956 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
957 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
958 /*end_group=*/std::nullopt,
959 /*end_object=*/std::nullopt,
960 /*authorization_info=*/std::nullopt,
961 };
962 StrictMock<webtransport::test::MockStream> mock_stream;
963 std::unique_ptr<MoqtParserVisitor> stream_input =
964 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
965 bool correct_message = true;
966 EXPECT_CALL(mock_stream, Writev(_, _))
967 .WillOnce([&](absl::Span<const absl::string_view> data,
968 const quiche::StreamWriteOptions& options) {
969 correct_message = true;
970 EXPECT_EQ(*ExtractMessageType(data[0]),
971 MoqtMessageType::kSubscribeError);
972 return absl::OkStatus();
973 });
974 stream_input->OnSubscribeMessage(request);
975 EXPECT_TRUE(correct_message);
976 }
977
TEST_F(MoqtSessionTest,OneBidirectionalStreamClient)978 TEST_F(MoqtSessionTest, OneBidirectionalStreamClient) {
979 StrictMock<webtransport::test::MockStream> mock_stream;
980 EXPECT_CALL(mock_session_, OpenOutgoingBidirectionalStream())
981 .WillOnce(Return(&mock_stream));
982 std::unique_ptr<webtransport::StreamVisitor> visitor;
983 // Save a reference to MoqtSession::Stream
984 EXPECT_CALL(mock_stream, SetVisitor(_))
985 .WillOnce([&](std::unique_ptr<webtransport::StreamVisitor> new_visitor) {
986 visitor = std::move(new_visitor);
987 });
988 EXPECT_CALL(mock_stream, GetStreamId())
989 .WillOnce(Return(webtransport::StreamId(4)));
990 EXPECT_CALL(mock_session_, GetStreamById(4)).WillOnce(Return(&mock_stream));
991 bool correct_message = false;
992 EXPECT_CALL(mock_stream, visitor()).WillOnce([&] { return visitor.get(); });
993 EXPECT_CALL(mock_stream, Writev(_, _))
994 .WillOnce([&](absl::Span<const absl::string_view> data,
995 const quiche::StreamWriteOptions& options) {
996 correct_message = true;
997 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kClientSetup);
998 return absl::OkStatus();
999 });
1000 session_.OnSessionReady();
1001 EXPECT_TRUE(correct_message);
1002
1003 // Peer tries to open a bidi stream.
1004 bool reported_error = false;
1005 EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
1006 .WillOnce(Return(&mock_stream));
1007 EXPECT_CALL(mock_session_,
1008 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1009 "Bidirectional stream already open"))
1010 .Times(1);
1011 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
1012 .WillOnce([&](absl::string_view error_message) {
1013 reported_error = (error_message == "Bidirectional stream already open");
1014 });
1015 session_.OnIncomingBidirectionalStreamAvailable();
1016 EXPECT_TRUE(reported_error);
1017 }
1018
TEST_F(MoqtSessionTest,OneBidirectionalStreamServer)1019 TEST_F(MoqtSessionTest, OneBidirectionalStreamServer) {
1020 MoqtSessionParameters server_parameters = {
1021 /*version=*/MoqtVersion::kDraft03,
1022 /*perspective=*/quic::Perspective::IS_SERVER,
1023 /*using_webtrans=*/true,
1024 /*path=*/"",
1025 /*deliver_partial_objects=*/false,
1026 };
1027 MoqtSession server_session(&mock_session_, server_parameters,
1028 session_callbacks_.AsSessionCallbacks());
1029 StrictMock<webtransport::test::MockStream> mock_stream;
1030 std::unique_ptr<MoqtParserVisitor> stream_input =
1031 MoqtSessionPeer::CreateControlStream(&server_session, &mock_stream);
1032 MoqtClientSetup setup = {
1033 /*supported_versions*/ {MoqtVersion::kDraft03},
1034 /*role=*/MoqtRole::kPubSub,
1035 /*path=*/std::nullopt,
1036 };
1037 bool correct_message = false;
1038 EXPECT_CALL(mock_stream, Writev(_, _))
1039 .WillOnce([&](absl::Span<const absl::string_view> data,
1040 const quiche::StreamWriteOptions& options) {
1041 correct_message = true;
1042 EXPECT_EQ(*ExtractMessageType(data[0]), MoqtMessageType::kServerSetup);
1043 return absl::OkStatus();
1044 });
1045 EXPECT_CALL(mock_stream, GetStreamId()).WillOnce(Return(0));
1046 EXPECT_CALL(session_callbacks_.session_established_callback, Call()).Times(1);
1047 stream_input->OnClientSetupMessage(setup);
1048
1049 // Peer tries to open a bidi stream.
1050 bool reported_error = false;
1051 EXPECT_CALL(mock_session_, AcceptIncomingBidirectionalStream())
1052 .WillOnce(Return(&mock_stream));
1053 EXPECT_CALL(mock_session_,
1054 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1055 "Bidirectional stream already open"))
1056 .Times(1);
1057 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_))
1058 .WillOnce([&](absl::string_view error_message) {
1059 reported_error = (error_message == "Bidirectional stream already open");
1060 });
1061 server_session.OnIncomingBidirectionalStreamAvailable();
1062 EXPECT_TRUE(reported_error);
1063 }
1064
TEST_F(MoqtSessionTest,ReceiveUnsubscribe)1065 TEST_F(MoqtSessionTest, ReceiveUnsubscribe) {
1066 FullTrackName ftn("foo", "bar");
1067 MockLocalTrackVisitor visitor;
1068 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kTrack, &visitor);
1069 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 1, 3, 4);
1070 EXPECT_TRUE(session_.HasSubscribers(ftn));
1071 StrictMock<webtransport::test::MockStream> mock_stream;
1072 std::unique_ptr<MoqtParserVisitor> stream_input =
1073 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1074 MoqtUnsubscribe unsubscribe = {
1075 /*subscribe_id=*/0,
1076 };
1077 stream_input->OnUnsubscribeMessage(unsubscribe);
1078 EXPECT_FALSE(session_.HasSubscribers(ftn));
1079 }
1080
TEST_F(MoqtSessionTest,SendDatagram)1081 TEST_F(MoqtSessionTest, SendDatagram) {
1082 FullTrackName ftn("foo", "bar");
1083 MockLocalTrackVisitor track_visitor;
1084 session_.AddLocalTrack(ftn, MoqtForwardingPreference::kDatagram,
1085 &track_visitor);
1086 MoqtSessionPeer::AddSubscription(&session_, ftn, 0, 2, 5, 0);
1087
1088 // Publish in window.
1089 bool correct_message = false;
1090 uint8_t kExpectedMessage[] = {
1091 0x01, 0x00, 0x02, 0x05, 0x00, 0x00, 0x64,
1092 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66,
1093 };
1094 EXPECT_CALL(mock_session_, SendOrQueueDatagram(_))
1095 .WillOnce([&](absl::string_view datagram) {
1096 if (datagram.size() == sizeof(kExpectedMessage)) {
1097 correct_message = (0 == memcmp(datagram.data(), kExpectedMessage,
1098 sizeof(kExpectedMessage)));
1099 }
1100 return webtransport::DatagramStatus(
1101 webtransport::DatagramStatusCode::kSuccess, "");
1102 });
1103 session_.PublishObject(ftn, 5, 0, 0, "deadbeef", true);
1104 EXPECT_TRUE(correct_message);
1105 }
1106
TEST_F(MoqtSessionTest,ReceiveDatagram)1107 TEST_F(MoqtSessionTest, ReceiveDatagram) {
1108 MockRemoteTrackVisitor visitor_;
1109 FullTrackName ftn("foo", "bar");
1110 std::string payload = "deadbeef";
1111 MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
1112 MoqtObject object = {
1113 /*subscribe_id=*/1,
1114 /*track_alias=*/2,
1115 /*group_sequence=*/0,
1116 /*object_sequence=*/0,
1117 /*object_send_order=*/0,
1118 /*forwarding_preference=*/MoqtForwardingPreference::kDatagram,
1119 /*payload_length=*/8,
1120 };
1121 char datagram[] = {0x01, 0x01, 0x02, 0x00, 0x00, 0x00, 0x64,
1122 0x65, 0x61, 0x64, 0x62, 0x65, 0x65, 0x66};
1123 EXPECT_CALL(visitor_,
1124 OnObjectFragment(ftn, object.group_id, object.object_id,
1125 object.object_send_order,
1126 object.forwarding_preference, payload, true))
1127 .Times(1);
1128 session_.OnDatagramReceived(absl::string_view(datagram, sizeof(datagram)));
1129 }
1130
TEST_F(MoqtSessionTest,ForwardingPreferenceMismatch)1131 TEST_F(MoqtSessionTest, ForwardingPreferenceMismatch) {
1132 MockRemoteTrackVisitor visitor_;
1133 FullTrackName ftn("foo", "bar");
1134 std::string payload = "deadbeef";
1135 MoqtSessionPeer::CreateRemoteTrack(&session_, ftn, &visitor_, 2);
1136 MoqtObject object = {
1137 /*subscribe_id=*/1,
1138 /*track_alias=*/2,
1139 /*group_sequence=*/0,
1140 /*object_sequence=*/0,
1141 /*object_send_order=*/0,
1142 /*forwarding_preference=*/MoqtForwardingPreference::kGroup,
1143 /*payload_length=*/8,
1144 };
1145 StrictMock<webtransport::test::MockStream> mock_stream;
1146 std::unique_ptr<MoqtParserVisitor> object_stream =
1147 MoqtSessionPeer::CreateUniStream(&session_, &mock_stream);
1148
1149 EXPECT_CALL(visitor_, OnObjectFragment(_, _, _, _, _, _, _)).Times(1);
1150 EXPECT_CALL(mock_stream, GetStreamId())
1151 .WillRepeatedly(Return(kIncomingUniStreamId));
1152 object_stream->OnObjectMessage(object, payload, true);
1153 ++object.object_id;
1154 object.forwarding_preference = MoqtForwardingPreference::kTrack;
1155 EXPECT_CALL(mock_session_,
1156 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1157 "Forwarding preference changes mid-track"))
1158 .Times(1);
1159 object_stream->OnObjectMessage(object, payload, true);
1160 }
1161
TEST_F(MoqtSessionTest,AnnounceToPublisher)1162 TEST_F(MoqtSessionTest, AnnounceToPublisher) {
1163 MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
1164 testing::MockFunction<void(
1165 absl::string_view track_namespace,
1166 std::optional<MoqtAnnounceErrorReason> error_message)>
1167 announce_resolved_callback;
1168 EXPECT_CALL(announce_resolved_callback, Call(_, _)).Times(1);
1169 session_.Announce("foo", announce_resolved_callback.AsStdFunction());
1170 }
1171
TEST_F(MoqtSessionTest,SubscribeFromPublisher)1172 TEST_F(MoqtSessionTest, SubscribeFromPublisher) {
1173 MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kPublisher);
1174 MoqtSubscribe request = {
1175 /*subscribe_id=*/1,
1176 /*track_alias=*/2,
1177 /*track_namespace=*/"foo",
1178 /*track_name=*/"bar",
1179 /*start_group=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
1180 /*start_object=*/MoqtSubscribeLocation(true, static_cast<uint64_t>(0)),
1181 /*end_group=*/std::nullopt,
1182 /*end_object=*/std::nullopt,
1183 /*authorization_info=*/std::nullopt,
1184 };
1185 StrictMock<webtransport::test::MockStream> mock_stream;
1186 std::unique_ptr<MoqtParserVisitor> stream_input =
1187 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1188 // Request for track returns Protocol Violation.
1189 EXPECT_CALL(mock_session_,
1190 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1191 "Received SUBSCRIBE from publisher"))
1192 .Times(1);
1193 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1);
1194 stream_input->OnSubscribeMessage(request);
1195 }
1196
TEST_F(MoqtSessionTest,AnnounceFromSubscriber)1197 TEST_F(MoqtSessionTest, AnnounceFromSubscriber) {
1198 MoqtSessionPeer::set_peer_role(&session_, MoqtRole::kSubscriber);
1199 StrictMock<webtransport::test::MockStream> mock_stream;
1200 std::unique_ptr<MoqtParserVisitor> stream_input =
1201 MoqtSessionPeer::CreateControlStream(&session_, &mock_stream);
1202 MoqtAnnounce announce = {
1203 /*track_namespace=*/"foo",
1204 };
1205 EXPECT_CALL(mock_session_,
1206 CloseSession(static_cast<uint64_t>(MoqtError::kProtocolViolation),
1207 "Received ANNOUNCE from Subscriber"))
1208 .Times(1);
1209 EXPECT_CALL(session_callbacks_.session_terminated_callback, Call(_)).Times(1);
1210 stream_input->OnAnnounceMessage(announce);
1211 }
1212
1213 // TODO: Cover more error cases in the above
1214
1215 } // namespace test
1216
1217 } // namespace moqt
1218