1 // Copyright 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <cstdint>
6 #include <memory>
7 #include <optional>
8 #include <string>
9
10 #include "absl/strings/string_view.h"
11 #include "quiche/quic/core/crypto/quic_compressed_certs_cache.h"
12 #include "quiche/quic/core/crypto/quic_crypto_client_config.h"
13 #include "quiche/quic/core/crypto/quic_crypto_server_config.h"
14 #include "quiche/quic/core/crypto/quic_random.h"
15 #include "quiche/quic/core/quic_config.h"
16 #include "quiche/quic/core/quic_generic_session.h"
17 #include "quiche/quic/core/quic_types.h"
18 #include "quiche/quic/moqt/moqt_messages.h"
19 #include "quiche/quic/moqt/moqt_session.h"
20 #include "quiche/quic/moqt/moqt_track.h"
21 #include "quiche/quic/moqt/tools/moqt_mock_visitor.h"
22 #include "quiche/quic/test_tools/crypto_test_utils.h"
23 #include "quiche/quic/test_tools/simulator/simulator.h"
24 #include "quiche/quic/test_tools/simulator/test_harness.h"
25 #include "quiche/common/platform/api/quiche_test.h"
26
27 namespace moqt::test {
28
29 namespace {
30
31 using ::quic::simulator::Simulator;
32 using ::testing::_;
33 using ::testing::Assign;
34 using ::testing::Return;
35
36 class ClientEndpoint : public quic::simulator::QuicEndpointWithConnection {
37 public:
ClientEndpoint(Simulator * simulator,const std::string & name,const std::string & peer_name,MoqtVersion version)38 ClientEndpoint(Simulator* simulator, const std::string& name,
39 const std::string& peer_name, MoqtVersion version)
40 : QuicEndpointWithConnection(simulator, name, peer_name,
41 quic::Perspective::IS_CLIENT,
42 quic::GetQuicVersionsForGenericSession()),
43 crypto_config_(
44 quic::test::crypto_test_utils::ProofVerifierForTesting()),
45 quic_session_(connection_.get(), false, nullptr, quic::QuicConfig(),
46 "test.example.com", 443, "moqt", &session_,
47 /*visitor_owned=*/false, nullptr, &crypto_config_),
48 session_(
49 &quic_session_,
50 MoqtSessionParameters{.version = version,
51 .perspective = quic::Perspective::IS_CLIENT,
52 .using_webtrans = false,
53 .deliver_partial_objects = false},
54 callbacks_.AsSessionCallbacks()) {
55 quic_session_.Initialize();
56 }
57
session()58 MoqtSession* session() { return &session_; }
quic_session()59 quic::QuicGenericClientSession* quic_session() { return &quic_session_; }
established_callback()60 testing::MockFunction<void()>& established_callback() {
61 return callbacks_.session_established_callback;
62 }
terminated_callback()63 testing::MockFunction<void(absl::string_view)>& terminated_callback() {
64 return callbacks_.session_terminated_callback;
65 }
callbacks()66 MockSessionCallbacks& callbacks() { return callbacks_; }
67
68 private:
69 MockSessionCallbacks callbacks_;
70 quic::QuicCryptoClientConfig crypto_config_;
71 quic::QuicGenericClientSession quic_session_;
72 MoqtSession session_;
73 };
74
75 class ServerEndpoint : public quic::simulator::QuicEndpointWithConnection {
76 public:
ServerEndpoint(Simulator * simulator,const std::string & name,const std::string & peer_name,MoqtVersion version)77 ServerEndpoint(Simulator* simulator, const std::string& name,
78 const std::string& peer_name, MoqtVersion version)
79 : QuicEndpointWithConnection(simulator, name, peer_name,
80 quic::Perspective::IS_SERVER,
81 quic::GetQuicVersionsForGenericSession()),
82 compressed_certs_cache_(
83 quic::QuicCompressedCertsCache::kQuicCompressedCertsCacheSize),
84 crypto_config_(quic::QuicCryptoServerConfig::TESTING,
85 quic::QuicRandom::GetInstance(),
86 quic::test::crypto_test_utils::ProofSourceForTesting(),
87 quic::KeyExchangeSource::Default()),
88 quic_session_(connection_.get(), false, nullptr, quic::QuicConfig(),
89 "moqt", &session_,
90 /*visitor_owned=*/false, nullptr, &crypto_config_,
91 &compressed_certs_cache_),
92 session_(
93 &quic_session_,
94 MoqtSessionParameters{.version = version,
95 .perspective = quic::Perspective::IS_SERVER,
96 .using_webtrans = false,
97 .deliver_partial_objects = false},
98 callbacks_.AsSessionCallbacks()) {
99 quic_session_.Initialize();
100 }
101
session()102 MoqtSession* session() { return &session_; }
established_callback()103 testing::MockFunction<void()>& established_callback() {
104 return callbacks_.session_established_callback;
105 }
terminated_callback()106 testing::MockFunction<void(absl::string_view)>& terminated_callback() {
107 return callbacks_.session_terminated_callback;
108 }
callbacks()109 MockSessionCallbacks& callbacks() { return callbacks_; }
110
111 private:
112 MockSessionCallbacks callbacks_;
113 quic::QuicCompressedCertsCache compressed_certs_cache_;
114 quic::QuicCryptoServerConfig crypto_config_;
115 quic::QuicGenericServerSession quic_session_;
116 MoqtSession session_;
117 };
118
119 class MoqtIntegrationTest : public quiche::test::QuicheTest {
120 public:
CreateDefaultEndpoints()121 void CreateDefaultEndpoints() {
122 client_ = std::make_unique<ClientEndpoint>(
123 &test_harness_.simulator(), "Client", "Server", MoqtVersion::kDraft03);
124 server_ = std::make_unique<ServerEndpoint>(
125 &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft03);
126 test_harness_.set_client(client_.get());
127 test_harness_.set_server(server_.get());
128 }
129
WireUpEndpoints()130 void WireUpEndpoints() { test_harness_.WireUpEndpoints(); }
131
EstablishSession()132 void EstablishSession() {
133 CreateDefaultEndpoints();
134 WireUpEndpoints();
135
136 client_->quic_session()->CryptoConnect();
137 bool client_established = false;
138 bool server_established = false;
139 EXPECT_CALL(client_->established_callback(), Call())
140 .WillOnce(Assign(&client_established, true));
141 EXPECT_CALL(server_->established_callback(), Call())
142 .WillOnce(Assign(&server_established, true));
143 bool success = test_harness_.RunUntilWithDefaultTimeout(
144 [&]() { return client_established && server_established; });
145 QUICHE_CHECK(success);
146 }
147
148 protected:
149 quic::simulator::TestHarness test_harness_;
150
151 std::unique_ptr<ClientEndpoint> client_;
152 std::unique_ptr<ServerEndpoint> server_;
153 };
154
TEST_F(MoqtIntegrationTest,Handshake)155 TEST_F(MoqtIntegrationTest, Handshake) {
156 CreateDefaultEndpoints();
157 WireUpEndpoints();
158
159 client_->quic_session()->CryptoConnect();
160 bool client_established = false;
161 bool server_established = false;
162 EXPECT_CALL(client_->established_callback(), Call())
163 .WillOnce(Assign(&client_established, true));
164 EXPECT_CALL(server_->established_callback(), Call())
165 .WillOnce(Assign(&server_established, true));
166 bool success = test_harness_.RunUntilWithDefaultTimeout(
167 [&]() { return client_established && server_established; });
168 EXPECT_TRUE(success);
169 }
170
TEST_F(MoqtIntegrationTest,VersionMismatch)171 TEST_F(MoqtIntegrationTest, VersionMismatch) {
172 client_ = std::make_unique<ClientEndpoint>(
173 &test_harness_.simulator(), "Client", "Server",
174 MoqtVersion::kUnrecognizedVersionForTests);
175 server_ = std::make_unique<ServerEndpoint>(
176 &test_harness_.simulator(), "Server", "Client", MoqtVersion::kDraft03);
177 test_harness_.set_client(client_.get());
178 test_harness_.set_server(server_.get());
179 WireUpEndpoints();
180
181 client_->quic_session()->CryptoConnect();
182 bool client_terminated = false;
183 bool server_terminated = false;
184 EXPECT_CALL(client_->established_callback(), Call()).Times(0);
185 EXPECT_CALL(server_->established_callback(), Call()).Times(0);
186 EXPECT_CALL(client_->terminated_callback(), Call(_))
187 .WillOnce(Assign(&client_terminated, true));
188 EXPECT_CALL(server_->terminated_callback(), Call(_))
189 .WillOnce(Assign(&server_terminated, true));
190 bool success = test_harness_.RunUntilWithDefaultTimeout(
191 [&]() { return client_terminated && server_terminated; });
192 EXPECT_TRUE(success);
193 }
194
TEST_F(MoqtIntegrationTest,AnnounceSuccess)195 TEST_F(MoqtIntegrationTest, AnnounceSuccess) {
196 EstablishSession();
197 EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo"))
198 .WillOnce(Return(std::nullopt));
199 testing::MockFunction<void(
200 absl::string_view track_namespace,
201 std::optional<MoqtAnnounceErrorReason> error_message)>
202 announce_callback;
203 client_->session()->Announce("foo", announce_callback.AsStdFunction());
204 bool matches = false;
205 EXPECT_CALL(announce_callback, Call(_, _))
206 .WillOnce([&](absl::string_view track_namespace,
207 std::optional<MoqtAnnounceErrorReason> error) {
208 matches = true;
209 EXPECT_EQ(track_namespace, "foo");
210 EXPECT_FALSE(error.has_value());
211 });
212 bool success =
213 test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
214 EXPECT_TRUE(success);
215 }
216
TEST_F(MoqtIntegrationTest,AnnounceSuccessSubscribeInResponse)217 TEST_F(MoqtIntegrationTest, AnnounceSuccessSubscribeInResponse) {
218 EstablishSession();
219 EXPECT_CALL(server_->callbacks().incoming_announce_callback, Call("foo"))
220 .WillOnce(Return(std::nullopt));
221 MockRemoteTrackVisitor server_visitor;
222 testing::MockFunction<void(
223 absl::string_view track_namespace,
224 std::optional<MoqtAnnounceErrorReason> error_message)>
225 announce_callback;
226 client_->session()->Announce("foo", announce_callback.AsStdFunction());
227 bool matches = false;
228 EXPECT_CALL(announce_callback, Call(_, _))
229 .WillOnce([&](absl::string_view track_namespace,
230 std::optional<MoqtAnnounceErrorReason> error) {
231 EXPECT_EQ(track_namespace, "foo");
232 EXPECT_FALSE(error.has_value());
233 server_->session()->SubscribeCurrentGroup(track_namespace, "/catalog",
234 &server_visitor);
235 });
236 EXPECT_CALL(server_visitor, OnReply(_, _)).WillOnce([&]() {
237 matches = true;
238 });
239 bool success =
240 test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
241 EXPECT_TRUE(success);
242 }
243
TEST_F(MoqtIntegrationTest,AnnounceFailure)244 TEST_F(MoqtIntegrationTest, AnnounceFailure) {
245 EstablishSession();
246 testing::MockFunction<void(
247 absl::string_view track_namespace,
248 std::optional<MoqtAnnounceErrorReason> error_message)>
249 announce_callback;
250 client_->session()->Announce("foo", announce_callback.AsStdFunction());
251 bool matches = false;
252 EXPECT_CALL(announce_callback, Call(_, _))
253 .WillOnce([&](absl::string_view track_namespace,
254 std::optional<MoqtAnnounceErrorReason> error) {
255 matches = true;
256 EXPECT_EQ(track_namespace, "foo");
257 ASSERT_TRUE(error.has_value());
258 EXPECT_EQ(error->error_code,
259 MoqtAnnounceErrorCode::kAnnounceNotSupported);
260 });
261 bool success =
262 test_harness_.RunUntilWithDefaultTimeout([&]() { return matches; });
263 EXPECT_TRUE(success);
264 }
265
TEST_F(MoqtIntegrationTest,SubscribeAbsoluteOk)266 TEST_F(MoqtIntegrationTest, SubscribeAbsoluteOk) {
267 EstablishSession();
268 FullTrackName full_track_name("foo", "bar");
269 MockLocalTrackVisitor server_visitor;
270 MockRemoteTrackVisitor client_visitor;
271 server_->session()->AddLocalTrack(
272 full_track_name, MoqtForwardingPreference::kObject, &server_visitor);
273 std::optional<absl::string_view> expected_reason = std::nullopt;
274 bool received_ok = false;
275 EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
276 .WillOnce([&]() { received_ok = true; });
277 client_->session()->SubscribeAbsolute(full_track_name.track_namespace,
278 full_track_name.track_name, 0, 0,
279 &client_visitor);
280 bool success =
281 test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; });
282 EXPECT_TRUE(success);
283 }
284
TEST_F(MoqtIntegrationTest,SubscribeRelativeOk)285 TEST_F(MoqtIntegrationTest, SubscribeRelativeOk) {
286 EstablishSession();
287 FullTrackName full_track_name("foo", "bar");
288 MockLocalTrackVisitor server_visitor;
289 MockRemoteTrackVisitor client_visitor;
290 server_->session()->AddLocalTrack(
291 full_track_name, MoqtForwardingPreference::kObject, &server_visitor);
292 std::optional<absl::string_view> expected_reason = std::nullopt;
293 bool received_ok = false;
294 EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
295 .WillOnce([&]() { received_ok = true; });
296 client_->session()->SubscribeRelative(full_track_name.track_namespace,
297 full_track_name.track_name, 10, 10,
298 &client_visitor);
299 bool success =
300 test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; });
301 EXPECT_TRUE(success);
302 }
303
TEST_F(MoqtIntegrationTest,SubscribeCurrentGroupOk)304 TEST_F(MoqtIntegrationTest, SubscribeCurrentGroupOk) {
305 EstablishSession();
306 FullTrackName full_track_name("foo", "bar");
307 MockLocalTrackVisitor server_visitor;
308 MockRemoteTrackVisitor client_visitor;
309 server_->session()->AddLocalTrack(
310 full_track_name, MoqtForwardingPreference::kObject, &server_visitor);
311 std::optional<absl::string_view> expected_reason = std::nullopt;
312 bool received_ok = false;
313 EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
314 .WillOnce([&]() { received_ok = true; });
315 client_->session()->SubscribeCurrentGroup(full_track_name.track_namespace,
316 full_track_name.track_name,
317 &client_visitor);
318 bool success =
319 test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; });
320 EXPECT_TRUE(success);
321 }
322
TEST_F(MoqtIntegrationTest,SubscribeError)323 TEST_F(MoqtIntegrationTest, SubscribeError) {
324 EstablishSession();
325 FullTrackName full_track_name("foo", "bar");
326 MockRemoteTrackVisitor client_visitor;
327 std::optional<absl::string_view> expected_reason = "Track does not exist";
328 bool received_ok = false;
329 EXPECT_CALL(client_visitor, OnReply(full_track_name, expected_reason))
330 .WillOnce([&]() { received_ok = true; });
331 client_->session()->SubscribeRelative(full_track_name.track_namespace,
332 full_track_name.track_name, 10, 10,
333 &client_visitor);
334 bool success =
335 test_harness_.RunUntilWithDefaultTimeout([&]() { return received_ok; });
336 EXPECT_TRUE(success);
337 }
338
339 } // namespace
340
341 } // namespace moqt::test
342