1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks5_client_socket.h"
6
7 #include <algorithm>
8 #include <iterator>
9 #include <map>
10 #include <memory>
11 #include <utility>
12
13 #include "base/containers/span.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/sys_byteorder.h"
17 #include "build/build_config.h"
18 #include "net/base/address_list.h"
19 #include "net/base/test_completion_callback.h"
20 #include "net/base/winsock_init.h"
21 #include "net/log/net_log_event_type.h"
22 #include "net/log/test_net_log.h"
23 #include "net/log/test_net_log_util.h"
24 #include "net/socket/client_socket_factory.h"
25 #include "net/socket/socket_test_util.h"
26 #include "net/socket/tcp_client_socket.h"
27 #include "net/test/gtest_util.h"
28 #include "net/test/test_with_task_environment.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "testing/gmock/include/gmock/gmock.h"
31 #include "testing/gtest/include/gtest/gtest.h"
32 #include "testing/platform_test.h"
33
34 using net::test::IsError;
35 using net::test::IsOk;
36
37 //-----------------------------------------------------------------------------
38
39 namespace net {
40
41 class NetLog;
42
43 namespace {
44
45 // Base class to test SOCKS5ClientSocket
46 class SOCKS5ClientSocketTest : public PlatformTest, public WithTaskEnvironment {
47 public:
48 SOCKS5ClientSocketTest();
49
50 SOCKS5ClientSocketTest(const SOCKS5ClientSocketTest&) = delete;
51 SOCKS5ClientSocketTest& operator=(const SOCKS5ClientSocketTest&) = delete;
52
53 // Create a SOCKSClientSocket on top of a MockSocket.
54 std::unique_ptr<SOCKS5ClientSocket> BuildMockSocket(
55 base::span<const MockRead> reads,
56 base::span<const MockWrite> writes,
57 const std::string& hostname,
58 int port,
59 NetLog* net_log);
60
61 void SetUp() override;
62
63 protected:
64 const uint16_t kNwPort;
65 RecordingNetLogObserver net_log_observer_;
66 std::unique_ptr<SOCKS5ClientSocket> user_sock_;
67 AddressList address_list_;
68 // Filled in by BuildMockSocket() and owned by its return value
69 // (which |user_sock| is set to).
70 raw_ptr<StreamSocket> tcp_sock_;
71 TestCompletionCallback callback_;
72 std::unique_ptr<SocketDataProvider> data_;
73 };
74
SOCKS5ClientSocketTest()75 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
76 : kNwPort(base::HostToNet16(80)) {}
77
78 // Set up platform before every test case
SetUp()79 void SOCKS5ClientSocketTest::SetUp() {
80 PlatformTest::SetUp();
81
82 // Create the "localhost" AddressList used by the TCP connection to connect.
83 address_list_ =
84 AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), 1080);
85 }
86
BuildMockSocket(base::span<const MockRead> reads,base::span<const MockWrite> writes,const std::string & hostname,int port,NetLog * net_log)87 std::unique_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
88 base::span<const MockRead> reads,
89 base::span<const MockWrite> writes,
90 const std::string& hostname,
91 int port,
92 NetLog* net_log) {
93 TestCompletionCallback callback;
94 data_ = std::make_unique<StaticSocketDataProvider>(reads, writes);
95 auto tcp_sock = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
96 data_.get());
97 tcp_sock_ = tcp_sock.get();
98
99 int rv = tcp_sock_->Connect(callback.callback());
100 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
101 rv = callback.WaitForResult();
102 EXPECT_THAT(rv, IsOk());
103 EXPECT_TRUE(tcp_sock_->IsConnected());
104
105 // The SOCKS5ClientSocket takes ownership of |tcp_sock_|, but keep a
106 // non-owning pointer to it.
107 return std::make_unique<SOCKS5ClientSocket>(std::move(tcp_sock),
108 HostPortPair(hostname, port),
109 TRAFFIC_ANNOTATION_FOR_TESTS);
110 }
111
112 // Tests a complete SOCKS5 handshake and the disconnection.
TEST_F(SOCKS5ClientSocketTest,CompleteHandshake)113 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
114 const std::string payload_write = "random data";
115 const std::string payload_read = "moar random data";
116
117 const char kOkRequest[] = {
118 0x05, // Version
119 0x01, // Command (CONNECT)
120 0x00, // Reserved.
121 0x03, // Address type (DOMAINNAME).
122 0x09, // Length of domain (9)
123 // Domain string:
124 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
125 0x00, 0x50, // 16-bit port (80)
126 };
127
128 MockWrite data_writes[] = {
129 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
130 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest)),
131 MockWrite(ASYNC, payload_write.data(), payload_write.size())};
132 MockRead data_reads[] = {
133 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
134 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
135 MockRead(ASYNC, payload_read.data(), payload_read.size()) };
136
137 user_sock_ =
138 BuildMockSocket(data_reads, data_writes, "localhost", 80, NetLog::Get());
139
140 // At this state the TCP connection is completed but not the SOCKS handshake.
141 EXPECT_TRUE(tcp_sock_->IsConnected());
142 EXPECT_FALSE(user_sock_->IsConnected());
143
144 int rv = user_sock_->Connect(callback_.callback());
145 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
146 EXPECT_FALSE(user_sock_->IsConnected());
147
148 auto net_log_entries = net_log_observer_.GetEntries();
149 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
150 NetLogEventType::SOCKS5_CONNECT));
151
152 rv = callback_.WaitForResult();
153
154 EXPECT_THAT(rv, IsOk());
155 EXPECT_TRUE(user_sock_->IsConnected());
156
157 net_log_entries = net_log_observer_.GetEntries();
158 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
159 NetLogEventType::SOCKS5_CONNECT));
160
161 auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload_write.size());
162 memcpy(buffer->data(), payload_write.data(), payload_write.size());
163 rv = user_sock_->Write(buffer.get(), payload_write.size(),
164 callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
165 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
166 rv = callback_.WaitForResult();
167 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
168
169 buffer = base::MakeRefCounted<IOBufferWithSize>(payload_read.size());
170 rv =
171 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
172 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
173 rv = callback_.WaitForResult();
174 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
175 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
176
177 user_sock_->Disconnect();
178 EXPECT_FALSE(tcp_sock_->IsConnected());
179 EXPECT_FALSE(user_sock_->IsConnected());
180 }
181
182 // Test that you can call Connect() again after having called Disconnect().
TEST_F(SOCKS5ClientSocketTest,ConnectAndDisconnectTwice)183 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
184 const std::string hostname = "my-host-name";
185 const char kSOCKS5DomainRequest[] = {
186 0x05, // VER
187 0x01, // CMD
188 0x00, // RSV
189 0x03, // ATYPE
190 };
191
192 std::string request(kSOCKS5DomainRequest, std::size(kSOCKS5DomainRequest));
193 request.push_back(static_cast<char>(hostname.size()));
194 request.append(hostname);
195 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
196
197 for (int i = 0; i < 2; ++i) {
198 MockWrite data_writes[] = {
199 MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
200 MockWrite(SYNCHRONOUS, request.data(), request.size())
201 };
202 MockRead data_reads[] = {
203 MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
204 MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
205 };
206
207 user_sock_ =
208 BuildMockSocket(data_reads, data_writes, hostname, 80, nullptr);
209
210 int rv = user_sock_->Connect(callback_.callback());
211 EXPECT_THAT(rv, IsOk());
212 EXPECT_TRUE(user_sock_->IsConnected());
213
214 user_sock_->Disconnect();
215 EXPECT_FALSE(user_sock_->IsConnected());
216 }
217 }
218
219 // Test that we fail trying to connect to a hostname longer than 255 bytes.
TEST_F(SOCKS5ClientSocketTest,LargeHostNameFails)220 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
221 // Create a string of length 256, where each character is 'x'.
222 std::string large_host_name;
223 std::fill_n(std::back_inserter(large_host_name), 256, 'x');
224
225 // Create a SOCKS socket, with mock transport socket.
226 MockWrite data_writes[] = {MockWrite()};
227 MockRead data_reads[] = {MockRead()};
228 user_sock_ =
229 BuildMockSocket(data_reads, data_writes, large_host_name, 80, nullptr);
230
231 // Try to connect -- should fail (without having read/written anything to
232 // the transport socket first) because the hostname is too long.
233 TestCompletionCallback callback;
234 int rv = user_sock_->Connect(callback.callback());
235 EXPECT_THAT(rv, IsError(ERR_SOCKS_CONNECTION_FAILED));
236 }
237
TEST_F(SOCKS5ClientSocketTest,PartialReadWrites)238 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
239 const std::string hostname = "www.google.com";
240
241 const char kOkRequest[] = {
242 0x05, // Version
243 0x01, // Command (CONNECT)
244 0x00, // Reserved.
245 0x03, // Address type (DOMAINNAME).
246 0x0E, // Length of domain (14)
247 // Domain string:
248 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
249 0x00, 0x50, // 16-bit port (80)
250 };
251
252 // Test for partial greet request write
253 {
254 const char partial1[] = { 0x05, 0x01 };
255 const char partial2[] = { 0x00 };
256 MockWrite data_writes[] = {
257 MockWrite(ASYNC, partial1, std::size(partial1)),
258 MockWrite(ASYNC, partial2, std::size(partial2)),
259 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
260 MockRead data_reads[] = {
261 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
262 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
263 user_sock_ =
264 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
265 int rv = user_sock_->Connect(callback_.callback());
266 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
267
268 auto net_log_entries = net_log_observer_.GetEntries();
269 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
270 NetLogEventType::SOCKS5_CONNECT));
271
272 rv = callback_.WaitForResult();
273 EXPECT_THAT(rv, IsOk());
274 EXPECT_TRUE(user_sock_->IsConnected());
275
276 net_log_entries = net_log_observer_.GetEntries();
277 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
278 NetLogEventType::SOCKS5_CONNECT));
279 }
280
281 // Test for partial greet response read
282 {
283 const char partial1[] = { 0x05 };
284 const char partial2[] = { 0x00 };
285 MockWrite data_writes[] = {
286 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
287 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
288 MockRead data_reads[] = {
289 MockRead(ASYNC, partial1, std::size(partial1)),
290 MockRead(ASYNC, partial2, std::size(partial2)),
291 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength)};
292 user_sock_ =
293 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
294 int rv = user_sock_->Connect(callback_.callback());
295 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
296
297 auto net_log_entries = net_log_observer_.GetEntries();
298 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
299 NetLogEventType::SOCKS5_CONNECT));
300 rv = callback_.WaitForResult();
301 EXPECT_THAT(rv, IsOk());
302 EXPECT_TRUE(user_sock_->IsConnected());
303 net_log_entries = net_log_observer_.GetEntries();
304 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
305 NetLogEventType::SOCKS5_CONNECT));
306 }
307
308 // Test for partial handshake request write.
309 {
310 const int kSplitPoint = 3; // Break handshake write into two parts.
311 MockWrite data_writes[] = {
312 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
313 MockWrite(ASYNC, kOkRequest, kSplitPoint),
314 MockWrite(ASYNC, kOkRequest + kSplitPoint,
315 std::size(kOkRequest) - kSplitPoint)};
316 MockRead data_reads[] = {
317 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
318 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
319 user_sock_ =
320 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
321 int rv = user_sock_->Connect(callback_.callback());
322 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
323 auto net_log_entries = net_log_observer_.GetEntries();
324 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
325 NetLogEventType::SOCKS5_CONNECT));
326 rv = callback_.WaitForResult();
327 EXPECT_THAT(rv, IsOk());
328 EXPECT_TRUE(user_sock_->IsConnected());
329 net_log_entries = net_log_observer_.GetEntries();
330 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
331 NetLogEventType::SOCKS5_CONNECT));
332 }
333
334 // Test for partial handshake response read
335 {
336 const int kSplitPoint = 6; // Break the handshake read into two parts.
337 MockWrite data_writes[] = {
338 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
339 MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
340 MockRead data_reads[] = {
341 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
342 MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
343 MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
344 kSOCKS5OkResponseLength - kSplitPoint)
345 };
346
347 user_sock_ =
348 BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
349 int rv = user_sock_->Connect(callback_.callback());
350 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
351 auto net_log_entries = net_log_observer_.GetEntries();
352 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
353 NetLogEventType::SOCKS5_CONNECT));
354 rv = callback_.WaitForResult();
355 EXPECT_THAT(rv, IsOk());
356 EXPECT_TRUE(user_sock_->IsConnected());
357 net_log_entries = net_log_observer_.GetEntries();
358 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
359 NetLogEventType::SOCKS5_CONNECT));
360 }
361 }
362
TEST_F(SOCKS5ClientSocketTest,Tag)363 TEST_F(SOCKS5ClientSocketTest, Tag) {
364 StaticSocketDataProvider data;
365 auto tagging_sock = std::make_unique<MockTaggingStreamSocket>(
366 std::make_unique<MockTCPClientSocket>(address_list_, NetLog::Get(),
367 &data));
368 auto* tagging_sock_ptr = tagging_sock.get();
369
370 // |socket| takes ownership of |tagging_sock|, but keep a non-owning pointer
371 // to it.
372 SOCKS5ClientSocket socket(std::move(tagging_sock),
373 HostPortPair("localhost", 80),
374 TRAFFIC_ANNOTATION_FOR_TESTS);
375
376 EXPECT_EQ(tagging_sock_ptr->tag(), SocketTag());
377 #if BUILDFLAG(IS_ANDROID)
378 SocketTag tag(0x12345678, 0x87654321);
379 socket.ApplySocketTag(tag);
380 EXPECT_EQ(tagging_sock_ptr->tag(), tag);
381 #endif // BUILDFLAG(IS_ANDROID)
382 }
383
384 } // namespace
385
386 } // namespace net
387