xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/moqt/test_tools/moqt_test_message.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright (c) 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
6 #define QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
7 
8 #include <cstddef>
9 #include <cstdint>
10 #include <cstring>
11 #include <memory>
12 #include <optional>
13 #include <vector>
14 
15 #include "absl/strings/string_view.h"
16 #include "absl/types/variant.h"
17 #include "quiche/quic/core/quic_data_reader.h"
18 #include "quiche/quic/core/quic_data_writer.h"
19 #include "quiche/quic/core/quic_time.h"
20 #include "quiche/quic/moqt/moqt_messages.h"
21 #include "quiche/quic/platform/api/quic_logging.h"
22 #include "quiche/quic/platform/api/quic_test.h"
23 #include "quiche/common/platform/api/quiche_export.h"
24 #include "quiche/common/quiche_endian.h"
25 
26 namespace moqt::test {
27 
28 // Base class containing a wire image and the corresponding structured
29 // representation of an example of each message. It allows parser and framer
30 // tests to iterate through all message types without much specialized code.
31 class QUICHE_NO_EXPORT TestMessageBase {
32  public:
TestMessageBase(MoqtMessageType message_type)33   TestMessageBase(MoqtMessageType message_type) : message_type_(message_type) {}
34   virtual ~TestMessageBase() = default;
message_type()35   MoqtMessageType message_type() const { return message_type_; }
36 
37   typedef absl::variant<MoqtClientSetup, MoqtServerSetup, MoqtObject,
38                         MoqtSubscribe, MoqtSubscribeOk, MoqtSubscribeError,
39                         MoqtUnsubscribe, MoqtSubscribeDone, MoqtAnnounce,
40                         MoqtAnnounceOk, MoqtAnnounceError, MoqtUnannounce,
41                         MoqtGoAway>
42       MessageStructuredData;
43 
44   // The total actual size of the message.
total_message_size()45   size_t total_message_size() const { return wire_image_size_; }
46 
PacketSample()47   absl::string_view PacketSample() const {
48     return absl::string_view(wire_image_, wire_image_size_);
49   }
50 
set_wire_image_size(size_t wire_image_size)51   void set_wire_image_size(size_t wire_image_size) {
52     wire_image_size_ = wire_image_size;
53   }
54 
55   // Returns a copy of the structured data for the message.
56   virtual MessageStructuredData structured_data() const = 0;
57 
58   // Compares |values| to the derived class's structured data to make sure
59   // they are equal.
60   virtual bool EqualFieldValues(MessageStructuredData& values) const = 0;
61 
62   // Expand all varints in the message. This is pure virtual because each
63   // message has a different layout of varints.
64   virtual void ExpandVarints() = 0;
65 
66  protected:
SetWireImage(uint8_t * wire_image,size_t wire_image_size)67   void SetWireImage(uint8_t* wire_image, size_t wire_image_size) {
68     memcpy(wire_image_, wire_image, wire_image_size);
69     wire_image_size_ = wire_image_size;
70   }
71 
72   // Expands all the varints in the message, alternating between making them 2,
73   // 4, and 8 bytes long. Updates length fields accordingly.
74   // Each character in |varints| corresponds to a byte in the original message.
75   // If there is a 'v', it is a varint that should be expanded. If '-', skip
76   // to the next byte.
ExpandVarintsImpl(absl::string_view varints)77   void ExpandVarintsImpl(absl::string_view varints) {
78     int next_varint_len = 2;
79     char new_wire_image[kMaxMessageHeaderSize + 1];
80     quic::QuicDataReader reader(
81         absl::string_view(wire_image_, wire_image_size_));
82     quic::QuicDataWriter writer(sizeof(new_wire_image), new_wire_image);
83     size_t i = 0;
84     while (!reader.IsDoneReading()) {
85       if (i >= varints.length() || varints[i++] == '-') {
86         uint8_t byte;
87         reader.ReadUInt8(&byte);
88         writer.WriteUInt8(byte);
89         continue;
90       }
91       uint64_t value;
92       reader.ReadVarInt62(&value);
93       writer.WriteVarInt62WithForcedLength(
94           value, static_cast<quiche::QuicheVariableLengthIntegerLength>(
95                      next_varint_len));
96       next_varint_len *= 2;
97       if (next_varint_len == 16) {
98         next_varint_len = 2;
99       }
100     }
101     memcpy(wire_image_, new_wire_image, writer.length());
102     wire_image_size_ = writer.length();
103   }
104 
105  protected:
106   MoqtMessageType message_type_;
107 
108  private:
109   char wire_image_[kMaxMessageHeaderSize + 20];
110   size_t wire_image_size_;
111 };
112 
113 // Base class for the two subtypes of Object Message.
114 class QUICHE_NO_EXPORT ObjectMessage : public TestMessageBase {
115  public:
ObjectMessage(MoqtMessageType type)116   ObjectMessage(MoqtMessageType type) : TestMessageBase(type) {
117     object_.forwarding_preference = GetForwardingPreference(type);
118   }
119 
EqualFieldValues(MessageStructuredData & values)120   bool EqualFieldValues(MessageStructuredData& values) const override {
121     auto cast = std::get<MoqtObject>(values);
122     if (cast.subscribe_id != object_.subscribe_id) {
123       QUIC_LOG(INFO) << "OBJECT Track ID mismatch";
124       return false;
125     }
126     if (cast.track_alias != object_.track_alias) {
127       QUIC_LOG(INFO) << "OBJECT Track ID mismatch";
128       return false;
129     }
130     if (cast.group_id != object_.group_id) {
131       QUIC_LOG(INFO) << "OBJECT Group Sequence mismatch";
132       return false;
133     }
134     if (cast.object_id != object_.object_id) {
135       QUIC_LOG(INFO) << "OBJECT Object Sequence mismatch";
136       return false;
137     }
138     if (cast.object_send_order != object_.object_send_order) {
139       QUIC_LOG(INFO) << "OBJECT Object Send Order mismatch";
140       return false;
141     }
142     if (cast.forwarding_preference != object_.forwarding_preference) {
143       QUIC_LOG(INFO) << "OBJECT Object Send Order mismatch";
144       return false;
145     }
146     if (cast.payload_length != object_.payload_length) {
147       QUIC_LOG(INFO) << "OBJECT Payload Length mismatch";
148       return false;
149     }
150     return true;
151   }
152 
structured_data()153   MessageStructuredData structured_data() const override {
154     return TestMessageBase::MessageStructuredData(object_);
155   }
156 
157  protected:
158   MoqtObject object_ = {
159       /*subscribe_id=*/3,
160       /*track_alias=*/4,
161       /*group_id*/ 5,
162       /*object_id=*/6,
163       /*object_send_order=*/7,
164       /*forwarding_preference=*/MoqtForwardingPreference::kTrack,
165       /*payload_length=*/std::nullopt,
166   };
167 };
168 
169 class QUICHE_NO_EXPORT ObjectStreamMessage : public ObjectMessage {
170  public:
ObjectStreamMessage()171   ObjectStreamMessage() : ObjectMessage(MoqtMessageType::kObjectStream) {
172     SetWireImage(raw_packet_, sizeof(raw_packet_));
173     object_.forwarding_preference = MoqtForwardingPreference::kObject;
174   }
175 
ExpandVarints()176   void ExpandVarints() override {
177     ExpandVarintsImpl("vvvvvv");  // first six fields are varints
178   }
179 
180  private:
181   uint8_t raw_packet_[9] = {
182       0x00, 0x03, 0x04, 0x05, 0x06, 0x07,  // varints
183       0x66, 0x6f, 0x6f,                    // payload = "foo"
184   };
185 };
186 
187 class QUICHE_NO_EXPORT ObjectDatagramMessage : public ObjectMessage {
188  public:
ObjectDatagramMessage()189   ObjectDatagramMessage() : ObjectMessage(MoqtMessageType::kObjectDatagram) {
190     SetWireImage(raw_packet_, sizeof(raw_packet_));
191     object_.forwarding_preference = MoqtForwardingPreference::kDatagram;
192   }
193 
ExpandVarints()194   void ExpandVarints() override {
195     ExpandVarintsImpl("vvvvvv");  // first six fields are varints
196   }
197 
198  private:
199   uint8_t raw_packet_[9] = {
200       0x01, 0x03, 0x04, 0x05, 0x06, 0x07,  // varints
201       0x66, 0x6f, 0x6f,                    // payload = "foo"
202   };
203 };
204 
205 // Concatentation of the base header and the object-specific header. Follow-on
206 // object headers are handled in a different class.
207 class QUICHE_NO_EXPORT StreamHeaderTrackMessage : public ObjectMessage {
208  public:
StreamHeaderTrackMessage()209   StreamHeaderTrackMessage()
210       : ObjectMessage(MoqtMessageType::kStreamHeaderTrack) {
211     SetWireImage(raw_packet_, sizeof(raw_packet_));
212     object_.forwarding_preference = MoqtForwardingPreference::kTrack;
213     object_.payload_length = 3;
214   }
215 
ExpandVarints()216   void ExpandVarints() override {
217     ExpandVarintsImpl("--vvvvvv");  // six one-byte varints
218   }
219 
220  private:
221   // Some tests check that a FIN sent at the halfway point of a message results
222   // in an error. Without the unnecessary expanded varint 0x0405, the halfway
223   // point falls at the end of the Stream Header, which is legal. Expand the
224   // varint so that the FIN would be illegal.
225   uint8_t raw_packet_[11] = {
226       0x40, 0x50,              // two byte type field
227       0x03, 0x04, 0x07,        // varints
228       0x05, 0x06,              // object middler
229       0x03, 0x66, 0x6f, 0x6f,  // payload = "foo"
230   };
231 };
232 
233 // Used only for tests that process multiple objects on one stream.
234 class QUICHE_NO_EXPORT StreamMiddlerTrackMessage : public ObjectMessage {
235  public:
StreamMiddlerTrackMessage()236   StreamMiddlerTrackMessage()
237       : ObjectMessage(MoqtMessageType::kStreamHeaderTrack) {
238     SetWireImage(raw_packet_, sizeof(raw_packet_));
239     object_.forwarding_preference = MoqtForwardingPreference::kTrack;
240     object_.payload_length = 3;
241     object_.group_id = 9;
242     object_.object_id = 10;
243   }
244 
ExpandVarints()245   void ExpandVarints() override { ExpandVarintsImpl("vvv"); }
246 
247  private:
248   uint8_t raw_packet_[6] = {
249       0x09, 0x0a, 0x03, 0x62, 0x61, 0x72,  // object middler; payload = "bar"
250   };
251 };
252 
253 class QUICHE_NO_EXPORT StreamHeaderGroupMessage : public ObjectMessage {
254  public:
StreamHeaderGroupMessage()255   StreamHeaderGroupMessage()
256       : ObjectMessage(MoqtMessageType::kStreamHeaderGroup) {
257     SetWireImage(raw_packet_, sizeof(raw_packet_));
258     object_.forwarding_preference = MoqtForwardingPreference::kGroup;
259     object_.payload_length = 3;
260   }
261 
ExpandVarints()262   void ExpandVarints() override {
263     ExpandVarintsImpl("--vvvvvv");  // six one-byte varints
264   }
265 
266  private:
267   uint8_t raw_packet_[11] = {
268       0x40, 0x51,                    // two-byte type field
269       0x03, 0x04, 0x05, 0x07,        // varints
270       0x06, 0x03, 0x66, 0x6f, 0x6f,  // object middler; payload = "foo"
271   };
272 };
273 
274 // Used only for tests that process multiple objects on one stream.
275 class QUICHE_NO_EXPORT StreamMiddlerGroupMessage : public ObjectMessage {
276  public:
StreamMiddlerGroupMessage()277   StreamMiddlerGroupMessage()
278       : ObjectMessage(MoqtMessageType::kStreamHeaderGroup) {
279     SetWireImage(raw_packet_, sizeof(raw_packet_));
280     object_.forwarding_preference = MoqtForwardingPreference::kGroup;
281     object_.payload_length = 3;
282     object_.object_id = 9;
283   }
284 
ExpandVarints()285   void ExpandVarints() override { ExpandVarintsImpl("vvv"); }
286 
287  private:
288   uint8_t raw_packet_[5] = {
289       0x09, 0x03, 0x62, 0x61, 0x72,  // object middler; payload = "bar"
290   };
291 };
292 
293 class QUICHE_NO_EXPORT ClientSetupMessage : public TestMessageBase {
294  public:
ClientSetupMessage(bool webtrans)295   explicit ClientSetupMessage(bool webtrans)
296       : TestMessageBase(MoqtMessageType::kClientSetup) {
297     if (webtrans) {
298       // Should not send PATH.
299       client_setup_.path = std::nullopt;
300       raw_packet_[5] = 0x01;  // only one parameter
301       SetWireImage(raw_packet_, sizeof(raw_packet_) - 5);
302     } else {
303       SetWireImage(raw_packet_, sizeof(raw_packet_));
304     }
305   }
306 
EqualFieldValues(MessageStructuredData & values)307   bool EqualFieldValues(MessageStructuredData& values) const override {
308     auto cast = std::get<MoqtClientSetup>(values);
309     if (cast.supported_versions.size() !=
310         client_setup_.supported_versions.size()) {
311       QUIC_LOG(INFO) << "CLIENT_SETUP number of supported versions mismatch";
312       return false;
313     }
314     for (uint64_t i = 0; i < cast.supported_versions.size(); ++i) {
315       // Listed versions are 1 and 2, in that order.
316       if (cast.supported_versions[i] != client_setup_.supported_versions[i]) {
317         QUIC_LOG(INFO) << "CLIENT_SETUP supported version mismatch";
318         return false;
319       }
320     }
321     if (cast.role != client_setup_.role) {
322       QUIC_LOG(INFO) << "CLIENT_SETUP role mismatch";
323       return false;
324     }
325     if (cast.path != client_setup_.path) {
326       QUIC_LOG(INFO) << "CLIENT_SETUP path mismatch";
327       return false;
328     }
329     return true;
330   }
331 
ExpandVarints()332   void ExpandVarints() override {
333     if (client_setup_.path.has_value()) {
334       ExpandVarintsImpl("--vvvvvv-vv---");
335       // first two bytes are already a 2B varint. Also, don't expand parameter
336       // varints because that messes up the parameter length field.
337     } else {
338       ExpandVarintsImpl("--vvvvvv-");
339     }
340   }
341 
structured_data()342   MessageStructuredData structured_data() const override {
343     return TestMessageBase::MessageStructuredData(client_setup_);
344   }
345 
346  private:
347   uint8_t raw_packet_[14] = {
348       0x40, 0x40,                   // type
349       0x02, 0x01, 0x02,             // versions
350       0x02,                         // 2 parameters
351       0x00, 0x01, 0x03,             // role = PubSub
352       0x01, 0x03, 0x66, 0x6f, 0x6f  // path = "foo"
353   };
354   MoqtClientSetup client_setup_ = {
355       /*supported_versions=*/std::vector<MoqtVersion>(
356           {static_cast<MoqtVersion>(1), static_cast<MoqtVersion>(2)}),
357       /*role=*/MoqtRole::kPubSub,
358       /*path=*/"foo",
359   };
360 };
361 
362 class QUICHE_NO_EXPORT ServerSetupMessage : public TestMessageBase {
363  public:
ServerSetupMessage()364   explicit ServerSetupMessage()
365       : TestMessageBase(MoqtMessageType::kServerSetup) {
366     SetWireImage(raw_packet_, sizeof(raw_packet_));
367   }
368 
EqualFieldValues(MessageStructuredData & values)369   bool EqualFieldValues(MessageStructuredData& values) const override {
370     auto cast = std::get<MoqtServerSetup>(values);
371     if (cast.selected_version != server_setup_.selected_version) {
372       QUIC_LOG(INFO) << "SERVER_SETUP selected version mismatch";
373       return false;
374     }
375     if (cast.role != server_setup_.role) {
376       QUIC_LOG(INFO) << "SERVER_SETUP role mismatch";
377       return false;
378     }
379     return true;
380   }
381 
ExpandVarints()382   void ExpandVarints() override {
383     ExpandVarintsImpl("--vvvv-");  // first two bytes are already a 2b varint
384   }
385 
structured_data()386   MessageStructuredData structured_data() const override {
387     return TestMessageBase::MessageStructuredData(server_setup_);
388   }
389 
390  private:
391   uint8_t raw_packet_[7] = {
392       0x40, 0x41,        // type
393       0x01, 0x01,        // version, one param
394       0x00, 0x01, 0x03,  // role = PubSub
395   };
396   MoqtServerSetup server_setup_ = {
397       /*selected_version=*/static_cast<MoqtVersion>(1),
398       /*role=*/MoqtRole::kPubSub,
399   };
400 };
401 
402 class QUICHE_NO_EXPORT SubscribeMessage : public TestMessageBase {
403  public:
SubscribeMessage()404   SubscribeMessage() : TestMessageBase(MoqtMessageType::kSubscribe) {
405     SetWireImage(raw_packet_, sizeof(raw_packet_));
406   }
407 
EqualFieldValues(MessageStructuredData & values)408   bool EqualFieldValues(MessageStructuredData& values) const override {
409     auto cast = std::get<MoqtSubscribe>(values);
410     if (cast.subscribe_id != subscribe_.subscribe_id) {
411       QUIC_LOG(INFO) << "SUBSCRIBE subscribe ID mismatch";
412       return false;
413     }
414     if (cast.track_alias != subscribe_.track_alias) {
415       QUIC_LOG(INFO) << "SUBSCRIBE track alias mismatch";
416       return false;
417     }
418     if (cast.track_namespace != subscribe_.track_namespace) {
419       QUIC_LOG(INFO) << "SUBSCRIBE track namespace mismatch";
420       return false;
421     }
422     if (cast.track_name != subscribe_.track_name) {
423       QUIC_LOG(INFO) << "SUBSCRIBE track name mismatch";
424       return false;
425     }
426     if (cast.start_group != subscribe_.start_group) {
427       QUIC_LOG(INFO) << "SUBSCRIBE start group mismatch";
428       return false;
429     }
430     if (cast.start_object != subscribe_.start_object) {
431       QUIC_LOG(INFO) << "SUBSCRIBE start object mismatch";
432       return false;
433     }
434     if (cast.end_group != subscribe_.end_group) {
435       QUIC_LOG(INFO) << "SUBSCRIBE end group mismatch";
436       return false;
437     }
438     if (cast.end_object != subscribe_.end_object) {
439       QUIC_LOG(INFO) << "SUBSCRIBE end object mismatch";
440       return false;
441     }
442     if (cast.authorization_info != subscribe_.authorization_info) {
443       QUIC_LOG(INFO) << "SUBSCRIBE authorization info mismatch";
444       return false;
445     }
446     return true;
447   }
448 
ExpandVarints()449   void ExpandVarints() override {
450     ExpandVarintsImpl("vvvv---v----vvvvvvvvv");
451   }
452 
structured_data()453   MessageStructuredData structured_data() const override {
454     return TestMessageBase::MessageStructuredData(subscribe_);
455   }
456 
457  private:
458   uint8_t raw_packet_[24] = {
459     0x03,
460     0x01,
461     0x02,  // id and alias
462     0x03,
463     0x66,
464     0x6f,
465     0x6f,  // track_namespace = "foo"
466     0x04,
467     0x61,
468     0x62,
469     0x63,
470     0x64,  // track_name = "abcd"
471     0x02,
472     0x04,  // start_group = 4 (relative previous)
473     0x01,
474     0x01,  // start_object = 1 (absolute)
475     0x00,  // end_group = none
476     0x00,  // end_object = none
477            // TODO(martinduke): figure out what to do about the missing num
478            // parameters field.
479     0x01,  // 1 parameter
480     0x02,
481     0x03,
482     0x62,
483     0x61,
484     0x72,  // authorization_info = "bar"
485   };
486 
487   MoqtSubscribe subscribe_ = {
488       /*subscribe_id=*/1,
489       /*track_alias=*/2,
490       /*track_namespace=*/"foo",
491       /*track_name=*/"abcd",
492       /*start_group=*/MoqtSubscribeLocation(false, (int64_t)(-4)),
493       /*start_object=*/MoqtSubscribeLocation(true, (uint64_t)1),
494       /*end_group=*/std::nullopt,
495       /*end_object=*/std::nullopt,
496       /*authorization_info=*/"bar",
497   };
498 };
499 
500 class QUICHE_NO_EXPORT SubscribeOkMessage : public TestMessageBase {
501  public:
SubscribeOkMessage()502   SubscribeOkMessage() : TestMessageBase(MoqtMessageType::kSubscribeOk) {
503     SetWireImage(raw_packet_, sizeof(raw_packet_));
504   }
505 
EqualFieldValues(MessageStructuredData & values)506   bool EqualFieldValues(MessageStructuredData& values) const override {
507     auto cast = std::get<MoqtSubscribeOk>(values);
508     if (cast.subscribe_id != subscribe_ok_.subscribe_id) {
509       QUIC_LOG(INFO) << "SUBSCRIBE OK subscribe ID mismatch";
510       return false;
511     }
512     if (cast.expires != subscribe_ok_.expires) {
513       QUIC_LOG(INFO) << "SUBSCRIBE OK expiration mismatch";
514       return false;
515     }
516     if (cast.largest_id != subscribe_ok_.largest_id) {
517       QUIC_LOG(INFO) << "SUBSCRIBE OK largest ID mismatch";
518       return false;
519     }
520     return true;
521   }
522 
ExpandVarints()523   void ExpandVarints() override { ExpandVarintsImpl("vvv-vv"); }
524 
structured_data()525   MessageStructuredData structured_data() const override {
526     return TestMessageBase::MessageStructuredData(subscribe_ok_);
527   }
528 
SetInvalidContentExists()529   void SetInvalidContentExists() {
530     raw_packet_[3] = 0x02;
531     SetWireImage(raw_packet_, sizeof(raw_packet_));
532   }
533 
534  private:
535   uint8_t raw_packet_[6] = {
536       0x04, 0x01, 0x03,  // subscribe_id = 1, expires = 3
537       0x01, 0x0c, 0x14,  // largest_group_id = 12, largest_object_id = 20,
538   };
539 
540   MoqtSubscribeOk subscribe_ok_ = {
541       /*subscribe_id=*/1,
542       /*expires=*/quic::QuicTimeDelta::FromMilliseconds(3),
543       /*largest_id=*/FullSequence(12, 20),
544   };
545 };
546 
547 class QUICHE_NO_EXPORT SubscribeErrorMessage : public TestMessageBase {
548  public:
SubscribeErrorMessage()549   SubscribeErrorMessage() : TestMessageBase(MoqtMessageType::kSubscribeError) {
550     SetWireImage(raw_packet_, sizeof(raw_packet_));
551   }
552 
EqualFieldValues(MessageStructuredData & values)553   bool EqualFieldValues(MessageStructuredData& values) const override {
554     auto cast = std::get<MoqtSubscribeError>(values);
555     if (cast.subscribe_id != subscribe_error_.subscribe_id) {
556       QUIC_LOG(INFO) << "SUBSCRIBE ERROR subscribe_id mismatch";
557       return false;
558     }
559     if (cast.error_code != subscribe_error_.error_code) {
560       QUIC_LOG(INFO) << "SUBSCRIBE ERROR error code mismatch";
561       return false;
562     }
563     if (cast.reason_phrase != subscribe_error_.reason_phrase) {
564       QUIC_LOG(INFO) << "SUBSCRIBE ERROR reason phrase mismatch";
565       return false;
566     }
567     if (cast.track_alias != subscribe_error_.track_alias) {
568       QUIC_LOG(INFO) << "SUBSCRIBE ERROR track alias mismatch";
569       return false;
570     }
571     return true;
572   }
573 
ExpandVarints()574   void ExpandVarints() override { ExpandVarintsImpl("vvvv---v"); }
575 
structured_data()576   MessageStructuredData structured_data() const override {
577     return TestMessageBase::MessageStructuredData(subscribe_error_);
578   }
579 
580  private:
581   uint8_t raw_packet_[8] = {
582       0x05, 0x02,              // subscribe_id = 2
583       0x01,                    // error_code = 2
584       0x03, 0x62, 0x61, 0x72,  // reason_phrase = "bar"
585       0x04,                    // track_alias = 4
586   };
587 
588   MoqtSubscribeError subscribe_error_ = {
589       /*subscribe_id=*/2,
590       /*subscribe=*/SubscribeErrorCode::kInvalidRange,
591       /*reason_phrase=*/"bar",
592       /*track_alias=*/4,
593   };
594 };
595 
596 class QUICHE_NO_EXPORT UnsubscribeMessage : public TestMessageBase {
597  public:
UnsubscribeMessage()598   UnsubscribeMessage() : TestMessageBase(MoqtMessageType::kUnsubscribe) {
599     SetWireImage(raw_packet_, sizeof(raw_packet_));
600   }
601 
EqualFieldValues(MessageStructuredData & values)602   bool EqualFieldValues(MessageStructuredData& values) const override {
603     auto cast = std::get<MoqtUnsubscribe>(values);
604     if (cast.subscribe_id != unsubscribe_.subscribe_id) {
605       QUIC_LOG(INFO) << "UNSUBSCRIBE subscribe ID mismatch";
606       return false;
607     }
608     return true;
609   }
610 
ExpandVarints()611   void ExpandVarints() override { ExpandVarintsImpl("vv"); }
612 
structured_data()613   MessageStructuredData structured_data() const override {
614     return TestMessageBase::MessageStructuredData(unsubscribe_);
615   }
616 
617  private:
618   uint8_t raw_packet_[2] = {
619       0x0a, 0x03,  // subscribe_id = 3
620   };
621 
622   MoqtUnsubscribe unsubscribe_ = {
623       /*subscribe_id=*/3,
624   };
625 };
626 
627 class QUICHE_NO_EXPORT SubscribeDoneMessage : public TestMessageBase {
628  public:
SubscribeDoneMessage()629   SubscribeDoneMessage() : TestMessageBase(MoqtMessageType::kSubscribeDone) {
630     SetWireImage(raw_packet_, sizeof(raw_packet_));
631   }
632 
EqualFieldValues(MessageStructuredData & values)633   bool EqualFieldValues(MessageStructuredData& values) const override {
634     auto cast = std::get<MoqtSubscribeDone>(values);
635     if (cast.subscribe_id != subscribe_done_.subscribe_id) {
636       QUIC_LOG(INFO) << "SUBSCRIBE_DONE subscribe ID mismatch";
637       return false;
638     }
639     if (cast.status_code != subscribe_done_.status_code) {
640       QUIC_LOG(INFO) << "SUBSCRIBE_DONE status code mismatch";
641       return false;
642     }
643     if (cast.reason_phrase != subscribe_done_.reason_phrase) {
644       QUIC_LOG(INFO) << "SUBSCRIBE_DONE reason phrase mismatch";
645       return false;
646     }
647     if (cast.final_id != subscribe_done_.final_id) {
648       QUIC_LOG(INFO) << "SUBSCRIBE_DONE final ID mismatch";
649       return false;
650     }
651     return true;
652   }
653 
ExpandVarints()654   void ExpandVarints() override { ExpandVarintsImpl("vvvv---vv"); }
655 
structured_data()656   MessageStructuredData structured_data() const override {
657     return TestMessageBase::MessageStructuredData(subscribe_done_);
658   }
659 
SetInvalidContentExists()660   void SetInvalidContentExists() {
661     raw_packet_[6] = 0x02;
662     SetWireImage(raw_packet_, sizeof(raw_packet_));
663   }
664 
665  private:
666   uint8_t raw_packet_[9] = {
667       0x0b, 0x02, 0x03,  // subscribe_id = 2, error_code = 3,
668       0x02, 0x68, 0x69,  // reason_phrase = "hi"
669       0x01, 0x08, 0x0c,  // final_id = (8,12)
670   };
671 
672   MoqtSubscribeDone subscribe_done_ = {
673       /*subscribe_id=*/2,
674       /*error_code=*/3,
675       /*reason_phrase=*/"hi",
676       /*final_id=*/FullSequence(8, 12),
677   };
678 };
679 
680 class QUICHE_NO_EXPORT AnnounceMessage : public TestMessageBase {
681  public:
AnnounceMessage()682   AnnounceMessage() : TestMessageBase(MoqtMessageType::kAnnounce) {
683     SetWireImage(raw_packet_, sizeof(raw_packet_));
684   }
685 
EqualFieldValues(MessageStructuredData & values)686   bool EqualFieldValues(MessageStructuredData& values) const override {
687     auto cast = std::get<MoqtAnnounce>(values);
688     if (cast.track_namespace != announce_.track_namespace) {
689       QUIC_LOG(INFO) << "ANNOUNCE MESSAGE track namespace mismatch";
690       return false;
691     }
692     if (cast.authorization_info != announce_.authorization_info) {
693       QUIC_LOG(INFO) << "ANNOUNCE MESSAGE authorization info mismatch";
694       return false;
695     }
696     return true;
697   }
698 
ExpandVarints()699   void ExpandVarints() override { ExpandVarintsImpl("vv---vvv---"); }
700 
structured_data()701   MessageStructuredData structured_data() const override {
702     return TestMessageBase::MessageStructuredData(announce_);
703   }
704 
705  private:
706   uint8_t raw_packet_[11] = {
707       0x06, 0x03, 0x66, 0x6f, 0x6f,  // track_namespace = "foo"
708       0x01,                          // 1 parameter
709       0x02, 0x03, 0x62, 0x61, 0x72,  // authorization_info = "bar"
710   };
711 
712   MoqtAnnounce announce_ = {
713       /*track_namespace=*/"foo",
714       /*authorization_info=*/"bar",
715   };
716 };
717 
718 class QUICHE_NO_EXPORT AnnounceOkMessage : public TestMessageBase {
719  public:
AnnounceOkMessage()720   AnnounceOkMessage() : TestMessageBase(MoqtMessageType::kAnnounceOk) {
721     SetWireImage(raw_packet_, sizeof(raw_packet_));
722   }
723 
EqualFieldValues(MessageStructuredData & values)724   bool EqualFieldValues(MessageStructuredData& values) const override {
725     auto cast = std::get<MoqtAnnounceOk>(values);
726     if (cast.track_namespace != announce_ok_.track_namespace) {
727       QUIC_LOG(INFO) << "ANNOUNCE OK MESSAGE track namespace mismatch";
728       return false;
729     }
730     return true;
731   }
732 
ExpandVarints()733   void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
734 
structured_data()735   MessageStructuredData structured_data() const override {
736     return TestMessageBase::MessageStructuredData(announce_ok_);
737   }
738 
739  private:
740   uint8_t raw_packet_[5] = {
741       0x07, 0x03, 0x66, 0x6f, 0x6f,  // track_namespace = "foo"
742   };
743 
744   MoqtAnnounceOk announce_ok_ = {
745       /*track_namespace=*/"foo",
746   };
747 };
748 
749 class QUICHE_NO_EXPORT AnnounceErrorMessage : public TestMessageBase {
750  public:
AnnounceErrorMessage()751   AnnounceErrorMessage() : TestMessageBase(MoqtMessageType::kAnnounceError) {
752     SetWireImage(raw_packet_, sizeof(raw_packet_));
753   }
754 
EqualFieldValues(MessageStructuredData & values)755   bool EqualFieldValues(MessageStructuredData& values) const override {
756     auto cast = std::get<MoqtAnnounceError>(values);
757     if (cast.track_namespace != announce_error_.track_namespace) {
758       QUIC_LOG(INFO) << "ANNOUNCE ERROR track namespace mismatch";
759       return false;
760     }
761     if (cast.error_code != announce_error_.error_code) {
762       QUIC_LOG(INFO) << "ANNOUNCE ERROR error code mismatch";
763       return false;
764     }
765     if (cast.reason_phrase != announce_error_.reason_phrase) {
766       QUIC_LOG(INFO) << "ANNOUNCE ERROR reason phrase mismatch";
767       return false;
768     }
769     return true;
770   }
771 
ExpandVarints()772   void ExpandVarints() override { ExpandVarintsImpl("vv---vv---"); }
773 
structured_data()774   MessageStructuredData structured_data() const override {
775     return TestMessageBase::MessageStructuredData(announce_error_);
776   }
777 
778  private:
779   uint8_t raw_packet_[10] = {
780       0x08, 0x03, 0x66, 0x6f, 0x6f,  // track_namespace = "foo"
781       0x01,                          // error_code = 1
782       0x03, 0x62, 0x61, 0x72,        // reason_phrase = "bar"
783   };
784 
785   MoqtAnnounceError announce_error_ = {
786       /*track_namespace=*/"foo",
787       /*error_code=*/MoqtAnnounceErrorCode::kAnnounceNotSupported,
788       /*reason_phrase=*/"bar",
789   };
790 };
791 
792 class QUICHE_NO_EXPORT UnannounceMessage : public TestMessageBase {
793  public:
UnannounceMessage()794   UnannounceMessage() : TestMessageBase(MoqtMessageType::kUnannounce) {
795     SetWireImage(raw_packet_, sizeof(raw_packet_));
796   }
797 
EqualFieldValues(MessageStructuredData & values)798   bool EqualFieldValues(MessageStructuredData& values) const override {
799     auto cast = std::get<MoqtUnannounce>(values);
800     if (cast.track_namespace != unannounce_.track_namespace) {
801       QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch";
802       return false;
803     }
804     return true;
805   }
806 
ExpandVarints()807   void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
808 
structured_data()809   MessageStructuredData structured_data() const override {
810     return TestMessageBase::MessageStructuredData(unannounce_);
811   }
812 
813  private:
814   uint8_t raw_packet_[5] = {
815       0x09, 0x03, 0x66, 0x6f, 0x6f,  // track_namespace
816   };
817 
818   MoqtUnannounce unannounce_ = {
819       /*track_namespace=*/"foo",
820   };
821 };
822 
823 class QUICHE_NO_EXPORT GoAwayMessage : public TestMessageBase {
824  public:
GoAwayMessage()825   GoAwayMessage() : TestMessageBase(MoqtMessageType::kGoAway) {
826     SetWireImage(raw_packet_, sizeof(raw_packet_));
827   }
828 
EqualFieldValues(MessageStructuredData & values)829   bool EqualFieldValues(MessageStructuredData& values) const override {
830     auto cast = std::get<MoqtGoAway>(values);
831     if (cast.new_session_uri != goaway_.new_session_uri) {
832       QUIC_LOG(INFO) << "UNSUBSCRIBE full track name mismatch";
833       return false;
834     }
835     return true;
836   }
837 
ExpandVarints()838   void ExpandVarints() override { ExpandVarintsImpl("vv---"); }
839 
structured_data()840   MessageStructuredData structured_data() const override {
841     return TestMessageBase::MessageStructuredData(goaway_);
842   }
843 
844  private:
845   uint8_t raw_packet_[5] = {
846       0x10, 0x03, 0x66, 0x6f, 0x6f,
847   };
848 
849   MoqtGoAway goaway_ = {
850       /*new_session_uri=*/"foo",
851   };
852 };
853 
854 // Factory function for test messages.
CreateTestMessage(MoqtMessageType message_type,bool is_webtrans)855 static inline std::unique_ptr<TestMessageBase> CreateTestMessage(
856     MoqtMessageType message_type, bool is_webtrans) {
857   switch (message_type) {
858     case MoqtMessageType::kObjectStream:
859       return std::make_unique<ObjectStreamMessage>();
860     case MoqtMessageType::kObjectDatagram:
861       return std::make_unique<ObjectDatagramMessage>();
862     case MoqtMessageType::kSubscribe:
863       return std::make_unique<SubscribeMessage>();
864     case MoqtMessageType::kSubscribeOk:
865       return std::make_unique<SubscribeOkMessage>();
866     case MoqtMessageType::kSubscribeError:
867       return std::make_unique<SubscribeErrorMessage>();
868     case MoqtMessageType::kUnsubscribe:
869       return std::make_unique<UnsubscribeMessage>();
870     case MoqtMessageType::kSubscribeDone:
871       return std::make_unique<SubscribeDoneMessage>();
872     case MoqtMessageType::kAnnounce:
873       return std::make_unique<AnnounceMessage>();
874     case MoqtMessageType::kAnnounceOk:
875       return std::make_unique<AnnounceOkMessage>();
876     case MoqtMessageType::kAnnounceError:
877       return std::make_unique<AnnounceErrorMessage>();
878     case MoqtMessageType::kUnannounce:
879       return std::make_unique<UnannounceMessage>();
880     case MoqtMessageType::kGoAway:
881       return std::make_unique<GoAwayMessage>();
882     case MoqtMessageType::kClientSetup:
883       return std::make_unique<ClientSetupMessage>(is_webtrans);
884     case MoqtMessageType::kServerSetup:
885       return std::make_unique<ServerSetupMessage>();
886     case MoqtMessageType::kStreamHeaderTrack:
887       return std::make_unique<StreamHeaderTrackMessage>();
888     case MoqtMessageType::kStreamHeaderGroup:
889       return std::make_unique<StreamHeaderGroupMessage>();
890     default:
891       return nullptr;
892   }
893 }
894 
895 }  // namespace moqt::test
896 
897 #endif  // QUICHE_QUIC_MOQT_TEST_TOOLS_MOQT_TEST_MESSAGE_H_
898