xref: /aosp_15_r20/external/cronet/net/test/embedded_test_server/http2_connection.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/http2_connection.h"
6 
7 #include <memory>
8 
9 #include "base/functional/bind.h"
10 #include "base/functional/callback_helpers.h"
11 #include "base/memory/raw_ptr.h"
12 #include "base/memory/raw_ref.h"
13 #include "base/strings/strcat.h"
14 #include "base/strings/string_piece.h"
15 #include "base/task/sequenced_task_runner.h"
16 #include "net/http/http_response_headers.h"
17 #include "net/http/http_status_code.h"
18 #include "net/socket/stream_socket.h"
19 #include "net/ssl/ssl_info.h"
20 #include "net/test/embedded_test_server/embedded_test_server.h"
21 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
22 
23 namespace net {
24 
25 namespace {
26 
GenerateHeaders(HttpStatusCode status,base::StringPairs headers)27 std::vector<http2::adapter::Header> GenerateHeaders(HttpStatusCode status,
28                                                     base::StringPairs headers) {
29   std::vector<http2::adapter::Header> response_vector;
30   response_vector.emplace_back(
31       http2::adapter::HeaderRep(std::string(":status")),
32       http2::adapter::HeaderRep(base::NumberToString(status)));
33   for (const auto& header : headers) {
34     // Connection (and related) headers are considered malformed and will
35     // result in a client error
36     if (base::EqualsCaseInsensitiveASCII(header.first, "connection"))
37       continue;
38     response_vector.emplace_back(
39         http2::adapter::HeaderRep(base::ToLowerASCII(header.first)),
40         http2::adapter::HeaderRep(header.second));
41   }
42 
43   return response_vector;
44 }
45 
46 }  // namespace
47 
48 namespace test_server {
49 
50 class Http2Connection::DataFrameSource
51     : public http2::adapter::DataFrameSource {
52  public:
DataFrameSource(Http2Connection * connection,const StreamId & stream_id)53   explicit DataFrameSource(Http2Connection* connection,
54                            const StreamId& stream_id)
55       : connection_(connection), stream_id_(stream_id) {}
56   ~DataFrameSource() override = default;
57   DataFrameSource(const DataFrameSource&) = delete;
58   DataFrameSource& operator=(const DataFrameSource&) = delete;
59 
SelectPayloadLength(size_t max_length)60   std::pair<int64_t, bool> SelectPayloadLength(size_t max_length) override {
61     if (chunks_.empty())
62       return {kBlocked, last_frame_};
63 
64     bool finished = (chunks_.size() <= 1) &&
65                     (chunks_.front().size() <= max_length) && last_frame_;
66 
67     return {std::min(chunks_.front().size(), max_length), finished};
68   }
69 
Send(std::string_view frame_header,size_t payload_length)70   bool Send(std::string_view frame_header, size_t payload_length) override {
71     std::string concatenated =
72         base::StrCat({frame_header, chunks_.front().substr(0, payload_length)});
73     const int64_t result = connection_->OnReadyToSend(concatenated);
74     // Write encountered error.
75     if (result < 0) {
76       connection_->OnConnectionError(ConnectionError::kSendError);
77       return false;
78     }
79 
80     // Write blocked.
81     if (result == 0) {
82       connection_->blocked_streams_.insert(*stream_id_);
83       return false;
84     }
85 
86     if (static_cast<const size_t>(result) < concatenated.size()) {
87       // Probably need to handle this better within this test class.
88       QUICHE_LOG(DFATAL)
89           << "DATA frame not fully flushed. Connection will be corrupt!";
90       connection_->OnConnectionError(ConnectionError::kSendError);
91       return false;
92     }
93 
94     chunks_.front().erase(0, payload_length);
95 
96     if (chunks_.front().empty())
97       chunks_.pop();
98 
99     if (chunks_.empty() && send_completion_callback_) {
100       std::move(send_completion_callback_).Run();
101     }
102 
103     return true;
104   }
105 
send_fin() const106   bool send_fin() const override { return true; }
107 
AddChunk(std::string chunk)108   void AddChunk(std::string chunk) { chunks_.push(std::move(chunk)); }
set_last_frame(bool last_frame)109   void set_last_frame(bool last_frame) { last_frame_ = last_frame; }
SetSendCompletionCallback(base::OnceClosure callback)110   void SetSendCompletionCallback(base::OnceClosure callback) {
111     send_completion_callback_ = std::move(callback);
112   }
113 
114  private:
115   const raw_ptr<Http2Connection> connection_;
116   const raw_ref<const StreamId, DanglingUntriaged> stream_id_;
117   std::queue<std::string> chunks_;
118   bool last_frame_ = false;
119   base::OnceClosure send_completion_callback_;
120 };
121 
122 // Corresponds to an HTTP/2 stream
123 class Http2Connection::ResponseDelegate : public HttpResponseDelegate {
124  public:
ResponseDelegate(Http2Connection * connection,StreamId stream_id)125   ResponseDelegate(Http2Connection* connection, StreamId stream_id)
126       : stream_id_(stream_id), connection_(connection) {}
127   ~ResponseDelegate() override = default;
128   ResponseDelegate(const ResponseDelegate&) = delete;
129   ResponseDelegate& operator=(const ResponseDelegate&) = delete;
130 
AddResponse(std::unique_ptr<HttpResponse> response)131   void AddResponse(std::unique_ptr<HttpResponse> response) override {
132     responses_.push_back(std::move(response));
133   }
134 
SendResponseHeaders(HttpStatusCode status,const std::string & status_reason,const base::StringPairs & headers)135   void SendResponseHeaders(HttpStatusCode status,
136                            const std::string& status_reason,
137                            const base::StringPairs& headers) override {
138     std::unique_ptr<DataFrameSource> data_frame =
139         std::make_unique<DataFrameSource>(connection_, stream_id_);
140     data_frame_ = data_frame.get();
141     connection_->adapter()->SubmitResponse(
142         stream_id_, GenerateHeaders(status, headers), std::move(data_frame));
143     connection_->SendIfNotProcessing();
144   }
145 
SendRawResponseHeaders(const std::string & headers)146   void SendRawResponseHeaders(const std::string& headers) override {
147     scoped_refptr<HttpResponseHeaders> parsed_headers =
148         HttpResponseHeaders::TryToCreate(headers);
149     if (parsed_headers->response_code() == 0) {
150       connection_->OnConnectionError(ConnectionError::kParseError);
151       LOG(ERROR) << "raw headers could not be parsed";
152     }
153     base::StringPairs header_pairs;
154     size_t iter = 0;
155     std::string key, value;
156     while (parsed_headers->EnumerateHeaderLines(&iter, &key, &value))
157       header_pairs.emplace_back(key, value);
158     SendResponseHeaders(
159         static_cast<HttpStatusCode>(parsed_headers->response_code()),
160         /*status_reason=*/"", header_pairs);
161   }
162 
SendContents(const std::string & contents,base::OnceClosure callback)163   void SendContents(const std::string& contents,
164                     base::OnceClosure callback) override {
165     DCHECK(data_frame_);
166     data_frame_->AddChunk(contents);
167     data_frame_->SetSendCompletionCallback(std::move(callback));
168     connection_->adapter()->ResumeStream(stream_id_);
169     connection_->SendIfNotProcessing();
170   }
171 
FinishResponse()172   void FinishResponse() override {
173     data_frame_->set_last_frame(true);
174     connection_->adapter()->ResumeStream(stream_id_);
175     connection_->SendIfNotProcessing();
176   }
177 
SendContentsAndFinish(const std::string & contents)178   void SendContentsAndFinish(const std::string& contents) override {
179     data_frame_->set_last_frame(true);
180     SendContents(contents, base::DoNothing());
181   }
182 
SendHeadersContentAndFinish(HttpStatusCode status,const std::string & status_reason,const base::StringPairs & headers,const std::string & contents)183   void SendHeadersContentAndFinish(HttpStatusCode status,
184                                    const std::string& status_reason,
185                                    const base::StringPairs& headers,
186                                    const std::string& contents) override {
187     std::unique_ptr<DataFrameSource> data_frame =
188         std::make_unique<DataFrameSource>(connection_, stream_id_);
189     data_frame->AddChunk(contents);
190     data_frame->set_last_frame(true);
191     connection_->adapter()->SubmitResponse(
192         stream_id_, GenerateHeaders(status, headers), std::move(data_frame));
193     connection_->SendIfNotProcessing();
194   }
GetWeakPtr()195   base::WeakPtr<ResponseDelegate> GetWeakPtr() {
196     return weak_factory_.GetWeakPtr();
197   }
198 
199  private:
200   std::vector<std::unique_ptr<HttpResponse>> responses_;
201   StreamId stream_id_;
202   const raw_ptr<Http2Connection> connection_;
203   raw_ptr<DataFrameSource, DanglingUntriaged> data_frame_;
204   base::WeakPtrFactory<ResponseDelegate> weak_factory_{this};
205 };
206 
Http2Connection(std::unique_ptr<StreamSocket> socket,EmbeddedTestServerConnectionListener * connection_listener,EmbeddedTestServer * embedded_test_server)207 Http2Connection::Http2Connection(
208     std::unique_ptr<StreamSocket> socket,
209     EmbeddedTestServerConnectionListener* connection_listener,
210     EmbeddedTestServer* embedded_test_server)
211     : socket_(std::move(socket)),
212       connection_listener_(connection_listener),
213       embedded_test_server_(embedded_test_server),
214       read_buf_(base::MakeRefCounted<IOBufferWithSize>(4096)) {
215   http2::adapter::OgHttp2Adapter::Options options;
216   options.perspective = http2::adapter::Perspective::kServer;
217   adapter_ = http2::adapter::OgHttp2Adapter::Create(*this, options);
218 }
219 
220 Http2Connection::~Http2Connection() = default;
221 
OnSocketReady()222 void Http2Connection::OnSocketReady() {
223   ReadData();
224 }
225 
ReadData()226 void Http2Connection::ReadData() {
227   while (true) {
228     int rv = socket_->Read(
229         read_buf_.get(), read_buf_->size(),
230         base::BindOnce(&Http2Connection::OnDataRead, base::Unretained(this)));
231     if (rv == ERR_IO_PENDING)
232       return;
233     if (!HandleData(rv))
234       return;
235   }
236 }
237 
OnDataRead(int rv)238 void Http2Connection::OnDataRead(int rv) {
239   if (HandleData(rv))
240     ReadData();
241 }
242 
HandleData(int rv)243 bool Http2Connection::HandleData(int rv) {
244   if (rv <= 0) {
245     embedded_test_server_->RemoveConnection(this);
246     return false;
247   }
248 
249   if (connection_listener_)
250     connection_listener_->ReadFromSocket(*socket_, rv);
251 
252   std::string_view remaining_buffer(read_buf_->data(), rv);
253   while (!remaining_buffer.empty()) {
254     int result = adapter_->ProcessBytes(remaining_buffer);
255     if (result < 0)
256       return false;
257     remaining_buffer = remaining_buffer.substr(result);
258   }
259 
260   // Any frames and data sources will be queued up and sent all at once below
261   DCHECK(!processing_responses_);
262   processing_responses_ = true;
263   while (!ready_streams_.empty()) {
264     StreamId stream_id = ready_streams_.front();
265     ready_streams_.pop();
266     auto delegate = std::make_unique<ResponseDelegate>(this, stream_id);
267     ResponseDelegate* delegate_ptr = delegate.get();
268     response_map_[stream_id] = std::move(delegate);
269     embedded_test_server_->HandleRequest(delegate_ptr->GetWeakPtr(),
270                                          std::move(request_map_[stream_id]));
271     request_map_.erase(stream_id);
272   }
273   adapter_->Send();
274   processing_responses_ = false;
275   return true;
276 }
277 
Socket()278 StreamSocket* Http2Connection::Socket() {
279   return socket_.get();
280 }
281 
TakeSocket()282 std::unique_ptr<StreamSocket> Http2Connection::TakeSocket() {
283   return std::move(socket_);
284 }
285 
GetWeakPtr()286 base::WeakPtr<HttpConnection> Http2Connection::GetWeakPtr() {
287   return weak_factory_.GetWeakPtr();
288 }
289 
OnReadyToSend(std::string_view serialized)290 int64_t Http2Connection::OnReadyToSend(std::string_view serialized) {
291   if (write_buf_)
292     return kSendBlocked;
293 
294   write_buf_ = base::MakeRefCounted<DrainableIOBuffer>(
295       base::MakeRefCounted<StringIOBuffer>(std::string(serialized)),
296       serialized.size());
297   SendInternal();
298   return serialized.size();
299 }
300 
OnCloseStream(StreamId stream_id,http2::adapter::Http2ErrorCode error_code)301 bool Http2Connection::OnCloseStream(StreamId stream_id,
302                                     http2::adapter::Http2ErrorCode error_code) {
303   response_map_.erase(stream_id);
304   return true;
305 }
306 
SendInternal()307 void Http2Connection::SendInternal() {
308   DCHECK(socket_);
309   DCHECK(write_buf_);
310   while (write_buf_->BytesRemaining() > 0) {
311     int rv = socket_->Write(write_buf_.get(), write_buf_->BytesRemaining(),
312                             base::BindOnce(&Http2Connection::OnSendInternalDone,
313                                            base::Unretained(this)),
314                             TRAFFIC_ANNOTATION_FOR_TESTS);
315     if (rv == ERR_IO_PENDING)
316       return;
317 
318     if (rv < 0) {
319       embedded_test_server_->RemoveConnection(this);
320       break;
321     }
322 
323     write_buf_->DidConsume(rv);
324   }
325   write_buf_ = nullptr;
326 }
327 
OnSendInternalDone(int rv)328 void Http2Connection::OnSendInternalDone(int rv) {
329   DCHECK(write_buf_);
330   if (rv < 0) {
331     embedded_test_server_->RemoveConnection(this);
332     write_buf_ = nullptr;
333     return;
334   }
335   write_buf_->DidConsume(rv);
336 
337   SendInternal();
338 
339   if (!write_buf_) {
340     // Now that writing is no longer blocked, any blocked streams can be
341     // resumed.
342     for (const auto& stream_id : blocked_streams_)
343       adapter_->ResumeStream(stream_id);
344 
345     if (adapter_->want_write()) {
346       base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
347           FROM_HERE, base::BindOnce(&Http2Connection::SendIfNotProcessing,
348                                     weak_factory_.GetWeakPtr()));
349     }
350   }
351 }
352 
SendIfNotProcessing()353 void Http2Connection::SendIfNotProcessing() {
354   if (!processing_responses_) {
355     processing_responses_ = true;
356     adapter_->Send();
357     processing_responses_ = false;
358   }
359 }
360 
361 http2::adapter::Http2VisitorInterface::OnHeaderResult
OnHeaderForStream(http2::adapter::Http2StreamId stream_id,std::string_view key,std::string_view value)362 Http2Connection::OnHeaderForStream(http2::adapter::Http2StreamId stream_id,
363                                    std::string_view key,
364                                    std::string_view value) {
365   header_map_[stream_id][std::string(key)] = std::string(value);
366   return http2::adapter::Http2VisitorInterface::HEADER_OK;
367 }
368 
OnEndHeadersForStream(http2::adapter::Http2StreamId stream_id)369 bool Http2Connection::OnEndHeadersForStream(
370     http2::adapter::Http2StreamId stream_id) {
371   HttpRequest::HeaderMap header_map = header_map_[stream_id];
372   auto request = std::make_unique<HttpRequest>();
373   // TODO(crbug.com/1375303): Handle proxy cases.
374   request->relative_url = header_map[":path"];
375   request->base_url = GURL(header_map[":authority"]);
376   request->method_string = header_map[":method"];
377   request->method = HttpRequestParser::GetMethodType(request->method_string);
378   request->headers = header_map;
379 
380   request->has_content = false;
381 
382   SSLInfo ssl_info;
383   DCHECK(socket_->GetSSLInfo(&ssl_info));
384   request->ssl_info = ssl_info;
385   request_map_[stream_id] = std::move(request);
386 
387   return true;
388 }
389 
OnEndStream(http2::adapter::Http2StreamId stream_id)390 bool Http2Connection::OnEndStream(http2::adapter::Http2StreamId stream_id) {
391   ready_streams_.push(stream_id);
392   return true;
393 }
394 
OnFrameHeader(StreamId,size_t,uint8_t,uint8_t)395 bool Http2Connection::OnFrameHeader(StreamId /*stream_id*/,
396                                     size_t /*length*/,
397                                     uint8_t /*type*/,
398                                     uint8_t /*flags*/) {
399   return true;
400 }
401 
OnBeginHeadersForStream(StreamId stream_id)402 bool Http2Connection::OnBeginHeadersForStream(StreamId stream_id) {
403   return true;
404 }
405 
OnBeginDataForStream(StreamId stream_id,size_t payload_length)406 bool Http2Connection::OnBeginDataForStream(StreamId stream_id,
407                                            size_t payload_length) {
408   return true;
409 }
410 
OnDataForStream(StreamId stream_id,std::string_view data)411 bool Http2Connection::OnDataForStream(StreamId stream_id,
412                                       std::string_view data) {
413   auto request = request_map_.find(stream_id);
414   if (request == request_map_.end()) {
415     // We should not receive data before receiving headers.
416     return false;
417   }
418 
419   request->second->has_content = true;
420   request->second->content.append(data);
421   adapter_->MarkDataConsumedForStream(stream_id, data.size());
422   return true;
423 }
424 
OnDataPaddingLength(StreamId stream_id,size_t padding_length)425 bool Http2Connection::OnDataPaddingLength(StreamId stream_id,
426                                           size_t padding_length) {
427   adapter_->MarkDataConsumedForStream(stream_id, padding_length);
428   return true;
429 }
430 
OnGoAway(StreamId last_accepted_stream_id,http2::adapter::Http2ErrorCode error_code,std::string_view opaque_data)431 bool Http2Connection::OnGoAway(StreamId last_accepted_stream_id,
432                                http2::adapter::Http2ErrorCode error_code,
433                                std::string_view opaque_data) {
434   return true;
435 }
436 
OnBeforeFrameSent(uint8_t frame_type,StreamId stream_id,size_t length,uint8_t flags)437 int Http2Connection::OnBeforeFrameSent(uint8_t frame_type,
438                                        StreamId stream_id,
439                                        size_t length,
440                                        uint8_t flags) {
441   return 0;
442 }
443 
OnFrameSent(uint8_t frame_type,StreamId stream_id,size_t length,uint8_t flags,uint32_t error_code)444 int Http2Connection::OnFrameSent(uint8_t frame_type,
445                                  StreamId stream_id,
446                                  size_t length,
447                                  uint8_t flags,
448                                  uint32_t error_code) {
449   return 0;
450 }
451 
OnInvalidFrame(StreamId stream_id,InvalidFrameError error)452 bool Http2Connection::OnInvalidFrame(StreamId stream_id,
453                                      InvalidFrameError error) {
454   return true;
455 }
456 
OnMetadataForStream(StreamId stream_id,std::string_view metadata)457 bool Http2Connection::OnMetadataForStream(StreamId stream_id,
458                                           std::string_view metadata) {
459   return true;
460 }
461 
OnMetadataEndForStream(StreamId stream_id)462 bool Http2Connection::OnMetadataEndForStream(StreamId stream_id) {
463   return true;
464 }
465 
466 }  // namespace test_server
467 
468 }  // namespace net
469