1 // Copyright (c) 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
6 #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
7
8 #include <cstddef>
9 #include <cstdint>
10 #include <cstring>
11 #include <memory>
12 #include <optional>
13 #include <vector>
14
15 #include "absl/strings/string_view.h"
16 #include "absl/types/variant.h"
17 #include "quiche/quic/core/quic_data_reader.h"
18 #include "quiche/quic/core/quic_data_writer.h"
19 #include "quiche/quic/core/quic_time.h"
20 #include "quiche/quic/moqt/moqt_messages.h"
21 #include "quiche/quic/platform/api/quic_logging.h"
22 #include "quiche/quic/platform/api/quic_test.h"
23 #include "quiche/common/platform/api/quiche_export.h"
24 #include "quiche/common/quiche_endian.h"
25
26 namespace moqt::test {
27
28 // Base class containing a wire image and the corresponding structured
29 // representation of an example of each message. It allows parser and framer
30 // tests to iterate through all message types without much specialized code.
31 class QUICHE_NO_EXPORT TestMessageBase {
32 public:
TestMessageBase(MoqtMessageType message_type)33 TestMessageBase(MoqtMessageType message_type) : message_type_(message_type) {}
34 virtual ~TestMessageBase() = default;
message_type()35 MoqtMessageType message_type() const { return message_type_; }
36
37 typedef absl::variant<MoqtClientSetup, MoqtServerSetup, MoqtObject,
38 MoqtSubscribe, MoqtSubscribeOk, MoqtSubscribeError,
39 MoqtUnsubscribe, MoqtSubscribeDone, MoqtAnnounce,
40 MoqtAnnounceOk, MoqtAnnounceError, MoqtUnannounce,
41 MoqtGoAway>
42 MessageStructuredData;
43
44 // The total actual size of the message.
total_message_size()45 size_t total_message_size() const { return wire_image_size_; }
46
PacketSample()47 absl::string_view PacketSample() const {
48 return absl::string_view(wire_image_, wire_image_size_);
49 }
50
set_wire_image_size(size_t wire_image_size)51 void set_wire_image_size(size_t wire_image_size) {
52 wire_image_size_ = wire_image_size;
53 }
54
55 // Returns a copy of the structured data for the message.
56 virtual MessageStructuredData structured_data() const = 0;
57
58 // Compares |values| to the derived class's structured data to make sure
59 // they are equal.
60 virtual bool EqualFieldValues(MessageStructuredData& values) const = 0;
61
62 // Expand all varints in the message. This is pure virtual because each
63 // message has a different layout of varints.
64 virtual void ExpandVarints() = 0;
65
66 protected:
SetWireImage(uint8_t * wire_image,size_t wire_image_size)67 void SetWireImage(uint8_t* wire_image, size_t wire_image_size) {
68 memcpy(wire_image_, wire_image, wire_image_size);
69 wire_image_size_ = wire_image_size;
70 }
71
72 // Expands all the varints in the message, alternating between making them 2,
73 // 4, and 8 bytes long. Updates length fields accordingly.
74 // Each character in |varints| corresponds to a byte in the original message.
75 // If there is a 'v', it is a varint that should be expanded. If '-', skip
76 // to the next byte.
ExpandVarintsImpl(absl::string_view varints)77 void ExpandVarintsImpl(absl::string_view varints) {
78 int next_varint_len = 2;
79 char new_wire_image[kMaxMessageHeaderSize + 1];
80 quic::QuicDataReader reader(
81 absl::string_view(wire_image_, wire_image_size_));
82 quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image);
83 size_t i = 0;
84 while (!reader.IsDoneReading()) {
85 if (i >= varints.length() || varints[i++] == '-') {
86 uint8_t byte;
87 reader.ReadUInt8(&byte);
88 writer.WriteUInt8(byte);
89 continue;
90 }
91 uint64_t value;
92 reader.ReadVarInt62(&value);
93 writer.WriteVarInt62WithForcedLength(
94 value, static_cast<quiche::QuicheVariableLengthIntegerLength>(
95 next_varint_len));
96 next_varint_len *= 2;
97 if (next_varint_len == 16) {
98 next_varint_len = 2;
99 }
100 }
101 memcpy(wire_image_, new_wire_image, writer.length());
102 wire_image_size_ = writer.length();
103 }
104
105 protected:
106 MoqtMessageType message_type_;
107
108 private:
109 char wire_image_[kMaxMessageHeaderSize + 20];
110 size_t wire_image_size_;
111 };
112
113 // Base class for the two subtypes of Object Message.
114 class QUICHE_NO_EXPORT ObjectMessage : public TestMessageBase {
115 public:
ObjectMessage(MoqtMessageType type)116 ObjectMessage(MoqtMessageType type) : TestMessageBase(type) {
117 object_.forwarding_preference = GetForwardingPreference(type);
118 }
119
EqualFieldValues(MessageStructuredData & values)120 bool EqualFieldValues(MessageStructuredData& values) const override {
121 auto cast = std::get<MoqtObject>(values);
122 if (cast.subscribe_id != object_.subscribe_id) {
123 QUIC_LOG(INFO) << "OBJECT Track ID mismatch";
124 return false;
125 }
126 if (cast.track_alias != object_.track_alias) {
127 QUIC_LOG(INFO) << "OBJECT Track ID mismatch";
128 return false;
129 }
130 if (cast.group_id != object_.group_id) {
131 QUIC_LOG(INFO) << "OBJECT Group Sequence mismatch";
132 return false;
133 }
134 if (cast.object_id != object_.object_id) {
135 QUIC_LOG(INFO) << "OBJECT Object Sequence mismatch";
136 return false;
137 }
138 if (cast.object_send_order != object_.object_send_order) {
139 QUIC_LOG(INFO) << "OBJECT Object Send Order mismatch";
140 return false;
141 }
142 if (cast.forwarding_preference != object_.forwarding_preference) {
143 QUIC_LOG(INFO) << "OBJECT Object Send Order mismatch";
144 return false;
145 }
146 if (cast.payload_length != object_.payload_length) {
147 QUIC_LOG(INFO) << "OBJECT Payload Length mismatch";
148 return false;
149 }
150 return true;
151 }
152
structured_data()153 MessageStructuredData structured_data() const override {
154 return TestMessageBase::MessageStructuredData(object_);
155 }
156
157 protected:
158 MoqtObject object_ = {
159 /*subscribe_id=*/3,
160 /*track_alias=*/4,
161 /*group_id*/ 5,
162 /*object_id=*/6,
163 /*object_send_order=*/7,
164 /*forwarding_preference=*/MoqtForwardingPreference::kTrack,
165 /*payload_length=*/std::nullopt,
166 };
167 };
168
169 class QUICHE_NO_EXPORT ObjectStreamMessage : public ObjectMessage {
170 public:
ObjectStreamMessage()171 ObjectStreamMessage() : ObjectMessage(MoqtMessageType::kObjectStream) {
172 SetWireImage(raw_packet_, sizeof(raw_packet_));
173 object_.forwarding_preference = MoqtForwardingPreference::kObject;
174 }
175
ExpandVarints()176 void ExpandVarints() override {
177 ExpandVarintsImpl("vvvvvv"); // first six fields are varints
178 }
179
180 private:
181 uint8_t raw_packet_[9] = {
182 0x00, 0x03, 0x04, 0x05, 0x06, 0x07, // varints
183 0x66, 0x6f, 0x6f, // payload = "foo"
184 };
185 };
186
187 class QUICHE_NO_EXPORT ObjectDatagramMessage : public ObjectMessage {
188 public:
ObjectDatagramMessage()189 ObjectDatagramMessage() : ObjectMessage(MoqtMessageType::kObjectDatagram) {
190 SetWireImage(raw_packet_, sizeof(raw_packet_));
191 object_.forwarding_preference = MoqtForwardingPreference::kDatagram;
192 }
193
ExpandVarints()194 void ExpandVarints() override {
195 ExpandVarintsImpl("vvvvvv"); // first six fields are varints
196 }
197
198 private:
199 uint8_t raw_packet_[9] = {
200 0x01, 0x03, 0x04, 0x05, 0x06, 0x07, // varints
201 0x66, 0x6f, 0x6f, // payload = "foo"
202 };
203 };
204
205 // Concatentation of the base header and the object-specific header. Follow-on
206 // object headers are handled in a different class.
207 class QUICHE_NO_EXPORT StreamHeaderTrackMessage : public ObjectMessage {
208 public:
StreamHeaderTrackMessage()209 StreamHeaderTrackMessage()
210 : ObjectMessage(MoqtMessageType::kStreamHeaderTrack) {
211 SetWireImage(raw_packet_, sizeof(raw_packet_));
212 object_.forwarding_preference = MoqtForwardingPreference::kTrack;
213 object_.payload_length = 3;
214 }
215
ExpandVarints()216 void ExpandVarints() override {
217 ExpandVarintsImpl("--vvvvvv"); // six one-byte varints
218 }
219
220 private:
221 // Some tests check that a FIN sent at the halfway point of a message results
222 // in an error. Without the unnecessary expanded varint 0x0405, the halfway
223 // point falls at the end of the Stream Header, which is legal. Expand the
224 // varint so that the FIN would be illegal.
225 uint8_t raw_packet_[11] = {
226 0x40, 0x50, // two byte type field
227 0x03, 0x04, 0x07, // varints
228 0x05, 0x06, // object middler
229 0x03, 0x66, 0x6f, 0x6f, // payload = "foo"
230 };
231 };
232
233 // Used only for tests that process multiple objects on one stream.
234 class QUICHE_NO_EXPORT StreamMiddlerTrackMessage : public ObjectMessage {
235 public:
StreamMiddlerTrackMessage()236 StreamMiddlerTrackMessage()
237 : ObjectMessage(MoqtMessageType::kStreamHeaderTrack) {
238 SetWireImage(raw_packet_, sizeof(raw_packet_));
239 object_.forwarding_preference = MoqtForwardingPreference::kTrack;
240 object_.payload_length = 3;
241 object_.group_id = 9;
242 object_.object_id = 10;
243 }
244
ExpandVarints()245 void ExpandVarints() override { ExpandVarintsImpl("vvv"); }
246
247 private:
248 uint8_t raw_packet_[6] = {
249 0x09, 0x0a, 0x03, 0x62, 0x61, 0x72, // object middler; payload = "bar"
250 };
251 };
252
253 class QUICHE_NO_EXPORT StreamHeaderGroupMessage : public ObjectMessage {
254 public:
StreamHeaderGroupMessage()255 StreamHeaderGroupMessage()
256 : ObjectMessage(MoqtMessageType::kStreamHeaderGroup) {
257 SetWireImage(raw_packet_, sizeof(raw_packet_));
258 object_.forwarding_preference = MoqtForwardingPreference::kGroup;
259 object_.payload_length = 3;
260 }
261
ExpandVarints()262 void ExpandVarints() override {
263 ExpandVarintsImpl("--vvvvvv"); // six one-byte varints
264 }
265
266 private:
267 uint8_t raw_packet_[11] = {
268 0x40, 0x51, // two-byte type field
269 0x03, 0x04, 0x05, 0x07, // varints
270 0x06, 0x03, 0x66, 0x6f, 0x6f, // object middler; payload = "foo"
271 };
272 };
273
274 // Used only for tests that process multiple objects on one stream.
275 class QUICHE_NO_EXPORT StreamMiddlerGroupMessage : public ObjectMessage {
276 public:
StreamMiddlerGroupMessage()277 StreamMiddlerGroupMessage()
278 : ObjectMessage(MoqtMessageType::kStreamHeaderGroup) {
279 SetWireImage(raw_packet_, sizeof(raw_packet_));
280 object_.forwarding_preference = MoqtForwardingPreference::kGroup;
281 object_.payload_length = 3;
282 object_.object_id = 9;
283 }
284
ExpandVarints()285 void ExpandVarints() override { ExpandVarintsImpl("vvv"); }
286
287 private:
288 uint8_t raw_packet_[5] = {
289 0x09, 0x03, 0x62, 0x61, 0x72, // object middler; payload = "bar"
290 };
291 };
292
293 class QUICHE_NO_EXPORT ClientSetupMessage : public TestMessageBase {
294 public:
ClientSetupMessage(bool webtrans)295 explicit ClientSetupMessage(bool webtrans)
296 : TestMessageBase(MoqtMessageType::kClientSetup) {
297 if (webtrans) {
298 // Should not send PATH.
299 client_setup_.path = std::nullopt;
300 raw_packet_[5] = 0x01; // only one parameter
301 SetWireImage(raw_packet_, sizeof(raw_packet_) - 5);
302 } else {
303 SetWireImage(raw_packet_, sizeof(raw_packet_));
304 }
305 }
306
EqualFieldValues(MessageStructuredData & values)307 bool EqualFieldValues(MessageStructuredData& values) const override {
308 auto cast = std::get<MoqtClientSetup>(values);
309 if (cast.supported_versions.size() !=
310 client_setup_.supported_versions.size()) {
311 QUIC_LOG(INFO) << "CLIENT_SETUP number of supported versions mismatch";
312 return false;
313 }
314 for (uint64_t i = 0; i < cast.supported_versions.size(); ++i) {
315 // Listed versions are 1 and 2, in that order.
316 if (cast.supported_versions[i] != client_setup_.supported_versions[i]) {
317 QUIC_LOG(INFO) << "CLIENT_SETUP supported version mismatch";
318 return false;
319 }
320 }
321 if (cast.role != client_setup_.role) {
322 QUIC_LOG(INFO) << "CLIENT_SETUP role mismatch";
323 return false;
324 }
325 if (cast.path != client_setup_.path) {
326 QUIC_LOG(INFO) << "CLIENT_SETUP path mismatch";
327 return false;
328 }
329 return true;
330 }
331
ExpandVarints()332 void ExpandVarints() override {
333 if (client_setup_.path.has_value()) {
334 ExpandVarintsImpl("--vvvvvv-vv---");
335 // first two bytes are already a 2B varint. Also, don't expand parameter
336 // varints because that messes up the parameter length field.
337 } else {
338 ExpandVarintsImpl("--vvvvvv-");
339 }
340 }
341
structured_data()342 MessageStructuredData structured_data() const override {
343 return TestMessageBase::MessageStructuredData(client_setup_);
344 }
345
346 private:
347 uint8_t raw_packet_[14] = {
348 0x40, 0x40, // type
349 0x02, 0x01, 0x02, // versions
350 0x02, // 2 parameters
351 0x00, 0x01, 0x03, // role = PubSub
352 0x01, 0x03, 0x66, 0x6f, 0x6f // path = "foo"
353 };
354 MoqtClientSetup client_setup_ = {
355 /*supported_versions=*/std::vector<MoqtVersion>(
356 {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}),
357 /*role=*/MoqtRole::kPubSub,
358 /*path=*/"foo",
359 };
360 };
361
362 class QUICHE_NO_EXPORT ServerSetupMessage : public TestMessageBase {
363 public:
ServerSetupMessage()364 explicit ServerSetupMessage()
365 : TestMessageBase(MoqtMessageType::kServerSetup) {
366 SetWireImage(raw_packet_, sizeof(raw_packet_));
367 }
368
EqualFieldValues(MessageStructuredData & values)369 bool EqualFieldValues(MessageStructuredData& values) const override {
370 auto cast = std::get<MoqtServerSetup>(values);
371 if (cast.selected_version != server_setup_.selected_version) {
372 QUIC_LOG(INFO) << "SERVER_SETUP selected version mismatch";
373 return false;
374 }
375 if (cast.role != server_setup_.role) {
376 QUIC_LOG(INFO) << "SERVER_SETUP role mismatch";
377 return false;
378 }
379 return true;
380 }
381
ExpandVarints()382 void ExpandVarints() override {
383 ExpandVarintsImpl("--vvvv-"); // first two bytes are already a 2b varint
384 }
385
structured_data()386 MessageStructuredData structured_data() const override {
387 return TestMessageBase::MessageStructuredData(server_setup_);
388 }
389
390 private:
391 uint8_t raw_packet_[7] = {
392 0x40, 0x41, // type
393 0x01, 0x01, // version, one param
394 0x00, 0x01, 0x03, // role = PubSub
395 };
396 MoqtServerSetup server_setup_ = {
397 /*selected_version=*/static_cast<MoqtVersion>(1),
398 /*role=*/MoqtRole::kPubSub,
399 };
400 };
401
402 class QUICHE_NO_EXPORT SubscribeMessage : public TestMessageBase {
403 public:
SubscribeMessage()404 SubscribeMessage() : TestMessageBase(MoqtMessageType::kSubscribe) {
405 SetWireImage(raw_packet_, sizeof(raw_packet_));
406 }
407
EqualFieldValues(MessageStructuredData & values)408 bool EqualFieldValues(MessageStructuredData& values) const override {
409 auto cast = std::get<MoqtSubscribe>(values);
410 if (cast.subscribe_id != subscribe_.subscribe_id) {
411 QUIC_LOG(INFO) << "SUBSCRIBE subscribe ID mismatch";
412 return false;
413 }
414 if (cast.track_alias != subscribe_.track_alias) {
415 QUIC_LOG(INFO) << "SUBSCRIBE track alias mismatch";
416 return false;
417 }
418 if (cast.track_namespace != subscribe_.track_namespace) {
419 QUIC_LOG(INFO) << "SUBSCRIBE track namespace mismatch";
420 return false;
421 }
422 if (cast.track_name != subscribe_.track_name) {
423 QUIC_LOG(INFO) << "SUBSCRIBE track name mismatch";
424 return false;
425 }
426 if (cast.start_group != subscribe_.start_group) {
427 QUIC_LOG(INFO) << "SUBSCRIBE start group mismatch";
428 return false;
429 }
430 if (cast.start_object != subscribe_.start_object) {
431 QUIC_LOG(INFO) << "SUBSCRIBE start object mismatch";
432 return false;
433 }
434 if (cast.end_group != subscribe_.end_group) {
435 QUIC_LOG(INFO) << "SUBSCRIBE end group mismatch";
436 return false;
437 }
438 if (cast.end_object != subscribe_.end_object) {
439 QUIC_LOG(INFO) << "SUBSCRIBE end object mismatch";
440 return false;
441 }
442 if (cast.authorization_info != subscribe_.authorization_info) {
443 QUIC_LOG(INFO) << "SUBSCRIBE authorization info mismatch";
444 return false;
445 }
446 return true;
447 }
448
ExpandVarints()449 void ExpandVarints() override {
450 ExpandVarintsImpl("vvvv---v----vvvvvvvvv");
451 }
452
structured_data()453 MessageStructuredData structured_data() const override {
454 return TestMessageBase::MessageStructuredData(subscribe_);
455 }
456
457 private:
458 uint8_t raw_packet_[24] = {
459 0x03,
460 0x01,
461 0x02, // id and alias
462 0x03,
463 0x66,
464 0x6f,
465 0x6f, // track_namespace = "foo"
466 0x04,
467 0x61,
468 0x62,
469 0x63,
470 0x64, // track_name = "abcd"
471 0x02,
472 0x04, // start_group = 4 (relative previous)
473 0x01,
474 0x01, // start_object = 1 (absolute)
475 0x00, // end_group = none
476 0x00, // end_object = none
477 // TODO(martinduke): figure out what to do about the missing num
478 // parameters field.
479 0x01, // 1 parameter
480 0x02,
481 0x03,
482 0x62,
483 0x61,
484 0x72, // authorization_info = "bar"
485 };
486
487 MoqtSubscribe subscribe_ = {
488 /*subscribe_id=*/1,
489 /*track_alias=*/2,
490 /*track_namespace=*/"foo",
491 /*track_name=*/"abcd",
492 /*start_group=*/MoqtSubscribeLocation(false, (int64_t)(-4)),
493 /*start_object=*/MoqtSubscribeLocation(true, (uint64_t)1),
494 /*end_group=*/std::nullopt,
495 /*end_object=*/std::nullopt,
496 /*authorization_info=*/"bar",
497 };
498 };
499
500 class QUICHE_NO_EXPORT SubscribeOkMessage : public TestMessageBase {
501 public:
SubscribeOkMessage()502 SubscribeOkMessage() : TestMessageBase(MoqtMessageType::kSubscribeOk) {
503 SetWireImage(raw_packet_, sizeof(raw_packet_));
504 }
505
EqualFieldValues(MessageStructuredData & values)506 bool EqualFieldValues(MessageStructuredData& values) const override {
507 auto cast = std::get<MoqtSubscribeOk>(values);
508 if (cast.subscribe_id != subscribe_ok_.subscribe_id) {
509 QUIC_LOG(INFO) << "SUBSCRIBE OK subscribe ID mismatch";
510 return false;
511 }
512 if (cast.expires != subscribe_ok_.expires) {
513 QUIC_LOG(INFO) << "SUBSCRIBE OK expiration mismatch";
514 return false;
515 }
516 if (cast.largest_id != subscribe_ok_.largest_id) {
517 QUIC_LOG(INFO) << "SUBSCRIBE OK largest ID mismatch";
518 return false;
519 }
520 return true;
521 }
522
ExpandVarints()523 void ExpandVarints() override { ExpandVarintsImpl("vvv-vv"); }
524
structured_data()525 MessageStructuredData structured_data() const override {
526 return TestMessageBase::MessageStructuredData(subscribe_ok_);
527 }
528
SetInvalidContentExists()529 void SetInvalidContentExists() {
530 raw_packet_[3] = 0x02;
531 SetWireImage(raw_packet_, sizeof(raw_packet_));
532 }
533
534 private:
535 uint8_t raw_packet_[6] = {
536 0x04, 0x01, 0x03, // subscribe_id = 1, expires = 3
537 0x01, 0x0c, 0x14, // largest_group_id = 12, largest_object_id = 20,
538 };
539
540 MoqtSubscribeOk subscribe_ok_ = {
541 /*subscribe_id=*/1,
542 /*expires=*/quic::QuicTimeDelta::FromMilliseconds(3),
543 /*largest_id=*/FullSequence(12, 20),
544 };
545 };
546
547 class QUICHE_NO_EXPORT SubscribeErrorMessage : public TestMessageBase {
548 public:
SubscribeErrorMessage()549 SubscribeErrorMessage() : TestMessageBase(MoqtMessageType::kSubscribeError) {
550 SetWireImage(raw_packet_, sizeof(raw_packet_));
551 }
552
EqualFieldValues(MessageStructuredData & values)553 bool EqualFieldValues(MessageStructuredData& values) const override {
554 auto cast = std::get<MoqtSubscribeError>(values);
555 if (cast.subscribe_id != subscribe_error_.subscribe_id) {
556 QUIC_LOG(INFO) << "SUBSCRIBE ERROR subscribe_id mismatch";
557 return false;
558 }
559 if (cast.error_code != subscribe_error_.error_code) {
560 QUIC_LOG(INFO) << "SUBSCRIBE ERROR error code mismatch";
561 return false;
562 }
563 if (cast.reason_phrase != subscribe_error_.reason_phrase) {
564 QUIC_LOG(INFO) << "SUBSCRIBE ERROR reason phrase mismatch";
565 return false;
566 }
567 if (cast.track_alias != subscribe_error_.track_alias) {
568 QUIC_LOG(INFO) << "SUBSCRIBE ERROR track alias mismatch";
569 return false;
570 }
571 return true;
572 }
573
ExpandVarints()574 void ExpandVarints() override { ExpandVarintsImpl("vvvv---v"); }
575
structured_data()576 MessageStructuredData structured_data() const override {
577 return TestMessageBase::MessageStructuredData(subscribe_error_);
578 }
579
580 private:
581 uint8_t raw_packet_[8] = {
582 0x05, 0x02, // subscribe_id = 2
583 0x01, // error_code = 2
584 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar"
585 0x04, // track_alias = 4
586 };
587
588 MoqtSubscribeError subscribe_error_ = {
589 /*subscribe_id=*/2,
590 /*subscribe=*/SubscribeErrorCode::kInvalidRange,
591 /*reason_phrase=*/"bar",
592 /*track_alias=*/4,
593 };
594 };
595
596 class QUICHE_NO_EXPORT UnsubscribeMessage : public TestMessageBase {
597 public:
UnsubscribeMessage()598 UnsubscribeMessage() : TestMessageBase(MoqtMessageType::kUnsubscribe) {
599 SetWireImage(raw_packet_, sizeof(raw_packet_));
600 }
601
EqualFieldValues(MessageStructuredData & values)602 bool EqualFieldValues(MessageStructuredData& values) const override {
603 auto cast = std::get<MoqtUnsubscribe>(values);
604 if (cast.subscribe_id != unsubscribe_.subscribe_id) {
605 QUIC_LOG(INFO) << "UNSUBSCRIBE subscribe ID mismatch";
606 return false;
607 }
608 return true;
609 }
610
ExpandVarints()611 void ExpandVarints() override { ExpandVarintsImpl("vv"); }
612
structured_data()613 MessageStructuredData structured_data() const override {
614 return TestMessageBase::MessageStructuredData(unsubscribe_);
615 }
616
617 private:
618 uint8_t raw_packet_[2] = {
619 0x0a, 0x03, // subscribe_id = 3
620 };
621
622 MoqtUnsubscribe unsubscribe_ = {
623 /*subscribe_id=*/3,
624 };
625 };
626
627 class QUICHE_NO_EXPORT SubscribeDoneMessage : public TestMessageBase {
628 public:
SubscribeDoneMessage()629 SubscribeDoneMessage() : TestMessageBase(MoqtMessageType::kSubscribeDone) {
630 SetWireImage(raw_packet_, sizeof(raw_packet_));
631 }
632
EqualFieldValues(MessageStructuredData & values)633 bool EqualFieldValues(MessageStructuredData& values) const override {
634 auto cast = std::get<MoqtSubscribeDone>(values);
635 if (cast.subscribe_id != subscribe_done_.subscribe_id) {
636 QUIC_LOG(INFO) << "SUBSCRIBE_DONE subscribe ID mismatch";
637 return false;
638 }
639 if (cast.status_code != subscribe_done_.status_code) {
640 QUIC_LOG(INFO) << "SUBSCRIBE_DONE status code mismatch";
641 return false;
642 }
643 if (cast.reason_phrase != subscribe_done_.reason_phrase) {
644 QUIC_LOG(INFO) << "SUBSCRIBE_DONE reason phrase mismatch";
645 return false;
646 }
647 if (cast.final_id != subscribe_done_.final_id) {
648 QUIC_LOG(INFO) << "SUBSCRIBE_DONE final ID mismatch";
649 return false;
650 }
651 return true;
652 }
653
ExpandVarints()654 void ExpandVarints() override { ExpandVarintsImpl("vvvv---vv"); }
655
structured_data()656 MessageStructuredData structured_data() const override {
657 return TestMessageBase::MessageStructuredData(subscribe_done_);
658 }
659
SetInvalidContentExists()660 void SetInvalidContentExists() {
661 raw_packet_[6] = 0x02;
662 SetWireImage(raw_packet_, sizeof(raw_packet_));
663 }
664
665 private:
666 uint8_t raw_packet_[9] = {
667 0x0b, 0x02, 0x03, // subscribe_id = 2, error_code = 3,
668 0x02, 0x68, 0x69, // reason_phrase = "hi"
669 0x01, 0x08, 0x0c, // final_id = (8,12)
670 };
671
672 MoqtSubscribeDone subscribe_done_ = {
673 /*subscribe_id=*/2,
674 /*error_code=*/3,
675 /*reason_phrase=*/"hi",
676 /*final_id=*/FullSequence(8, 12),
677 };
678 };
679
680 class QUICHE_NO_EXPORT AnnounceMessage : public TestMessageBase {
681 public:
AnnounceMessage()682 AnnounceMessage() : TestMessageBase(MoqtMessageType::kAnnounce) {
683 SetWireImage(raw_packet_, sizeof(raw_packet_));
684 }
685
EqualFieldValues(MessageStructuredData & values)686 bool EqualFieldValues(MessageStructuredData& values) const override {
687 auto cast = std::get<MoqtAnnounce>(values);
688 if (cast.track_namespace != announce_.track_namespace) {
689 QUIC_LOG(INFO) << "ANNOUNCE MESSAGE track namespace mismatch";
690 return false;
691 }
692 if (cast.authorization_info != announce_.authorization_info) {
693 QUIC_LOG(INFO) << "ANNOUNCE MESSAGE authorization info mismatch";
694 return false;
695 }
696 return true;
697 }
698
ExpandVarints()699 void ExpandVarints() override { ExpandVarintsImpl("vv---vvv---"); }
700
structured_data()701 MessageStructuredData structured_data() const override {
702 return TestMessageBase::MessageStructuredData(announce_);
703 }
704
705 private:
706 uint8_t raw_packet_[11] = {
707 0x06, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo"
708 0x01, // 1 parameter
709 0x02, 0x03, 0x62, 0x61, 0x72, // authorization_info = "bar"
710 };
711
712 MoqtAnnounce announce_ = {
713 /*track_namespace=*/"foo",
714 /*authorization_info=*/"bar",
715 };
716 };
717
718 class QUICHE_NO_EXPORT AnnounceOkMessage : public TestMessageBase {
719 public:
AnnounceOkMessage()720 AnnounceOkMessage() : TestMessageBase(MoqtMessageType::kAnnounceOk) {
721 SetWireImage(raw_packet_, sizeof(raw_packet_));
722 }
723
EqualFieldValues(MessageStructuredData & values)724 bool EqualFieldValues(MessageStructuredData& values) const override {
725 auto cast = std::get<MoqtAnnounceOk>(values);
726 if (cast.track_namespace != announce_ok_.track_namespace) {
727 QUIC_LOG(INFO) << "ANNOUNCE OK MESSAGE track namespace mismatch";
728 return false;
729 }
730 return true;
731 }
732
ExpandVarints()733 void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
734
structured_data()735 MessageStructuredData structured_data() const override {
736 return TestMessageBase::MessageStructuredData(announce_ok_);
737 }
738
739 private:
740 uint8_t raw_packet_[5] = {
741 0x07, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo"
742 };
743
744 MoqtAnnounceOk announce_ok_ = {
745 /*track_namespace=*/"foo",
746 };
747 };
748
749 class QUICHE_NO_EXPORT AnnounceErrorMessage : public TestMessageBase {
750 public:
AnnounceErrorMessage()751 AnnounceErrorMessage() : TestMessageBase(MoqtMessageType::kAnnounceError) {
752 SetWireImage(raw_packet_, sizeof(raw_packet_));
753 }
754
EqualFieldValues(MessageStructuredData & values)755 bool EqualFieldValues(MessageStructuredData& values) const override {
756 auto cast = std::get<MoqtAnnounceError>(values);
757 if (cast.track_namespace != announce_error_.track_namespace) {
758 QUIC_LOG(INFO) << "ANNOUNCE ERROR track namespace mismatch";
759 return false;
760 }
761 if (cast.error_code != announce_error_.error_code) {
762 QUIC_LOG(INFO) << "ANNOUNCE ERROR error code mismatch";
763 return false;
764 }
765 if (cast.reason_phrase != announce_error_.reason_phrase) {
766 QUIC_LOG(INFO) << "ANNOUNCE ERROR reason phrase mismatch";
767 return false;
768 }
769 return true;
770 }
771
ExpandVarints()772 void ExpandVarints() override { ExpandVarintsImpl("vv---vv---"); }
773
structured_data()774 MessageStructuredData structured_data() const override {
775 return TestMessageBase::MessageStructuredData(announce_error_);
776 }
777
778 private:
779 uint8_t raw_packet_[10] = {
780 0x08, 0x03, 0x66, 0x6f, 0x6f, // track_namespace = "foo"
781 0x01, // error_code = 1
782 0x03, 0x62, 0x61, 0x72, // reason_phrase = "bar"
783 };
784
785 MoqtAnnounceError announce_error_ = {
786 /*track_namespace=*/"foo",
787 /*error_code=*/MoqtAnnounceErrorCode::kAnnounceNotSupported,
788 /*reason_phrase=*/"bar",
789 };
790 };
791
792 class QUICHE_NO_EXPORT UnannounceMessage : public TestMessageBase {
793 public:
UnannounceMessage()794 UnannounceMessage() : TestMessageBase(MoqtMessageType::kUnannounce) {
795 SetWireImage(raw_packet_, sizeof(raw_packet_));
796 }
797
EqualFieldValues(MessageStructuredData & values)798 bool EqualFieldValues(MessageStructuredData& values) const override {
799 auto cast = std::get<MoqtUnannounce>(values);
800 if (cast.track_namespace != unannounce_.track_namespace) {
801 QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch";
802 return false;
803 }
804 return true;
805 }
806
ExpandVarints()807 void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
808
structured_data()809 MessageStructuredData structured_data() const override {
810 return TestMessageBase::MessageStructuredData(unannounce_);
811 }
812
813 private:
814 uint8_t raw_packet_[5] = {
815 0x09, 0x03, 0x66, 0x6f, 0x6f, // track_namespace
816 };
817
818 MoqtUnannounce unannounce_ = {
819 /*track_namespace=*/"foo",
820 };
821 };
822
823 class QUICHE_NO_EXPORT GoAwayMessage : public TestMessageBase {
824 public:
GoAwayMessage()825 GoAwayMessage() : TestMessageBase(MoqtMessageType::kGoAway) {
826 SetWireImage(raw_packet_, sizeof(raw_packet_));
827 }
828
EqualFieldValues(MessageStructuredData & values)829 bool EqualFieldValues(MessageStructuredData& values) const override {
830 auto cast = std::get<MoqtGoAway>(values);
831 if (cast.new_session_uri != goaway_.new_session_uri) {
832 QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch";
833 return false;
834 }
835 return true;
836 }
837
ExpandVarints()838 void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
839
structured_data()840 MessageStructuredData structured_data() const override {
841 return TestMessageBase::MessageStructuredData(goaway_);
842 }
843
844 private:
845 uint8_t raw_packet_[5] = {
846 0x10, 0x03, 0x66, 0x6f, 0x6f,
847 };
848
849 MoqtGoAway goaway_ = {
850 /*new_session_uri=*/"foo",
851 };
852 };
853
854 // Factory function for test messages.
CreateTestMessage(MoqtMessageType message_type,bool is_webtrans)855 static inline std::unique_ptr<TestMessageBase> CreateTestMessage(
856 MoqtMessageType message_type, bool is_webtrans) {
857 switch (message_type) {
858 case MoqtMessageType::kObjectStream:
859 return std::make_unique<ObjectStreamMessage>();
860 case MoqtMessageType::kObjectDatagram:
861 return std::make_unique<ObjectDatagramMessage>();
862 case MoqtMessageType::kSubscribe:
863 return std::make_unique<SubscribeMessage>();
864 case MoqtMessageType::kSubscribeOk:
865 return std::make_unique<SubscribeOkMessage>();
866 case MoqtMessageType::kSubscribeError:
867 return std::make_unique<SubscribeErrorMessage>();
868 case MoqtMessageType::kUnsubscribe:
869 return std::make_unique<UnsubscribeMessage>();
870 case MoqtMessageType::kSubscribeDone:
871 return std::make_unique<SubscribeDoneMessage>();
872 case MoqtMessageType::kAnnounce:
873 return std::make_unique<AnnounceMessage>();
874 case MoqtMessageType::kAnnounceOk:
875 return std::make_unique<AnnounceOkMessage>();
876 case MoqtMessageType::kAnnounceError:
877 return std::make_unique<AnnounceErrorMessage>();
878 case MoqtMessageType::kUnannounce:
879 return std::make_unique<UnannounceMessage>();
880 case MoqtMessageType::kGoAway:
881 return std::make_unique<GoAwayMessage>();
882 case MoqtMessageType::kClientSetup:
883 return std::make_unique<ClientSetupMessage>(is_webtrans);
884 case MoqtMessageType::kServerSetup:
885 return std::make_unique<ServerSetupMessage>();
886 case MoqtMessageType::kStreamHeaderTrack:
887 return std::make_unique<StreamHeaderTrackMessage>();
888 case MoqtMessageType::kStreamHeaderGroup:
889 return std::make_unique<StreamHeaderGroupMessage>();
890 default:
891 return nullptr;
892 }
893 }
894
895 } // namespace moqt::test
896
897 #endif // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
898