xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/tools/connect_tunnel_test.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2022 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/tools/connect_tunnel.h"
6 
7 #include <cstdint>
8 #include <utility>
9 
10 #include "absl/container/flat_hash_set.h"
11 #include "absl/status/status.h"
12 #include "absl/status/statusor.h"
13 #include "absl/strings/str_cat.h"
14 #include "absl/strings/string_view.h"
15 #include "quiche/quic/core/connecting_client_socket.h"
16 #include "quiche/quic/core/quic_connection_id.h"
17 #include "quiche/quic/core/quic_error_codes.h"
18 #include "quiche/quic/core/quic_types.h"
19 #include "quiche/quic/core/socket_factory.h"
20 #include "quiche/quic/platform/api/quic_socket_address.h"
21 #include "quiche/quic/platform/api/quic_test_loopback.h"
22 #include "quiche/quic/test_tools/quic_test_utils.h"
23 #include "quiche/quic/tools/quic_backend_response.h"
24 #include "quiche/quic/tools/quic_simple_server_backend.h"
25 #include "quiche/common/platform/api/quiche_mem_slice.h"
26 #include "quiche/common/platform/api/quiche_test.h"
27 #include "quiche/spdy/core/http2_header_block.h"
28 
29 namespace quic::test {
30 namespace {
31 
32 using ::testing::_;
33 using ::testing::AllOf;
34 using ::testing::AnyOf;
35 using ::testing::ElementsAre;
36 using ::testing::Eq;
37 using ::testing::Ge;
38 using ::testing::Gt;
39 using ::testing::InvokeWithoutArgs;
40 using ::testing::IsEmpty;
41 using ::testing::Matcher;
42 using ::testing::NiceMock;
43 using ::testing::Pair;
44 using ::testing::Property;
45 using ::testing::Return;
46 using ::testing::StrictMock;
47 
48 class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler {
49  public:
connection_id() const50   QuicConnectionId connection_id() const override {
51     return TestConnectionId(41212);
52   }
stream_id() const53   QuicStreamId stream_id() const override { return 100; }
peer_host() const54   std::string peer_host() const override { return "127.0.0.1"; }
55 
56   MOCK_METHOD(QuicSpdyStream*, GetStream, (), (override));
57   MOCK_METHOD(void, OnResponseBackendComplete,
58               (const QuicBackendResponse* response), (override));
59   MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream),
60               (override));
61   MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error),
62               (override));
63 };
64 
65 class MockSocketFactory : public SocketFactory {
66  public:
67   MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>, CreateTcpClientSocket,
68               (const quic::QuicSocketAddress& peer_address,
69                QuicByteCount receive_buffer_size,
70                QuicByteCount send_buffer_size,
71                ConnectingClientSocket::AsyncVisitor* async_visitor),
72               (override));
73   MOCK_METHOD(std::unique_ptr<ConnectingClientSocket>,
74               CreateConnectingUdpClientSocket,
75               (const quic::QuicSocketAddress& peer_address,
76                QuicByteCount receive_buffer_size,
77                QuicByteCount send_buffer_size,
78                ConnectingClientSocket::AsyncVisitor* async_visitor),
79               (override));
80 };
81 
82 class MockSocket : public ConnectingClientSocket {
83  public:
84   MOCK_METHOD(absl::Status, ConnectBlocking, (), (override));
85   MOCK_METHOD(void, ConnectAsync, (), (override));
86   MOCK_METHOD(void, Disconnect, (), (override));
87   MOCK_METHOD(absl::StatusOr<QuicSocketAddress>, GetLocalAddress, (),
88               (override));
89   MOCK_METHOD(absl::StatusOr<quiche::QuicheMemSlice>, ReceiveBlocking,
90               (QuicByteCount max_size), (override));
91   MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override));
92   MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override));
93   MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data),
94               (override));
95   MOCK_METHOD(void, SendAsync, (std::string data), (override));
96   MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override));
97 };
98 
99 class ConnectTunnelTest : public quiche::test::QuicheTest {
100  public:
SetUp()101   void SetUp() override {
102 #if defined(_WIN32)
103     WSADATA wsa_data;
104     const WORD version_required = MAKEWORD(2, 2);
105     ASSERT_EQ(WSAStartup(version_required, &wsa_data), 0);
106 #endif
107     auto socket = std::make_unique<StrictMock<MockSocket>>();
108     socket_ = socket.get();
109     ON_CALL(socket_factory_,
110             CreateTcpClientSocket(
111                 AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort),
112                       QuicSocketAddress(TestLoopback6(), kAcceptablePort)),
113                 _, _, &tunnel_))
114         .WillByDefault(Return(ByMove(std::move(socket))));
115   }
116 
117  protected:
118   static constexpr absl::string_view kAcceptableDestination = "localhost";
119   static constexpr uint16_t kAcceptablePort = 977;
120 
121   StrictMock<MockRequestHandler> request_handler_;
122   NiceMock<MockSocketFactory> socket_factory_;
123   StrictMock<MockSocket>* socket_;
124 
125   ConnectTunnel tunnel_{
126       &request_handler_,
127       &socket_factory_,
128       /*acceptable_destinations=*/
129       {{std::string(kAcceptableDestination), kAcceptablePort},
130        {TestLoopback4().ToString(), kAcceptablePort},
131        {absl::StrCat("[", TestLoopback6().ToString(), "]"), kAcceptablePort}}};
132 };
133 
TEST_F(ConnectTunnelTest,OpenTunnel)134 TEST_F(ConnectTunnelTest, OpenTunnel) {
135   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
136   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
137   EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
138     tunnel_.ReceiveComplete(absl::CancelledError());
139   }));
140 
141   spdy::Http2HeaderBlock expected_response_headers;
142   expected_response_headers[":status"] = "200";
143   QuicBackendResponse expected_response;
144   expected_response.set_headers(std::move(expected_response_headers));
145   expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
146   EXPECT_CALL(request_handler_,
147               OnResponseBackendComplete(
148                   AllOf(Property(&QuicBackendResponse::response_type,
149                                  QuicBackendResponse::INCOMPLETE_RESPONSE),
150                         Property(&QuicBackendResponse::headers,
151                                  ElementsAre(Pair(":status", "200"))),
152                         Property(&QuicBackendResponse::trailers, IsEmpty()),
153                         Property(&QuicBackendResponse::body, IsEmpty()))));
154 
155   spdy::Http2HeaderBlock request_headers;
156   request_headers[":method"] = "CONNECT";
157   request_headers[":authority"] =
158       absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
159 
160   tunnel_.OpenTunnel(request_headers);
161   EXPECT_TRUE(tunnel_.IsConnectedToDestination());
162   tunnel_.OnClientStreamClose();
163   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
164 }
165 
TEST_F(ConnectTunnelTest,OpenTunnelToIpv4LiteralDestination)166 TEST_F(ConnectTunnelTest, OpenTunnelToIpv4LiteralDestination) {
167   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
168   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
169   EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
170     tunnel_.ReceiveComplete(absl::CancelledError());
171   }));
172 
173   spdy::Http2HeaderBlock expected_response_headers;
174   expected_response_headers[":status"] = "200";
175   QuicBackendResponse expected_response;
176   expected_response.set_headers(std::move(expected_response_headers));
177   expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
178   EXPECT_CALL(request_handler_,
179               OnResponseBackendComplete(
180                   AllOf(Property(&QuicBackendResponse::response_type,
181                                  QuicBackendResponse::INCOMPLETE_RESPONSE),
182                         Property(&QuicBackendResponse::headers,
183                                  ElementsAre(Pair(":status", "200"))),
184                         Property(&QuicBackendResponse::trailers, IsEmpty()),
185                         Property(&QuicBackendResponse::body, IsEmpty()))));
186 
187   spdy::Http2HeaderBlock request_headers;
188   request_headers[":method"] = "CONNECT";
189   request_headers[":authority"] =
190       absl::StrCat(TestLoopback4().ToString(), ":", kAcceptablePort);
191 
192   tunnel_.OpenTunnel(request_headers);
193   EXPECT_TRUE(tunnel_.IsConnectedToDestination());
194   tunnel_.OnClientStreamClose();
195   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
196 }
197 
TEST_F(ConnectTunnelTest,OpenTunnelToIpv6LiteralDestination)198 TEST_F(ConnectTunnelTest, OpenTunnelToIpv6LiteralDestination) {
199   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
200   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
201   EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
202     tunnel_.ReceiveComplete(absl::CancelledError());
203   }));
204 
205   spdy::Http2HeaderBlock expected_response_headers;
206   expected_response_headers[":status"] = "200";
207   QuicBackendResponse expected_response;
208   expected_response.set_headers(std::move(expected_response_headers));
209   expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE);
210   EXPECT_CALL(request_handler_,
211               OnResponseBackendComplete(
212                   AllOf(Property(&QuicBackendResponse::response_type,
213                                  QuicBackendResponse::INCOMPLETE_RESPONSE),
214                         Property(&QuicBackendResponse::headers,
215                                  ElementsAre(Pair(":status", "200"))),
216                         Property(&QuicBackendResponse::trailers, IsEmpty()),
217                         Property(&QuicBackendResponse::body, IsEmpty()))));
218 
219   spdy::Http2HeaderBlock request_headers;
220   request_headers[":method"] = "CONNECT";
221   request_headers[":authority"] =
222       absl::StrCat("[", TestLoopback6().ToString(), "]:", kAcceptablePort);
223 
224   tunnel_.OpenTunnel(request_headers);
225   EXPECT_TRUE(tunnel_.IsConnectedToDestination());
226   tunnel_.OnClientStreamClose();
227   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
228 }
229 
TEST_F(ConnectTunnelTest,OpenTunnelWithMalformedRequest)230 TEST_F(ConnectTunnelTest, OpenTunnelWithMalformedRequest) {
231   EXPECT_CALL(request_handler_,
232               TerminateStreamWithError(Property(
233                   &QuicResetStreamError::ietf_application_code,
234                   static_cast<uint64_t>(QuicHttp3ErrorCode::MESSAGE_ERROR))));
235 
236   spdy::Http2HeaderBlock request_headers;
237   request_headers[":method"] = "CONNECT";
238   // No ":authority" header.
239 
240   tunnel_.OpenTunnel(request_headers);
241   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
242   tunnel_.OnClientStreamClose();
243 }
244 
TEST_F(ConnectTunnelTest,OpenTunnelWithUnacceptableDestination)245 TEST_F(ConnectTunnelTest, OpenTunnelWithUnacceptableDestination) {
246   EXPECT_CALL(
247       request_handler_,
248       TerminateStreamWithError(Property(
249           &QuicResetStreamError::ietf_application_code,
250           static_cast<uint64_t>(QuicHttp3ErrorCode::REQUEST_REJECTED))));
251 
252   spdy::Http2HeaderBlock request_headers;
253   request_headers[":method"] = "CONNECT";
254   request_headers[":authority"] = "unacceptable.test:100";
255 
256   tunnel_.OpenTunnel(request_headers);
257   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
258   tunnel_.OnClientStreamClose();
259 }
260 
TEST_F(ConnectTunnelTest,ReceiveFromDestination)261 TEST_F(ConnectTunnelTest, ReceiveFromDestination) {
262   static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55";
263 
264   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
265   EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2);
266   EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
267     tunnel_.ReceiveComplete(absl::CancelledError());
268   }));
269 
270   EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
271 
272   EXPECT_CALL(request_handler_, SendStreamData(kData, /*close_stream=*/false));
273 
274   spdy::Http2HeaderBlock request_headers;
275   request_headers[":method"] = "CONNECT";
276   request_headers[":authority"] =
277       absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
278 
279   tunnel_.OpenTunnel(request_headers);
280 
281   // Simulate receiving `kData`.
282   tunnel_.ReceiveComplete(MemSliceFromString(kData));
283 
284   tunnel_.OnClientStreamClose();
285 }
286 
TEST_F(ConnectTunnelTest,SendToDestination)287 TEST_F(ConnectTunnelTest, SendToDestination) {
288   static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55";
289 
290   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
291   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
292   EXPECT_CALL(*socket_, SendBlocking(Matcher<std::string>(Eq(kData))))
293       .WillOnce(Return(absl::OkStatus()));
294   EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() {
295     tunnel_.ReceiveComplete(absl::CancelledError());
296   }));
297 
298   EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
299 
300   spdy::Http2HeaderBlock request_headers;
301   request_headers[":method"] = "CONNECT";
302   request_headers[":authority"] =
303       absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
304 
305   tunnel_.OpenTunnel(request_headers);
306   tunnel_.SendDataToDestination(kData);
307   tunnel_.OnClientStreamClose();
308 }
309 
TEST_F(ConnectTunnelTest,DestinationDisconnect)310 TEST_F(ConnectTunnelTest, DestinationDisconnect) {
311   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
312   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
313   EXPECT_CALL(*socket_, Disconnect());
314 
315   EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
316   EXPECT_CALL(request_handler_, SendStreamData("", /*close_stream=*/true));
317 
318   spdy::Http2HeaderBlock request_headers;
319   request_headers[":method"] = "CONNECT";
320   request_headers[":authority"] =
321       absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
322 
323   tunnel_.OpenTunnel(request_headers);
324 
325   // Simulate receiving empty data.
326   tunnel_.ReceiveComplete(quiche::QuicheMemSlice());
327 
328   EXPECT_FALSE(tunnel_.IsConnectedToDestination());
329 
330   tunnel_.OnClientStreamClose();
331 }
332 
TEST_F(ConnectTunnelTest,DestinationTcpConnectionError)333 TEST_F(ConnectTunnelTest, DestinationTcpConnectionError) {
334   EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus()));
335   EXPECT_CALL(*socket_, ReceiveAsync(Gt(0)));
336   EXPECT_CALL(*socket_, Disconnect());
337 
338   EXPECT_CALL(request_handler_, OnResponseBackendComplete(_));
339   EXPECT_CALL(request_handler_,
340               TerminateStreamWithError(Property(
341                   &QuicResetStreamError::ietf_application_code,
342                   static_cast<uint64_t>(QuicHttp3ErrorCode::CONNECT_ERROR))));
343 
344   spdy::Http2HeaderBlock request_headers;
345   request_headers[":method"] = "CONNECT";
346   request_headers[":authority"] =
347       absl::StrCat(kAcceptableDestination, ":", kAcceptablePort);
348 
349   tunnel_.OpenTunnel(request_headers);
350 
351   // Simulate receving error.
352   tunnel_.ReceiveComplete(absl::UnknownError("error"));
353 
354   tunnel_.OnClientStreamClose();
355 }
356 
357 }  // namespace
358 }  // namespace quic::test
359