1 /*
2  * Copyright 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "test_channel_transport.h"
18 
19 #include <errno.h>   // for errno, EBADF
20 #include <stddef.h>  // for size_t
21 
22 #include <cstdint>      // for uint8_t
23 #include <cstring>      // for strerror
24 #include <type_traits>  // for remove_extent_t
25 
26 #include "log.h"
27 #include "net/async_data_channel.h"  // for AsyncDataChannel
28 
29 using std::vector;
30 
31 namespace rootcanal {
32 
SetUp(std::shared_ptr<AsyncDataChannelServer> server,ConnectCallback connection_callback)33 bool TestChannelTransport::SetUp(std::shared_ptr<AsyncDataChannelServer> server,
34                                  ConnectCallback connection_callback) {
35   socket_server_ = server;
36   socket_server_->SetOnConnectCallback(connection_callback);
37   socket_server_->StartListening();
38   return socket_server_ != nullptr;
39 }
40 
CleanUp()41 void TestChannelTransport::CleanUp() {
42   socket_server_->StopListening();
43   socket_server_->Close();
44 }
45 
OnCommandReady(AsyncDataChannel * socket,std::function<void (void)> unwatch)46 void TestChannelTransport::OnCommandReady(AsyncDataChannel* socket,
47                                           std::function<void(void)> unwatch) {
48   uint8_t command_name_size = 0;
49   ssize_t bytes_read = socket->Recv(&command_name_size, 1);
50   if (bytes_read != 1) {
51     INFO("Unexpected (command_name_size) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
52     socket->Close();
53   }
54   vector<uint8_t> command_name_raw;
55   command_name_raw.resize(command_name_size);
56   bytes_read = socket->Recv(command_name_raw.data(), command_name_size);
57   if (bytes_read != command_name_size) {
58     INFO("Unexpected (command_name) bytes_read: {} != {}, {}", bytes_read, command_name_size,
59          strerror(errno));
60   }
61   std::string command_name(command_name_raw.begin(), command_name_raw.end());
62 
63   if (command_name == "CLOSE_TEST_CHANNEL" || command_name.empty()) {
64     INFO("Test channel closed");
65     unwatch();
66     socket->Close();
67     return;
68   }
69 
70   uint8_t num_args = 0;
71   bytes_read = socket->Recv(&num_args, 1);
72   if (bytes_read != 1) {
73     INFO("Unexpected (num_args) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
74   }
75   vector<std::string> args;
76   for (uint8_t i = 0; i < num_args; ++i) {
77     uint8_t arg_size = 0;
78     bytes_read = socket->Recv(&arg_size, 1);
79     if (bytes_read != 1) {
80       INFO("Unexpected (arg_size) bytes_read: {} != {}, {}", bytes_read, 1, strerror(errno));
81     }
82     vector<uint8_t> arg;
83     arg.resize(arg_size);
84     bytes_read = socket->Recv(arg.data(), arg_size);
85     if (bytes_read != arg_size) {
86       INFO("Unexpected (arg) bytes_read: {} != {}, {}", bytes_read, arg_size, strerror(errno));
87     }
88     args.push_back(std::string(arg.begin(), arg.end()));
89   }
90 
91   command_handler_(command_name, args);
92 }
93 
SendResponse(std::shared_ptr<AsyncDataChannel> socket,const std::string & response)94 void TestChannelTransport::SendResponse(std::shared_ptr<AsyncDataChannel> socket,
95                                         const std::string& response) {
96   size_t size = response.size();
97   // Cap to 64K
98   if (size > 0xffff) {
99     size = 0xffff;
100   }
101   uint8_t size_buf[4] = {
102           static_cast<uint8_t>(size & 0xff), static_cast<uint8_t>((size >> 8) & 0xff),
103           static_cast<uint8_t>((size >> 16) & 0xff), static_cast<uint8_t>((size >> 24) & 0xff)};
104   ssize_t written = socket->Send(size_buf, 4);
105   if (written == -1 && errno == EBADF) {
106     WARNING("Unable to send a response.  EBADF");
107     return;
108   }
109   ASSERT_LOG(written == 4, "What happened? written = %zd errno = %d", written, errno);
110   written = socket->Send(reinterpret_cast<const uint8_t*>(response.c_str()), size);
111   ASSERT_LOG(written == static_cast<int>(size), "What happened? written = %zd errno = %d", written,
112              errno);
113 }
114 
RegisterCommandHandler(const std::function<void (const std::string &,const std::vector<std::string> &)> & callback)115 void TestChannelTransport::RegisterCommandHandler(
116         const std::function<void(const std::string&, const std::vector<std::string>&)>& callback) {
117   command_handler_ = callback;
118 }
119 
120 }  // namespace rootcanal
121