1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/transport_connect_sub_job.h"
6
7 #include <set>
8 #include <string>
9 #include <utility>
10
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/notreached.h"
14 #include "net/base/ip_endpoint.h"
15 #include "net/base/net_errors.h"
16 #include "net/log/net_log_with_source.h"
17 #include "net/socket/client_socket_factory.h"
18 #include "net/socket/connection_attempts.h"
19 #include "net/socket/socket_performance_watcher.h"
20 #include "net/socket/socket_performance_watcher_factory.h"
21 #include "net/socket/websocket_endpoint_lock_manager.h"
22
23 namespace net {
24
25 namespace {
26
27 // StreamSocket wrapper that registers/unregisters the wrapped StreamSocket with
28 // a WebSocketEndpointLockManager on creation/destruction.
29 class WebSocketStreamSocket final : public StreamSocket {
30 public:
WebSocketStreamSocket(std::unique_ptr<StreamSocket> wrapped_socket,WebSocketEndpointLockManager * websocket_endpoint_lock_manager,const IPEndPoint & address)31 WebSocketStreamSocket(
32 std::unique_ptr<StreamSocket> wrapped_socket,
33 WebSocketEndpointLockManager* websocket_endpoint_lock_manager,
34 const IPEndPoint& address)
35 : wrapped_socket_(std::move(wrapped_socket)),
36 lock_releaser_(websocket_endpoint_lock_manager, address) {}
37
38 WebSocketStreamSocket(const WebSocketStreamSocket&) = delete;
39 WebSocketStreamSocket& operator=(const WebSocketStreamSocket&) = delete;
40
41 ~WebSocketStreamSocket() override = default;
42
43 // Socket implementation:
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)44 int Read(IOBuffer* buf,
45 int buf_len,
46 CompletionOnceCallback callback) override {
47 return wrapped_socket_->Read(buf, buf_len, std::move(callback));
48 }
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)49 int ReadIfReady(IOBuffer* buf,
50 int buf_len,
51 CompletionOnceCallback callback) override {
52 return wrapped_socket_->ReadIfReady(buf, buf_len, std::move(callback));
53 }
CancelReadIfReady()54 int CancelReadIfReady() override {
55 return wrapped_socket_->CancelReadIfReady();
56 }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)57 int Write(IOBuffer* buf,
58 int buf_len,
59 CompletionOnceCallback callback,
60 const NetworkTrafficAnnotationTag& traffic_annotation) override {
61 return wrapped_socket_->Write(buf, buf_len, std::move(callback),
62 traffic_annotation);
63 }
SetReceiveBufferSize(int32_t size)64 int SetReceiveBufferSize(int32_t size) override {
65 return wrapped_socket_->SetReceiveBufferSize(size);
66 }
SetSendBufferSize(int32_t size)67 int SetSendBufferSize(int32_t size) override {
68 return wrapped_socket_->SetSendBufferSize(size);
69 }
SetDnsAliases(std::set<std::string> aliases)70 void SetDnsAliases(std::set<std::string> aliases) override {
71 wrapped_socket_->SetDnsAliases(std::move(aliases));
72 }
GetDnsAliases() const73 const std::set<std::string>& GetDnsAliases() const override {
74 return wrapped_socket_->GetDnsAliases();
75 }
76
77 // StreamSocket implementation:
Connect(CompletionOnceCallback callback)78 int Connect(CompletionOnceCallback callback) override {
79 return wrapped_socket_->Connect(std::move(callback));
80 }
Disconnect()81 void Disconnect() override { wrapped_socket_->Disconnect(); }
IsConnected() const82 bool IsConnected() const override { return wrapped_socket_->IsConnected(); }
IsConnectedAndIdle() const83 bool IsConnectedAndIdle() const override {
84 return wrapped_socket_->IsConnectedAndIdle();
85 }
GetPeerAddress(IPEndPoint * address) const86 int GetPeerAddress(IPEndPoint* address) const override {
87 return wrapped_socket_->GetPeerAddress(address);
88 }
GetLocalAddress(IPEndPoint * address) const89 int GetLocalAddress(IPEndPoint* address) const override {
90 return wrapped_socket_->GetLocalAddress(address);
91 }
NetLog() const92 const NetLogWithSource& NetLog() const override {
93 return wrapped_socket_->NetLog();
94 }
WasEverUsed() const95 bool WasEverUsed() const override { return wrapped_socket_->WasEverUsed(); }
GetNegotiatedProtocol() const96 NextProto GetNegotiatedProtocol() const override {
97 return wrapped_socket_->GetNegotiatedProtocol();
98 }
GetSSLInfo(SSLInfo * ssl_info)99 bool GetSSLInfo(SSLInfo* ssl_info) override {
100 return wrapped_socket_->GetSSLInfo(ssl_info);
101 }
GetTotalReceivedBytes() const102 int64_t GetTotalReceivedBytes() const override {
103 return wrapped_socket_->GetTotalReceivedBytes();
104 }
ApplySocketTag(const SocketTag & tag)105 void ApplySocketTag(const SocketTag& tag) override {
106 wrapped_socket_->ApplySocketTag(tag);
107 }
108
109 private:
110 std::unique_ptr<StreamSocket> wrapped_socket_;
111 WebSocketEndpointLockManager::LockReleaser lock_releaser_;
112 };
113
114 } // namespace
115
TransportConnectSubJob(std::vector<IPEndPoint> addresses,TransportConnectJob * parent_job,SubJobType type)116 TransportConnectSubJob::TransportConnectSubJob(
117 std::vector<IPEndPoint> addresses,
118 TransportConnectJob* parent_job,
119 SubJobType type)
120 : parent_job_(parent_job), addresses_(std::move(addresses)), type_(type) {}
121
122 TransportConnectSubJob::~TransportConnectSubJob() = default;
123
124 // Start connecting.
Start()125 int TransportConnectSubJob::Start() {
126 DCHECK_EQ(STATE_NONE, next_state_);
127 next_state_ = STATE_OBTAIN_LOCK;
128 return DoLoop(OK);
129 }
130
131 // Called by WebSocketEndpointLockManager when the lock becomes available.
GotEndpointLock()132 void TransportConnectSubJob::GotEndpointLock() {
133 DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
134 OnIOComplete(OK);
135 }
136
GetLoadState() const137 LoadState TransportConnectSubJob::GetLoadState() const {
138 switch (next_state_) {
139 case STATE_OBTAIN_LOCK:
140 case STATE_OBTAIN_LOCK_COMPLETE:
141 // TODO(ricea): Add a WebSocket-specific LOAD_STATE ?
142 return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
143 case STATE_TRANSPORT_CONNECT_COMPLETE:
144 case STATE_DONE:
145 return LOAD_STATE_CONNECTING;
146 case STATE_NONE:
147 return LOAD_STATE_IDLE;
148 }
149 NOTREACHED();
150 return LOAD_STATE_IDLE;
151 }
152
CurrentAddress() const153 const IPEndPoint& TransportConnectSubJob::CurrentAddress() const {
154 DCHECK_LT(current_address_index_, addresses_.size());
155 return addresses_[current_address_index_];
156 }
157
OnIOComplete(int result)158 void TransportConnectSubJob::OnIOComplete(int result) {
159 int rv = DoLoop(result);
160 if (rv != ERR_IO_PENDING)
161 parent_job_->OnSubJobComplete(rv, this); // |this| deleted
162 }
163
DoLoop(int result)164 int TransportConnectSubJob::DoLoop(int result) {
165 DCHECK_NE(next_state_, STATE_NONE);
166
167 int rv = result;
168 do {
169 State state = next_state_;
170 next_state_ = STATE_NONE;
171 switch (state) {
172 case STATE_OBTAIN_LOCK:
173 DCHECK_EQ(OK, rv);
174 rv = DoEndpointLock();
175 break;
176 case STATE_OBTAIN_LOCK_COMPLETE:
177 DCHECK_EQ(OK, rv);
178 rv = DoEndpointLockComplete();
179 break;
180 case STATE_TRANSPORT_CONNECT_COMPLETE:
181 rv = DoTransportConnectComplete(rv);
182 break;
183 default:
184 NOTREACHED();
185 rv = ERR_FAILED;
186 break;
187 }
188 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE &&
189 next_state_ != STATE_DONE);
190
191 return rv;
192 }
193
DoEndpointLock()194 int TransportConnectSubJob::DoEndpointLock() {
195 next_state_ = STATE_OBTAIN_LOCK_COMPLETE;
196 if (!parent_job_->websocket_endpoint_lock_manager()) {
197 return OK;
198 }
199 return parent_job_->websocket_endpoint_lock_manager()->LockEndpoint(
200 CurrentAddress(), this);
201 }
202
DoEndpointLockComplete()203 int TransportConnectSubJob::DoEndpointLockComplete() {
204 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
205 AddressList one_address(CurrentAddress());
206
207 // Create a `SocketPerformanceWatcher`, and pass the ownership.
208 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher;
209 if (auto* factory = parent_job_->socket_performance_watcher_factory();
210 factory != nullptr) {
211 socket_performance_watcher = factory->CreateSocketPerformanceWatcher(
212 SocketPerformanceWatcherFactory::PROTOCOL_TCP,
213 CurrentAddress().address());
214 }
215
216 const NetLogWithSource& net_log = parent_job_->net_log();
217 transport_socket_ =
218 parent_job_->client_socket_factory()->CreateTransportClientSocket(
219 one_address, std::move(socket_performance_watcher),
220 parent_job_->network_quality_estimator(), net_log.net_log(),
221 net_log.source());
222
223 net_log.AddEvent(NetLogEventType::TRANSPORT_CONNECT_JOB_CONNECT_ATTEMPT, [&] {
224 auto dict = base::Value::Dict().Set("address", CurrentAddress().ToString());
225 transport_socket_->NetLog().source().AddToEventParameters(dict);
226 return dict;
227 });
228
229 // If `websocket_endpoint_lock_manager_` is non-null, this class now owns an
230 // endpoint lock. Wrap `socket` in a `WebSocketStreamSocket` to take ownership
231 // of the lock and release it when the socket goes out of scope. This must
232 // happen before any early returns in this method.
233 if (parent_job_->websocket_endpoint_lock_manager()) {
234 transport_socket_ = std::make_unique<WebSocketStreamSocket>(
235 std::move(transport_socket_),
236 parent_job_->websocket_endpoint_lock_manager(), CurrentAddress());
237 }
238
239 transport_socket_->ApplySocketTag(parent_job_->socket_tag());
240
241 // This use of base::Unretained() is safe because transport_socket_ is
242 // destroyed in the destructor.
243 return transport_socket_->Connect(base::BindOnce(
244 &TransportConnectSubJob::OnIOComplete, base::Unretained(this)));
245 }
246
DoTransportConnectComplete(int result)247 int TransportConnectSubJob::DoTransportConnectComplete(int result) {
248 next_state_ = STATE_DONE;
249 if (result != OK) {
250 // Drop the socket to release the endpoint lock, if any.
251 transport_socket_.reset();
252
253 parent_job_->connection_attempts_.push_back(
254 ConnectionAttempt(CurrentAddress(), result));
255
256 // Don't try the next address if entering suspend mode.
257 if (result != ERR_NETWORK_IO_SUSPENDED &&
258 current_address_index_ + 1 < addresses_.size()) {
259 // Try falling back to the next address in the list.
260 next_state_ = STATE_OBTAIN_LOCK;
261 ++current_address_index_;
262 result = OK;
263 }
264
265 return result;
266 }
267
268 return result;
269 }
270
271 } // namespace net
272