1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/server/web_socket_encoder.h"
6
7 #include <limits>
8 #include <string_view>
9 #include <utility>
10
11 #include "base/check.h"
12 #include "base/memory/ptr_util.h"
13 #include "base/strings/strcat.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "net/base/io_buffer.h"
16 #include "net/websockets/websocket_deflate_parameters.h"
17 #include "net/websockets/websocket_extension.h"
18 #include "net/websockets/websocket_extension_parser.h"
19 #include "net/websockets/websocket_frame.h"
20
21 namespace net {
22
23 const char WebSocketEncoder::kClientExtensions[] =
24 "permessage-deflate; client_max_window_bits";
25
26 namespace {
27
28 const int kInflaterChunkSize = 16 * 1024;
29
30 // Constants for hybi-10 frame format.
31
32 const unsigned char kFinalBit = 0x80;
33 const unsigned char kReserved1Bit = 0x40;
34 const unsigned char kReserved2Bit = 0x20;
35 const unsigned char kReserved3Bit = 0x10;
36 const unsigned char kOpCodeMask = 0xF;
37 const unsigned char kMaskBit = 0x80;
38 const unsigned char kPayloadLengthMask = 0x7F;
39
40 const size_t kMaxSingleBytePayloadLength = 125;
41 const size_t kTwoBytePayloadLengthField = 126;
42 const size_t kEightBytePayloadLengthField = 127;
43 const size_t kMaskingKeyWidthInBytes = 4;
44
DecodeFrameHybi17(std::string_view frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)45 WebSocket::ParseResult DecodeFrameHybi17(std::string_view frame,
46 bool client_frame,
47 int* bytes_consumed,
48 std::string* output,
49 bool* compressed) {
50 size_t data_length = frame.length();
51 if (data_length < 2)
52 return WebSocket::FRAME_INCOMPLETE;
53
54 const char* buffer_begin = const_cast<char*>(frame.data());
55 const char* p = buffer_begin;
56 const char* buffer_end = p + data_length;
57
58 unsigned char first_byte = *p++;
59 unsigned char second_byte = *p++;
60
61 bool final = (first_byte & kFinalBit) != 0;
62 bool reserved1 = (first_byte & kReserved1Bit) != 0;
63 bool reserved2 = (first_byte & kReserved2Bit) != 0;
64 bool reserved3 = (first_byte & kReserved3Bit) != 0;
65 int op_code = first_byte & kOpCodeMask;
66 bool masked = (second_byte & kMaskBit) != 0;
67 *compressed = reserved1;
68 if (reserved2 || reserved3)
69 return WebSocket::FRAME_ERROR; // Only compression extension is supported.
70
71 bool closed = false;
72 switch (op_code) {
73 case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
74 closed = true;
75 break;
76
77 case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
78 case WebSocketFrameHeader::OpCodeEnum::
79 kOpCodeContinuation: // Treated in the same as kOpCodeText.
80 case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
81 case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
82 break;
83
84 case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary: // We don't support
85 // binary frames yet.
86 default:
87 return WebSocket::FRAME_ERROR;
88 }
89
90 if (client_frame && !masked) // In Hybi-17 spec client MUST mask its frame.
91 return WebSocket::FRAME_ERROR;
92
93 uint64_t payload_length64 = second_byte & kPayloadLengthMask;
94 if (payload_length64 > kMaxSingleBytePayloadLength) {
95 int extended_payload_length_size;
96 if (payload_length64 == kTwoBytePayloadLengthField) {
97 extended_payload_length_size = 2;
98 } else {
99 DCHECK(payload_length64 == kEightBytePayloadLengthField);
100 extended_payload_length_size = 8;
101 }
102 if (buffer_end - p < extended_payload_length_size)
103 return WebSocket::FRAME_INCOMPLETE;
104 payload_length64 = 0;
105 for (int i = 0; i < extended_payload_length_size; ++i) {
106 payload_length64 <<= 8;
107 payload_length64 |= static_cast<unsigned char>(*p++);
108 }
109 }
110
111 size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
112 static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
113 static size_t max_length = std::numeric_limits<size_t>::max();
114 if (payload_length64 > max_payload_length ||
115 payload_length64 + actual_masking_key_length > max_length) {
116 // WebSocket frame length too large.
117 return WebSocket::FRAME_ERROR;
118 }
119 size_t payload_length = static_cast<size_t>(payload_length64);
120
121 size_t total_length = actual_masking_key_length + payload_length;
122 if (static_cast<size_t>(buffer_end - p) < total_length)
123 return WebSocket::FRAME_INCOMPLETE;
124
125 if (masked) {
126 output->resize(payload_length);
127 const char* masking_key = p;
128 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
129 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload.
130 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
131 } else {
132 output->assign(p, p + payload_length);
133 }
134
135 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
136 *bytes_consumed = pos;
137
138 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
139 return WebSocket::FRAME_PING;
140
141 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
142 return WebSocket::FRAME_PONG;
143
144 if (closed)
145 return WebSocket::FRAME_CLOSE;
146
147 return final ? WebSocket::FRAME_OK_FINAL : WebSocket::FRAME_OK_MIDDLE;
148 }
149
EncodeFrameHybi17(std::string_view message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)150 void EncodeFrameHybi17(std::string_view message,
151 int masking_key,
152 bool compressed,
153 WebSocketFrameHeader::OpCodeEnum op_code,
154 std::string* output) {
155 std::vector<char> frame;
156 size_t data_length = message.length();
157
158 int reserved1 = compressed ? kReserved1Bit : 0;
159 frame.push_back(kFinalBit | op_code | reserved1);
160 char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
161 if (data_length <= kMaxSingleBytePayloadLength) {
162 frame.push_back(static_cast<char>(data_length) | mask_key_bit);
163 } else if (data_length <= 0xFFFF) {
164 frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
165 frame.push_back((data_length & 0xFF00) >> 8);
166 frame.push_back(data_length & 0xFF);
167 } else {
168 frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
169 char extended_payload_length[8];
170 size_t remaining = data_length;
171 // Fill the length into extended_payload_length in the network byte order.
172 for (int i = 0; i < 8; ++i) {
173 extended_payload_length[7 - i] = remaining & 0xFF;
174 remaining >>= 8;
175 }
176 frame.insert(frame.end(), extended_payload_length,
177 extended_payload_length + 8);
178 DCHECK(!remaining);
179 }
180
181 const char* data = const_cast<char*>(message.data());
182 if (masking_key != 0) {
183 const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
184 frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
185 for (size_t i = 0; i < data_length; ++i) // Mask the payload.
186 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
187 } else {
188 frame.insert(frame.end(), data, data + data_length);
189 }
190 *output = std::string(frame.data(), frame.size());
191 }
192
193 } // anonymous namespace
194
195 // static
CreateServer()196 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
197 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
198 }
199
200 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)201 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
202 const std::string& extensions,
203 WebSocketDeflateParameters* deflate_parameters) {
204 WebSocketExtensionParser parser;
205 if (!parser.Parse(extensions)) {
206 // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
207 // connection.
208 return nullptr;
209 }
210
211 for (const auto& extension : parser.extensions()) {
212 std::string failure_message;
213 WebSocketDeflateParameters offer;
214 if (!offer.Initialize(extension, &failure_message) ||
215 !offer.IsValidAsRequest(&failure_message)) {
216 // We decline unknown / malformed extensions.
217 continue;
218 }
219
220 WebSocketDeflateParameters response = offer;
221 if (offer.is_client_max_window_bits_specified() &&
222 !offer.has_client_max_window_bits_value()) {
223 // We need to choose one value for the response.
224 response.SetClientMaxWindowBits(15);
225 }
226 DCHECK(response.IsValidAsResponse());
227 DCHECK(offer.IsCompatibleWith(response));
228 auto deflater = std::make_unique<WebSocketDeflater>(
229 response.server_context_take_over_mode());
230 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
231 kInflaterChunkSize);
232 if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
233 !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
234 // For some reason we cannot accept the parameters.
235 continue;
236 }
237 *deflate_parameters = response;
238 return base::WrapUnique(new WebSocketEncoder(
239 FOR_SERVER, std::move(deflater), std::move(inflater)));
240 }
241
242 // We cannot find an acceptable offer.
243 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
244 }
245
246 // static
CreateClient(const std::string & response_extensions)247 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
248 const std::string& response_extensions) {
249 // TODO(yhirano): Add a way to return an error.
250
251 WebSocketExtensionParser parser;
252 if (!parser.Parse(response_extensions)) {
253 // Parse error. Note that there are two cases here.
254 // 1) There is no Sec-WebSocket-Extensions header.
255 // 2) There is a malformed Sec-WebSocketExtensions header.
256 // We should return a deflate-disabled encoder for the former case and
257 // fail the connection for the latter case.
258 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
259 }
260 if (parser.extensions().size() != 1) {
261 // Only permessage-deflate extension is supported.
262 // TODO (yhirano): Fail the connection.
263 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
264 }
265 const auto& extension = parser.extensions()[0];
266 WebSocketDeflateParameters params;
267 std::string failure_message;
268 if (!params.Initialize(extension, &failure_message) ||
269 !params.IsValidAsResponse(&failure_message)) {
270 // TODO (yhirano): Fail the connection.
271 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
272 }
273
274 auto deflater = std::make_unique<WebSocketDeflater>(
275 params.client_context_take_over_mode());
276 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
277 kInflaterChunkSize);
278 if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
279 !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
280 // TODO (yhirano): Fail the connection.
281 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
282 }
283
284 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
285 std::move(inflater)));
286 }
287
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)288 WebSocketEncoder::WebSocketEncoder(Type type,
289 std::unique_ptr<WebSocketDeflater> deflater,
290 std::unique_ptr<WebSocketInflater> inflater)
291 : type_(type),
292 deflater_(std::move(deflater)),
293 inflater_(std::move(inflater)) {}
294
295 WebSocketEncoder::~WebSocketEncoder() = default;
296
DecodeFrame(std::string_view frame,int * bytes_consumed,std::string * output)297 WebSocket::ParseResult WebSocketEncoder::DecodeFrame(std::string_view frame,
298 int* bytes_consumed,
299 std::string* output) {
300 bool compressed;
301 std::string current_output;
302 WebSocket::ParseResult result = DecodeFrameHybi17(
303 frame, type_ == FOR_SERVER, bytes_consumed, ¤t_output, &compressed);
304 switch (result) {
305 case WebSocket::FRAME_OK_FINAL:
306 case WebSocket::FRAME_OK_MIDDLE: {
307 if (continuation_message_frames_.empty())
308 is_current_message_compressed_ = compressed;
309 continuation_message_frames_.push_back(current_output);
310
311 if (result == WebSocket::FRAME_OK_FINAL) {
312 *output = base::StrCat(continuation_message_frames_);
313 continuation_message_frames_.clear();
314 if (is_current_message_compressed_ && !Inflate(output)) {
315 return WebSocket::FRAME_ERROR;
316 }
317 }
318 break;
319 }
320
321 case WebSocket::FRAME_PING:
322 *output = current_output;
323 break;
324
325 default:
326 // This function doesn't need special handling for other parse results.
327 break;
328 }
329
330 return result;
331 }
332
EncodeTextFrame(std::string_view frame,int masking_key,std::string * output)333 void WebSocketEncoder::EncodeTextFrame(std::string_view frame,
334 int masking_key,
335 std::string* output) {
336 std::string compressed;
337 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
338 if (Deflate(frame, &compressed))
339 EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
340 else
341 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
342 }
343
EncodeCloseFrame(std::string_view frame,int masking_key,std::string * output)344 void WebSocketEncoder::EncodeCloseFrame(std::string_view frame,
345 int masking_key,
346 std::string* output) {
347 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeClose;
348 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
349 }
350
EncodePongFrame(std::string_view frame,int masking_key,std::string * output)351 void WebSocketEncoder::EncodePongFrame(std::string_view frame,
352 int masking_key,
353 std::string* output) {
354 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
355 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
356 }
357
Inflate(std::string * message)358 bool WebSocketEncoder::Inflate(std::string* message) {
359 if (!inflater_)
360 return false;
361 if (!inflater_->AddBytes(message->data(), message->length()))
362 return false;
363 if (!inflater_->Finish())
364 return false;
365
366 std::vector<char> output;
367 while (inflater_->CurrentOutputSize() > 0) {
368 scoped_refptr<IOBufferWithSize> chunk =
369 inflater_->GetOutput(inflater_->CurrentOutputSize());
370 if (!chunk.get())
371 return false;
372 output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
373 }
374
375 *message =
376 output.size() ? std::string(output.data(), output.size()) : std::string();
377 return true;
378 }
379
Deflate(std::string_view message,std::string * output)380 bool WebSocketEncoder::Deflate(std::string_view message, std::string* output) {
381 if (!deflater_)
382 return false;
383 if (!deflater_->AddBytes(message.data(), message.length())) {
384 deflater_->Finish();
385 return false;
386 }
387 if (!deflater_->Finish())
388 return false;
389 scoped_refptr<IOBufferWithSize> buffer =
390 deflater_->GetOutput(deflater_->CurrentOutputSize());
391 if (!buffer.get())
392 return false;
393 *output = std::string(buffer->data(), buffer->size());
394 return true;
395 }
396
397 } // namespace net
398