xref: /aosp_15_r20/external/cronet/net/socket/socks5_client_socket_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks5_client_socket.h"
6 
7 #include <algorithm>
8 #include <iterator>
9 #include <map>
10 #include <memory>
11 #include <utility>
12 
13 #include "base/containers/span.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/sys_byteorder.h"
17 #include "build/build_config.h"
18 #include "net/base/address_list.h"
19 #include "net/base/test_completion_callback.h"
20 #include "net/base/winsock_init.h"
21 #include "net/log/net_log_event_type.h"
22 #include "net/log/test_net_log.h"
23 #include "net/log/test_net_log_util.h"
24 #include "net/socket/client_socket_factory.h"
25 #include "net/socket/socket_test_util.h"
26 #include "net/socket/tcp_client_socket.h"
27 #include "net/test/gtest_util.h"
28 #include "net/test/test_with_task_environment.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "testing/gmock/include/gmock/gmock.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32 #include "testing/platform_test.h"
33 
34 using net::test::IsError;
35 using net::test::IsOk;
36 
37 //-----------------------------------------------------------------------------
38 
39 namespace net {
40 
41 class NetLog;
42 
43 namespace {
44 
45 // Base class to test SOCKS5ClientSocket
46 class SOCKS5ClientSocketTest : public PlatformTest, public WithTaskEnvironment {
47  public:
48   SOCKS5ClientSocketTest();
49 
50   SOCKS5ClientSocketTest(const SOCKS5ClientSocketTest&) = delete;
51   SOCKS5ClientSocketTest& operator=(const SOCKS5ClientSocketTest&) = delete;
52 
53   // Create a SOCKSClientSocket on top of a MockSocket.
54   std::unique_ptr<SOCKS5ClientSocket> BuildMockSocket(
55       base::span<const MockRead> reads,
56       base::span<const MockWrite> writes,
57       const std::string& hostname,
58       int port,
59       NetLog* net_log);
60 
61   void SetUp() override;
62 
63  protected:
64   const uint16_t kNwPort;
65   RecordingNetLogObserver net_log_observer_;
66   std::unique_ptr<SOCKS5ClientSocket> user_sock_;
67   AddressList address_list_;
68   // Filled in by BuildMockSocket() and owned by its return value
69   // (which |user_sock| is set to).
70   raw_ptr<StreamSocket> tcp_sock_;
71   TestCompletionCallback callback_;
72   std::unique_ptr<SocketDataProvider> data_;
73 };
74 
SOCKS5ClientSocketTest()75 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
76     : kNwPort(base::HostToNet16(80)) {}
77 
78 // Set up platform before every test case
SetUp()79 void SOCKS5ClientSocketTest::SetUp() {
80   PlatformTest::SetUp();
81 
82   // Create the "localhost" AddressList used by the TCP connection to connect.
83   address_list_ =
84       AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), 1080);
85 }
86 
BuildMockSocket(base::span<const MockRead> reads,base::span<const MockWrite> writes,const std::string & hostname,int port,NetLog * net_log)87 std::unique_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
88     base::span<const MockRead> reads,
89     base::span<const MockWrite> writes,
90     const std::string& hostname,
91     int port,
92     NetLog* net_log) {
93   TestCompletionCallback callback;
94   data_ = std::make_unique<StaticSocketDataProvider>(reads, writes);
95   auto tcp_sock = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
96                                                         data_.get());
97   tcp_sock_ = tcp_sock.get();
98 
99   int rv = tcp_sock_->Connect(callback.callback());
100   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
101   rv = callback.WaitForResult();
102   EXPECT_THAT(rv, IsOk());
103   EXPECT_TRUE(tcp_sock_->IsConnected());
104 
105   // The SOCKS5ClientSocket takes ownership of |tcp_sock_|, but keep a
106   // non-owning pointer to it.
107   return std::make_unique<SOCKS5ClientSocket>(std::move(tcp_sock),
108                                               HostPortPair(hostname, port),
109                                               TRAFFIC_ANNOTATION_FOR_TESTS);
110 }
111 
112 // Tests a complete SOCKS5 handshake and the disconnection.
TEST_F(SOCKS5ClientSocketTest,CompleteHandshake)113 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
114   const std::string payload_write = "random data";
115   const std::string payload_read = "moar random data";
116 
117   const char kOkRequest[] = {
118     0x05,  // Version
119     0x01,  // Command (CONNECT)
120     0x00,  // Reserved.
121     0x03,  // Address type (DOMAINNAME).
122     0x09,  // Length of domain (9)
123     // Domain string:
124     'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
125     0x00, 0x50,  // 16-bit port (80)
126   };
127 
128   MockWrite data_writes[] = {
129       MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
130       MockWrite(ASYNC, kOkRequest, std::size(kOkRequest)),
131       MockWrite(ASYNC, payload_write.data(), payload_write.size())};
132   MockRead data_reads[] = {
133       MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
134       MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
135       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
136 
137   user_sock_ =
138       BuildMockSocket(data_reads, data_writes, "localhost", 80, NetLog::Get());
139 
140   // At this state the TCP connection is completed but not the SOCKS handshake.
141   EXPECT_TRUE(tcp_sock_->IsConnected());
142   EXPECT_FALSE(user_sock_->IsConnected());
143 
144   int rv = user_sock_->Connect(callback_.callback());
145   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
146   EXPECT_FALSE(user_sock_->IsConnected());
147 
148   auto net_log_entries = net_log_observer_.GetEntries();
149   EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
150                                     NetLogEventType::SOCKS5_CONNECT));
151 
152   rv = callback_.WaitForResult();
153 
154   EXPECT_THAT(rv, IsOk());
155   EXPECT_TRUE(user_sock_->IsConnected());
156 
157   net_log_entries = net_log_observer_.GetEntries();
158   EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
159                                   NetLogEventType::SOCKS5_CONNECT));
160 
161   auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload_write.size());
162   memcpy(buffer->data(), payload_write.data(), payload_write.size());
163   rv = user_sock_->Write(buffer.get(), payload_write.size(),
164                          callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
165   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
166   rv = callback_.WaitForResult();
167   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
168 
169   buffer = base::MakeRefCounted<IOBufferWithSize>(payload_read.size());
170   rv =
171       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
172   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
173   rv = callback_.WaitForResult();
174   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
175   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
176 
177   user_sock_->Disconnect();
178   EXPECT_FALSE(tcp_sock_->IsConnected());
179   EXPECT_FALSE(user_sock_->IsConnected());
180 }
181 
182 // Test that you can call Connect() again after having called Disconnect().
TEST_F(SOCKS5ClientSocketTest,ConnectAndDisconnectTwice)183 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
184   const std::string hostname = "my-host-name";
185   const char kSOCKS5DomainRequest[] = {
186       0x05,  // VER
187       0x01,  // CMD
188       0x00,  // RSV
189       0x03,  // ATYPE
190   };
191 
192   std::string request(kSOCKS5DomainRequest, std::size(kSOCKS5DomainRequest));
193   request.push_back(static_cast<char>(hostname.size()));
194   request.append(hostname);
195   request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
196 
197   for (int i = 0; i < 2; ++i) {
198     MockWrite data_writes[] = {
199         MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
200         MockWrite(SYNCHRONOUS, request.data(), request.size())
201     };
202     MockRead data_reads[] = {
203         MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
204         MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
205     };
206 
207     user_sock_ =
208         BuildMockSocket(data_reads, data_writes, hostname, 80, nullptr);
209 
210     int rv = user_sock_->Connect(callback_.callback());
211     EXPECT_THAT(rv, IsOk());
212     EXPECT_TRUE(user_sock_->IsConnected());
213 
214     user_sock_->Disconnect();
215     EXPECT_FALSE(user_sock_->IsConnected());
216   }
217 }
218 
219 // Test that we fail trying to connect to a hostname longer than 255 bytes.
TEST_F(SOCKS5ClientSocketTest,LargeHostNameFails)220 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
221   // Create a string of length 256, where each character is 'x'.
222   std::string large_host_name;
223   std::fill_n(std::back_inserter(large_host_name), 256, 'x');
224 
225   // Create a SOCKS socket, with mock transport socket.
226   MockWrite data_writes[] = {MockWrite()};
227   MockRead data_reads[] = {MockRead()};
228   user_sock_ =
229       BuildMockSocket(data_reads, data_writes, large_host_name, 80, nullptr);
230 
231   // Try to connect -- should fail (without having read/written anything to
232   // the transport socket first) because the hostname is too long.
233   TestCompletionCallback callback;
234   int rv = user_sock_->Connect(callback.callback());
235   EXPECT_THAT(rv, IsError(ERR_SOCKS_CONNECTION_FAILED));
236 }
237 
TEST_F(SOCKS5ClientSocketTest,PartialReadWrites)238 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
239   const std::string hostname = "www.google.com";
240 
241   const char kOkRequest[] = {
242     0x05,  // Version
243     0x01,  // Command (CONNECT)
244     0x00,  // Reserved.
245     0x03,  // Address type (DOMAINNAME).
246     0x0E,  // Length of domain (14)
247     // Domain string:
248     'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
249     0x00, 0x50,  // 16-bit port (80)
250   };
251 
252   // Test for partial greet request write
253   {
254     const char partial1[] = { 0x05, 0x01 };
255     const char partial2[] = { 0x00 };
256     MockWrite data_writes[] = {
257         MockWrite(ASYNC, partial1, std::size(partial1)),
258         MockWrite(ASYNC, partial2, std::size(partial2)),
259         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
260     MockRead data_reads[] = {
261         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
262         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
263     user_sock_ =
264         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
265     int rv = user_sock_->Connect(callback_.callback());
266     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
267 
268     auto net_log_entries = net_log_observer_.GetEntries();
269     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
270                                       NetLogEventType::SOCKS5_CONNECT));
271 
272     rv = callback_.WaitForResult();
273     EXPECT_THAT(rv, IsOk());
274     EXPECT_TRUE(user_sock_->IsConnected());
275 
276     net_log_entries = net_log_observer_.GetEntries();
277     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
278                                     NetLogEventType::SOCKS5_CONNECT));
279   }
280 
281   // Test for partial greet response read
282   {
283     const char partial1[] = { 0x05 };
284     const char partial2[] = { 0x00 };
285     MockWrite data_writes[] = {
286         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
287         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
288     MockRead data_reads[] = {
289         MockRead(ASYNC, partial1, std::size(partial1)),
290         MockRead(ASYNC, partial2, std::size(partial2)),
291         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength)};
292     user_sock_ =
293         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
294     int rv = user_sock_->Connect(callback_.callback());
295     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
296 
297     auto net_log_entries = net_log_observer_.GetEntries();
298     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
299                                       NetLogEventType::SOCKS5_CONNECT));
300     rv = callback_.WaitForResult();
301     EXPECT_THAT(rv, IsOk());
302     EXPECT_TRUE(user_sock_->IsConnected());
303     net_log_entries = net_log_observer_.GetEntries();
304     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
305                                     NetLogEventType::SOCKS5_CONNECT));
306   }
307 
308   // Test for partial handshake request write.
309   {
310     const int kSplitPoint = 3;  // Break handshake write into two parts.
311     MockWrite data_writes[] = {
312         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
313         MockWrite(ASYNC, kOkRequest, kSplitPoint),
314         MockWrite(ASYNC, kOkRequest + kSplitPoint,
315                   std::size(kOkRequest) - kSplitPoint)};
316     MockRead data_reads[] = {
317         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
318         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
319     user_sock_ =
320         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
321     int rv = user_sock_->Connect(callback_.callback());
322     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
323     auto net_log_entries = net_log_observer_.GetEntries();
324     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
325                                       NetLogEventType::SOCKS5_CONNECT));
326     rv = callback_.WaitForResult();
327     EXPECT_THAT(rv, IsOk());
328     EXPECT_TRUE(user_sock_->IsConnected());
329     net_log_entries = net_log_observer_.GetEntries();
330     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
331                                     NetLogEventType::SOCKS5_CONNECT));
332   }
333 
334   // Test for partial handshake response read
335   {
336     const int kSplitPoint = 6;  // Break the handshake read into two parts.
337     MockWrite data_writes[] = {
338         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
339         MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
340     MockRead data_reads[] = {
341         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
342         MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
343         MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
344                  kSOCKS5OkResponseLength - kSplitPoint)
345     };
346 
347     user_sock_ =
348         BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
349     int rv = user_sock_->Connect(callback_.callback());
350     EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
351     auto net_log_entries = net_log_observer_.GetEntries();
352     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
353                                       NetLogEventType::SOCKS5_CONNECT));
354     rv = callback_.WaitForResult();
355     EXPECT_THAT(rv, IsOk());
356     EXPECT_TRUE(user_sock_->IsConnected());
357     net_log_entries = net_log_observer_.GetEntries();
358     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
359                                     NetLogEventType::SOCKS5_CONNECT));
360   }
361 }
362 
TEST_F(SOCKS5ClientSocketTest,Tag)363 TEST_F(SOCKS5ClientSocketTest, Tag) {
364   StaticSocketDataProvider data;
365   auto tagging_sock = std::make_unique<MockTaggingStreamSocket>(
366       std::make_unique<MockTCPClientSocket>(address_list_, NetLog::Get(),
367                                             &data));
368   auto* tagging_sock_ptr = tagging_sock.get();
369 
370   // |socket| takes ownership of |tagging_sock|, but keep a non-owning pointer
371   // to it.
372   SOCKS5ClientSocket socket(std::move(tagging_sock),
373                             HostPortPair("localhost", 80),
374                             TRAFFIC_ANNOTATION_FOR_TESTS);
375 
376   EXPECT_EQ(tagging_sock_ptr->tag(), SocketTag());
377 #if BUILDFLAG(IS_ANDROID)
378   SocketTag tag(0x12345678, 0x87654321);
379   socket.ApplySocketTag(tag);
380   EXPECT_EQ(tagging_sock_ptr->tag(), tag);
381 #endif  // BUILDFLAG(IS_ANDROID)
382 }
383 
384 }  // namespace
385 
386 }  // namespace net
387