1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/transport_client_socket_test_util.h"
6
7 #include <string>
8
9 #include "base/functional/bind.h"
10 #include "base/memory/raw_ptr.h"
11 #include "base/memory/ref_counted.h"
12 #include "base/run_loop.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/base/ip_address.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/test_completion_callback.h"
18 #include "net/log/net_log_event_type.h"
19 #include "net/log/net_log_source.h"
20 #include "net/log/net_log_with_source.h"
21 #include "net/log/test_net_log.h"
22 #include "net/log/test_net_log_util.h"
23 #include "net/socket/client_socket_factory.h"
24 #include "net/socket/tcp_client_socket.h"
25 #include "net/socket/tcp_server_socket.h"
26 #include "net/test/gtest_util.h"
27 #include "net/test/test_with_task_environment.h"
28 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
29 #include "testing/gmock/include/gmock/gmock.h"
30 #include "testing/gtest/include/gtest/gtest.h"
31 #include "testing/platform_test.h"
32
33 using net::test::IsError;
34 using net::test::IsOk;
35
36 namespace net {
37
38 namespace {
39
40 const char kServerReply[] = "HTTP/1.1 404 Not Found";
41
42 } // namespace
43
44 class TransportClientSocketTest : public ::testing::Test,
45 public WithTaskEnvironment {
46 public:
TransportClientSocketTest()47 TransportClientSocketTest()
48 : socket_factory_(ClientSocketFactory::GetDefaultFactory()) {}
49
50 ~TransportClientSocketTest() override = default;
51
52 // Testcase hooks
53 void SetUp() override;
54
CloseServerSocket()55 void CloseServerSocket() {
56 // delete the connected_sock_, which will close it.
57 connected_sock_.reset();
58 }
59
AcceptCallback(int res)60 void AcceptCallback(int res) {
61 ASSERT_THAT(res, IsOk());
62 connect_loop_.Quit();
63 }
64
65 // Establishes a connection to the server.
66 void EstablishConnection(TestCompletionCallback* callback);
67
68 protected:
69 base::RunLoop connect_loop_;
70 uint16_t listen_port_ = 0;
71 RecordingNetLogObserver net_log_observer_;
72 const raw_ptr<ClientSocketFactory> socket_factory_;
73 std::unique_ptr<StreamSocket> sock_;
74 std::unique_ptr<StreamSocket> connected_sock_;
75
76 private:
77 std::unique_ptr<TCPServerSocket> listen_sock_;
78 };
79
SetUp()80 void TransportClientSocketTest::SetUp() {
81 // Open a server socket on an ephemeral port.
82 listen_sock_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
83 IPEndPoint local_address(IPAddress::IPv4Localhost(), 0);
84 ASSERT_THAT(
85 listen_sock_->Listen(local_address, 1, /*ipv6_only=*/std::nullopt),
86 IsOk());
87 // Get the server's address (including the actual port number).
88 ASSERT_THAT(listen_sock_->GetLocalAddress(&local_address), IsOk());
89 listen_port_ = local_address.port();
90 listen_sock_->Accept(
91 &connected_sock_,
92 base::BindOnce(&TransportClientSocketTest::AcceptCallback,
93 base::Unretained(this)));
94
95 AddressList addr = AddressList::CreateFromIPAddress(
96 IPAddress::IPv4Localhost(), listen_port_);
97 sock_ = socket_factory_->CreateTransportClientSocket(
98 addr, nullptr, nullptr, NetLog::Get(), NetLogSource());
99 }
100
EstablishConnection(TestCompletionCallback * callback)101 void TransportClientSocketTest::EstablishConnection(
102 TestCompletionCallback* callback) {
103 int rv = sock_->Connect(callback->callback());
104 // Wait for |listen_sock_| to accept a connection.
105 connect_loop_.Run();
106 // Now wait for the client socket to accept the connection.
107 EXPECT_THAT(callback->GetResult(rv), IsOk());
108 }
109
TEST_F(TransportClientSocketTest,Connect)110 TEST_F(TransportClientSocketTest, Connect) {
111 TestCompletionCallback callback;
112 EXPECT_FALSE(sock_->IsConnected());
113
114 int rv = sock_->Connect(callback.callback());
115 // Wait for |listen_sock_| to accept a connection.
116 connect_loop_.Run();
117
118 auto net_log_entries = net_log_observer_.GetEntries();
119 EXPECT_TRUE(
120 LogContainsBeginEvent(net_log_entries, 0, NetLogEventType::SOCKET_ALIVE));
121 EXPECT_TRUE(
122 LogContainsBeginEvent(net_log_entries, 1, NetLogEventType::TCP_CONNECT));
123 // Now wait for the client socket to accept the connection.
124 if (rv != OK) {
125 ASSERT_EQ(rv, ERR_IO_PENDING);
126 rv = callback.WaitForResult();
127 EXPECT_EQ(rv, OK);
128 }
129
130 EXPECT_TRUE(sock_->IsConnected());
131 net_log_entries = net_log_observer_.GetEntries();
132 EXPECT_TRUE(
133 LogContainsEndEvent(net_log_entries, -1, NetLogEventType::TCP_CONNECT));
134
135 sock_->Disconnect();
136 EXPECT_FALSE(sock_->IsConnected());
137 }
138
TEST_F(TransportClientSocketTest,IsConnected)139 TEST_F(TransportClientSocketTest, IsConnected) {
140 auto buf = base::MakeRefCounted<IOBufferWithSize>(4096);
141 TestCompletionCallback callback;
142 uint32_t bytes_read;
143
144 EXPECT_FALSE(sock_->IsConnected());
145 EXPECT_FALSE(sock_->IsConnectedAndIdle());
146
147 EstablishConnection(&callback);
148
149 EXPECT_TRUE(sock_->IsConnected());
150 EXPECT_TRUE(sock_->IsConnectedAndIdle());
151
152 // Send the request and wait for the server to respond.
153 SendRequestAndResponse(sock_.get(), connected_sock_.get());
154
155 // Drain a single byte so we know we've received some data.
156 bytes_read = DrainStreamSocket(sock_.get(), buf.get(), 1, 1, &callback);
157 ASSERT_EQ(bytes_read, 1u);
158
159 // Socket should be considered connected, but not idle, due to
160 // pending data.
161 EXPECT_TRUE(sock_->IsConnected());
162 EXPECT_FALSE(sock_->IsConnectedAndIdle());
163
164 bytes_read = DrainStreamSocket(sock_.get(), buf.get(), 4096,
165 strlen(kServerReply) - 1, &callback);
166 ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
167
168 // After draining the data, the socket should be back to connected
169 // and idle.
170 EXPECT_TRUE(sock_->IsConnected());
171 EXPECT_TRUE(sock_->IsConnectedAndIdle());
172
173 // This time close the server socket immediately after the server response.
174 SendRequestAndResponse(sock_.get(), connected_sock_.get());
175 CloseServerSocket();
176
177 bytes_read = DrainStreamSocket(sock_.get(), buf.get(), 1, 1, &callback);
178 ASSERT_EQ(bytes_read, 1u);
179
180 // As above because of data.
181 EXPECT_TRUE(sock_->IsConnected());
182 EXPECT_FALSE(sock_->IsConnectedAndIdle());
183
184 bytes_read = DrainStreamSocket(sock_.get(), buf.get(), 4096,
185 strlen(kServerReply) - 1, &callback);
186 ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
187
188 // Once the data is drained, the socket should now be seen as not
189 // connected.
190 if (sock_->IsConnected()) {
191 // In the unlikely event that the server's connection closure is not
192 // processed in time, wait for the connection to be closed.
193 int rv = sock_->Read(buf.get(), 4096, callback.callback());
194 EXPECT_EQ(0, callback.GetResult(rv));
195 EXPECT_FALSE(sock_->IsConnected());
196 }
197 EXPECT_FALSE(sock_->IsConnectedAndIdle());
198 }
199
TEST_F(TransportClientSocketTest,Read)200 TEST_F(TransportClientSocketTest, Read) {
201 TestCompletionCallback callback;
202 EstablishConnection(&callback);
203
204 SendRequestAndResponse(sock_.get(), connected_sock_.get());
205
206 auto buf = base::MakeRefCounted<IOBufferWithSize>(4096);
207 uint32_t bytes_read = DrainStreamSocket(sock_.get(), buf.get(), 4096,
208 strlen(kServerReply), &callback);
209 ASSERT_EQ(bytes_read, strlen(kServerReply));
210 ASSERT_EQ(std::string(kServerReply), std::string(buf->data(), bytes_read));
211
212 // All data has been read now. Read once more to force an ERR_IO_PENDING, and
213 // then close the server socket, and note the close.
214
215 int rv = sock_->Read(buf.get(), 4096, callback.callback());
216 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
217 CloseServerSocket();
218 EXPECT_EQ(0, callback.WaitForResult());
219 }
220
TEST_F(TransportClientSocketTest,Read_SmallChunks)221 TEST_F(TransportClientSocketTest, Read_SmallChunks) {
222 TestCompletionCallback callback;
223 EstablishConnection(&callback);
224
225 SendRequestAndResponse(sock_.get(), connected_sock_.get());
226
227 auto buf = base::MakeRefCounted<IOBufferWithSize>(1);
228 uint32_t bytes_read = 0;
229 while (bytes_read < strlen(kServerReply)) {
230 int rv = sock_->Read(buf.get(), 1, callback.callback());
231 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
232
233 rv = callback.GetResult(rv);
234
235 ASSERT_EQ(1, rv);
236 bytes_read += rv;
237 }
238
239 // All data has been read now. Read once more to force an ERR_IO_PENDING, and
240 // then close the server socket, and note the close.
241
242 int rv = sock_->Read(buf.get(), 1, callback.callback());
243 ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
244 CloseServerSocket();
245 EXPECT_EQ(0, callback.WaitForResult());
246 }
247
TEST_F(TransportClientSocketTest,Read_Interrupted)248 TEST_F(TransportClientSocketTest, Read_Interrupted) {
249 TestCompletionCallback callback;
250 EstablishConnection(&callback);
251
252 SendRequestAndResponse(sock_.get(), connected_sock_.get());
253
254 // Do a partial read and then exit. This test should not crash!
255 auto buf = base::MakeRefCounted<IOBufferWithSize>(16);
256 int rv = sock_->Read(buf.get(), 16, callback.callback());
257 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
258
259 rv = callback.GetResult(rv);
260
261 EXPECT_NE(0, rv);
262 }
263
TEST_F(TransportClientSocketTest,FullDuplex_ReadFirst)264 TEST_F(TransportClientSocketTest, FullDuplex_ReadFirst) {
265 TestCompletionCallback callback;
266 EstablishConnection(&callback);
267
268 // Read first. There's no data, so it should return ERR_IO_PENDING.
269 const int kBufLen = 4096;
270 auto buf = base::MakeRefCounted<IOBufferWithSize>(kBufLen);
271 int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
272 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
273
274 const int kWriteBufLen = 64 * 1024;
275 auto request_buffer = base::MakeRefCounted<IOBufferWithSize>(kWriteBufLen);
276 char* request_data = request_buffer->data();
277 memset(request_data, 'A', kWriteBufLen);
278 TestCompletionCallback write_callback;
279
280 int bytes_written = 0;
281 while (true) {
282 rv = sock_->Write(request_buffer.get(), kWriteBufLen,
283 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
284 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
285 if (rv == ERR_IO_PENDING) {
286 ReadDataOfExpectedLength(connected_sock_.get(), bytes_written);
287 SendServerResponse(connected_sock_.get());
288 rv = write_callback.WaitForResult();
289 break;
290 }
291 bytes_written += rv;
292 }
293
294 // At this point, both read and write have returned ERR_IO_PENDING, and the
295 // write callback has executed. We wait for the read callback to run now to
296 // make sure that the socket can handle full duplex communications.
297
298 rv = callback.WaitForResult();
299 EXPECT_GE(rv, 0);
300 }
301
TEST_F(TransportClientSocketTest,FullDuplex_WriteFirst)302 TEST_F(TransportClientSocketTest, FullDuplex_WriteFirst) {
303 TestCompletionCallback callback;
304 EstablishConnection(&callback);
305
306 const int kWriteBufLen = 64 * 1024;
307 auto request_buffer = base::MakeRefCounted<IOBufferWithSize>(kWriteBufLen);
308 char* request_data = request_buffer->data();
309 memset(request_data, 'A', kWriteBufLen);
310 TestCompletionCallback write_callback;
311
312 int bytes_written = 0;
313 while (true) {
314 int rv =
315 sock_->Write(request_buffer.get(), kWriteBufLen,
316 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
317 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
318
319 if (rv == ERR_IO_PENDING)
320 break;
321 bytes_written += rv;
322 }
323
324 // Now we have the Write() blocked on ERR_IO_PENDING. It's time to force the
325 // Read() to block on ERR_IO_PENDING too.
326
327 const int kBufLen = 4096;
328 auto buf = base::MakeRefCounted<IOBufferWithSize>(kBufLen);
329 while (true) {
330 int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
331 ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
332 if (rv == ERR_IO_PENDING)
333 break;
334 }
335
336 // At this point, both read and write have returned ERR_IO_PENDING. Now we
337 // run the write and read callbacks to make sure they can handle full duplex
338 // communications.
339
340 ReadDataOfExpectedLength(connected_sock_.get(), bytes_written);
341 SendServerResponse(connected_sock_.get());
342 int rv = write_callback.WaitForResult();
343 EXPECT_GE(rv, 0);
344
345 rv = callback.WaitForResult();
346 EXPECT_GT(rv, 0);
347 }
348
349 } // namespace net
350