// Copyright 2023 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_http3_handshake_stream.h" #include #include #include "base/check.h" #include "base/check_op.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/memory/scoped_refptr.h" #include "base/strings/strcat.h" #include "base/strings/stringprintf.h" #include "base/time/time.h" #include "net/base/ip_endpoint.h" #include "net/http/http_request_headers.h" #include "net/http/http_request_info.h" #include "net/http/http_response_headers.h" #include "net/http/http_response_info.h" #include "net/http/http_status_code.h" #include "net/spdy/spdy_http_utils.h" #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "net/websockets/websocket_basic_stream.h" #include "net/websockets/websocket_deflate_predictor_impl.h" #include "net/websockets/websocket_deflate_stream.h" #include "net/websockets/websocket_handshake_constants.h" #include "net/websockets/websocket_handshake_request_info.h" namespace net { struct AlternativeService; namespace { bool ValidateStatus(const HttpResponseHeaders* headers) { return headers->GetStatusLine() == "HTTP/1.1 200"; } } // namespace WebSocketHttp3HandshakeStream::WebSocketHttp3HandshakeStream( std::unique_ptr session, WebSocketStream::ConnectDelegate* connect_delegate, std::vector requested_sub_protocols, std::vector requested_extensions, WebSocketStreamRequestAPI* request, std::set dns_aliases) : session_(std::move(session)), connect_delegate_(connect_delegate), requested_sub_protocols_(std::move(requested_sub_protocols)), requested_extensions_(std::move(requested_extensions)), stream_request_(request), dns_aliases_(std::move(dns_aliases)) { DCHECK(connect_delegate); DCHECK(request); } WebSocketHttp3HandshakeStream::~WebSocketHttp3HandshakeStream() { RecordHandshakeResult(result_); } void WebSocketHttp3HandshakeStream::RegisterRequest( const HttpRequestInfo* request_info) { DCHECK(request_info); DCHECK(request_info->traffic_annotation.is_valid()); request_info_ = request_info; } int WebSocketHttp3HandshakeStream::InitializeStream( bool can_send_early, RequestPriority priority, const NetLogWithSource& net_log, CompletionOnceCallback callback) { priority_ = priority; net_log_ = net_log; request_time_ = base::Time::Now(); return OK; } int WebSocketHttp3HandshakeStream::SendRequest( const HttpRequestHeaders& request_headers, HttpResponseInfo* response, CompletionOnceCallback callback) { DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketKey)); DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketProtocol)); DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketExtensions)); DCHECK(request_headers.HasHeader(HttpRequestHeaders::kOrigin)); DCHECK(request_headers.HasHeader(websockets::kUpgrade)); DCHECK(request_headers.HasHeader(HttpRequestHeaders::kConnection)); DCHECK(request_headers.HasHeader(websockets::kSecWebSocketVersion)); if (!session_) { constexpr int rv = ERR_CONNECTION_CLOSED; OnFailure("Connection closed before sending request.", rv, std::nullopt); return rv; } http_response_info_ = response; IPEndPoint address; int result = session_->GetPeerAddress(&address); if (result != OK) { OnFailure("Error getting IP address.", result, std::nullopt); return result; } http_response_info_->remote_endpoint = address; auto request = std::make_unique( request_info_->url, base::Time::Now()); request->headers = request_headers; AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, requested_extensions_, &request->headers); AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, requested_sub_protocols_, &request->headers); CreateSpdyHeadersFromHttpRequestForWebSocket( request_info_->url, request->headers, &http3_request_headers_); connect_delegate_->OnStartOpeningHandshake(std::move(request)); callback_ = std::move(callback); std::unique_ptr stream_adapter = session_->CreateWebSocketQuicStreamAdapter( this, base::BindOnce( &WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest, base::Unretained(this)), NetworkTrafficAnnotationTag(request_info_->traffic_annotation)); if (!stream_adapter) { return ERR_IO_PENDING; } ReceiveAdapterAndStartRequest(std::move(stream_adapter)); return OK; } int WebSocketHttp3HandshakeStream::ReadResponseHeaders( CompletionOnceCallback callback) { if (stream_closed_) { return stream_error_; } if (response_headers_complete_) { return ValidateResponse(); } callback_ = std::move(callback); return ERR_IO_PENDING; } // TODO(momoka): Implement this. int WebSocketHttp3HandshakeStream::ReadResponseBody( IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { return OK; } void WebSocketHttp3HandshakeStream::Close(bool not_reusable) { if (stream_adapter_) { stream_adapter_->Disconnect(); stream_closed_ = true; stream_error_ = ERR_CONNECTION_CLOSED; } } // TODO(momoka): Implement this. bool WebSocketHttp3HandshakeStream::IsResponseBodyComplete() const { return false; } // TODO(momoka): Implement this. bool WebSocketHttp3HandshakeStream::IsConnectionReused() const { return true; } // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::SetConnectionReused() {} // TODO(momoka): Implement this. bool WebSocketHttp3HandshakeStream::CanReuseConnection() const { return false; } // TODO(momoka): Implement this. int64_t WebSocketHttp3HandshakeStream::GetTotalReceivedBytes() const { return 0; } // TODO(momoka): Implement this. int64_t WebSocketHttp3HandshakeStream::GetTotalSentBytes() const { return 0; } // TODO(momoka): Implement this. bool WebSocketHttp3HandshakeStream::GetAlternativeService( AlternativeService* alternative_service) const { return false; } // TODO(momoka): Implement this. bool WebSocketHttp3HandshakeStream::GetLoadTimingInfo( LoadTimingInfo* load_timing_info) const { return false; } // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {} // TODO(momoka): Implement this. int WebSocketHttp3HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) { return 0; } // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::Drain(HttpNetworkSession* session) {} // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::SetPriority(RequestPriority priority) {} // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::PopulateNetErrorDetails( NetErrorDetails* details) {} // TODO(momoka): Implement this. std::unique_ptr WebSocketHttp3HandshakeStream::RenewStreamForAuth() { return nullptr; } // TODO(momoka): Implement this. const std::set& WebSocketHttp3HandshakeStream::GetDnsAliases() const { return dns_aliases_; } // TODO(momoka): Implement this. std::string_view WebSocketHttp3HandshakeStream::GetAcceptChViaAlps() const { return {}; } // WebSocketHandshakeStreamBase methods. // TODO(momoka): Implement this. std::unique_ptr WebSocketHttp3HandshakeStream::Upgrade() { DCHECK(extension_params_.get()); stream_adapter_->clear_delegate(); std::unique_ptr basic_stream = std::make_unique(std::move(stream_adapter_), nullptr, sub_protocol_, extensions_, net_log_); if (!extension_params_->deflate_enabled) { return basic_stream; } return std::make_unique( std::move(basic_stream), extension_params_->deflate_parameters, std::make_unique()); } bool WebSocketHttp3HandshakeStream::CanReadFromStream() const { return stream_adapter_ && stream_adapter_->is_initialized(); } base::WeakPtr WebSocketHttp3HandshakeStream::GetWeakPtr() { return weak_ptr_factory_.GetWeakPtr(); } void WebSocketHttp3HandshakeStream::OnHeadersSent() { std::move(callback_).Run(OK); } void WebSocketHttp3HandshakeStream::OnHeadersReceived( const spdy::Http2HeaderBlock& response_headers) { DCHECK(!response_headers_complete_); DCHECK(http_response_info_); response_headers_complete_ = true; const int rv = SpdyHeadersToHttpResponse(response_headers, http_response_info_); DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS); // Do not store SSLInfo in the response here, HttpNetworkTransaction will take // care of that part. http_response_info_->was_alpn_negotiated = true; http_response_info_->response_time = base::Time::Now(); http_response_info_->request_time = request_time_; http_response_info_->connection_info = HttpConnectionInfo::kHTTP2; http_response_info_->alpn_negotiated_protocol = HttpConnectionInfoToString(http_response_info_->connection_info); if (callback_) { std::move(callback_).Run(ValidateResponse()); } } void WebSocketHttp3HandshakeStream::OnClose(int status) { DCHECK(stream_adapter_); DCHECK_GT(ERR_IO_PENDING, status); stream_closed_ = true; stream_error_ = status; stream_adapter_.reset(); // If response headers have already been received, // then ValidateResponse() sets `result_`. if (!response_headers_complete_) { result_ = HandshakeResult::HTTP3_FAILED; } OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}), status, std::nullopt); if (callback_) { std::move(callback_).Run(status); } } void WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest( std::unique_ptr adapter) { stream_adapter_ = std::move(adapter); // WriteHeaders returns synchronously. stream_adapter_->WriteHeaders(std::move(http3_request_headers_), false); } int WebSocketHttp3HandshakeStream::ValidateResponse() { DCHECK(http_response_info_); const HttpResponseHeaders* headers = http_response_info_->headers.get(); const int response_code = headers->response_code(); switch (response_code) { case HTTP_OK: return ValidateUpgradeResponse(headers); // We need to pass these through for authentication to work. case HTTP_UNAUTHORIZED: case HTTP_PROXY_AUTHENTICATION_REQUIRED: return OK; // Other status codes are potentially risky (see the warnings in the // WHATWG WebSocket API spec) and so are dropped by default. default: OnFailure( base::StringPrintf( "Error during WebSocket handshake: Unexpected response code: %d", headers->response_code()), ERR_FAILED, headers->response_code()); result_ = HandshakeResult::HTTP3_INVALID_STATUS; return ERR_INVALID_RESPONSE; } } int WebSocketHttp3HandshakeStream::ValidateUpgradeResponse( const HttpResponseHeaders* headers) { extension_params_ = std::make_unique(); std::string failure_message; if (!ValidateStatus(headers)) { result_ = HandshakeResult::HTTP3_INVALID_STATUS; } else if (!ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_, &failure_message)) { result_ = HandshakeResult::HTTP3_FAILED_SUBPROTO; } else if (!ValidateExtensions(headers, &extensions_, &failure_message, extension_params_.get())) { result_ = HandshakeResult::HTTP3_FAILED_EXTENSIONS; } else { result_ = HandshakeResult::HTTP3_CONNECTED; return OK; } const int rv = ERR_INVALID_RESPONSE; OnFailure("Error during WebSocket handshake: " + failure_message, rv, std::nullopt); return rv; } // TODO(momoka): Implement this. void WebSocketHttp3HandshakeStream::OnFailure( const std::string& message, int net_error, std::optional response_code) { stream_request_->OnFailure(message, net_error, response_code); } } // namespace net