xref: /aosp_15_r20/external/ot-br-posix/third_party/Simple-web-server/repo/utility.hpp (revision 4a64e381480ef79f0532b2421e44e6ee336b8e0d)
1 #ifndef SIMPLE_WEB_UTILITY_HPP
2 #define SIMPLE_WEB_UTILITY_HPP
3 
4 #include "status_code.hpp"
5 #include <atomic>
6 #include <chrono>
7 #include <cstdlib>
8 #include <ctime>
9 #include <iostream>
10 #include <memory>
11 #include <mutex>
12 #include <string>
13 #include <unordered_map>
14 
15 #ifndef SW_DEPRECATED
16 #if defined(__GNUC__) || defined(__clang__)
17 #define SW_DEPRECATED __attribute__((deprecated))
18 #elif defined(_MSC_VER)
19 #define SW_DEPRECATED __declspec(deprecated)
20 #else
21 #define SW_DEPRECATED
22 #endif
23 #endif
24 
25 #if __cplusplus > 201402L || _MSVC_LANG > 201402L
26 #include <string_view>
27 namespace SimpleWeb {
28   using string_view = std::string_view;
29 }
30 #elif !defined(ASIO_STANDALONE)
31 #include <boost/utility/string_ref.hpp>
32 namespace SimpleWeb {
33   using string_view = boost::string_ref;
34 }
35 #else
36 namespace SimpleWeb {
37   using string_view = const std::string &;
38 }
39 #endif
40 
41 namespace SimpleWeb {
case_insensitive_equal(const std::string & str1,const std::string & str2)42   inline bool case_insensitive_equal(const std::string &str1, const std::string &str2) noexcept {
43     return str1.size() == str2.size() &&
44            std::equal(str1.begin(), str1.end(), str2.begin(), [](char a, char b) {
45              return tolower(a) == tolower(b);
46            });
47   }
48   class CaseInsensitiveEqual {
49   public:
operator ()(const std::string & str1,const std::string & str2) const50     bool operator()(const std::string &str1, const std::string &str2) const noexcept {
51       return case_insensitive_equal(str1, str2);
52     }
53   };
54   // Based on https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x/2595226#2595226
55   class CaseInsensitiveHash {
56   public:
operator ()(const std::string & str) const57     std::size_t operator()(const std::string &str) const noexcept {
58       std::size_t h = 0;
59       std::hash<int> hash;
60       for(auto c : str)
61         h ^= hash(tolower(c)) + 0x9e3779b9 + (h << 6) + (h >> 2);
62       return h;
63     }
64   };
65 
66   using CaseInsensitiveMultimap = std::unordered_multimap<std::string, std::string, CaseInsensitiveHash, CaseInsensitiveEqual>;
67 
68   /// Percent encoding and decoding
69   class Percent {
70   public:
71     /// Returns percent-encoded string
encode(const std::string & value)72     static std::string encode(const std::string &value) noexcept {
73       static auto hex_chars = "0123456789ABCDEF";
74 
75       std::string result;
76       result.reserve(value.size()); // Minimum size of result
77 
78       for(auto &chr : value) {
79         if(!((chr >= '0' && chr <= '9') || (chr >= 'A' && chr <= 'Z') || (chr >= 'a' && chr <= 'z') || chr == '-' || chr == '.' || chr == '_' || chr == '~'))
80           result += std::string("%") + hex_chars[static_cast<unsigned char>(chr) >> 4] + hex_chars[static_cast<unsigned char>(chr) & 15];
81         else
82           result += chr;
83       }
84 
85       return result;
86     }
87 
88     /// Returns percent-decoded string
decode(const std::string & value)89     static std::string decode(const std::string &value) noexcept {
90       std::string result;
91       result.reserve(value.size() / 3 + (value.size() % 3)); // Minimum size of result
92 
93       for(std::size_t i = 0; i < value.size(); ++i) {
94         auto &chr = value[i];
95         if(chr == '%' && i + 2 < value.size()) {
96           auto hex = value.substr(i + 1, 2);
97           auto decoded_chr = static_cast<char>(std::strtol(hex.c_str(), nullptr, 16));
98           result += decoded_chr;
99           i += 2;
100         }
101         else if(chr == '+')
102           result += ' ';
103         else
104           result += chr;
105       }
106 
107       return result;
108     }
109   };
110 
111   /// Query string creation and parsing
112   class QueryString {
113   public:
114     /// Returns query string created from given field names and values
create(const CaseInsensitiveMultimap & fields)115     static std::string create(const CaseInsensitiveMultimap &fields) noexcept {
116       std::string result;
117 
118       bool first = true;
119       for(auto &field : fields) {
120         result += (!first ? "&" : "") + field.first + '=' + Percent::encode(field.second);
121         first = false;
122       }
123 
124       return result;
125     }
126 
127     /// Returns query keys with percent-decoded values.
parse(const std::string & query_string)128     static CaseInsensitiveMultimap parse(const std::string &query_string) noexcept {
129       CaseInsensitiveMultimap result;
130 
131       if(query_string.empty())
132         return result;
133 
134       std::size_t name_pos = 0;
135       auto name_end_pos = std::string::npos;
136       auto value_pos = std::string::npos;
137       for(std::size_t c = 0; c < query_string.size(); ++c) {
138         if(query_string[c] == '&') {
139           auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? c : name_end_pos) - name_pos);
140           if(!name.empty()) {
141             auto value = value_pos == std::string::npos ? std::string() : query_string.substr(value_pos, c - value_pos);
142             result.emplace(std::move(name), Percent::decode(value));
143           }
144           name_pos = c + 1;
145           name_end_pos = std::string::npos;
146           value_pos = std::string::npos;
147         }
148         else if(query_string[c] == '=' && name_end_pos == std::string::npos) {
149           name_end_pos = c;
150           value_pos = c + 1;
151         }
152       }
153       if(name_pos < query_string.size()) {
154         auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? std::string::npos : name_end_pos - name_pos));
155         if(!name.empty()) {
156           auto value = value_pos >= query_string.size() ? std::string() : query_string.substr(value_pos);
157           result.emplace(std::move(name), Percent::decode(value));
158         }
159       }
160 
161       return result;
162     }
163   };
164 
165   class HttpHeader {
166   public:
167     /// Parse header fields from stream
parse(std::istream & stream)168     static CaseInsensitiveMultimap parse(std::istream &stream) noexcept {
169       CaseInsensitiveMultimap result;
170       std::string line;
171       std::size_t param_end;
172       while(getline(stream, line) && (param_end = line.find(':')) != std::string::npos) {
173         std::size_t value_start = param_end + 1;
174         while(value_start + 1 < line.size() && line[value_start] == ' ')
175           ++value_start;
176         if(value_start < line.size())
177           result.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - (line.back() == '\r' ? 1 : 0)));
178       }
179       return result;
180     }
181 
182     class FieldValue {
183     public:
184       class SemicolonSeparatedAttributes {
185       public:
186         /// Parse Set-Cookie or Content-Disposition from given header field value.
187         /// Attribute values are percent-decoded.
parse(const std::string & value)188         static CaseInsensitiveMultimap parse(const std::string &value) {
189           CaseInsensitiveMultimap result;
190 
191           std::size_t name_start_pos = std::string::npos;
192           std::size_t name_end_pos = std::string::npos;
193           std::size_t value_start_pos = std::string::npos;
194           for(std::size_t c = 0; c < value.size(); ++c) {
195             if(name_start_pos == std::string::npos) {
196               if(value[c] != ' ' && value[c] != ';')
197                 name_start_pos = c;
198             }
199             else {
200               if(name_end_pos == std::string::npos) {
201                 if(value[c] == ';') {
202                   result.emplace(value.substr(name_start_pos, c - name_start_pos), std::string());
203                   name_start_pos = std::string::npos;
204                 }
205                 else if(value[c] == '=')
206                   name_end_pos = c;
207               }
208               else {
209                 if(value_start_pos == std::string::npos) {
210                   if(value[c] == '"' && c + 1 < value.size())
211                     value_start_pos = c + 1;
212                   else
213                     value_start_pos = c;
214                 }
215                 else if(value[c] == '"' || value[c] == ';') {
216                   result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, c - value_start_pos)));
217                   name_start_pos = std::string::npos;
218                   name_end_pos = std::string::npos;
219                   value_start_pos = std::string::npos;
220                 }
221               }
222             }
223           }
224           if(name_start_pos != std::string::npos) {
225             if(name_end_pos == std::string::npos)
226               result.emplace(value.substr(name_start_pos), std::string());
227             else if(value_start_pos != std::string::npos) {
228               if(value.back() == '"')
229                 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, value.size() - 1)));
230               else
231                 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos)));
232             }
233           }
234 
235           return result;
236         }
237       };
238     };
239   };
240 
241   class RequestMessage {
242   public:
243     /** Parse request line and header fields from a request stream.
244      *
245      * @param[in]  stream       Stream to parse.
246      * @param[out] method       HTTP method.
247      * @param[out] path         Path from request URI.
248      * @param[out] query_string Query string from request URI.
249      * @param[out] version      HTTP version.
250      * @param[out] header       Header fields.
251      *
252      * @return True if stream is parsed successfully, false if not.
253      */
parse(std::istream & stream,std::string & method,std::string & path,std::string & query_string,std::string & version,CaseInsensitiveMultimap & header)254     static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) noexcept {
255       std::string line;
256       std::size_t method_end;
257       if(getline(stream, line) && (method_end = line.find(' ')) != std::string::npos) {
258         method = line.substr(0, method_end);
259 
260         std::size_t query_start = std::string::npos;
261         std::size_t path_and_query_string_end = std::string::npos;
262         for(std::size_t i = method_end + 1; i < line.size(); ++i) {
263           if(line[i] == '?' && (i + 1) < line.size() && query_start == std::string::npos)
264             query_start = i + 1;
265           else if(line[i] == ' ') {
266             path_and_query_string_end = i;
267             break;
268           }
269         }
270         if(path_and_query_string_end != std::string::npos) {
271           if(query_start != std::string::npos) {
272             path = line.substr(method_end + 1, query_start - method_end - 2);
273             query_string = line.substr(query_start, path_and_query_string_end - query_start);
274           }
275           else
276             path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1);
277 
278           std::size_t protocol_end;
279           if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) {
280             if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0)
281               return false;
282             version = line.substr(protocol_end + 1, line.size() - protocol_end - 2);
283           }
284           else
285             return false;
286 
287           header = HttpHeader::parse(stream);
288         }
289         else
290           return false;
291       }
292       else
293         return false;
294       return true;
295     }
296   };
297 
298   class ResponseMessage {
299   public:
300     /** Parse status line and header fields from a response stream.
301      *
302      * @param[in]  stream      Stream to parse.
303      * @param[out] version     HTTP version.
304      * @param[out] status_code HTTP status code.
305      * @param[out] header      Header fields.
306      *
307      * @return True if stream is parsed successfully, false if not.
308      */
parse(std::istream & stream,std::string & version,std::string & status_code,CaseInsensitiveMultimap & header)309     static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) noexcept {
310       std::string line;
311       std::size_t version_end;
312       if(getline(stream, line) && (version_end = line.find(' ')) != std::string::npos) {
313         if(5 < line.size())
314           version = line.substr(5, version_end - 5);
315         else
316           return false;
317         if((version_end + 1) < line.size())
318           status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - (line.back() == '\r' ? 1 : 0));
319         else
320           return false;
321 
322         header = HttpHeader::parse(stream);
323       }
324       else
325         return false;
326       return true;
327     }
328   };
329 
330   /// Date class working with formats specified in RFC 7231 Date/Time Formats
331   class Date {
332   public:
333     /// Returns the given std::chrono::system_clock::time_point as a string with the following format: Wed, 31 Jul 2019 11:34:23 GMT.
to_string(const std::chrono::system_clock::time_point time_point)334     static std::string to_string(const std::chrono::system_clock::time_point time_point) noexcept {
335       static std::string result_cache;
336       static std::chrono::system_clock::time_point last_time_point;
337 
338       static std::mutex mutex;
339       std::lock_guard<std::mutex> lock(mutex);
340 
341       if(std::chrono::duration_cast<std::chrono::seconds>(time_point - last_time_point).count() == 0 && !result_cache.empty())
342         return result_cache;
343 
344       last_time_point = time_point;
345 
346       std::string result;
347       result.reserve(29);
348 
349       auto time = std::chrono::system_clock::to_time_t(time_point);
350       tm tm;
351 #if defined(_MSC_VER) || defined(__MINGW32__)
352       if(gmtime_s(&tm, &time) != 0)
353         return {};
354       auto gmtime = &tm;
355 #else
356       auto gmtime = gmtime_r(&time, &tm);
357       if(!gmtime)
358         return {};
359 #endif
360 
361       switch(gmtime->tm_wday) {
362       case 0: result += "Sun, "; break;
363       case 1: result += "Mon, "; break;
364       case 2: result += "Tue, "; break;
365       case 3: result += "Wed, "; break;
366       case 4: result += "Thu, "; break;
367       case 5: result += "Fri, "; break;
368       case 6: result += "Sat, "; break;
369       }
370 
371       result += gmtime->tm_mday < 10 ? '0' : static_cast<char>(gmtime->tm_mday / 10 + 48);
372       result += static_cast<char>(gmtime->tm_mday % 10 + 48);
373 
374       switch(gmtime->tm_mon) {
375       case 0: result += " Jan "; break;
376       case 1: result += " Feb "; break;
377       case 2: result += " Mar "; break;
378       case 3: result += " Apr "; break;
379       case 4: result += " May "; break;
380       case 5: result += " Jun "; break;
381       case 6: result += " Jul "; break;
382       case 7: result += " Aug "; break;
383       case 8: result += " Sep "; break;
384       case 9: result += " Oct "; break;
385       case 10: result += " Nov "; break;
386       case 11: result += " Dec "; break;
387       }
388 
389       auto year = gmtime->tm_year + 1900;
390       result += static_cast<char>(year / 1000 + 48);
391       result += static_cast<char>((year / 100) % 10 + 48);
392       result += static_cast<char>((year / 10) % 10 + 48);
393       result += static_cast<char>(year % 10 + 48);
394       result += ' ';
395 
396       result += gmtime->tm_hour < 10 ? '0' : static_cast<char>(gmtime->tm_hour / 10 + 48);
397       result += static_cast<char>(gmtime->tm_hour % 10 + 48);
398       result += ':';
399 
400       result += gmtime->tm_min < 10 ? '0' : static_cast<char>(gmtime->tm_min / 10 + 48);
401       result += static_cast<char>(gmtime->tm_min % 10 + 48);
402       result += ':';
403 
404       result += gmtime->tm_sec < 10 ? '0' : static_cast<char>(gmtime->tm_sec / 10 + 48);
405       result += static_cast<char>(gmtime->tm_sec % 10 + 48);
406 
407       result += " GMT";
408 
409       result_cache = result;
410       return result;
411     }
412   };
413 } // namespace SimpleWeb
414 
415 #ifdef __SSE2__
416 #include <emmintrin.h>
417 namespace SimpleWeb {
spin_loop_pause()418   inline void spin_loop_pause() noexcept { _mm_pause(); }
419 } // namespace SimpleWeb
420 // TODO: need verification that the following checks are correct:
421 #elif defined(_MSC_VER) && _MSC_VER >= 1800 && (defined(_M_X64) || defined(_M_IX86))
422 #include <intrin.h>
423 namespace SimpleWeb {
spin_loop_pause()424   inline void spin_loop_pause() noexcept { _mm_pause(); }
425 } // namespace SimpleWeb
426 #else
427 namespace SimpleWeb {
spin_loop_pause()428   inline void spin_loop_pause() noexcept {}
429 } // namespace SimpleWeb
430 #endif
431 
432 namespace SimpleWeb {
433   /// Makes it possible to for instance cancel Asio handlers without stopping asio::io_service.
434   class ScopeRunner {
435     /// Scope count that is set to -1 if scopes are to be canceled.
436     std::atomic<long> count;
437 
438   public:
439     class SharedLock {
440       friend class ScopeRunner;
441       std::atomic<long> &count;
SharedLock(std::atomic<long> & count)442       SharedLock(std::atomic<long> &count) noexcept : count(count) {}
443       SharedLock &operator=(const SharedLock &) = delete;
444       SharedLock(const SharedLock &) = delete;
445 
446     public:
~SharedLock()447       ~SharedLock() noexcept {
448         count.fetch_sub(1);
449       }
450     };
451 
ScopeRunner()452     ScopeRunner() noexcept : count(0) {}
453 
454     /// Returns nullptr if scope should be exited, or a shared lock otherwise.
455     /// The shared lock ensures that a potential destructor call is delayed until all locks are released.
continue_lock()456     std::unique_ptr<SharedLock> continue_lock() noexcept {
457       long expected = count;
458       while(expected >= 0 && !count.compare_exchange_weak(expected, expected + 1))
459         spin_loop_pause();
460 
461       if(expected < 0)
462         return nullptr;
463       else
464         return std::unique_ptr<SharedLock>(new SharedLock(count));
465     }
466 
467     /// Blocks until all shared locks are released, then prevents future shared locks.
stop()468     void stop() noexcept {
469       long expected = 0;
470       while(!count.compare_exchange_weak(expected, -1)) {
471         if(expected < 0)
472           return;
473         expected = 0;
474         spin_loop_pause();
475       }
476     }
477   };
478 } // namespace SimpleWeb
479 
480 #endif // SIMPLE_WEB_UTILITY_HPP
481