1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdint.h>
18 
19 #include <deque>
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <thread>
24 
25 #include <android-base/logging.h>
26 #include <android-base/stringprintf.h>
27 #include <android-base/thread_annotations.h>
28 
29 #include "adb_unique_fd.h"
30 #include "adb_utils.h"
31 #include "sysdeps.h"
32 #include "transport.h"
33 #include "types.h"
34 
CreateWakeFds(unique_fd * read,unique_fd * write)35 static void CreateWakeFds(unique_fd* read, unique_fd* write) {
36     // TODO: eventfd on linux?
37     int wake_fds[2];
38     int rc = adb_socketpair(wake_fds);
39     set_file_block_mode(wake_fds[0], false);
40     set_file_block_mode(wake_fds[1], false);
41     CHECK_EQ(0, rc);
42     *read = unique_fd(wake_fds[0]);
43     *write = unique_fd(wake_fds[1]);
44 }
45 
46 struct NonblockingFdConnection : public Connection {
NonblockingFdConnectionNonblockingFdConnection47     NonblockingFdConnection(unique_fd fd) : started_(false), fd_(std::move(fd)) {
48         set_file_block_mode(fd_.get(), false);
49         CreateWakeFds(&wake_fd_read_, &wake_fd_write_);
50     }
51 
SetRunningNonblockingFdConnection52     void SetRunning(bool value) {
53         std::lock_guard<std::mutex> lock(run_mutex_);
54         running_ = value;
55     }
56 
IsRunningNonblockingFdConnection57     bool IsRunning() {
58         std::lock_guard<std::mutex> lock(run_mutex_);
59         return running_;
60     }
61 
RunNonblockingFdConnection62     void Run(std::string* error) {
63         SetRunning(true);
64         while (IsRunning()) {
65             adb_pollfd pfds[2] = {
66                 {.fd = fd_.get(), .events = POLLIN},
67                 {.fd = wake_fd_read_.get(), .events = POLLIN},
68             };
69 
70             {
71                 std::lock_guard<std::mutex> lock(this->write_mutex_);
72                 if (!writable_) {
73                     pfds[0].events |= POLLOUT;
74                 }
75             }
76 
77             int rc = adb_poll(pfds, 2, -1);
78             if (rc == -1) {
79                 *error = android::base::StringPrintf("poll failed: %s", strerror(errno));
80                 return;
81             } else if (rc == 0) {
82                 LOG(FATAL) << "poll timed out with an infinite timeout?";
83             }
84 
85             if (pfds[0].revents) {
86                 if ((pfds[0].revents & POLLOUT)) {
87                     std::lock_guard<std::mutex> lock(this->write_mutex_);
88                     if (DispatchWrites() == WriteResult::Error) {
89                         *error = "write failed";
90                         return;
91                     }
92                 }
93 
94                 if (pfds[0].revents & POLLIN) {
95                     // TODO: Should we be getting blocks from a free list?
96                     auto block = IOVector::block_type(MAX_PAYLOAD);
97                     rc = adb_read(fd_.get(), &block[0], block.size());
98                     if (rc == -1) {
99                         *error = std::string("read failed: ") + strerror(errno);
100                         return;
101                     } else if (rc == 0) {
102                         *error = "read failed: EOF";
103                         return;
104                     }
105                     block.resize(rc);
106                     read_buffer_.append(std::move(block));
107 
108                     if (!read_header_ && read_buffer_.size() >= sizeof(amessage)) {
109                         auto header_buf = read_buffer_.take_front(sizeof(amessage)).coalesce();
110                         CHECK_EQ(sizeof(amessage), header_buf.size());
111                         read_header_ = std::make_unique<amessage>();
112                         memcpy(read_header_.get(), header_buf.data(), sizeof(amessage));
113                     }
114 
115                     if (read_header_ && read_buffer_.size() >= read_header_->data_length) {
116                         auto data_chain = read_buffer_.take_front(read_header_->data_length);
117 
118                         // TODO: Make apacket carry around a IOVector instead of coalescing.
119                         auto payload = std::move(data_chain).coalesce();
120                         auto packet = std::make_unique<apacket>();
121                         packet->msg = *read_header_;
122                         packet->payload = std::move(payload);
123                         read_header_ = nullptr;
124                         transport_->HandleRead(std::move(packet));
125                     }
126                 }
127             }
128 
129             if (pfds[1].revents) {
130                 uint64_t buf;
131                 rc = adb_read(wake_fd_read_.get(), &buf, sizeof(buf));
132                 CHECK_EQ(static_cast<int>(sizeof(buf)), rc);
133 
134                 // We were woken up either to add POLLOUT to our events, or to exit.
135                 // Do nothing.
136             }
137         }
138     }
139 
StartNonblockingFdConnection140     bool Start() override final {
141         if (started_.exchange(true)) {
142             LOG(FATAL) << "Connection started multiple times?";
143         }
144 
145         thread_ = std::thread([this]() {
146             std::string error = "connection closed";
147             Run(&error);
148             transport_->HandleError(error);
149         });
150         return true;
151     }
152 
StopNonblockingFdConnection153     void Stop() override final {
154         SetRunning(false);
155         WakeThread();
156         thread_.join();
157     }
158 
DoTlsHandshakeNonblockingFdConnection159     bool DoTlsHandshake(RSA* key, std::string* auth_key) override final {
160         LOG(FATAL) << "Not supported yet";
161         return false;
162     }
163 
WakeThreadNonblockingFdConnection164     void WakeThread() {
165         uint64_t buf = 0;
166         if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) {
167             LOG(FATAL) << "failed to wake up thread";
168         }
169     }
170 
171     enum class WriteResult {
172         Error,
173         Completed,
174         TryAgain,
175     };
176 
DispatchWritesNonblockingFdConnection177     WriteResult DispatchWrites() REQUIRES(write_mutex_) {
178         CHECK(!write_buffer_.empty());
179         auto iovs = write_buffer_.iovecs();
180         ssize_t rc = adb_writev(fd_.get(), iovs.data(), iovs.size());
181         if (rc == -1) {
182             if (errno == EAGAIN || errno == EWOULDBLOCK) {
183                 writable_ = false;
184                 return WriteResult::TryAgain;
185             }
186 
187             return WriteResult::Error;
188         } else if (rc == 0) {
189             errno = 0;
190             return WriteResult::Error;
191         }
192 
193         write_buffer_.drop_front(rc);
194         writable_ = write_buffer_.empty();
195         if (write_buffer_.empty()) {
196             return WriteResult::Completed;
197         }
198 
199         // There's data left in the range, which means our write returned early.
200         return WriteResult::TryAgain;
201     }
202 
WriteNonblockingFdConnection203     bool Write(std::unique_ptr<apacket> packet) final {
204         std::lock_guard<std::mutex> lock(write_mutex_);
205         const char* header_begin = reinterpret_cast<const char*>(&packet->msg);
206         const char* header_end = header_begin + sizeof(packet->msg);
207         auto header_block = IOVector::block_type(header_begin, header_end);
208         write_buffer_.append(std::move(header_block));
209         if (!packet->payload.empty()) {
210             write_buffer_.append(std::move(packet->payload));
211         }
212 
213         WriteResult result = DispatchWrites();
214         if (result == WriteResult::TryAgain) {
215             WakeThread();
216         }
217         return result != WriteResult::Error;
218     }
219 
220     std::thread thread_;
221 
222     std::atomic<bool> started_;
223     std::mutex run_mutex_;
224     bool running_ GUARDED_BY(run_mutex_);
225 
226     std::unique_ptr<amessage> read_header_;
227     IOVector read_buffer_;
228 
229     unique_fd fd_;
230     unique_fd wake_fd_read_;
231     unique_fd wake_fd_write_;
232 
233     std::mutex write_mutex_;
234     bool writable_ GUARDED_BY(write_mutex_) = true;
235     IOVector write_buffer_ GUARDED_BY(write_mutex_);
236 
237     IOVector incoming_queue_;
238 };
239 
FromFd(unique_fd fd)240 std::unique_ptr<Connection> Connection::FromFd(unique_fd fd) {
241     return std::make_unique<NonblockingFdConnection>(std::move(fd));
242 }
243