1 // Copyright (c) 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/moqt/moqt_framer.h"
6
7 #include <memory>
8 #include <optional>
9 #include <string>
10 #include <vector>
11
12 #include "absl/strings/str_cat.h"
13 #include "absl/strings/string_view.h"
14 #include "quiche/quic/moqt/moqt_messages.h"
15 #include "quiche/quic/moqt/test_tools/moqt_test_message.h"
16 #include "quiche/quic/platform/api/quic_expect_bug.h"
17 #include "quiche/quic/platform/api/quic_test.h"
18 #include "quiche/common/quiche_buffer_allocator.h"
19 #include "quiche/common/simple_buffer_allocator.h"
20
21 namespace moqt::test {
22
23 struct MoqtFramerTestParams {
MoqtFramerTestParamsmoqt::test::MoqtFramerTestParams24 MoqtFramerTestParams(MoqtMessageType message_type, bool uses_web_transport)
25 : message_type(message_type), uses_web_transport(uses_web_transport) {}
26 MoqtMessageType message_type;
27 bool uses_web_transport;
28 };
29
GetMoqtFramerTestParams()30 std::vector<MoqtFramerTestParams> GetMoqtFramerTestParams() {
31 std::vector<MoqtFramerTestParams> params;
32 std::vector<MoqtMessageType> message_types = {
33 MoqtMessageType::kObjectStream,
34 MoqtMessageType::kSubscribe,
35 MoqtMessageType::kSubscribeOk,
36 MoqtMessageType::kSubscribeError,
37 MoqtMessageType::kUnsubscribe,
38 MoqtMessageType::kSubscribeDone,
39 MoqtMessageType::kAnnounce,
40 MoqtMessageType::kAnnounceOk,
41 MoqtMessageType::kAnnounceError,
42 MoqtMessageType::kUnannounce,
43 MoqtMessageType::kGoAway,
44 MoqtMessageType::kClientSetup,
45 MoqtMessageType::kServerSetup,
46 MoqtMessageType::kStreamHeaderTrack,
47 MoqtMessageType::kStreamHeaderGroup,
48 };
49 std::vector<bool> uses_web_transport_bool = {
50 false,
51 true,
52 };
53 for (const MoqtMessageType message_type : message_types) {
54 if (message_type == MoqtMessageType::kClientSetup) {
55 for (const bool uses_web_transport : uses_web_transport_bool) {
56 params.push_back(
57 MoqtFramerTestParams(message_type, uses_web_transport));
58 }
59 } else {
60 // All other types are processed the same for either perspective or
61 // transport.
62 params.push_back(MoqtFramerTestParams(message_type, true));
63 }
64 }
65 return params;
66 }
67
ParamNameFormatter(const testing::TestParamInfo<MoqtFramerTestParams> & info)68 std::string ParamNameFormatter(
69 const testing::TestParamInfo<MoqtFramerTestParams>& info) {
70 return MoqtMessageTypeToString(info.param.message_type) + "_" +
71 (info.param.uses_web_transport ? "WebTransport" : "QUIC");
72 }
73
SerializeObject(MoqtFramer & framer,const MoqtObject & message,absl::string_view payload,bool is_first_in_stream)74 quiche::QuicheBuffer SerializeObject(MoqtFramer& framer,
75 const MoqtObject& message,
76 absl::string_view payload,
77 bool is_first_in_stream) {
78 MoqtObject adjusted_message = message;
79 adjusted_message.payload_length = payload.size();
80 quiche::QuicheBuffer header =
81 framer.SerializeObjectHeader(adjusted_message, is_first_in_stream);
82 if (header.empty()) {
83 return quiche::QuicheBuffer();
84 }
85 return quiche::QuicheBuffer::Copy(
86 quiche::SimpleBufferAllocator::Get(),
87 absl::StrCat(header.AsStringView(), payload));
88 }
89
90 class MoqtFramerTest
91 : public quic::test::QuicTestWithParam<MoqtFramerTestParams> {
92 public:
MoqtFramerTest()93 MoqtFramerTest()
94 : message_type_(GetParam().message_type),
95 webtrans_(GetParam().uses_web_transport),
96 buffer_allocator_(quiche::SimpleBufferAllocator::Get()),
97 framer_(buffer_allocator_, GetParam().uses_web_transport) {}
98
MakeMessage(MoqtMessageType message_type)99 std::unique_ptr<TestMessageBase> MakeMessage(MoqtMessageType message_type) {
100 return CreateTestMessage(message_type, webtrans_);
101 }
102
SerializeMessage(TestMessageBase::MessageStructuredData & structured_data)103 quiche::QuicheBuffer SerializeMessage(
104 TestMessageBase::MessageStructuredData& structured_data) {
105 switch (message_type_) {
106 case MoqtMessageType::kObjectStream:
107 case MoqtMessageType::kStreamHeaderTrack:
108 case MoqtMessageType::kStreamHeaderGroup: {
109 MoqtObject data = std::get<MoqtObject>(structured_data);
110 return SerializeObject(framer_, data, "foo", true);
111 }
112 case MoqtMessageType::kSubscribe: {
113 auto data = std::get<MoqtSubscribe>(structured_data);
114 return framer_.SerializeSubscribe(data);
115 }
116 case MoqtMessageType::kSubscribeOk: {
117 auto data = std::get<MoqtSubscribeOk>(structured_data);
118 return framer_.SerializeSubscribeOk(data);
119 }
120 case MoqtMessageType::kSubscribeError: {
121 auto data = std::get<MoqtSubscribeError>(structured_data);
122 return framer_.SerializeSubscribeError(data);
123 }
124 case MoqtMessageType::kUnsubscribe: {
125 auto data = std::get<MoqtUnsubscribe>(structured_data);
126 return framer_.SerializeUnsubscribe(data);
127 }
128 case MoqtMessageType::kSubscribeDone: {
129 auto data = std::get<MoqtSubscribeDone>(structured_data);
130 return framer_.SerializeSubscribeDone(data);
131 }
132 case MoqtMessageType::kAnnounce: {
133 auto data = std::get<MoqtAnnounce>(structured_data);
134 return framer_.SerializeAnnounce(data);
135 }
136 case moqt::MoqtMessageType::kAnnounceOk: {
137 auto data = std::get<MoqtAnnounceOk>(structured_data);
138 return framer_.SerializeAnnounceOk(data);
139 }
140 case moqt::MoqtMessageType::kAnnounceError: {
141 auto data = std::get<MoqtAnnounceError>(structured_data);
142 return framer_.SerializeAnnounceError(data);
143 }
144 case MoqtMessageType::kUnannounce: {
145 auto data = std::get<MoqtUnannounce>(structured_data);
146 return framer_.SerializeUnannounce(data);
147 }
148 case moqt::MoqtMessageType::kGoAway: {
149 auto data = std::get<MoqtGoAway>(structured_data);
150 return framer_.SerializeGoAway(data);
151 }
152 case MoqtMessageType::kClientSetup: {
153 auto data = std::get<MoqtClientSetup>(structured_data);
154 return framer_.SerializeClientSetup(data);
155 }
156 case MoqtMessageType::kServerSetup: {
157 auto data = std::get<MoqtServerSetup>(structured_data);
158 return framer_.SerializeServerSetup(data);
159 }
160 default:
161 // kObjectDatagram is a totally different code path.
162 return quiche::QuicheBuffer();
163 }
164 }
165
166 MoqtMessageType message_type_;
167 bool webtrans_;
168 quiche::SimpleBufferAllocator* buffer_allocator_;
169 MoqtFramer framer_;
170 };
171
172 INSTANTIATE_TEST_SUITE_P(MoqtFramerTests, MoqtFramerTest,
173 testing::ValuesIn(GetMoqtFramerTestParams()),
174 ParamNameFormatter);
175
TEST_P(MoqtFramerTest,OneMessage)176 TEST_P(MoqtFramerTest, OneMessage) {
177 auto message = MakeMessage(message_type_);
178 auto structured_data = message->structured_data();
179 auto buffer = SerializeMessage(structured_data);
180 EXPECT_EQ(buffer.size(), message->total_message_size());
181 EXPECT_EQ(buffer.AsStringView(), message->PacketSample());
182 }
183
184 class MoqtFramerSimpleTest : public quic::test::QuicTest {
185 public:
MoqtFramerSimpleTest()186 MoqtFramerSimpleTest()
187 : buffer_allocator_(quiche::SimpleBufferAllocator::Get()),
188 framer_(buffer_allocator_, /*web_transport=*/true) {}
189
190 quiche::SimpleBufferAllocator* buffer_allocator_;
191 MoqtFramer framer_;
192 };
193
TEST_F(MoqtFramerSimpleTest,GroupMiddler)194 TEST_F(MoqtFramerSimpleTest, GroupMiddler) {
195 auto header = std::make_unique<StreamHeaderGroupMessage>();
196 auto buffer1 = SerializeObject(
197 framer_, std::get<MoqtObject>(header->structured_data()), "foo", true);
198 EXPECT_EQ(buffer1.size(), header->total_message_size());
199 EXPECT_EQ(buffer1.AsStringView(), header->PacketSample());
200
201 auto middler = std::make_unique<StreamMiddlerGroupMessage>();
202 auto buffer2 = SerializeObject(
203 framer_, std::get<MoqtObject>(middler->structured_data()), "bar", false);
204 EXPECT_EQ(buffer2.size(), middler->total_message_size());
205 EXPECT_EQ(buffer2.AsStringView(), middler->PacketSample());
206 }
207
TEST_F(MoqtFramerSimpleTest,TrackMiddler)208 TEST_F(MoqtFramerSimpleTest, TrackMiddler) {
209 auto header = std::make_unique<StreamHeaderTrackMessage>();
210 auto buffer1 = SerializeObject(
211 framer_, std::get<MoqtObject>(header->structured_data()), "foo", true);
212 EXPECT_EQ(buffer1.size(), header->total_message_size());
213 EXPECT_EQ(buffer1.AsStringView(), header->PacketSample());
214
215 auto middler = std::make_unique<StreamMiddlerTrackMessage>();
216 auto buffer2 = SerializeObject(
217 framer_, std::get<MoqtObject>(middler->structured_data()), "bar", false);
218 EXPECT_EQ(buffer2.size(), middler->total_message_size());
219 EXPECT_EQ(buffer2.AsStringView(), middler->PacketSample());
220 }
221
TEST_F(MoqtFramerSimpleTest,BadObjectInput)222 TEST_F(MoqtFramerSimpleTest, BadObjectInput) {
223 MoqtObject object = {
224 /*subscribe_id=*/3,
225 /*track_alias=*/4,
226 /*group_id=*/5,
227 /*object_id=*/6,
228 /*object_send_order=*/7,
229 /*forwarding_preference=*/MoqtForwardingPreference::kObject,
230 /*payload_length=*/std::nullopt,
231 };
232 quiche::QuicheBuffer buffer;
233 object.forwarding_preference = MoqtForwardingPreference::kDatagram;
234 EXPECT_QUIC_BUG(buffer = framer_.SerializeObjectHeader(object, false),
235 "must be first");
236 EXPECT_TRUE(buffer.empty());
237 object.forwarding_preference = MoqtForwardingPreference::kGroup;
238 EXPECT_QUIC_BUG(buffer = framer_.SerializeObjectHeader(object, false),
239 "requires knowing the object length");
240 EXPECT_TRUE(buffer.empty());
241 }
242
TEST_F(MoqtFramerSimpleTest,Datagram)243 TEST_F(MoqtFramerSimpleTest, Datagram) {
244 auto datagram = std::make_unique<ObjectDatagramMessage>();
245 MoqtObject object = {
246 /*subscribe_id=*/3,
247 /*track_alias=*/4,
248 /*group_id=*/5,
249 /*object_id=*/6,
250 /*object_send_order=*/7,
251 /*forwarding_preference=*/MoqtForwardingPreference::kObject,
252 /*payload_length=*/std::nullopt,
253 };
254 std::string payload = "foo";
255 quiche::QuicheBuffer buffer;
256 buffer = framer_.SerializeObjectDatagram(object, payload);
257 EXPECT_EQ(buffer.size(), datagram->total_message_size());
258 EXPECT_EQ(buffer.AsStringView(), datagram->PacketSample());
259 }
260
261 } // namespace moqt::test
262