1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <atomic>
16 #include <cstdint>
17 #include <fstream>
18 #include <limits>
19 #include <map>
20 #include <memory>
21 #include <numeric>
22 #include <queue>
23 #include <set>
24 #include <string>
25 #include <thread>
26 #include <vector>
27
28 #include <openssl/sha.h>
29
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/optional.h"
34 #include "absl/types/variant.h"
35
36 #include "src/core/ext/transport/chttp2/transport/huffsyms.h"
37
38 ///////////////////////////////////////////////////////////////////////////////
39 // SHA256 hash handling
40 // We need strong uniqueness checks of some very long strings - so we hash
41 // them with SHA256 and compare.
42 struct Hash {
43 uint8_t bytes[SHA256_DIGEST_LENGTH];
operator ==Hash44 bool operator==(const Hash& other) const {
45 return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) == 0;
46 }
operator <Hash47 bool operator<(const Hash& other) const {
48 return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) < 0;
49 }
ToStringHash50 std::string ToString() const {
51 std::string result;
52 for (int i = 0; i < SHA256_DIGEST_LENGTH; i++) {
53 absl::StrAppend(&result, absl::Hex(bytes[i], absl::kZeroPad2));
54 }
55 return result;
56 }
57 };
58
59 // Given a vector of ints (T), return a Hash object with the sha256
60 template <typename T>
HashVec(const std::vector<T> & v)61 Hash HashVec(const std::vector<T>& v) {
62 Hash h;
63 SHA256(reinterpret_cast<const uint8_t*>(v.data()), v.size() * sizeof(T),
64 h.bytes);
65 return h;
66 }
67
68 ///////////////////////////////////////////////////////////////////////////////
69 // BitQueue
70 // A utility that treats a sequence of bits like a queue
71 class BitQueue {
72 public:
BitQueue(unsigned mask,int len)73 BitQueue(unsigned mask, int len) : mask_(mask), len_(len) {}
BitQueue()74 BitQueue() : BitQueue(0, 0) {}
75
76 // Return the most significant bit (the front of the queue)
Front() const77 int Front() const { return (mask_ >> (len_ - 1)) & 1; }
78 // Pop one bit off the queue
Pop()79 void Pop() {
80 mask_ &= ~(1 << (len_ - 1));
81 len_--;
82 }
Empty() const83 bool Empty() const { return len_ == 0; }
length() const84 int length() const { return len_; }
mask() const85 unsigned mask() const { return mask_; }
86
87 // Text representation of the queue
ToString() const88 std::string ToString() const {
89 return absl::StrCat(absl::Hex(mask_), "/", len_);
90 }
91
92 // Comparisons so that we can use BitQueue as a key in a std::map
operator <(const BitQueue & other) const93 bool operator<(const BitQueue& other) const {
94 return std::tie(mask_, len_) < std::tie(other.mask_, other.len_);
95 }
96
97 private:
98 // The bits
99 unsigned mask_;
100 // How many bits have we
101 int len_;
102 };
103
104 ///////////////////////////////////////////////////////////////////////////////
105 // Symbol sets for the huffman tree
106
107 // A Sym is one symbol in the tree, and the bits that we need to read to decode
108 // that symbol. As we progress through decoding we remove bits from the symbol,
109 // but also condense the number of symbols we're considering.
110 struct Sym {
111 BitQueue bits;
112 int symbol;
113
operator <Sym114 bool operator<(const Sym& other) const {
115 return std::tie(bits, symbol) < std::tie(other.bits, other.symbol);
116 }
117 };
118
119 // A SymSet is all the symbols we're considering at some time
120 using SymSet = std::vector<Sym>;
121
122 // Debug utility to turn a SymSet into a string
SymSetString(const SymSet & syms)123 std::string SymSetString(const SymSet& syms) {
124 std::vector<std::string> parts;
125 for (const Sym& sym : syms) {
126 parts.push_back(absl::StrCat(sym.symbol, ":", sym.bits.ToString()));
127 }
128 return absl::StrJoin(parts, ",");
129 }
130
131 // Initial SymSet - all the symbols [0..256] with their bits initialized from
132 // the http2 static huffman tree.
AllSyms()133 SymSet AllSyms() {
134 SymSet syms;
135 for (int i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) {
136 Sym sym;
137 sym.bits =
138 BitQueue(grpc_chttp2_huffsyms[i].bits, grpc_chttp2_huffsyms[i].length);
139 sym.symbol = i;
140 syms.push_back(sym);
141 }
142 return syms;
143 }
144
145 // What whould we do after reading a set of bits?
146 struct ReadActions {
147 // Emit these symbols
148 std::vector<int> emit;
149 // Number of bits that were consumed by the read
150 int consumed;
151 // Remaining SymSet that we need to consider on the next read action
152 SymSet remaining;
153 };
154
155 // Given a SymSet \a pending, read through the bits in \a index and determine
156 // what actions the decoder should take.
157 // allow_multiple controls the behavior should we get to the last bit in pending
158 // and hence know which symbol to emit, but we still have bits in index.
159 // We could either start decoding the next symbol (allow_multiple == true), or
160 // we could stop (allow_multiple == false).
161 // If allow_multiple is true we tend to emit more per read op, but generate
162 // bigger tables.
ActionsFor(BitQueue index,SymSet pending,bool allow_multiple)163 ReadActions ActionsFor(BitQueue index, SymSet pending, bool allow_multiple) {
164 std::vector<int> emit;
165 int len_start = index.length();
166 int len_consume = len_start;
167
168 // We read one bit in index at a time, so whilst we have bits...
169 while (!index.Empty()) {
170 SymSet next_pending;
171 // For each symbol in the pending set
172 for (auto sym : pending) {
173 // If the first bit doesn't match, then that symbol is not part of our
174 // remaining set.
175 if (sym.bits.Front() != index.Front()) continue;
176 sym.bits.Pop();
177 next_pending.push_back(sym);
178 }
179 switch (next_pending.size()) {
180 case 0:
181 // There should be no bit patterns that are undecodable.
182 abort();
183 case 1:
184 // If we have one symbol left, we need to have decoded all of it.
185 if (!next_pending[0].bits.Empty()) abort();
186 // Emit that symbol
187 emit.push_back(next_pending[0].symbol);
188 // Track how many bits we've read.
189 len_consume = index.length() - 1;
190 // If we allow multiple, reprime pending and continue, otherwise stop.
191 if (!allow_multiple) goto done;
192 pending = AllSyms();
193 break;
194 default:
195 pending = std::move(next_pending);
196 break;
197 }
198 // Finished with this bit, continue with next
199 index.Pop();
200 }
201 done:
202 return ReadActions{std::move(emit), len_start - len_consume, pending};
203 }
204
205 ///////////////////////////////////////////////////////////////////////////////
206 // MatchCase
207 // A variant that helps us bunch together related ReadActions
208
209 // A Matched in a MatchCase indicates that we need to emit some number of
210 // symbols
211 struct Matched {
212 // number of symbols to emit
213 int emits;
214
operator <Matched215 bool operator<(const Matched& other) const { return emits < other.emits; }
216 };
217
218 // Unmatched says we didn't emit anything and we need to keep decoding
219 struct Unmatched {
220 SymSet syms;
221
operator <Unmatched222 bool operator<(const Unmatched& other) const { return syms < other.syms; }
223 };
224
225 // Emit end of stream
226 struct End {
operator <End227 bool operator<(End) const { return false; }
228 };
229
230 using MatchCase = absl::variant<Matched, Unmatched, End>;
231
232 ///////////////////////////////////////////////////////////////////////////////
233 // Text & numeric helper functions
234
235 // Given a vector of lines, indent those lines by some number of indents
236 // (2 spaces) and return that.
IndentLines(std::vector<std::string> lines,int n=1)237 std::vector<std::string> IndentLines(std::vector<std::string> lines,
238 int n = 1) {
239 std::string indent(2 * n, ' ');
240 for (auto& line : lines) {
241 line = absl::StrCat(indent, line);
242 }
243 return lines;
244 }
245
246 // Given a snake_case_name return a PascalCaseName
ToPascalCase(const std::string & in)247 std::string ToPascalCase(const std::string& in) {
248 std::string out;
249 bool next_upper = true;
250 for (char c : in) {
251 if (c == '_') {
252 next_upper = true;
253 } else {
254 if (next_upper) {
255 out.push_back(toupper(c));
256 next_upper = false;
257 } else {
258 out.push_back(c);
259 }
260 }
261 }
262 return out;
263 }
264
265 // Return a uint type for some number of bits (16 -> uint16_t, 32 -> uint32_t)
Uint(int bits)266 std::string Uint(int bits) { return absl::StrCat("uint", bits, "_t"); }
267
268 // Given a maximum value, how many bits to store it in a uint
TypeBitsForMax(int max)269 int TypeBitsForMax(int max) {
270 if (max <= 255) {
271 return 8;
272 } else if (max <= 65535) {
273 return 16;
274 } else {
275 return 32;
276 }
277 }
278
279 // Combine Uint & TypeBitsForMax to make for more concise code
TypeForMax(int max)280 std::string TypeForMax(int max) { return Uint(TypeBitsForMax(max)); }
281
282 // How many bits are needed to encode a value
BitsForMaxValue(int x)283 int BitsForMaxValue(int x) {
284 int n = 0;
285 while (x >= (1 << n)) n++;
286 return n;
287 }
288
289 ///////////////////////////////////////////////////////////////////////////////
290 // Codegen framework
291 // Some helpers so we don't need to generate all the code linearly, which helps
292 // organize this a little more nicely.
293
294 // An Item is our primitive for code generation, it can generate some lines
295 // that it would like to emit - those lines are fed to a parent item that might
296 // generate more lines or mutate the ones we return, and so on until codegen
297 // is complete.
298 class Item {
299 public:
300 virtual ~Item() = default;
301 virtual std::vector<std::string> ToLines() const = 0;
ToString() const302 std::string ToString() const {
303 return absl::StrCat(absl::StrJoin(ToLines(), "\n"), "\n");
304 }
305 };
306 using ItemPtr = std::unique_ptr<Item>;
307
308 // An item that emits one line (the one given as an argument!)
309 class String : public Item {
310 public:
String(std::string s)311 explicit String(std::string s) : s_(std::move(s)) {}
ToLines() const312 std::vector<std::string> ToLines() const override { return {s_}; }
313
314 private:
315 std::string s_;
316 };
317
318 // An item that returns a fixed copyright notice and autogenerated note text.
319 class Prelude final : public Item {
320 public:
ToLines() const321 std::vector<std::string> ToLines() const {
322 return {
323 "// Copyright 2022 gRPC authors.",
324 "//",
325 "// Licensed under the Apache License, Version 2.0 (the "
326 "\"License\");",
327 "// you may not use this file except in compliance with the License.",
328 "// You may obtain a copy of the License at",
329 "//",
330 "// http://www.apache.org/licenses/LICENSE-2.0",
331 "//",
332 "// Unless required by applicable law or agreed to in writing, "
333 "software",
334 "// distributed under the License is distributed on an \"AS IS\" "
335 "BASIS,",
336 "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or "
337 "implied.",
338 "// See the License for the specific language governing permissions "
339 "and",
340 "// limitations under the License.",
341 "",
342 std::string(80, '/'),
343 "// This file is autogenerated: see "
344 "tools/codegen/core/gen_huffman_decompressor.cc",
345 ""};
346 }
347 };
348
349 class Switch;
350
351 // A Sink is an Item that we can add more Items to.
352 // At codegen time it calls each of its children in turn and concatenates
353 // their results together.
354 class Sink : public Item {
355 public:
ToLines() const356 std::vector<std::string> ToLines() const override {
357 std::vector<std::string> lines;
358 for (const auto& item : children_) {
359 for (const auto& line : item->ToLines()) {
360 lines.push_back(line);
361 }
362 }
363 return lines;
364 }
365
366 // Add one string to our output.
Add(std::string s)367 void Add(std::string s) {
368 children_.push_back(std::make_unique<String>(std::move(s)));
369 }
370
371 // Add an item of type T to our output (constructing it with args).
372 template <typename T, typename... Args>
Add(Args &&...args)373 T* Add(Args&&... args) {
374 auto v = std::make_unique<T>(std::forward<Args>(args)...);
375 auto* r = v.get();
376 children_.push_back(std::move(v));
377 return r;
378 }
379
380 private:
381 std::vector<ItemPtr> children_;
382 };
383
384 // A sink that indents its lines by one indent (2 spaces)
385 class Indent : public Sink {
386 public:
ToLines() const387 std::vector<std::string> ToLines() const override {
388 return IndentLines(Sink::ToLines());
389 }
390 };
391
392 // A Sink that wraps its lines in a while block
393 class While : public Sink {
394 public:
While(std::string cond)395 explicit While(std::string cond) : cond_(std::move(cond)) {}
ToLines() const396 std::vector<std::string> ToLines() const override {
397 std::vector<std::string> lines;
398 lines.push_back(absl::StrCat("while (", cond_, ") {"));
399 for (const auto& line : IndentLines(Sink::ToLines())) {
400 lines.push_back(line);
401 }
402 lines.push_back("}");
403 return lines;
404 }
405
406 private:
407 std::string cond_;
408 };
409
410 // A switch statement.
411 // Cases can be modified by calling the Case member.
412 // Identical cases are collapsed into 'case X: case Y:' type blocks.
413 class Switch : public Item {
414 public:
415 // \a cond is the condition to place at the head of the switch statement.
416 // eg. "switch (cond) {".
Switch(std::string cond)417 explicit Switch(std::string cond) : cond_(std::move(cond)) {}
ToLines() const418 std::vector<std::string> ToLines() const override {
419 std::map<std::string, std::vector<std::string>> reverse_map;
420 for (const auto& kv : cases_) {
421 reverse_map[kv.second.ToString()].push_back(kv.first);
422 }
423 std::vector<std::string> lines;
424 lines.push_back(absl::StrCat("switch (", cond_, ") {"));
425 for (const auto& kv : reverse_map) {
426 for (const auto& cond : kv.second) {
427 lines.push_back(absl::StrCat(" case ", cond, ":"));
428 }
429 lines.back().append(" {");
430 for (const auto& case_line :
431 IndentLines(cases_.find(kv.second[0])->second.ToLines(), 2)) {
432 lines.push_back(case_line);
433 }
434 lines.push_back(" }");
435 }
436 lines.push_back("}");
437 return lines;
438 }
439
Case(std::string cond)440 Sink* Case(std::string cond) { return &cases_[cond]; }
441
442 private:
443 std::string cond_;
444 std::map<std::string, Sink> cases_;
445 };
446
447 ///////////////////////////////////////////////////////////////////////////////
448 // BuildCtx declaration
449 // Shared state for one code gen attempt
450
451 class TableBuilder;
452 class FunMaker;
453
454 class BuildCtx {
455 public:
BuildCtx(std::vector<int> max_bits_for_depth,Sink * global_fns,Sink * global_decls,Sink * global_values,FunMaker * fun_maker)456 BuildCtx(std::vector<int> max_bits_for_depth, Sink* global_fns,
457 Sink* global_decls, Sink* global_values, FunMaker* fun_maker)
458 : max_bits_for_depth_(std::move(max_bits_for_depth)),
459 global_fns_(global_fns),
460 global_decls_(global_decls),
461 global_values_(global_values),
462 fun_maker_(fun_maker) {}
463
464 void AddStep(SymSet start_syms, int num_bits, bool is_top, bool refill,
465 int depth, Sink* out);
466 void AddMatchBody(TableBuilder* table_builder, std::string index,
467 std::string ofs, const MatchCase& match_case, bool is_top,
468 bool refill, int depth, Sink* out);
469 void AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
470 Sink* out);
471
NewId()472 int NewId() { return next_id_++; }
MaxBitsForTop() const473 int MaxBitsForTop() const { return max_bits_for_depth_[0]; }
474
PreviousNameForArtifact(std::string proposed_name,Hash hash)475 absl::optional<std::string> PreviousNameForArtifact(std::string proposed_name,
476 Hash hash) {
477 auto it = arrays_.find(hash);
478 if (it == arrays_.end()) {
479 arrays_.emplace(hash, proposed_name);
480 return absl::nullopt;
481 }
482 return it->second;
483 }
484
global_fns() const485 Sink* global_fns() const { return global_fns_; }
global_decls() const486 Sink* global_decls() const { return global_decls_; }
global_values() const487 Sink* global_values() const { return global_values_; }
488
489 private:
490 const std::vector<int> max_bits_for_depth_;
491 std::map<Hash, std::string> arrays_;
492 int next_id_ = 1;
493 Sink* const global_fns_;
494 Sink* const global_decls_;
495 Sink* const global_values_;
496 FunMaker* const fun_maker_;
497 };
498
499 ///////////////////////////////////////////////////////////////////////////////
500 // TableBuilder
501 // All our magic for building decode tables.
502 // We have three kinds of tables to generate:
503 // 1. op tables that translate a bit sequence to which decode case we should
504 // execute (and arguments to it), and
505 // 2. emit tables that translate an index given by the op table and tell us
506 // which symbols to emit
507 // Op table format
508 // Our opcodes contain an offset into an emit table, a number of bits consumed
509 // and an operation. The consumed bits are how many of the presented to us bits
510 // we actually took. The operation tells whether to emit some symbols (and how
511 // many) or to keep decoding.
512 // Optimization 1:
513 // op tables are essentially dense maps of bits -> opcode, and it turns out
514 // that *many* of the opcodes repeat across index bits for some of our tables
515 // so for those we split the table into two levels: first level indexes into
516 // a child table, and the child table contains the deduped opcodes.
517 // Optimization 2:
518 // Emit tables are a bit list of uint8_ts, and are indexed into by the op
519 // table (with an offset and length) - since many symbols get repeated, we try
520 // to overlay the symbols in the emit table to reduce the size.
521 // Optimization 3:
522 // We shard the table into some number of slices and use the top bits of the
523 // incoming lookup to select the shard. This tends to allow us to use smaller
524 // types to represent the table, saving on footprint.
525
526 class TableBuilder {
527 public:
TableBuilder(BuildCtx * ctx)528 explicit TableBuilder(BuildCtx* ctx) : ctx_(ctx), id_(ctx->NewId()) {}
529
530 // Append one case to the table
Add(int match_case,std::vector<uint8_t> emit,int consumed_bits)531 void Add(int match_case, std::vector<uint8_t> emit, int consumed_bits) {
532 elems_.push_back({match_case, std::move(emit), consumed_bits});
533 max_consumed_bits_ = std::max(max_consumed_bits_, consumed_bits);
534 max_match_case_ = std::max(max_match_case_, match_case);
535 }
536
537 // Build the table
Build() const538 void Build() const {
539 Choose()->Build(this, BitsForMaxValue(elems_.size() - 1));
540 }
541
542 // Generate a call to the accessor function for the emit table
EmitAccessor(std::string index,std::string offset)543 std::string EmitAccessor(std::string index, std::string offset) {
544 return absl::StrCat("GetEmit", id_, "(", index, ", ", offset, ")");
545 }
546
547 // Generate a call to the accessor function for the op table
OpAccessor(std::string index)548 std::string OpAccessor(std::string index) {
549 return absl::StrCat("GetOp", id_, "(", index, ")");
550 }
551
ConsumeBits() const552 int ConsumeBits() const { return BitsForMaxValue(max_consumed_bits_); }
MatchBits() const553 int MatchBits() const { return BitsForMaxValue(max_match_case_); }
554
555 private:
556 // One element in the op table.
557 struct Elem {
558 int match_case;
559 std::vector<uint8_t> emit;
560 int consumed_bits;
561 };
562
563 // A nested slice is one slice of a table using two level lookup
564 // - i.e. we look at an outer table to get an index into the inner table,
565 // and then fetch the result from there.
566 struct NestedSlice {
567 std::vector<uint8_t> emit;
568 std::vector<uint64_t> inner;
569 std::vector<int> outer;
570
571 // Various sizes return number of bits to be generated
572
InnerSizeTableBuilder::NestedSlice573 size_t InnerSize() const {
574 return inner.size() *
575 TypeBitsForMax(*std::max_element(inner.begin(), inner.end()));
576 }
577
OuterSizeTableBuilder::NestedSlice578 size_t OuterSize() const {
579 return outer.size() *
580 TypeBitsForMax(*std::max_element(outer.begin(), outer.end()));
581 }
582
EmitSizeTableBuilder::NestedSlice583 size_t EmitSize() const { return emit.size() * 8; }
584 };
585
586 // A slice is one part of a larger table.
587 struct Slice {
588 std::vector<uint8_t> emit;
589 std::vector<uint64_t> ops;
590
591 // Various sizes return number of bits to be generated
592
OpsSizeTableBuilder::Slice593 size_t OpsSize() const {
594 return ops.size() *
595 TypeBitsForMax(*std::max_element(ops.begin(), ops.end()));
596 }
597
EmitSizeTableBuilder::Slice598 size_t EmitSize() const { return emit.size() * 8; }
599
600 // Given a vector of symbols to emit, return the offset into the emit table
601 // that they're at (adding them to the emit table if necessary).
OffsetOfTableBuilder::Slice602 int OffsetOf(const std::vector<uint8_t>& x) {
603 if (x.empty()) return 0;
604 auto r = std::search(emit.begin(), emit.end(), x.begin(), x.end());
605 if (r == emit.end()) {
606 // look for a partial match @ end
607 for (size_t check_len = x.size() - 1; check_len > 0; check_len--) {
608 if (emit.size() < check_len) continue;
609 bool matches = true;
610 for (size_t i = 0; matches && i < check_len; i++) {
611 if (emit[emit.size() - check_len + i] != x[i]) matches = false;
612 }
613 if (matches) {
614 int offset = emit.size() - check_len;
615 for (size_t i = check_len; i < x.size(); i++) {
616 emit.push_back(x[i]);
617 }
618 return offset;
619 }
620 }
621 // add new
622 int result = emit.size();
623 for (auto v : x) emit.push_back(v);
624 return result;
625 }
626 return r - emit.begin();
627 }
628
629 // Convert this slice to a nested slice.
MakeNestedSliceTableBuilder::Slice630 NestedSlice MakeNestedSlice() const {
631 NestedSlice result;
632 result.emit = emit;
633 std::map<uint64_t, int> op_to_inner;
634 for (auto v : ops) {
635 auto it = op_to_inner.find(v);
636 if (it == op_to_inner.end()) {
637 it = op_to_inner.emplace(v, op_to_inner.size()).first;
638 result.inner.push_back(v);
639 }
640 result.outer.push_back(it->second);
641 }
642 return result;
643 }
644 };
645
646 // An EncodeOption is a potential way of encoding a table.
647 struct EncodeOption {
648 // Overall size (in bits) of the table encoding
649 virtual size_t Size() const = 0;
650 // Generate the code
651 virtual void Build(const TableBuilder* builder, int op_bits) const = 0;
~EncodeOptionTableBuilder::EncodeOption652 virtual ~EncodeOption() {}
653 };
654
655 // NestedTable is a table that uses two level lookup for each slice
656 struct NestedTable : public EncodeOption {
657 std::vector<NestedSlice> slices;
658 int slice_bits;
SizeTableBuilder::NestedTable659 size_t Size() const override {
660 size_t sum = 0;
661 std::vector<Hash> h_emit;
662 std::vector<Hash> h_inner;
663 std::vector<Hash> h_outer;
664 for (size_t i = 0; i < slices.size(); i++) {
665 h_emit.push_back(HashVec(slices[i].emit));
666 h_inner.push_back(HashVec(slices[i].inner));
667 h_outer.push_back(HashVec(slices[i].outer));
668 }
669 std::set<Hash> seen;
670 for (size_t i = 0; i < slices.size(); i++) {
671 // Try to account for deduplication in the size calculation.
672 if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
673 if (seen.count(h_outer[i]) == 0) sum += slices[i].OuterSize();
674 if (seen.count(h_inner[i]) == 0) sum += slices[i].OuterSize();
675 seen.insert(h_emit[i]);
676 seen.insert(h_outer[i]);
677 seen.insert(h_inner[i]);
678 }
679 if (slice_bits != 0) sum += 3 * 64 * slices.size();
680 return sum;
681 }
BuildTableBuilder::NestedTable682 void Build(const TableBuilder* builder, int op_bits) const override {
683 Sink* const global_fns = builder->ctx_->global_fns();
684 Sink* const global_decls = builder->ctx_->global_decls();
685 Sink* const global_values = builder->ctx_->global_values();
686 const int id = builder->id_;
687 std::vector<std::string> lines;
688 const uint64_t max_inner = MaxInner();
689 const uint64_t max_outer = MaxOuter();
690 std::vector<std::unique_ptr<Array>> emit_names;
691 std::vector<std::unique_ptr<Array>> inner_names;
692 std::vector<std::unique_ptr<Array>> outer_names;
693 for (size_t i = 0; i < slices.size(); i++) {
694 emit_names.push_back(builder->GenArray(
695 slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
696 "uint8_t", slices[i].emit, true, global_decls, global_values));
697 inner_names.push_back(builder->GenArray(
698 slice_bits != 0, absl::StrCat("table", id, "_", i, "_inner"),
699 TypeForMax(max_inner), slices[i].inner, true, global_decls,
700 global_values));
701 outer_names.push_back(builder->GenArray(
702 slice_bits != 0, absl::StrCat("table", id, "_", i, "_outer"),
703 TypeForMax(max_outer), slices[i].outer, false, global_decls,
704 global_values));
705 }
706 if (slice_bits == 0) {
707 global_fns->Add(absl::StrCat(
708 "static inline uint64_t GetOp", id, "(size_t i) { return ",
709 inner_names[0]->Index(outer_names[0]->Index("i")), "; }"));
710 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
711 "(size_t, size_t emit) { return ",
712 emit_names[0]->Index("emit"), "; }"));
713 } else {
714 GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
715 global_values);
716 GenCompound(id, inner_names, "inner", TypeForMax(max_inner),
717 global_decls, global_values);
718 GenCompound(id, outer_names, "outer", TypeForMax(max_outer),
719 global_decls, global_values);
720 global_fns->Add(absl::StrCat(
721 "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
722 "_inner_[i >> ", op_bits - slice_bits, "][table", id,
723 "_outer_[i >> ", op_bits - slice_bits, "][i & 0x",
724 absl::Hex((1 << (op_bits - slice_bits)) - 1), "]]; }"));
725 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
726 "(size_t i, size_t emit) { return table",
727 id, "_emit_[i >> ", op_bits - slice_bits,
728 "][emit]; }"));
729 }
730 }
MaxInnerTableBuilder::NestedTable731 uint64_t MaxInner() const {
732 uint64_t max_inner = 0;
733 for (size_t i = 0; i < slices.size(); i++) {
734 max_inner = std::max(
735 max_inner,
736 *std::max_element(slices[i].inner.begin(), slices[i].inner.end()));
737 }
738 return max_inner;
739 }
MaxOuterTableBuilder::NestedTable740 int MaxOuter() const {
741 int max_outer = 0;
742 for (size_t i = 0; i < slices.size(); i++) {
743 max_outer = std::max(
744 max_outer,
745 *std::max_element(slices[i].outer.begin(), slices[i].outer.end()));
746 }
747 return max_outer;
748 }
749 };
750
751 // Encoding that uses single level lookup for each slice.
752 struct Table : public EncodeOption {
753 std::vector<Slice> slices;
754 int slice_bits;
SizeTableBuilder::Table755 size_t Size() const override {
756 size_t sum = 0;
757 std::vector<Hash> h_emit;
758 std::vector<Hash> h_ops;
759 for (size_t i = 0; i < slices.size(); i++) {
760 h_emit.push_back(HashVec(slices[i].emit));
761 h_ops.push_back(HashVec(slices[i].ops));
762 }
763 std::set<Hash> seen;
764 for (size_t i = 0; i < slices.size(); i++) {
765 if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
766 if (seen.count(h_ops[i]) == 0) sum += slices[i].OpsSize();
767 seen.insert(h_emit[i]);
768 seen.insert(h_ops[i]);
769 }
770 return sum + 3 * 64 * slices.size();
771 }
BuildTableBuilder::Table772 void Build(const TableBuilder* builder, int op_bits) const override {
773 Sink* const global_fns = builder->ctx_->global_fns();
774 Sink* const global_decls = builder->ctx_->global_decls();
775 Sink* const global_values = builder->ctx_->global_values();
776 uint64_t max_op = MaxOp();
777 const int id = builder->id_;
778 std::vector<std::unique_ptr<Array>> emit_names;
779 std::vector<std::unique_ptr<Array>> ops_names;
780 for (size_t i = 0; i < slices.size(); i++) {
781 emit_names.push_back(builder->GenArray(
782 slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
783 "uint8_t", slices[i].emit, true, global_decls, global_values));
784 ops_names.push_back(builder->GenArray(
785 slice_bits != 0, absl::StrCat("table", id, "_", i, "_ops"),
786 TypeForMax(max_op), slices[i].ops, true, global_decls,
787 global_values));
788 }
789 if (slice_bits == 0) {
790 global_fns->Add(absl::StrCat("static inline uint64_t GetOp", id,
791 "(size_t i) { return ",
792 ops_names[0]->Index("i"), "; }"));
793 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
794 "(size_t, size_t emit) { return ",
795 emit_names[0]->Index("emit"), "; }"));
796 } else {
797 GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
798 global_values);
799 GenCompound(id, ops_names, "ops", TypeForMax(max_op), global_decls,
800 global_values);
801 global_fns->Add(absl::StrCat(
802 "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
803 "_ops_[i >> ", op_bits - slice_bits, "][i & 0x",
804 absl::Hex((1 << (op_bits - slice_bits)) - 1), "]; }"));
805 global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
806 "(size_t i, size_t emit) { return table",
807 id, "_emit_[i >> ", op_bits - slice_bits,
808 "][emit]; }"));
809 }
810 }
MaxOpTableBuilder::Table811 uint64_t MaxOp() const {
812 uint64_t max_op = 0;
813 for (size_t i = 0; i < slices.size(); i++) {
814 max_op = std::max(max_op, *std::max_element(slices[i].ops.begin(),
815 slices[i].ops.end()));
816 }
817 return max_op;
818 }
819 // Convert to a two-level lookup
MakeNestedTableTableBuilder::Table820 std::unique_ptr<NestedTable> MakeNestedTable() {
821 std::unique_ptr<NestedTable> result(new NestedTable);
822 result->slice_bits = slice_bits;
823 for (const auto& slice : slices) {
824 result->slices.push_back(slice.MakeNestedSlice());
825 }
826 return result;
827 }
828 };
829
830 // Given a number of slices (2**slice_bits), generate a table that uses a
831 // single level lookup for each slice based on our input.
MakeTable(size_t slice_bits) const832 std::unique_ptr<Table> MakeTable(size_t slice_bits) const {
833 std::unique_ptr<Table> table = std::make_unique<Table>();
834 int slices = 1 << slice_bits;
835 table->slices.resize(slices);
836 table->slice_bits = slice_bits;
837 const int pack_consume_bits = ConsumeBits();
838 const int pack_match_bits = MatchBits();
839 for (size_t i = 0; i < slices; i++) {
840 auto& slice = table->slices[i];
841 for (size_t j = 0; j < elems_.size() / slices; j++) {
842 const auto& elem = elems_[i * elems_.size() / slices + j];
843 slice.ops.push_back(elem.consumed_bits |
844 (elem.match_case << pack_consume_bits) |
845 (slice.OffsetOf(elem.emit)
846 << (pack_consume_bits + pack_match_bits)));
847 }
848 }
849 return table;
850 }
851
852 class Array {
853 public:
854 virtual ~Array() = default;
855 virtual std::string Index(absl::string_view value) = 0;
856 virtual std::string ArrayName() = 0;
857 virtual int Cost() = 0;
858 };
859
860 class NamedArray : public Array {
861 public:
NamedArray(std::string name)862 explicit NamedArray(std::string name) : name_(std::move(name)) {}
Index(absl::string_view value)863 std::string Index(absl::string_view value) override {
864 return absl::StrCat(name_, "[", value, "]");
865 }
ArrayName()866 std::string ArrayName() override { return name_; }
Cost()867 int Cost() override { abort(); }
868
869 private:
870 std::string name_;
871 };
872
873 class IdentityArray : public Array {
874 public:
Index(absl::string_view value)875 std::string Index(absl::string_view value) override {
876 return std::string(value);
877 }
ArrayName()878 std::string ArrayName() override { abort(); }
Cost()879 int Cost() override { return 0; }
880 };
881
882 class ConstantArray : public Array {
883 public:
ConstantArray(std::string value)884 explicit ConstantArray(std::string value) : value_(std::move(value)) {}
Index(absl::string_view index)885 std::string Index(absl::string_view index) override {
886 return absl::StrCat("((void)", index, ", ", value_, ")");
887 }
ArrayName()888 std::string ArrayName() override { abort(); }
Cost()889 int Cost() override { return 0; }
890
891 private:
892 std::string value_;
893 };
894
895 class OffsetArray : public Array {
896 public:
OffsetArray(int offset)897 explicit OffsetArray(int offset) : offset_(offset) {}
Index(absl::string_view value)898 std::string Index(absl::string_view value) override {
899 return absl::StrCat(value, " + ", offset_);
900 }
ArrayName()901 std::string ArrayName() override { abort(); }
Cost()902 int Cost() override { return 10; }
903
904 private:
905 int offset_;
906 };
907
908 class LinearDivideArray : public Array {
909 public:
LinearDivideArray(int offset,int divisor)910 LinearDivideArray(int offset, int divisor)
911 : offset_(offset), divisor_(divisor) {}
Index(absl::string_view value)912 std::string Index(absl::string_view value) override {
913 return absl::StrCat(value, "/", divisor_, " + ", offset_);
914 }
ArrayName()915 std::string ArrayName() override { abort(); }
Cost()916 int Cost() override { return 20 + (offset_ != 0 ? 10 : 0); }
917
918 private:
919 int offset_;
920 int divisor_;
921 };
922
923 class TwoElemArray : public Array {
924 public:
TwoElemArray(std::string value0,std::string value1)925 TwoElemArray(std::string value0, std::string value1)
926 : value0_(std::move(value0)), value1_(std::move(value1)) {}
Index(absl::string_view value)927 std::string Index(absl::string_view value) override {
928 return absl::StrCat(value, " ? ", value1_, " : ", value0_);
929 }
ArrayName()930 std::string ArrayName() override { abort(); }
Cost()931 int Cost() override { return 40; }
932
933 private:
934 std::string value0_;
935 std::string value1_;
936 };
937
938 class Composite2Array : public Array {
939 public:
Composite2Array(std::unique_ptr<Array> a,std::unique_ptr<Array> b,int split)940 Composite2Array(std::unique_ptr<Array> a, std::unique_ptr<Array> b,
941 int split)
942 : a_(std::move(a)), b_(std::move(b)), split_(split) {}
Index(absl::string_view value)943 std::string Index(absl::string_view value) override {
944 return absl::StrCat(
945 "(", value, " < ", split_, " ? (", a_->Index(value), ") : (",
946 b_->Index(absl::StrCat("(", value, "-", split_, ")")), "))");
947 }
ArrayName()948 std::string ArrayName() override { abort(); }
Cost()949 int Cost() override { return 40 + a_->Cost() + b_->Cost(); }
950
951 private:
952 std::unique_ptr<Array> a_;
953 std::unique_ptr<Array> b_;
954 int split_;
955 };
956
957 // Helper to generate a compound table (an array of arrays)
GenCompound(int id,const std::vector<std::unique_ptr<Array>> & arrays,std::string ext,std::string type,Sink * global_decls,Sink * global_values)958 static void GenCompound(int id,
959 const std::vector<std::unique_ptr<Array>>& arrays,
960 std::string ext, std::string type, Sink* global_decls,
961 Sink* global_values) {
962 global_decls->Add(absl::StrCat("static const ", type, "* const table", id,
963 "_", ext, "_[", arrays.size(), "];"));
964 global_values->Add(absl::StrCat("const ", type,
965 "* const HuffDecoderCommon::table", id, "_",
966 ext, "_[", arrays.size(), "] = {"));
967 for (const std::unique_ptr<Array>& array : arrays) {
968 global_values->Add(absl::StrCat(" ", array->ArrayName(), ","));
969 }
970 global_values->Add("};");
971 }
972
973 // Try to create a simple function equivalent to a mapping implied by a set of
974 // values.
975 static const int kMaxArrayToFunctionRecursions = 1;
976 template <typename T>
ArrayToFunction(const std::vector<T> & values,int recurse=kMaxArrayToFunctionRecursions)977 static std::unique_ptr<Array> ArrayToFunction(
978 const std::vector<T>& values,
979 int recurse = kMaxArrayToFunctionRecursions) {
980 std::unique_ptr<Array> best = nullptr;
981 auto note_solution = [&best](std::unique_ptr<Array> a) {
982 if (best != nullptr && best->Cost() <= a->Cost()) return;
983 best = std::move(a);
984 };
985 // constant => k,k,k,k,...
986 bool is_constant = true;
987 for (size_t i = 1; i < values.size(); i++) {
988 if (values[i] != values[0]) {
989 is_constant = false;
990 break;
991 }
992 }
993 if (is_constant) {
994 note_solution(std::make_unique<ConstantArray>(absl::StrCat(values[0])));
995 }
996 // identity => 0,1,2,3,...
997 bool is_identity = true;
998 for (size_t i = 0; i < values.size(); i++) {
999 if (values[i] != i) {
1000 is_identity = false;
1001 break;
1002 }
1003 }
1004 if (is_identity) {
1005 note_solution(std::make_unique<IdentityArray>());
1006 }
1007 // offset => k,k+1,k+2,k+3,...
1008 bool is_offset = true;
1009 for (size_t i = 1; i < values.size(); i++) {
1010 if (values[i] - values[0] != i) {
1011 is_offset = false;
1012 break;
1013 }
1014 }
1015 if (is_offset) {
1016 note_solution(std::make_unique<OffsetArray>(values[0]));
1017 }
1018 // offset => k,k,k+1,k+1,...
1019 for (int d = 2; d < 32; d++) {
1020 bool is_linear = true;
1021 for (size_t i = 1; i < values.size(); i++) {
1022 if (values[i] - values[0] != (i / d)) {
1023 is_linear = false;
1024 break;
1025 }
1026 }
1027 if (is_linear) {
1028 note_solution(std::make_unique<LinearDivideArray>(values[0], d));
1029 }
1030 }
1031 // Two items can be resolved with a conditional
1032 if (values.size() == 2) {
1033 note_solution(std::make_unique<TwoElemArray>(absl::StrCat(values[0]),
1034 absl::StrCat(values[1])));
1035 }
1036 if ((recurse > 0 && values.size() >= 6) ||
1037 (recurse == kMaxArrayToFunctionRecursions)) {
1038 for (size_t i = 1; i < values.size() - 1; i++) {
1039 std::vector<T> left(values.begin(), values.begin() + i);
1040 std::vector<T> right(values.begin() + i, values.end());
1041 std::unique_ptr<Array> left_array = ArrayToFunction(left, recurse - 1);
1042 std::unique_ptr<Array> right_array =
1043 ArrayToFunction(right, recurse - 1);
1044 if (left_array && right_array) {
1045 note_solution(std::make_unique<Composite2Array>(
1046 std::move(left_array), std::move(right_array), i));
1047 }
1048 }
1049 }
1050 return best;
1051 }
1052
1053 // Helper to generate an array of values
1054 template <typename T>
GenArray(bool force_array,std::string name,std::string type,const std::vector<T> & values,bool hex,Sink * global_decls,Sink * global_values) const1055 std::unique_ptr<Array> GenArray(bool force_array, std::string name,
1056 std::string type,
1057 const std::vector<T>& values, bool hex,
1058 Sink* global_decls,
1059 Sink* global_values) const {
1060 if (!force_array) {
1061 auto fn = ArrayToFunction(values);
1062 if (fn != nullptr) return fn;
1063 }
1064 auto previous_name = ctx_->PreviousNameForArtifact(name, HashVec(values));
1065 if (previous_name.has_value()) {
1066 return std::make_unique<NamedArray>(absl::StrCat(*previous_name, "_"));
1067 }
1068 std::vector<std::string> elems;
1069 elems.reserve(values.size());
1070 for (const auto& elem : values) {
1071 if (hex) {
1072 if (type == "uint8_t") {
1073 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad2)));
1074 } else if (type == "uint16_t") {
1075 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad4)));
1076 } else {
1077 elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad8)));
1078 }
1079 } else {
1080 elems.push_back(absl::StrCat(elem));
1081 }
1082 }
1083 std::string data = absl::StrJoin(elems, ", ");
1084 global_decls->Add(absl::StrCat("static const ", type, " ", name, "_[",
1085 values.size(), "];"));
1086 global_values->Add(absl::StrCat("const ", type, " HuffDecoderCommon::",
1087 name, "_[", values.size(), "] = {"));
1088 global_values->Add(absl::StrCat(" ", data));
1089 global_values->Add("};");
1090 return std::make_unique<NamedArray>(absl::StrCat(name, "_"));
1091 }
1092
1093 // Choose an encoding for this set of tables.
1094 // We try all available values for slice count and choose the one that gives
1095 // the smallest footprint.
Choose() const1096 std::unique_ptr<EncodeOption> Choose() const {
1097 std::unique_ptr<EncodeOption> chosen;
1098 size_t best_size = std::numeric_limits<size_t>::max();
1099 for (size_t slice_bits = 0; (1 << slice_bits) < elems_.size();
1100 slice_bits++) {
1101 auto raw = MakeTable(slice_bits);
1102 size_t raw_size = raw->Size();
1103 auto nested = raw->MakeNestedTable();
1104 size_t nested_size = nested->Size();
1105 if (raw_size < best_size) {
1106 chosen = std::move(raw);
1107 best_size = raw_size;
1108 }
1109 if (nested_size < best_size) {
1110 chosen = std::move(nested);
1111 best_size = nested_size;
1112 }
1113 }
1114 return chosen;
1115 }
1116
1117 BuildCtx* const ctx_;
1118 std::vector<Elem> elems_;
1119 int max_consumed_bits_ = 0;
1120 int max_match_case_ = 0;
1121 const int id_;
1122 };
1123
1124 ///////////////////////////////////////////////////////////////////////////////
1125 // FunMaker
1126 // Handles generating the code for various functions.
1127
1128 class FunMaker {
1129 public:
FunMaker(Sink * sink)1130 explicit FunMaker(Sink* sink) : sink_(sink) {}
1131
1132 // Generate a refill function - that ensures the incoming bitmask has enough
1133 // bits for the next step.
RefillTo(int n)1134 std::string RefillTo(int n) {
1135 if (have_refills_.count(n) == 0) {
1136 have_refills_.insert(n);
1137 auto fn = NewFun(absl::StrCat("RefillTo", n), "bool");
1138 auto s = fn->Add<Switch>("buffer_len_");
1139 for (int i = 0; i < n; i++) {
1140 auto c = s->Case(absl::StrCat(i));
1141 const int bytes_needed = (n - i + 7) / 8;
1142 c->Add(absl::StrCat("return ", ReadBytes(bytes_needed), ";"));
1143 }
1144 fn->Add("return true;");
1145 }
1146 return absl::StrCat("RefillTo", n, "()");
1147 }
1148
1149 // At callsite, generate a call to a new function with base name
1150 // base_name (new functions get a suffix of how many instances of base_name
1151 // there have been).
1152 // Return a sink to fill in the body of the new function.
CallNewFun(std::string base_name,Sink * callsite)1153 Sink* CallNewFun(std::string base_name, Sink* callsite) {
1154 std::string name = absl::StrCat(base_name, have_funs_[base_name]++);
1155 callsite->Add(absl::StrCat(name, "();"));
1156 return NewFun(name, "void");
1157 }
1158
1159 private:
NewFun(std::string name,std::string returns)1160 Sink* NewFun(std::string name, std::string returns) {
1161 sink_->Add(absl::StrCat(returns, " ", name, "() {"));
1162 auto fn = sink_->Add<Indent>();
1163 sink_->Add("}");
1164 return fn;
1165 }
1166
1167 // Bring in some number of bytes from the input stream to our current read
1168 // bits.
ReadBytes(int bytes_needed)1169 std::string ReadBytes(int bytes_needed) {
1170 if (have_reads_.count(bytes_needed) == 0) {
1171 have_reads_.insert(bytes_needed);
1172 auto fn = NewFun(absl::StrCat("Read", bytes_needed), "bool");
1173 fn->Add(absl::StrCat("if (end_ - begin_ < ", bytes_needed,
1174 ") return false;"));
1175 fn->Add(absl::StrCat("buffer_ <<= ", 8 * bytes_needed, ";"));
1176 for (int i = 0; i < bytes_needed; i++) {
1177 fn->Add(absl::StrCat("buffer_ |= static_cast<uint64_t>(*begin_++) << ",
1178 8 * (bytes_needed - i - 1), ";"));
1179 }
1180 fn->Add(absl::StrCat("buffer_len_ += ", 8 * bytes_needed, ";"));
1181 fn->Add("return true;");
1182 }
1183 return absl::StrCat("Read", bytes_needed, "()");
1184 }
1185
1186 std::set<int> have_refills_;
1187 std::set<int> have_reads_;
1188 std::map<std::string, int> have_funs_;
1189 Sink* sink_;
1190 };
1191
1192 ///////////////////////////////////////////////////////////////////////////////
1193 // BuildCtx implementation
1194
AddDone(SymSet start_syms,int num_bits,bool all_ones_so_far,Sink * out)1195 void BuildCtx::AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
1196 Sink* out) {
1197 out->Add("done_ = true;");
1198 if (num_bits == 1) {
1199 if (!all_ones_so_far) out->Add("ok_ = false;");
1200 return;
1201 }
1202 // we must have 0 < buffer_len_ < num_bits
1203 auto s = out->Add<Switch>("buffer_len_");
1204 auto c0 = s->Case("0");
1205 if (!all_ones_so_far) c0->Add("ok_ = false;");
1206 c0->Add("return;");
1207 for (int i = 1; i < num_bits; i++) {
1208 auto c = s->Case(absl::StrCat(i));
1209 SymSet maybe;
1210 for (auto sym : start_syms) {
1211 if (sym.bits.length() > i) continue;
1212 maybe.push_back(sym);
1213 }
1214 if (maybe.empty()) {
1215 if (all_ones_so_far) {
1216 c->Add("ok_ = (buffer_ & ((1<<buffer_len_)-1)) == (1<<buffer_len_)-1;");
1217 } else {
1218 c->Add("ok_ = false;");
1219 }
1220 c->Add("return;");
1221 continue;
1222 }
1223 TableBuilder table_builder(this);
1224 enum Cases {
1225 kNoEmitOk,
1226 kFail,
1227 kEmitOk,
1228 };
1229 for (size_t n = 0; n < (1 << i); n++) {
1230 if (all_ones_so_far && n == (1 << i) - 1) {
1231 table_builder.Add(kNoEmitOk, {}, 0);
1232 goto next;
1233 }
1234 for (auto sym : maybe) {
1235 if ((n >> (i - sym.bits.length())) == sym.bits.mask()) {
1236 for (int j = 0; j < (i - sym.bits.length()); j++) {
1237 if ((n & (1 << j)) == 0) {
1238 table_builder.Add(kFail, {}, 0);
1239 goto next;
1240 }
1241 }
1242 table_builder.Add(kEmitOk, {static_cast<uint8_t>(sym.symbol)}, 0);
1243 goto next;
1244 }
1245 }
1246 table_builder.Add(kFail, {}, 0);
1247 next:;
1248 }
1249 table_builder.Build();
1250 c->Add(absl::StrCat("const auto index = buffer_ & ", (1 << i) - 1, ";"));
1251 c->Add(absl::StrCat("const auto op = ", table_builder.OpAccessor("index"),
1252 ";"));
1253 if (table_builder.ConsumeBits() != 0) {
1254 fprintf(stderr, "consume bits = %d\n", table_builder.ConsumeBits());
1255 abort();
1256 }
1257 auto s_fin = c->Add<Switch>(
1258 absl::StrCat("op & ", (1 << table_builder.MatchBits()) - 1));
1259 auto emit_ok = s_fin->Case(absl::StrCat(kEmitOk));
1260 emit_ok->Add(absl::StrCat(
1261 "sink_(",
1262 table_builder.EmitAccessor(
1263 "index", absl::StrCat("op >>", table_builder.MatchBits())),
1264 ");"));
1265 emit_ok->Add("break;");
1266 auto fail = s_fin->Case(absl::StrCat(kFail));
1267 fail->Add("ok_ = false;");
1268 fail->Add("break;");
1269 c->Add("return;");
1270 }
1271 }
1272
AddStep(SymSet start_syms,int num_bits,bool is_top,bool refill,int depth,Sink * out)1273 void BuildCtx::AddStep(SymSet start_syms, int num_bits, bool is_top,
1274 bool refill, int depth, Sink* out) {
1275 TableBuilder table_builder(this);
1276 if (refill) {
1277 out->Add(absl::StrCat("if (!", fun_maker_->RefillTo(num_bits), ") {"));
1278 auto ifblk = out->Add<Indent>();
1279 if (!is_top) {
1280 Sym some = start_syms[0];
1281 auto sym = grpc_chttp2_huffsyms[some.symbol];
1282 int consumed_len = (sym.length - some.bits.length());
1283 uint32_t consumed_mask = sym.bits >> some.bits.length();
1284 bool all_ones_so_far = consumed_mask == ((1 << consumed_len) - 1);
1285 AddDone(start_syms, num_bits, all_ones_so_far,
1286 fun_maker_->CallNewFun("Done", ifblk));
1287 ifblk->Add("return;");
1288 } else {
1289 AddDone(start_syms, num_bits, true,
1290 fun_maker_->CallNewFun("Done", ifblk));
1291 ifblk->Add("break;");
1292 }
1293 out->Add("}");
1294 }
1295 out->Add(absl::StrCat("const auto index = (buffer_ >> (buffer_len_ - ",
1296 num_bits, ")) & 0x", absl::Hex((1 << num_bits) - 1),
1297 ";"));
1298 std::map<MatchCase, int> match_cases;
1299 for (int i = 0; i < (1 << num_bits); i++) {
1300 auto actions = ActionsFor(BitQueue(i, num_bits), start_syms, is_top);
1301 auto add_case = [&match_cases](MatchCase match_case) {
1302 if (match_cases.find(match_case) == match_cases.end()) {
1303 match_cases[match_case] = match_cases.size();
1304 }
1305 return match_cases[match_case];
1306 };
1307 if (actions.emit.size() == 1 && actions.emit[0] == 256) {
1308 table_builder.Add(add_case(End{}), {}, actions.consumed);
1309 } else if (actions.consumed == 0) {
1310 table_builder.Add(add_case(Unmatched{std::move(actions.remaining)}), {},
1311 num_bits);
1312 } else {
1313 std::vector<uint8_t> emit;
1314 for (auto sym : actions.emit) emit.push_back(sym);
1315 table_builder.Add(
1316 add_case(Matched{static_cast<int>(actions.emit.size())}),
1317 std::move(emit), actions.consumed);
1318 }
1319 }
1320 table_builder.Build();
1321 out->Add(
1322 absl::StrCat("const auto op = ", table_builder.OpAccessor("index"), ";"));
1323 out->Add(absl::StrCat("const int consumed = op & ",
1324 (1 << table_builder.ConsumeBits()) - 1, ";"));
1325 out->Add("buffer_len_ -= consumed;");
1326 out->Add(absl::StrCat("const auto emit_ofs = op >> ",
1327 table_builder.ConsumeBits() + table_builder.MatchBits(),
1328 ";"));
1329 if (match_cases.size() == 1) {
1330 AddMatchBody(&table_builder, "index", "emit_ofs",
1331 match_cases.begin()->first, is_top, refill, depth, out);
1332 } else {
1333 auto s = out->Add<Switch>(
1334 absl::StrCat("(op >> ", table_builder.ConsumeBits(), ") & ",
1335 (1 << table_builder.MatchBits()) - 1));
1336 for (auto kv : match_cases) {
1337 auto c = s->Case(absl::StrCat(kv.second));
1338 AddMatchBody(&table_builder, "index", "emit_ofs", kv.first, is_top,
1339 refill, depth, c);
1340 c->Add("break;");
1341 }
1342 }
1343 }
1344
AddMatchBody(TableBuilder * table_builder,std::string index,std::string ofs,const MatchCase & match_case,bool is_top,bool refill,int depth,Sink * out)1345 void BuildCtx::AddMatchBody(TableBuilder* table_builder, std::string index,
1346 std::string ofs, const MatchCase& match_case,
1347 bool is_top, bool refill, int depth, Sink* out) {
1348 if (absl::holds_alternative<End>(match_case)) {
1349 out->Add("begin_ = end_;");
1350 out->Add("buffer_len_ = 0;");
1351 return;
1352 }
1353 if (auto* p = absl::get_if<Unmatched>(&match_case)) {
1354 if (refill) {
1355 int max_bits = 0;
1356 for (auto sym : p->syms) max_bits = std::max(max_bits, sym.bits.length());
1357 AddStep(p->syms,
1358 depth + 1 >= max_bits_for_depth_.size()
1359 ? max_bits
1360 : std::min(max_bits, max_bits_for_depth_[depth + 1]),
1361 false, true, depth + 1,
1362 fun_maker_->CallNewFun("DecodeStep", out));
1363 }
1364 return;
1365 }
1366 const auto& matched = absl::get<Matched>(match_case);
1367 for (int i = 0; i < matched.emits; i++) {
1368 out->Add(absl::StrCat(
1369 "sink_(",
1370 table_builder->EmitAccessor(index, absl::StrCat(ofs, " + ", i)), ");"));
1371 }
1372 }
1373
1374 ///////////////////////////////////////////////////////////////////////////////
1375 // Driver code
1376
1377 // Generated header and source code
1378 struct BuildOutput {
1379 std::string header;
1380 std::string source;
1381 };
1382
1383 // Given max_bits_for_depth = {n1,n2,n3,...}
1384 // Build a decoder that first considers n1 bits, then n2, then n3, ...
Build(std::vector<int> max_bits_for_depth)1385 BuildOutput Build(std::vector<int> max_bits_for_depth) {
1386 auto hdr = std::make_unique<Sink>();
1387 auto src = std::make_unique<Sink>();
1388 hdr->Add<Prelude>();
1389 src->Add<Prelude>();
1390 hdr->Add("#ifndef GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1391 hdr->Add("#define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1392 src->Add(
1393 "#include \"src/core/ext/transport/chttp2/transport/decode_huff.h\"");
1394 hdr->Add("#include <cstddef>");
1395 hdr->Add("#include <grpc/support/port_platform.h>");
1396 src->Add("#include <grpc/support/port_platform.h>");
1397 hdr->Add("#include <cstdint>");
1398 hdr->Add(
1399 absl::StrCat("// GEOMETRY: ", absl::StrJoin(max_bits_for_depth, ",")));
1400 hdr->Add("namespace grpc_core {");
1401 src->Add("namespace grpc_core {");
1402 hdr->Add("class HuffDecoderCommon {");
1403 hdr->Add(" protected:");
1404 auto global_fns = hdr->Add<Indent>();
1405 hdr->Add(" private:");
1406 auto global_decls = hdr->Add<Indent>();
1407 hdr->Add("};");
1408 hdr->Add(
1409 "template<typename F> class HuffDecoder : public HuffDecoderCommon {");
1410 hdr->Add(" public:");
1411 auto pub = hdr->Add<Indent>();
1412 hdr->Add(" private:");
1413 auto prv = hdr->Add<Indent>();
1414 FunMaker fun_maker(prv->Add<Sink>());
1415 hdr->Add("};");
1416 hdr->Add("} // namespace grpc_core");
1417 hdr->Add("#endif // GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1418 auto global_values = src->Add<Indent>();
1419 src->Add("} // namespace grpc_core");
1420 BuildCtx ctx(std::move(max_bits_for_depth), global_fns, global_decls,
1421 global_values, &fun_maker);
1422 // constructor
1423 pub->Add(
1424 "HuffDecoder(F sink, const uint8_t* begin, const uint8_t* end) : "
1425 "sink_(sink), begin_(begin), end_(end) {}");
1426 // members
1427 prv->Add("F sink_;");
1428 prv->Add("const uint8_t* begin_;");
1429 prv->Add("const uint8_t* const end_;");
1430 prv->Add("uint64_t buffer_ = 0;");
1431 prv->Add("int buffer_len_ = 0;");
1432 prv->Add("bool ok_ = true;");
1433 prv->Add("bool done_ = false;");
1434 // main fn
1435 pub->Add("bool Run() {");
1436 auto body = pub->Add<Indent>();
1437 body->Add("while (!done_) {");
1438 ctx.AddStep(AllSyms(), ctx.MaxBitsForTop(), true, true, 0,
1439 body->Add<Indent>());
1440 body->Add("}");
1441 body->Add("return ok_;");
1442 pub->Add("}");
1443 return {hdr->ToString(), src->ToString()};
1444 }
1445
1446 // Generate all permutations of max_bits_for_depth for the Build function,
1447 // with a minimum step size of 5 bits (needed for http2 I think) and a
1448 // configurable maximum step size.
1449 class PermutationBuilder {
1450 public:
PermutationBuilder(int max_depth)1451 explicit PermutationBuilder(int max_depth) : max_depth_(max_depth) {}
Run()1452 std::vector<std::vector<int>> Run() {
1453 Step({});
1454 return std::move(perms_);
1455 }
1456
1457 private:
Step(std::vector<int> so_far)1458 void Step(std::vector<int> so_far) {
1459 int sum_so_far = std::accumulate(so_far.begin(), so_far.end(), 0);
1460 if (so_far.size() > max_depth_ ||
1461 (so_far.size() == max_depth_ && sum_so_far != 30)) {
1462 return;
1463 }
1464 if (sum_so_far + 5 > 30) {
1465 perms_.emplace_back(std::move(so_far));
1466 return;
1467 }
1468 for (int i = 5; i <= std::min(30 - sum_so_far, 16); i++) {
1469 auto p = so_far;
1470 p.push_back(i);
1471 Step(std::move(p));
1472 }
1473 }
1474
1475 const int max_depth_;
1476 std::vector<std::vector<int>> perms_;
1477 };
1478
1479 // Does what it says.
WriteFile(std::string filename,std::string content)1480 void WriteFile(std::string filename, std::string content) {
1481 std::ofstream ofs(filename);
1482 ofs << content;
1483 }
1484
main(void)1485 int main(void) {
1486 BuildOutput best;
1487 size_t best_len = std::numeric_limits<size_t>::max();
1488 std::vector<std::unique_ptr<BuildOutput>> results;
1489 std::queue<std::thread> threads;
1490 // Generate all permutations of max_bits_for_depth for the Build function.
1491 // Then generate all variations of the code.
1492 for (const auto& perm : PermutationBuilder(30).Run()) {
1493 while (threads.size() > 200) {
1494 threads.front().join();
1495 threads.pop();
1496 }
1497 results.emplace_back(std::make_unique<BuildOutput>());
1498 threads.emplace([perm, r = results.back().get()] { *r = Build(perm); });
1499 }
1500 while (!threads.empty()) {
1501 threads.front().join();
1502 threads.pop();
1503 }
1504 // Choose the variation that generates the least code, weighted towards header
1505 // length
1506 for (auto& r : results) {
1507 size_t l = 5 * r->header.length() + r->source.length();
1508 if (l < best_len) {
1509 best_len = l;
1510 best = std::move(*r);
1511 }
1512 }
1513 WriteFile("src/core/ext/transport/chttp2/transport/decode_huff.h",
1514 best.header);
1515 WriteFile("src/core/ext/transport/chttp2/transport/decode_huff.cc",
1516 best.source);
1517 return 0;
1518 }
1519