1 // Copyright 2022 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/tools/connect_tunnel.h"
6
7 #include <cstdint>
8 #include <utility>
9
10 #include "absl/container/flat_hash_set.h"
11 #include "absl/status/status.h"
12 #include "absl/status/statusor.h"
13 #include "absl/strings/str_cat.h"
14 #include "absl/strings/string_view.h"
15 #include "quiche/quic/core/connecting_client_socket.h"
16 #include "quiche/quic/core/quic_connection_id.h"
17 #include "quiche/quic/core/quic_error_codes.h"
18 #include "quiche/quic/core/quic_types.h"
19 #include "quiche/quic/core/socket_factory.h"
20 #include "quiche/quic/platform/api/quic_socket_address.h"
21 #include "quiche/quic/platform/api/quic_test_loopback.h"
22 #include "quiche/quic/test_tools/quic_test_utils.h"
23 #include "quiche/quic/tools/quic_backend_response.h"
24 #include "quiche/quic/tools/quic_simple_server_backend.h"
25 #include "quiche/common/platform/api/quiche_mem_slice.h"
26 #include "quiche/common/platform/api/quiche_test.h"
27 #include "quiche/spdy/core/http2_header_block.h"
28
29 namespace quic::test {
30 namespace {
31
32 using ::testing::_;
33 using ::testing::AllOf;
34 using ::testing::AnyOf;
35 using ::testing::ElementsAre;
36 using ::testing::Eq;
37 using ::testing::Ge;
38 using ::testing::Gt;
39 using ::testing::InvokeWithoutArgs;
40 using ::testing::IsEmpty;
41 using ::testing::Matcher;
42 using ::testing::NiceMock;
43 using ::testing::Pair;
44 using ::testing::Property;
45 using ::testing::Return;
46 using ::testing::StrictMock;
47
48 class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler {
49 public:
connection_id() const50 QuicConnectionId connection_id() const override {
51 return TestConnectionId(41212);
52 }
stream_id() const53 QuicStreamId stream_id() const override { return 100; }
peer_host() const54 std::string peer_host() const override { return "127.0.0.1"; }
55
56 MOCK_METHOD(QuicSpdyStream*, GetStream, (), (override));
57 MOCK_METHOD(void, OnResponseBackendComplete,
58 (const QuicBackendResponse* response), (override));
59 MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream),
60 (override));
61 MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error),
62 (override));
63 };
64
65 class MockSocketFactory : public SocketFactory {
66 public:
67 MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>, CreateTcpClientSocket,
68 (const quic::QuicSocketAddress& peer_address,
69 QuicByteCount receive_buffer_size,
70 QuicByteCount send_buffer_size,
71 ConnectingClientSocket::AsyncVisitor* async_visitor),
72 (override));
73 MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>,
74 CreateConnectingUdpClientSocket,
75 (const quic::QuicSocketAddress& peer_address,
76 QuicByteCount receive_buffer_size,
77 QuicByteCount send_buffer_size,
78 ConnectingClientSocket::AsyncVisitor* async_visitor),
79 (override));
80 };
81
82 class MockSocket : public ConnectingClientSocket {
83 public:
84 MOCK_METHOD(absl::Status, ConnectBlocking, (), (override));
85 MOCK_METHOD(void, ConnectAsync, (), (override));
86 MOCK_METHOD(void, Disconnect, (), (override));
87 MOCK_METHOD(absl::StatusOr<QuicSocketAddress>, GetLocalAddress, (),
88 (override));
89 MOCK_METHOD(absl::StatusOr<quiche::QuicheMemSlice>, ReceiveBlocking,
90 (QuicByteCount max_size), (override));
91 MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override));
92 MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override));
93 MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data),
94 (override));
95 MOCK_METHOD(void, SendAsync, (std::string data), (override));
96 MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override));
97 };
98
99 class ConnectTunnelTest : public quiche::test::QuicheTest {
100 public:
SetUp()101 void SetUp() override {
102 #if defined(_WIN32)
103 WSADATA wsa_data;
104 const WORD version_required = MAKEWORD(2, 2);
105 ASSERT_EQ(WSAStartup(version_required, &wsa_data), 0);
106 #endif
107 auto socket = std::make_unique<StrictMock<MockSocket>>();
108 socket_ = socket.get();
109 ON_CALL(socket_factory_,
110 CreateTcpClientSocket(
111 AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort),
112 QuicSocketAddress(TestLoopback6(), kAcceptablePort)),
113 _, _, &tunnel_))
114 .WillByDefault(Return(ByMove(std::move(socket))));
115 }
116
117 protected:
118 static constexpr absl::string_view kAcceptableDestination = "localhost";
119 static constexpr uint16_t kAcceptablePort = 977;
120
121 StrictMock<MockRequestHandler> request_handler_;
122 NiceMock<MockSocketFactory> socket_factory_;
123 StrictMock<MockSocket>* socket_;
124
125 ConnectTunnel tunnel_{
126 &request_handler_,
127 &socket_factory_,
128 /*acceptable_destinations=*/
129 {{std::string(kAcceptableDestination), kAcceptablePort},
130 {TestLoopback4().ToString(), kAcceptablePort},
131 {absl::StrCat("[", TestLoopback6().ToString(), "]"), kAcceptablePort}}};
132 };
133
TEST_F(ConnectTunnelTest,OpenTunnel)134 TEST_F(ConnectTunnelTest, OpenTunnel) {
135 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
136 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
137 EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
138 tunnel_.ReceiveComplete(absl::CancelledError());
139 }));
140
141 spdy::Http2HeaderBlock expected_response_headers;
142 expected_response_headers[":status"] = "200";
143 QuicBackendResponse expected_response;
144 expected_response.set_headers(std::move(expected_response_headers));
145 expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
146 EXPECT_CALL(request_handler_,
147 OnResponseBackendComplete(
148 AllOf(Property(&QuicBackendResponse::response_type,
149 QuicBackendResponse::INCOMPLETE_RESPONSE),
150 Property(&QuicBackendResponse::headers,
151 ElementsAre(Pair(":status", "200"))),
152 Property(&QuicBackendResponse::trailers, IsEmpty()),
153 Property(&QuicBackendResponse::body, IsEmpty()))));
154
155 spdy::Http2HeaderBlock request_headers;
156 request_headers[":method"] = "CONNECT";
157 request_headers[":authority"] =
158 absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
159
160 tunnel_.OpenTunnel(request_headers);
161 EXPECT_TRUE(tunnel_.IsConnectedToDestination());
162 tunnel_.OnClientStreamClose();
163 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
164 }
165
TEST_F(ConnectTunnelTest,OpenTunnelToIpv4LiteralDestination)166 TEST_F(ConnectTunnelTest, OpenTunnelToIpv4LiteralDestination) {
167 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
168 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
169 EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
170 tunnel_.ReceiveComplete(absl::CancelledError());
171 }));
172
173 spdy::Http2HeaderBlock expected_response_headers;
174 expected_response_headers[":status"] = "200";
175 QuicBackendResponse expected_response;
176 expected_response.set_headers(std::move(expected_response_headers));
177 expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
178 EXPECT_CALL(request_handler_,
179 OnResponseBackendComplete(
180 AllOf(Property(&QuicBackendResponse::response_type,
181 QuicBackendResponse::INCOMPLETE_RESPONSE),
182 Property(&QuicBackendResponse::headers,
183 ElementsAre(Pair(":status", "200"))),
184 Property(&QuicBackendResponse::trailers, IsEmpty()),
185 Property(&QuicBackendResponse::body, IsEmpty()))));
186
187 spdy::Http2HeaderBlock request_headers;
188 request_headers[":method"] = "CONNECT";
189 request_headers[":authority"] =
190 absl::StrCat(TestLoopback4().ToString(), ":", kAcceptablePort);
191
192 tunnel_.OpenTunnel(request_headers);
193 EXPECT_TRUE(tunnel_.IsConnectedToDestination());
194 tunnel_.OnClientStreamClose();
195 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
196 }
197
TEST_F(ConnectTunnelTest,OpenTunnelToIpv6LiteralDestination)198 TEST_F(ConnectTunnelTest, OpenTunnelToIpv6LiteralDestination) {
199 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
200 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
201 EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
202 tunnel_.ReceiveComplete(absl::CancelledError());
203 }));
204
205 spdy::Http2HeaderBlock expected_response_headers;
206 expected_response_headers[":status"] = "200";
207 QuicBackendResponse expected_response;
208 expected_response.set_headers(std::move(expected_response_headers));
209 expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
210 EXPECT_CALL(request_handler_,
211 OnResponseBackendComplete(
212 AllOf(Property(&QuicBackendResponse::response_type,
213 QuicBackendResponse::INCOMPLETE_RESPONSE),
214 Property(&QuicBackendResponse::headers,
215 ElementsAre(Pair(":status", "200"))),
216 Property(&QuicBackendResponse::trailers, IsEmpty()),
217 Property(&QuicBackendResponse::body, IsEmpty()))));
218
219 spdy::Http2HeaderBlock request_headers;
220 request_headers[":method"] = "CONNECT";
221 request_headers[":authority"] =
222 absl::StrCat("[", TestLoopback6().ToString(), "]:", kAcceptablePort);
223
224 tunnel_.OpenTunnel(request_headers);
225 EXPECT_TRUE(tunnel_.IsConnectedToDestination());
226 tunnel_.OnClientStreamClose();
227 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
228 }
229
TEST_F(ConnectTunnelTest,OpenTunnelWithMalformedRequest)230 TEST_F(ConnectTunnelTest, OpenTunnelWithMalformedRequest) {
231 EXPECT_CALL(request_handler_,
232 TerminateStreamWithError(Property(
233 &QuicResetStreamError::ietf_application_code,
234 static_cast<uint64_t>(QuicHttp3ErrorCode::MESSAGE_ERROR))));
235
236 spdy::Http2HeaderBlock request_headers;
237 request_headers[":method"] = "CONNECT";
238 // No ":authority" header.
239
240 tunnel_.OpenTunnel(request_headers);
241 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
242 tunnel_.OnClientStreamClose();
243 }
244
TEST_F(ConnectTunnelTest,OpenTunnelWithUnacceptableDestination)245 TEST_F(ConnectTunnelTest, OpenTunnelWithUnacceptableDestination) {
246 EXPECT_CALL(
247 request_handler_,
248 TerminateStreamWithError(Property(
249 &QuicResetStreamError::ietf_application_code,
250 static_cast<uint64_t>(QuicHttp3ErrorCode::REQUEST_REJECTED))));
251
252 spdy::Http2HeaderBlock request_headers;
253 request_headers[":method"] = "CONNECT";
254 request_headers[":authority"] = "unacceptable.test:100";
255
256 tunnel_.OpenTunnel(request_headers);
257 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
258 tunnel_.OnClientStreamClose();
259 }
260
TEST_F(ConnectTunnelTest,ReceiveFromDestination)261 TEST_F(ConnectTunnelTest, ReceiveFromDestination) {
262 static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55";
263
264 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
265 EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2);
266 EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
267 tunnel_.ReceiveComplete(absl::CancelledError());
268 }));
269
270 EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
271
272 EXPECT_CALL(request_handler_, SendStreamData(kData, /*close_stream=*/false));
273
274 spdy::Http2HeaderBlock request_headers;
275 request_headers[":method"] = "CONNECT";
276 request_headers[":authority"] =
277 absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
278
279 tunnel_.OpenTunnel(request_headers);
280
281 // Simulate receiving `kData`.
282 tunnel_.ReceiveComplete(MemSliceFromString(kData));
283
284 tunnel_.OnClientStreamClose();
285 }
286
TEST_F(ConnectTunnelTest,SendToDestination)287 TEST_F(ConnectTunnelTest, SendToDestination) {
288 static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55";
289
290 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
291 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
292 EXPECT_CALL(*socket_, SendBlocking(Matcher<std::string>(Eq(kData))))
293 .WillOnce(Return(absl::OkStatus()));
294 EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
295 tunnel_.ReceiveComplete(absl::CancelledError());
296 }));
297
298 EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
299
300 spdy::Http2HeaderBlock request_headers;
301 request_headers[":method"] = "CONNECT";
302 request_headers[":authority"] =
303 absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
304
305 tunnel_.OpenTunnel(request_headers);
306 tunnel_.SendDataToDestination(kData);
307 tunnel_.OnClientStreamClose();
308 }
309
TEST_F(ConnectTunnelTest,DestinationDisconnect)310 TEST_F(ConnectTunnelTest, DestinationDisconnect) {
311 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
312 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
313 EXPECT_CALL(*socket_, Disconnect());
314
315 EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
316 EXPECT_CALL(request_handler_, SendStreamData("", /*close_stream=*/true));
317
318 spdy::Http2HeaderBlock request_headers;
319 request_headers[":method"] = "CONNECT";
320 request_headers[":authority"] =
321 absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
322
323 tunnel_.OpenTunnel(request_headers);
324
325 // Simulate receiving empty data.
326 tunnel_.ReceiveComplete(quiche::QuicheMemSlice());
327
328 EXPECT_FALSE(tunnel_.IsConnectedToDestination());
329
330 tunnel_.OnClientStreamClose();
331 }
332
TEST_F(ConnectTunnelTest,DestinationTcpConnectionError)333 TEST_F(ConnectTunnelTest, DestinationTcpConnectionError) {
334 EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
335 EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
336 EXPECT_CALL(*socket_, Disconnect());
337
338 EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
339 EXPECT_CALL(request_handler_,
340 TerminateStreamWithError(Property(
341 &QuicResetStreamError::ietf_application_code,
342 static_cast<uint64_t>(QuicHttp3ErrorCode::CONNECT_ERROR))));
343
344 spdy::Http2HeaderBlock request_headers;
345 request_headers[":method"] = "CONNECT";
346 request_headers[":authority"] =
347 absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
348
349 tunnel_.OpenTunnel(request_headers);
350
351 // Simulate receving error.
352 tunnel_.ReceiveComplete(absl::UnknownError("error"));
353
354 tunnel_.OnClientStreamClose();
355 }
356
357 } // namespace
358 } // namespace quic::test
359