1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/http/http_response_body_drainer.h"
6
7 #include <stdint.h>
8
9 #include <cstring>
10 #include <set>
11 #include <string_view>
12 #include <utility>
13
14 #include "base/compiler_specific.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/no_destructor.h"
20 #include "base/run_loop.h"
21 #include "base/task/single_thread_task_runner.h"
22 #include "net/base/completion_once_callback.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/net_errors.h"
25 #include "net/base/test_completion_callback.h"
26 #include "net/cert/mock_cert_verifier.h"
27 #include "net/http/http_network_session.h"
28 #include "net/http/http_server_properties.h"
29 #include "net/http/http_stream.h"
30 #include "net/http/transport_security_state.h"
31 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
32 #include "net/quic/quic_context.h"
33 #include "net/socket/socket_test_util.h"
34 #include "net/ssl/ssl_config_service_defaults.h"
35 #include "net/test/test_with_task_environment.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37
38 namespace net {
39
40 namespace {
41
42 const int kMagicChunkSize = 1024;
43 static_assert((HttpResponseBodyDrainer::kDrainBodyBufferSize %
44 kMagicChunkSize) == 0,
45 "chunk size needs to divide evenly into buffer size");
46
47 class CloseResultWaiter {
48 public:
49 CloseResultWaiter() = default;
50
51 CloseResultWaiter(const CloseResultWaiter&) = delete;
52 CloseResultWaiter& operator=(const CloseResultWaiter&) = delete;
53
WaitForResult()54 int WaitForResult() {
55 CHECK(!waiting_for_result_);
56 while (!have_result_) {
57 waiting_for_result_ = true;
58 loop_.Run();
59 waiting_for_result_ = false;
60 }
61 return result_;
62 }
63
set_result(bool result)64 void set_result(bool result) {
65 result_ = result;
66 have_result_ = true;
67 if (waiting_for_result_) {
68 loop_.Quit();
69 }
70 }
71
72 private:
73 int result_ = false;
74 bool have_result_ = false;
75 bool waiting_for_result_ = false;
76 base::RunLoop loop_;
77 };
78
79 class MockHttpStream : public HttpStream {
80 public:
MockHttpStream(CloseResultWaiter * result_waiter)81 explicit MockHttpStream(CloseResultWaiter* result_waiter)
82 : result_waiter_(result_waiter) {}
83
84 MockHttpStream(const MockHttpStream&) = delete;
85 MockHttpStream& operator=(const MockHttpStream&) = delete;
86
87 ~MockHttpStream() override = default;
88
89 // HttpStream implementation.
RegisterRequest(const HttpRequestInfo * request_info)90 void RegisterRequest(const HttpRequestInfo* request_info) override {}
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)91 int InitializeStream(bool can_send_early,
92 RequestPriority priority,
93 const NetLogWithSource& net_log,
94 CompletionOnceCallback callback) override {
95 return ERR_UNEXPECTED;
96 }
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)97 int SendRequest(const HttpRequestHeaders& request_headers,
98 HttpResponseInfo* response,
99 CompletionOnceCallback callback) override {
100 return ERR_UNEXPECTED;
101 }
ReadResponseHeaders(CompletionOnceCallback callback)102 int ReadResponseHeaders(CompletionOnceCallback callback) override {
103 return ERR_UNEXPECTED;
104 }
105
IsConnectionReused() const106 bool IsConnectionReused() const override { return false; }
SetConnectionReused()107 void SetConnectionReused() override {}
CanReuseConnection() const108 bool CanReuseConnection() const override { return can_reuse_connection_; }
GetTotalReceivedBytes() const109 int64_t GetTotalReceivedBytes() const override { return 0; }
GetTotalSentBytes() const110 int64_t GetTotalSentBytes() const override { return 0; }
GetAlternativeService(AlternativeService * alternative_service) const111 bool GetAlternativeService(
112 AlternativeService* alternative_service) const override {
113 return false;
114 }
GetSSLInfo(SSLInfo * ssl_info)115 void GetSSLInfo(SSLInfo* ssl_info) override {}
GetRemoteEndpoint(IPEndPoint * endpoint)116 int GetRemoteEndpoint(IPEndPoint* endpoint) override {
117 return ERR_UNEXPECTED;
118 }
119
120 // Mocked API
121 int ReadResponseBody(IOBuffer* buf,
122 int buf_len,
123 CompletionOnceCallback callback) override;
Close(bool not_reusable)124 void Close(bool not_reusable) override {
125 CHECK(!closed_);
126 closed_ = true;
127 result_waiter_->set_result(not_reusable);
128 }
129
RenewStreamForAuth()130 std::unique_ptr<HttpStream> RenewStreamForAuth() override { return nullptr; }
131
IsResponseBodyComplete() const132 bool IsResponseBodyComplete() const override { return is_complete_; }
133
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const134 bool GetLoadTimingInfo(LoadTimingInfo* load_timing_info) const override {
135 return false;
136 }
137
Drain(HttpNetworkSession *)138 void Drain(HttpNetworkSession*) override {}
139
PopulateNetErrorDetails(NetErrorDetails * details)140 void PopulateNetErrorDetails(NetErrorDetails* details) override { return; }
141
SetPriority(RequestPriority priority)142 void SetPriority(RequestPriority priority) override {}
143
GetDnsAliases() const144 const std::set<std::string>& GetDnsAliases() const override {
145 static const base::NoDestructor<std::set<std::string>> nullset_result;
146 return *nullset_result;
147 }
148
GetAcceptChViaAlps() const149 std::string_view GetAcceptChViaAlps() const override { return {}; }
150
151 // Methods to tweak/observer mock behavior:
set_stall_reads_forever()152 void set_stall_reads_forever() { stall_reads_forever_ = true; }
153
set_num_chunks(int num_chunks)154 void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
155
set_sync()156 void set_sync() { is_sync_ = true; }
157
set_is_last_chunk_zero_size()158 void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
159
160 // Sets result value of CanReuseConnection. Defaults to true.
set_can_reuse_connection(bool can_reuse_connection)161 void set_can_reuse_connection(bool can_reuse_connection) {
162 can_reuse_connection_ = can_reuse_connection;
163 }
164
SetRequestHeadersCallback(RequestHeadersCallback callback)165 void SetRequestHeadersCallback(RequestHeadersCallback callback) override {}
166
167 private:
168 int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
169 void CompleteRead();
170
closed() const171 bool closed() const { return closed_; }
172
173 const raw_ptr<CloseResultWaiter> result_waiter_;
174 scoped_refptr<IOBuffer> user_buf_;
175 CompletionOnceCallback callback_;
176 int buf_len_ = 0;
177 bool closed_ = false;
178 bool stall_reads_forever_ = false;
179 int num_chunks_ = 0;
180 bool is_sync_ = false;
181 bool is_last_chunk_zero_size_ = false;
182 bool is_complete_ = false;
183 bool can_reuse_connection_ = true;
184
185 base::WeakPtrFactory<MockHttpStream> weak_factory_{this};
186 };
187
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)188 int MockHttpStream::ReadResponseBody(IOBuffer* buf,
189 int buf_len,
190 CompletionOnceCallback callback) {
191 CHECK(!callback.is_null());
192 CHECK(callback_.is_null());
193 CHECK(buf);
194
195 if (stall_reads_forever_)
196 return ERR_IO_PENDING;
197
198 if (is_complete_)
199 return ERR_UNEXPECTED;
200
201 if (!is_sync_) {
202 user_buf_ = buf;
203 buf_len_ = buf_len;
204 callback_ = std::move(callback);
205 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
206 FROM_HERE, base::BindOnce(&MockHttpStream::CompleteRead,
207 weak_factory_.GetWeakPtr()));
208 return ERR_IO_PENDING;
209 } else {
210 return ReadResponseBodyImpl(buf, buf_len);
211 }
212 }
213
ReadResponseBodyImpl(IOBuffer * buf,int buf_len)214 int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
215 if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
216 buf_len = 0;
217 } else {
218 if (buf_len > kMagicChunkSize)
219 buf_len = kMagicChunkSize;
220 std::memset(buf->data(), 1, buf_len);
221 }
222 num_chunks_--;
223 if (!num_chunks_)
224 is_complete_ = true;
225
226 return buf_len;
227 }
228
CompleteRead()229 void MockHttpStream::CompleteRead() {
230 int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
231 user_buf_ = nullptr;
232 std::move(callback_).Run(result);
233 }
234
235 class HttpResponseBodyDrainerTest : public TestWithTaskEnvironment {
236 protected:
HttpResponseBodyDrainerTest()237 HttpResponseBodyDrainerTest()
238 : proxy_resolution_service_(
239 ConfiguredProxyResolutionService::CreateDirect()),
240 ssl_config_service_(std::make_unique<SSLConfigServiceDefaults>()),
241 http_server_properties_(std::make_unique<HttpServerProperties>()),
242 session_(CreateNetworkSession()),
243 mock_stream_(new MockHttpStream(&result_waiter_)) {
244 drainer_ = std::make_unique<HttpResponseBodyDrainer>(mock_stream_);
245 }
246
247 ~HttpResponseBodyDrainerTest() override = default;
248
CreateNetworkSession()249 std::unique_ptr<HttpNetworkSession> CreateNetworkSession() {
250 HttpNetworkSessionContext context;
251 context.client_socket_factory = &socket_factory_;
252 context.proxy_resolution_service = proxy_resolution_service_.get();
253 context.ssl_config_service = ssl_config_service_.get();
254 context.http_server_properties = http_server_properties_.get();
255 context.cert_verifier = &cert_verifier_;
256 context.transport_security_state = &transport_security_state_;
257 context.quic_context = &quic_context_;
258 return std::make_unique<HttpNetworkSession>(HttpNetworkSessionParams(),
259 context);
260 }
261
262 std::unique_ptr<ProxyResolutionService> proxy_resolution_service_;
263 std::unique_ptr<SSLConfigService> ssl_config_service_;
264 std::unique_ptr<HttpServerProperties> http_server_properties_;
265 MockCertVerifier cert_verifier_;
266 TransportSecurityState transport_security_state_;
267 QuicContext quic_context_;
268 MockClientSocketFactory socket_factory_;
269 const std::unique_ptr<HttpNetworkSession> session_;
270 CloseResultWaiter result_waiter_;
271 const raw_ptr<MockHttpStream, AcrossTasksDanglingUntriaged>
272 mock_stream_; // Owned by |drainer_|.
273 std::unique_ptr<HttpResponseBodyDrainer> drainer_;
274 };
275
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncSingleOK)276 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
277 mock_stream_->set_num_chunks(1);
278 mock_stream_->set_sync();
279 session_->StartResponseDrainer(std::move(drainer_));
280 EXPECT_FALSE(result_waiter_.WaitForResult());
281 }
282
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncOK)283 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
284 mock_stream_->set_num_chunks(3);
285 mock_stream_->set_sync();
286 session_->StartResponseDrainer(std::move(drainer_));
287 EXPECT_FALSE(result_waiter_.WaitForResult());
288 }
289
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncOK)290 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
291 mock_stream_->set_num_chunks(3);
292 session_->StartResponseDrainer(std::move(drainer_));
293 EXPECT_FALSE(result_waiter_.WaitForResult());
294 }
295
296 // Test the case when the final chunk is 0 bytes. This can happen when
297 // the final 0-byte chunk of a chunk-encoded http response is read in a last
298 // call to ReadResponseBody, after all data were returned from HttpStream.
TEST_F(HttpResponseBodyDrainerTest,DrainBodyAsyncEmptyChunk)299 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
300 mock_stream_->set_num_chunks(4);
301 mock_stream_->set_is_last_chunk_zero_size();
302 session_->StartResponseDrainer(std::move(drainer_));
303 EXPECT_FALSE(result_waiter_.WaitForResult());
304 }
305
TEST_F(HttpResponseBodyDrainerTest,DrainBodySyncEmptyChunk)306 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
307 mock_stream_->set_num_chunks(4);
308 mock_stream_->set_sync();
309 mock_stream_->set_is_last_chunk_zero_size();
310 session_->StartResponseDrainer(std::move(drainer_));
311 EXPECT_FALSE(result_waiter_.WaitForResult());
312 }
313
TEST_F(HttpResponseBodyDrainerTest,DrainBodySizeEqualsDrainBuffer)314 TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
315 mock_stream_->set_num_chunks(
316 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
317 session_->StartResponseDrainer(std::move(drainer_));
318 EXPECT_FALSE(result_waiter_.WaitForResult());
319 }
320
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTimeOut)321 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
322 mock_stream_->set_num_chunks(2);
323 mock_stream_->set_stall_reads_forever();
324 session_->StartResponseDrainer(std::move(drainer_));
325 EXPECT_TRUE(result_waiter_.WaitForResult());
326 }
327
TEST_F(HttpResponseBodyDrainerTest,CancelledBySession)328 TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
329 mock_stream_->set_num_chunks(2);
330 mock_stream_->set_stall_reads_forever();
331 session_->StartResponseDrainer(std::move(drainer_));
332 // HttpNetworkSession should delete |drainer_|.
333 }
334
TEST_F(HttpResponseBodyDrainerTest,DrainBodyTooLarge)335 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
336 int too_many_chunks =
337 HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
338 too_many_chunks += 1; // Now it's too large.
339
340 mock_stream_->set_num_chunks(too_many_chunks);
341 session_->StartResponseDrainer(std::move(drainer_));
342 EXPECT_TRUE(result_waiter_.WaitForResult());
343 }
344
TEST_F(HttpResponseBodyDrainerTest,DrainBodyCantReuse)345 TEST_F(HttpResponseBodyDrainerTest, DrainBodyCantReuse) {
346 mock_stream_->set_num_chunks(1);
347 mock_stream_->set_can_reuse_connection(false);
348 session_->StartResponseDrainer(std::move(drainer_));
349 EXPECT_TRUE(result_waiter_.WaitForResult());
350 }
351
352 } // namespace
353
354 } // namespace net
355