1*ec63e07aSXin Li // Copyright 2020 Google LLC
2*ec63e07aSXin Li //
3*ec63e07aSXin Li // Licensed under the Apache License, Version 2.0 (the "License");
4*ec63e07aSXin Li // you may not use this file except in compliance with the License.
5*ec63e07aSXin Li // You may obtain a copy of the License at
6*ec63e07aSXin Li //
7*ec63e07aSXin Li // https://www.apache.org/licenses/LICENSE-2.0
8*ec63e07aSXin Li //
9*ec63e07aSXin Li // Unless required by applicable law or agreed to in writing, software
10*ec63e07aSXin Li // distributed under the License is distributed on an "AS IS" BASIS,
11*ec63e07aSXin Li // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*ec63e07aSXin Li // See the License for the specific language governing permissions and
13*ec63e07aSXin Li // limitations under the License.
14*ec63e07aSXin Li
15*ec63e07aSXin Li #include "test_utils.h" // NOLINT(build/include)
16*ec63e07aSXin Li
17*ec63e07aSXin Li #include <absl/strings/match.h>
18*ec63e07aSXin Li #include <fcntl.h>
19*ec63e07aSXin Li #include <netdb.h>
20*ec63e07aSXin Li #include <netinet/in.h>
21*ec63e07aSXin Li #include <sys/socket.h>
22*ec63e07aSXin Li #include <sys/types.h>
23*ec63e07aSXin Li #include <unistd.h>
24*ec63e07aSXin Li
25*ec63e07aSXin Li #include <memory>
26*ec63e07aSXin Li #include <thread> // NOLINT(build/c++11)
27*ec63e07aSXin Li
28*ec63e07aSXin Li #include "absl/status/statusor.h"
29*ec63e07aSXin Li #include "sandboxed_api/util/status_macros.h"
30*ec63e07aSXin Li
31*ec63e07aSXin Li namespace curl::tests {
32*ec63e07aSXin Li
33*ec63e07aSXin Li int CurlTestUtils::port_;
34*ec63e07aSXin Li
35*ec63e07aSXin Li std::thread CurlTestUtils::server_thread_;
36*ec63e07aSXin Li
CurlTestSetUp()37*ec63e07aSXin Li absl::Status CurlTestUtils::CurlTestSetUp() {
38*ec63e07aSXin Li // Initialize sandbox2 and SAPI
39*ec63e07aSXin Li sandbox_ = std::make_unique<curl::CurlSapiSandbox>();
40*ec63e07aSXin Li SAPI_RETURN_IF_ERROR(sandbox_->Init());
41*ec63e07aSXin Li api_ = std::make_unique<curl::CurlApi>(sandbox_.get());
42*ec63e07aSXin Li
43*ec63e07aSXin Li // Initialize curl
44*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(curl::CURL * curl_handle, api_->curl_easy_init());
45*ec63e07aSXin Li if (!curl_handle) {
46*ec63e07aSXin Li return absl::UnavailableError("curl_easy_init returned nullptr");
47*ec63e07aSXin Li }
48*ec63e07aSXin Li curl_ = std::make_unique<sapi::v::RemotePtr>(curl_handle);
49*ec63e07aSXin Li
50*ec63e07aSXin Li int curl_code = 0;
51*ec63e07aSXin Li
52*ec63e07aSXin Li // Specify request URL
53*ec63e07aSXin Li sapi::v::ConstCStr sapi_url(kUrl);
54*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(
55*ec63e07aSXin Li curl_code, api_->curl_easy_setopt_ptr(curl_.get(), curl::CURLOPT_URL,
56*ec63e07aSXin Li sapi_url.PtrBefore()));
57*ec63e07aSXin Li if (curl_code != curl::CURLE_OK) {
58*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat(
59*ec63e07aSXin Li "curl_easy_setopt_ptr returned with the error code ", curl_code));
60*ec63e07aSXin Li }
61*ec63e07aSXin Li
62*ec63e07aSXin Li // Set port
63*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(curl_code, api_->curl_easy_setopt_long(
64*ec63e07aSXin Li curl_.get(), curl::CURLOPT_PORT, port_));
65*ec63e07aSXin Li if (curl_code != curl::CURLE_OK) {
66*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat(
67*ec63e07aSXin Li "curl_easy_setopt_long returned with the error code ", curl_code));
68*ec63e07aSXin Li }
69*ec63e07aSXin Li
70*ec63e07aSXin Li // Generate pointer to the WriteToMemory callback
71*ec63e07aSXin Li void* function_ptr;
72*ec63e07aSXin Li SAPI_RETURN_IF_ERROR(
73*ec63e07aSXin Li sandbox_->rpc_channel()->Symbol("WriteToMemory", &function_ptr));
74*ec63e07aSXin Li sapi::v::RemotePtr remote_function_ptr(function_ptr);
75*ec63e07aSXin Li
76*ec63e07aSXin Li // Set WriteToMemory as the write function
77*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(curl_code, api_->curl_easy_setopt_ptr(
78*ec63e07aSXin Li curl_.get(), curl::CURLOPT_WRITEFUNCTION,
79*ec63e07aSXin Li &remote_function_ptr));
80*ec63e07aSXin Li if (curl_code != curl::CURLE_OK) {
81*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat(
82*ec63e07aSXin Li "curl_easy_setopt_ptr returned with the error code ", curl_code));
83*ec63e07aSXin Li }
84*ec63e07aSXin Li
85*ec63e07aSXin Li // Pass memory chunk object to the callback
86*ec63e07aSXin Li chunk_ = std::make_unique<sapi::v::LenVal>(0);
87*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(
88*ec63e07aSXin Li curl_code, api_->curl_easy_setopt_ptr(
89*ec63e07aSXin Li curl_.get(), curl::CURLOPT_WRITEDATA, chunk_->PtrBoth()));
90*ec63e07aSXin Li if (curl_code != curl::CURLE_OK) {
91*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat(
92*ec63e07aSXin Li "curl_easy_setopt_ptr returned with the error code ", curl_code));
93*ec63e07aSXin Li }
94*ec63e07aSXin Li
95*ec63e07aSXin Li return absl::OkStatus();
96*ec63e07aSXin Li }
97*ec63e07aSXin Li
CurlTestTearDown()98*ec63e07aSXin Li absl::Status CurlTestUtils::CurlTestTearDown() {
99*ec63e07aSXin Li // Cleanup curl
100*ec63e07aSXin Li return api_->curl_easy_cleanup(curl_.get());
101*ec63e07aSXin Li }
102*ec63e07aSXin Li
PerformRequest()103*ec63e07aSXin Li absl::StatusOr<std::string> CurlTestUtils::PerformRequest() {
104*ec63e07aSXin Li // Perform the request
105*ec63e07aSXin Li SAPI_ASSIGN_OR_RETURN(int curl_code, api_->curl_easy_perform(curl_.get()));
106*ec63e07aSXin Li if (curl_code != curl::CURLE_OK) {
107*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat(
108*ec63e07aSXin Li "curl_easy_perform returned with the error code ", curl_code));
109*ec63e07aSXin Li }
110*ec63e07aSXin Li
111*ec63e07aSXin Li // Get pointer to the memory chunk
112*ec63e07aSXin Li SAPI_RETURN_IF_ERROR(sandbox_->TransferFromSandboxee(chunk_.get()));
113*ec63e07aSXin Li return std::string(reinterpret_cast<char*>(chunk_->GetData()));
114*ec63e07aSXin Li }
115*ec63e07aSXin Li
116*ec63e07aSXin Li namespace {
117*ec63e07aSXin Li
118*ec63e07aSXin Li // Read the socket until str is completely read
ReadUntil(const int socket,const std::string & str,const size_t max_request_size)119*ec63e07aSXin Li std::string ReadUntil(const int socket, const std::string& str,
120*ec63e07aSXin Li const size_t max_request_size) {
121*ec63e07aSXin Li std::string str_read;
122*ec63e07aSXin Li str_read.reserve(max_request_size);
123*ec63e07aSXin Li
124*ec63e07aSXin Li // Read one char at a time until str is suffix of buf
125*ec63e07aSXin Li while (!absl::EndsWith(str_read, str)) {
126*ec63e07aSXin Li char next_char;
127*ec63e07aSXin Li if (str_read.size() >= max_request_size ||
128*ec63e07aSXin Li read(socket, &next_char, 1) < 1) {
129*ec63e07aSXin Li return "";
130*ec63e07aSXin Li }
131*ec63e07aSXin Li str_read += next_char;
132*ec63e07aSXin Li }
133*ec63e07aSXin Li
134*ec63e07aSXin Li return str_read;
135*ec63e07aSXin Li }
136*ec63e07aSXin Li
137*ec63e07aSXin Li // Parse HTTP headers to return the Content-Length
GetContentLength(const std::string & headers)138*ec63e07aSXin Li ssize_t GetContentLength(const std::string& headers) {
139*ec63e07aSXin Li constexpr char kContentLength[] = "Content-Length: ";
140*ec63e07aSXin Li // Find the Content-Length header
141*ec63e07aSXin Li const auto length_header_start = headers.find(kContentLength);
142*ec63e07aSXin Li
143*ec63e07aSXin Li // There is no Content-Length field
144*ec63e07aSXin Li if (length_header_start == std::string::npos) {
145*ec63e07aSXin Li return 0;
146*ec63e07aSXin Li }
147*ec63e07aSXin Li
148*ec63e07aSXin Li // Find Content-Length string
149*ec63e07aSXin Li const auto length_start = length_header_start + strlen(kContentLength);
150*ec63e07aSXin Li const auto length_bytes = headers.find("\r\n", length_start) - length_start;
151*ec63e07aSXin Li
152*ec63e07aSXin Li // length_bytes exceeds maximum
153*ec63e07aSXin Li if (length_bytes >= 64) {
154*ec63e07aSXin Li return -1;
155*ec63e07aSXin Li }
156*ec63e07aSXin Li
157*ec63e07aSXin Li // Convert string to int and return
158*ec63e07aSXin Li return std::stoi(headers.substr(length_start, length_bytes));
159*ec63e07aSXin Li }
160*ec63e07aSXin Li
161*ec63e07aSXin Li // Read exactly content_bytes from the socket
ReadExact(int socket,size_t content_bytes)162*ec63e07aSXin Li std::string ReadExact(int socket, size_t content_bytes) {
163*ec63e07aSXin Li std::string str_read;
164*ec63e07aSXin Li str_read.reserve(content_bytes);
165*ec63e07aSXin Li
166*ec63e07aSXin Li // Read one char at a time until all chars are read
167*ec63e07aSXin Li while (str_read.size() < content_bytes) {
168*ec63e07aSXin Li char next_char;
169*ec63e07aSXin Li if (read(socket, &next_char, 1) < 1) {
170*ec63e07aSXin Li return "";
171*ec63e07aSXin Li }
172*ec63e07aSXin Li str_read += next_char;
173*ec63e07aSXin Li }
174*ec63e07aSXin Li
175*ec63e07aSXin Li return str_read;
176*ec63e07aSXin Li }
177*ec63e07aSXin Li
178*ec63e07aSXin Li // Listen on the socket and answer back to requests
ServerLoop(int listening_socket,sockaddr_in socket_address)179*ec63e07aSXin Li void ServerLoop(int listening_socket, sockaddr_in socket_address) {
180*ec63e07aSXin Li socklen_t socket_address_size = sizeof(socket_address);
181*ec63e07aSXin Li
182*ec63e07aSXin Li // Listen on the socket (maximum 1 connection)
183*ec63e07aSXin Li if (listen(listening_socket, 1) == -1) {
184*ec63e07aSXin Li return;
185*ec63e07aSXin Li }
186*ec63e07aSXin Li
187*ec63e07aSXin Li // Keep accepting connections until the thread is terminated
188*ec63e07aSXin Li // (i.e. server_thread_ is assigned to a new thread or destroyed)
189*ec63e07aSXin Li for (;;) {
190*ec63e07aSXin Li // File descriptor to the connection socket
191*ec63e07aSXin Li // This blocks the thread until a connection is established
192*ec63e07aSXin Li int accepted_socket =
193*ec63e07aSXin Li accept(listening_socket, reinterpret_cast<sockaddr*>(&socket_address),
194*ec63e07aSXin Li reinterpret_cast<socklen_t*>(&socket_address_size));
195*ec63e07aSXin Li if (accepted_socket == -1) {
196*ec63e07aSXin Li return;
197*ec63e07aSXin Li }
198*ec63e07aSXin Li
199*ec63e07aSXin Li constexpr int kMaxRequestSize = 4096;
200*ec63e07aSXin Li
201*ec63e07aSXin Li // Read until the end of the headers
202*ec63e07aSXin Li std::string headers =
203*ec63e07aSXin Li ReadUntil(accepted_socket, "\r\n\r\n", kMaxRequestSize);
204*ec63e07aSXin Li
205*ec63e07aSXin Li if (headers == "") {
206*ec63e07aSXin Li close(accepted_socket);
207*ec63e07aSXin Li return;
208*ec63e07aSXin Li }
209*ec63e07aSXin Li
210*ec63e07aSXin Li // Get the length of the request content
211*ec63e07aSXin Li ssize_t content_length = GetContentLength(headers);
212*ec63e07aSXin Li if (content_length > kMaxRequestSize - headers.size() ||
213*ec63e07aSXin Li content_length < 0) {
214*ec63e07aSXin Li close(accepted_socket);
215*ec63e07aSXin Li return;
216*ec63e07aSXin Li }
217*ec63e07aSXin Li
218*ec63e07aSXin Li // Read the request content
219*ec63e07aSXin Li std::string content = ReadExact(accepted_socket, content_length);
220*ec63e07aSXin Li
221*ec63e07aSXin Li // Prepare a response for the request
222*ec63e07aSXin Li std::string http_response =
223*ec63e07aSXin Li "HTTP/1.1 200 OK\nContent-Type: text/plain\nContent-Length: ";
224*ec63e07aSXin Li
225*ec63e07aSXin Li if (headers.substr(0, 3) == "GET") {
226*ec63e07aSXin Li http_response += "2\r\n\r\nOK";
227*ec63e07aSXin Li
228*ec63e07aSXin Li } else if (headers.substr(0, 4) == "POST") {
229*ec63e07aSXin Li http_response +=
230*ec63e07aSXin Li std::to_string(content.size()) + "\r\n\r\n" + std::string{content};
231*ec63e07aSXin Li
232*ec63e07aSXin Li } else {
233*ec63e07aSXin Li close(accepted_socket);
234*ec63e07aSXin Li return;
235*ec63e07aSXin Li }
236*ec63e07aSXin Li
237*ec63e07aSXin Li // Ignore any errors, the connection will be closed anyway
238*ec63e07aSXin Li write(accepted_socket, http_response.c_str(), http_response.size());
239*ec63e07aSXin Li
240*ec63e07aSXin Li // Close the socket
241*ec63e07aSXin Li close(accepted_socket);
242*ec63e07aSXin Li }
243*ec63e07aSXin Li }
244*ec63e07aSXin Li
245*ec63e07aSXin Li } // namespace
246*ec63e07aSXin Li
StartMockServer()247*ec63e07aSXin Li void CurlTestUtils::StartMockServer() {
248*ec63e07aSXin Li // Get the socket file descriptor
249*ec63e07aSXin Li int listening_socket = socket(AF_INET, SOCK_STREAM, 0);
250*ec63e07aSXin Li
251*ec63e07aSXin Li // Create the socket address object
252*ec63e07aSXin Li // The port is set to 0, meaning that it will be auto assigned
253*ec63e07aSXin Li // Only local connections can access this socket
254*ec63e07aSXin Li sockaddr_in socket_address{AF_INET, 0, htonl(INADDR_LOOPBACK)};
255*ec63e07aSXin Li socklen_t socket_address_size = sizeof(socket_address);
256*ec63e07aSXin Li if (listening_socket == -1) {
257*ec63e07aSXin Li return;
258*ec63e07aSXin Li }
259*ec63e07aSXin Li
260*ec63e07aSXin Li // Bind the file descriptor to the socket address object
261*ec63e07aSXin Li if (bind(listening_socket, reinterpret_cast<sockaddr*>(&socket_address),
262*ec63e07aSXin Li socket_address_size) == -1) {
263*ec63e07aSXin Li return;
264*ec63e07aSXin Li }
265*ec63e07aSXin Li
266*ec63e07aSXin Li // Assign an available port to the socket address object
267*ec63e07aSXin Li if (getsockname(listening_socket,
268*ec63e07aSXin Li reinterpret_cast<sockaddr*>(&socket_address),
269*ec63e07aSXin Li &socket_address_size) == -1) {
270*ec63e07aSXin Li return;
271*ec63e07aSXin Li }
272*ec63e07aSXin Li
273*ec63e07aSXin Li // Get the port number
274*ec63e07aSXin Li port_ = ntohs(socket_address.sin_port);
275*ec63e07aSXin Li
276*ec63e07aSXin Li // Set server_thread_ operation to socket listening
277*ec63e07aSXin Li server_thread_ = std::thread(ServerLoop, listening_socket, socket_address);
278*ec63e07aSXin Li }
279*ec63e07aSXin Li
280*ec63e07aSXin Li } // namespace curl::tests
281