1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <atomic>
16 #include <cstdint>
17 #include <fstream>
18 #include <limits>
19 #include <map>
20 #include <memory>
21 #include <numeric>
22 #include <queue>
23 #include <set>
24 #include <string>
25 #include <thread>
26 #include <vector>
27 
28 #include <openssl/sha.h>
29 
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/optional.h"
34 #include "absl/types/variant.h"
35 
36 #include "src/core/ext/transport/chttp2/transport/huffsyms.h"
37 
38 ///////////////////////////////////////////////////////////////////////////////
39 // SHA256 hash handling
40 // We need strong uniqueness checks of some very long strings - so we hash
41 // them with SHA256 and compare.
42 struct Hash {
43   uint8_t bytes[SHA256_DIGEST_LENGTH];
operator ==Hash44   bool operator==(const Hash& other) const {
45     return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) == 0;
46   }
operator <Hash47   bool operator<(const Hash& other) const {
48     return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) < 0;
49   }
ToStringHash50   std::string ToString() const {
51     std::string result;
52     for (int i = 0; i < SHA256_DIGEST_LENGTH; i++) {
53       absl::StrAppend(&result, absl::Hex(bytes[i], absl::kZeroPad2));
54     }
55     return result;
56   }
57 };
58 
59 // Given a vector of ints (T), return a Hash object with the sha256
60 template <typename T>
HashVec(const std::vector<T> & v)61 Hash HashVec(const std::vector<T>& v) {
62   Hash h;
63   SHA256(reinterpret_cast<const uint8_t*>(v.data()), v.size() * sizeof(T),
64          h.bytes);
65   return h;
66 }
67 
68 ///////////////////////////////////////////////////////////////////////////////
69 // BitQueue
70 // A utility that treats a sequence of bits like a queue
71 class BitQueue {
72  public:
BitQueue(unsigned mask,int len)73   BitQueue(unsigned mask, int len) : mask_(mask), len_(len) {}
BitQueue()74   BitQueue() : BitQueue(0, 0) {}
75 
76   // Return the most significant bit (the front of the queue)
Front() const77   int Front() const { return (mask_ >> (len_ - 1)) & 1; }
78   // Pop one bit off the queue
Pop()79   void Pop() {
80     mask_ &= ~(1 << (len_ - 1));
81     len_--;
82   }
Empty() const83   bool Empty() const { return len_ == 0; }
length() const84   int length() const { return len_; }
mask() const85   unsigned mask() const { return mask_; }
86 
87   // Text representation of the queue
ToString() const88   std::string ToString() const {
89     return absl::StrCat(absl::Hex(mask_), "/", len_);
90   }
91 
92   // Comparisons so that we can use BitQueue as a key in a std::map
operator <(const BitQueue & other) const93   bool operator<(const BitQueue& other) const {
94     return std::tie(mask_, len_) < std::tie(other.mask_, other.len_);
95   }
96 
97  private:
98   // The bits
99   unsigned mask_;
100   // How many bits have we
101   int len_;
102 };
103 
104 ///////////////////////////////////////////////////////////////////////////////
105 // Symbol sets for the huffman tree
106 
107 // A Sym is one symbol in the tree, and the bits that we need to read to decode
108 // that symbol. As we progress through decoding we remove bits from the symbol,
109 // but also condense the number of symbols we're considering.
110 struct Sym {
111   BitQueue bits;
112   int symbol;
113 
operator <Sym114   bool operator<(const Sym& other) const {
115     return std::tie(bits, symbol) < std::tie(other.bits, other.symbol);
116   }
117 };
118 
119 // A SymSet is all the symbols we're considering at some time
120 using SymSet = std::vector<Sym>;
121 
122 // Debug utility to turn a SymSet into a string
SymSetString(const SymSet & syms)123 std::string SymSetString(const SymSet& syms) {
124   std::vector<std::string> parts;
125   for (const Sym& sym : syms) {
126     parts.push_back(absl::StrCat(sym.symbol, ":", sym.bits.ToString()));
127   }
128   return absl::StrJoin(parts, ",");
129 }
130 
131 // Initial SymSet - all the symbols [0..256] with their bits initialized from
132 // the http2 static huffman tree.
AllSyms()133 SymSet AllSyms() {
134   SymSet syms;
135   for (int i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) {
136     Sym sym;
137     sym.bits =
138         BitQueue(grpc_chttp2_huffsyms[i].bits, grpc_chttp2_huffsyms[i].length);
139     sym.symbol = i;
140     syms.push_back(sym);
141   }
142   return syms;
143 }
144 
145 // What whould we do after reading a set of bits?
146 struct ReadActions {
147   // Emit these symbols
148   std::vector<int> emit;
149   // Number of bits that were consumed by the read
150   int consumed;
151   // Remaining SymSet that we need to consider on the next read action
152   SymSet remaining;
153 };
154 
155 // Given a SymSet \a pending, read through the bits in \a index and determine
156 // what actions the decoder should take.
157 // allow_multiple controls the behavior should we get to the last bit in pending
158 // and hence know which symbol to emit, but we still have bits in index.
159 // We could either start decoding the next symbol (allow_multiple == true), or
160 // we could stop (allow_multiple == false).
161 // If allow_multiple is true we tend to emit more per read op, but generate
162 // bigger tables.
ActionsFor(BitQueue index,SymSet pending,bool allow_multiple)163 ReadActions ActionsFor(BitQueue index, SymSet pending, bool allow_multiple) {
164   std::vector<int> emit;
165   int len_start = index.length();
166   int len_consume = len_start;
167 
168   // We read one bit in index at a time, so whilst we have bits...
169   while (!index.Empty()) {
170     SymSet next_pending;
171     // For each symbol in the pending set
172     for (auto sym : pending) {
173       // If the first bit doesn't match, then that symbol is not part of our
174       // remaining set.
175       if (sym.bits.Front() != index.Front()) continue;
176       sym.bits.Pop();
177       next_pending.push_back(sym);
178     }
179     switch (next_pending.size()) {
180       case 0:
181         // There should be no bit patterns that are undecodable.
182         abort();
183       case 1:
184         // If we have one symbol left, we need to have decoded all of it.
185         if (!next_pending[0].bits.Empty()) abort();
186         // Emit that symbol
187         emit.push_back(next_pending[0].symbol);
188         // Track how many bits we've read.
189         len_consume = index.length() - 1;
190         // If we allow multiple, reprime pending and continue, otherwise stop.
191         if (!allow_multiple) goto done;
192         pending = AllSyms();
193         break;
194       default:
195         pending = std::move(next_pending);
196         break;
197     }
198     // Finished with this bit, continue with next
199     index.Pop();
200   }
201 done:
202   return ReadActions{std::move(emit), len_start - len_consume, pending};
203 }
204 
205 ///////////////////////////////////////////////////////////////////////////////
206 // MatchCase
207 // A variant that helps us bunch together related ReadActions
208 
209 // A Matched in a MatchCase indicates that we need to emit some number of
210 // symbols
211 struct Matched {
212   // number of symbols to emit
213   int emits;
214 
operator <Matched215   bool operator<(const Matched& other) const { return emits < other.emits; }
216 };
217 
218 // Unmatched says we didn't emit anything and we need to keep decoding
219 struct Unmatched {
220   SymSet syms;
221 
operator <Unmatched222   bool operator<(const Unmatched& other) const { return syms < other.syms; }
223 };
224 
225 // Emit end of stream
226 struct End {
operator <End227   bool operator<(End) const { return false; }
228 };
229 
230 using MatchCase = absl::variant<Matched, Unmatched, End>;
231 
232 ///////////////////////////////////////////////////////////////////////////////
233 // Text & numeric helper functions
234 
235 // Given a vector of lines, indent those lines by some number of indents
236 // (2 spaces) and return that.
IndentLines(std::vector<std::string> lines,int n=1)237 std::vector<std::string> IndentLines(std::vector<std::string> lines,
238                                      int n = 1) {
239   std::string indent(2 * n, ' ');
240   for (auto& line : lines) {
241     line = absl::StrCat(indent, line);
242   }
243   return lines;
244 }
245 
246 // Given a snake_case_name return a PascalCaseName
ToPascalCase(const std::string & in)247 std::string ToPascalCase(const std::string& in) {
248   std::string out;
249   bool next_upper = true;
250   for (char c : in) {
251     if (c == '_') {
252       next_upper = true;
253     } else {
254       if (next_upper) {
255         out.push_back(toupper(c));
256         next_upper = false;
257       } else {
258         out.push_back(c);
259       }
260     }
261   }
262   return out;
263 }
264 
265 // Return a uint type for some number of bits (16 -> uint16_t, 32 -> uint32_t)
Uint(int bits)266 std::string Uint(int bits) { return absl::StrCat("uint", bits, "_t"); }
267 
268 // Given a maximum value, how many bits to store it in a uint
TypeBitsForMax(int max)269 int TypeBitsForMax(int max) {
270   if (max <= 255) {
271     return 8;
272   } else if (max <= 65535) {
273     return 16;
274   } else {
275     return 32;
276   }
277 }
278 
279 // Combine Uint & TypeBitsForMax to make for more concise code
TypeForMax(int max)280 std::string TypeForMax(int max) { return Uint(TypeBitsForMax(max)); }
281 
282 // How many bits are needed to encode a value
BitsForMaxValue(int x)283 int BitsForMaxValue(int x) {
284   int n = 0;
285   while (x >= (1 << n)) n++;
286   return n;
287 }
288 
289 ///////////////////////////////////////////////////////////////////////////////
290 // Codegen framework
291 // Some helpers so we don't need to generate all the code linearly, which helps
292 // organize this a little more nicely.
293 
294 // An Item is our primitive for code generation, it can generate some lines
295 // that it would like to emit - those lines are fed to a parent item that might
296 // generate more lines or mutate the ones we return, and so on until codegen
297 // is complete.
298 class Item {
299  public:
300   virtual ~Item() = default;
301   virtual std::vector<std::string> ToLines() const = 0;
ToString() const302   std::string ToString() const {
303     return absl::StrCat(absl::StrJoin(ToLines(), "\n"), "\n");
304   }
305 };
306 using ItemPtr = std::unique_ptr<Item>;
307 
308 // An item that emits one line (the one given as an argument!)
309 class String : public Item {
310  public:
String(std::string s)311   explicit String(std::string s) : s_(std::move(s)) {}
ToLines() const312   std::vector<std::string> ToLines() const override { return {s_}; }
313 
314  private:
315   std::string s_;
316 };
317 
318 // An item that returns a fixed copyright notice and autogenerated note text.
319 class Prelude final : public Item {
320  public:
ToLines() const321   std::vector<std::string> ToLines() const {
322     return {
323         "// Copyright 2022 gRPC authors.",
324         "//",
325         "// Licensed under the Apache License, Version 2.0 (the "
326         "\"License\");",
327         "// you may not use this file except in compliance with the License.",
328         "// You may obtain a copy of the License at",
329         "//",
330         "//     http://www.apache.org/licenses/LICENSE-2.0",
331         "//",
332         "// Unless required by applicable law or agreed to in writing, "
333         "software",
334         "// distributed under the License is distributed on an \"AS IS\" "
335         "BASIS,",
336         "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or "
337         "implied.",
338         "// See the License for the specific language governing permissions "
339         "and",
340         "// limitations under the License.",
341         "",
342         std::string(80, '/'),
343         "// This file is autogenerated: see "
344         "tools/codegen/core/gen_huffman_decompressor.cc",
345         ""};
346   }
347 };
348 
349 class Switch;
350 
351 // A Sink is an Item that we can add more Items to.
352 // At codegen time it calls each of its children in turn and concatenates
353 // their results together.
354 class Sink : public Item {
355  public:
ToLines() const356   std::vector<std::string> ToLines() const override {
357     std::vector<std::string> lines;
358     for (const auto& item : children_) {
359       for (const auto& line : item->ToLines()) {
360         lines.push_back(line);
361       }
362     }
363     return lines;
364   }
365 
366   // Add one string to our output.
Add(std::string s)367   void Add(std::string s) {
368     children_.push_back(std::make_unique<String>(std::move(s)));
369   }
370 
371   // Add an item of type T to our output (constructing it with args).
372   template <typename T, typename... Args>
Add(Args &&...args)373   T* Add(Args&&... args) {
374     auto v = std::make_unique<T>(std::forward<Args>(args)...);
375     auto* r = v.get();
376     children_.push_back(std::move(v));
377     return r;
378   }
379 
380  private:
381   std::vector<ItemPtr> children_;
382 };
383 
384 // A sink that indents its lines by one indent (2 spaces)
385 class Indent : public Sink {
386  public:
ToLines() const387   std::vector<std::string> ToLines() const override {
388     return IndentLines(Sink::ToLines());
389   }
390 };
391 
392 // A Sink that wraps its lines in a while block
393 class While : public Sink {
394  public:
While(std::string cond)395   explicit While(std::string cond) : cond_(std::move(cond)) {}
ToLines() const396   std::vector<std::string> ToLines() const override {
397     std::vector<std::string> lines;
398     lines.push_back(absl::StrCat("while (", cond_, ") {"));
399     for (const auto& line : IndentLines(Sink::ToLines())) {
400       lines.push_back(line);
401     }
402     lines.push_back("}");
403     return lines;
404   }
405 
406  private:
407   std::string cond_;
408 };
409 
410 // A switch statement.
411 // Cases can be modified by calling the Case member.
412 // Identical cases are collapsed into 'case X: case Y:' type blocks.
413 class Switch : public Item {
414  public:
415   // \a cond is the condition to place at the head of the switch statement.
416   // eg. "switch (cond) {".
Switch(std::string cond)417   explicit Switch(std::string cond) : cond_(std::move(cond)) {}
ToLines() const418   std::vector<std::string> ToLines() const override {
419     std::map<std::string, std::vector<std::string>> reverse_map;
420     for (const auto& kv : cases_) {
421       reverse_map[kv.second.ToString()].push_back(kv.first);
422     }
423     std::vector<std::string> lines;
424     lines.push_back(absl::StrCat("switch (", cond_, ") {"));
425     for (const auto& kv : reverse_map) {
426       for (const auto& cond : kv.second) {
427         lines.push_back(absl::StrCat("  case ", cond, ":"));
428       }
429       lines.back().append(" {");
430       for (const auto& case_line :
431            IndentLines(cases_.find(kv.second[0])->second.ToLines(), 2)) {
432         lines.push_back(case_line);
433       }
434       lines.push_back("  }");
435     }
436     lines.push_back("}");
437     return lines;
438   }
439 
Case(std::string cond)440   Sink* Case(std::string cond) { return &cases_[cond]; }
441 
442  private:
443   std::string cond_;
444   std::map<std::string, Sink> cases_;
445 };
446 
447 ///////////////////////////////////////////////////////////////////////////////
448 // BuildCtx declaration
449 // Shared state for one code gen attempt
450 
451 class TableBuilder;
452 class FunMaker;
453 
454 class BuildCtx {
455  public:
BuildCtx(std::vector<int> max_bits_for_depth,Sink * global_fns,Sink * global_decls,Sink * global_values,FunMaker * fun_maker)456   BuildCtx(std::vector<int> max_bits_for_depth, Sink* global_fns,
457            Sink* global_decls, Sink* global_values, FunMaker* fun_maker)
458       : max_bits_for_depth_(std::move(max_bits_for_depth)),
459         global_fns_(global_fns),
460         global_decls_(global_decls),
461         global_values_(global_values),
462         fun_maker_(fun_maker) {}
463 
464   void AddStep(SymSet start_syms, int num_bits, bool is_top, bool refill,
465                int depth, Sink* out);
466   void AddMatchBody(TableBuilder* table_builder, std::string index,
467                     std::string ofs, const MatchCase& match_case, bool is_top,
468                     bool refill, int depth, Sink* out);
469   void AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
470                Sink* out);
471 
NewId()472   int NewId() { return next_id_++; }
MaxBitsForTop() const473   int MaxBitsForTop() const { return max_bits_for_depth_[0]; }
474 
PreviousNameForArtifact(std::string proposed_name,Hash hash)475   absl::optional<std::string> PreviousNameForArtifact(std::string proposed_name,
476                                                       Hash hash) {
477     auto it = arrays_.find(hash);
478     if (it == arrays_.end()) {
479       arrays_.emplace(hash, proposed_name);
480       return absl::nullopt;
481     }
482     return it->second;
483   }
484 
global_fns() const485   Sink* global_fns() const { return global_fns_; }
global_decls() const486   Sink* global_decls() const { return global_decls_; }
global_values() const487   Sink* global_values() const { return global_values_; }
488 
489  private:
490   const std::vector<int> max_bits_for_depth_;
491   std::map<Hash, std::string> arrays_;
492   int next_id_ = 1;
493   Sink* const global_fns_;
494   Sink* const global_decls_;
495   Sink* const global_values_;
496   FunMaker* const fun_maker_;
497 };
498 
499 ///////////////////////////////////////////////////////////////////////////////
500 // TableBuilder
501 // All our magic for building decode tables.
502 // We have three kinds of tables to generate:
503 // 1. op tables that translate a bit sequence to which decode case we should
504 //    execute (and arguments to it), and
505 // 2. emit tables that translate an index given by the op table and tell us
506 //    which symbols to emit
507 // Op table format
508 // Our opcodes contain an offset into an emit table, a number of bits consumed
509 // and an operation. The consumed bits are how many of the presented to us bits
510 // we actually took. The operation tells whether to emit some symbols (and how
511 // many) or to keep decoding.
512 // Optimization 1:
513 // op tables are essentially dense maps of bits -> opcode, and it turns out
514 // that *many* of the opcodes repeat across index bits for some of our tables
515 // so for those we split the table into two levels: first level indexes into
516 // a child table, and the child table contains the deduped opcodes.
517 // Optimization 2:
518 // Emit tables are a bit list of uint8_ts, and are indexed into by the op
519 // table (with an offset and length) - since many symbols get repeated, we try
520 // to overlay the symbols in the emit table to reduce the size.
521 // Optimization 3:
522 // We shard the table into some number of slices and use the top bits of the
523 // incoming lookup to select the shard. This tends to allow us to use smaller
524 // types to represent the table, saving on footprint.
525 
526 class TableBuilder {
527  public:
TableBuilder(BuildCtx * ctx)528   explicit TableBuilder(BuildCtx* ctx) : ctx_(ctx), id_(ctx->NewId()) {}
529 
530   // Append one case to the table
Add(int match_case,std::vector<uint8_t> emit,int consumed_bits)531   void Add(int match_case, std::vector<uint8_t> emit, int consumed_bits) {
532     elems_.push_back({match_case, std::move(emit), consumed_bits});
533     max_consumed_bits_ = std::max(max_consumed_bits_, consumed_bits);
534     max_match_case_ = std::max(max_match_case_, match_case);
535   }
536 
537   // Build the table
Build() const538   void Build() const {
539     Choose()->Build(this, BitsForMaxValue(elems_.size() - 1));
540   }
541 
542   // Generate a call to the accessor function for the emit table
EmitAccessor(std::string index,std::string offset)543   std::string EmitAccessor(std::string index, std::string offset) {
544     return absl::StrCat("GetEmit", id_, "(", index, ", ", offset, ")");
545   }
546 
547   // Generate a call to the accessor function for the op table
OpAccessor(std::string index)548   std::string OpAccessor(std::string index) {
549     return absl::StrCat("GetOp", id_, "(", index, ")");
550   }
551 
ConsumeBits() const552   int ConsumeBits() const { return BitsForMaxValue(max_consumed_bits_); }
MatchBits() const553   int MatchBits() const { return BitsForMaxValue(max_match_case_); }
554 
555  private:
556   // One element in the op table.
557   struct Elem {
558     int match_case;
559     std::vector<uint8_t> emit;
560     int consumed_bits;
561   };
562 
563   // A nested slice is one slice of a table using two level lookup
564   // - i.e. we look at an outer table to get an index into the inner table,
565   //   and then fetch the result from there.
566   struct NestedSlice {
567     std::vector<uint8_t> emit;
568     std::vector<uint64_t> inner;
569     std::vector<int> outer;
570 
571     // Various sizes return number of bits to be generated
572 
InnerSizeTableBuilder::NestedSlice573     size_t InnerSize() const {
574       return inner.size() *
575              TypeBitsForMax(*std::max_element(inner.begin(), inner.end()));
576     }
577 
OuterSizeTableBuilder::NestedSlice578     size_t OuterSize() const {
579       return outer.size() *
580              TypeBitsForMax(*std::max_element(outer.begin(), outer.end()));
581     }
582 
EmitSizeTableBuilder::NestedSlice583     size_t EmitSize() const { return emit.size() * 8; }
584   };
585 
586   // A slice is one part of a larger table.
587   struct Slice {
588     std::vector<uint8_t> emit;
589     std::vector<uint64_t> ops;
590 
591     // Various sizes return number of bits to be generated
592 
OpsSizeTableBuilder::Slice593     size_t OpsSize() const {
594       return ops.size() *
595              TypeBitsForMax(*std::max_element(ops.begin(), ops.end()));
596     }
597 
EmitSizeTableBuilder::Slice598     size_t EmitSize() const { return emit.size() * 8; }
599 
600     // Given a vector of symbols to emit, return the offset into the emit table
601     // that they're at (adding them to the emit table if necessary).
OffsetOfTableBuilder::Slice602     int OffsetOf(const std::vector<uint8_t>& x) {
603       if (x.empty()) return 0;
604       auto r = std::search(emit.begin(), emit.end(), x.begin(), x.end());
605       if (r == emit.end()) {
606         // look for a partial match @ end
607         for (size_t check_len = x.size() - 1; check_len > 0; check_len--) {
608           if (emit.size() < check_len) continue;
609           bool matches = true;
610           for (size_t i = 0; matches && i < check_len; i++) {
611             if (emit[emit.size() - check_len + i] != x[i]) matches = false;
612           }
613           if (matches) {
614             int offset = emit.size() - check_len;
615             for (size_t i = check_len; i < x.size(); i++) {
616               emit.push_back(x[i]);
617             }
618             return offset;
619           }
620         }
621         // add new
622         int result = emit.size();
623         for (auto v : x) emit.push_back(v);
624         return result;
625       }
626       return r - emit.begin();
627     }
628 
629     // Convert this slice to a nested slice.
MakeNestedSliceTableBuilder::Slice630     NestedSlice MakeNestedSlice() const {
631       NestedSlice result;
632       result.emit = emit;
633       std::map<uint64_t, int> op_to_inner;
634       for (auto v : ops) {
635         auto it = op_to_inner.find(v);
636         if (it == op_to_inner.end()) {
637           it = op_to_inner.emplace(v, op_to_inner.size()).first;
638           result.inner.push_back(v);
639         }
640         result.outer.push_back(it->second);
641       }
642       return result;
643     }
644   };
645 
646   // An EncodeOption is a potential way of encoding a table.
647   struct EncodeOption {
648     // Overall size (in bits) of the table encoding
649     virtual size_t Size() const = 0;
650     // Generate the code
651     virtual void Build(const TableBuilder* builder, int op_bits) const = 0;
~EncodeOptionTableBuilder::EncodeOption652     virtual ~EncodeOption() {}
653   };
654 
655   // NestedTable is a table that uses two level lookup for each slice
656   struct NestedTable : public EncodeOption {
657     std::vector<NestedSlice> slices;
658     int slice_bits;
SizeTableBuilder::NestedTable659     size_t Size() const override {
660       size_t sum = 0;
661       std::vector<Hash> h_emit;
662       std::vector<Hash> h_inner;
663       std::vector<Hash> h_outer;
664       for (size_t i = 0; i < slices.size(); i++) {
665         h_emit.push_back(HashVec(slices[i].emit));
666         h_inner.push_back(HashVec(slices[i].inner));
667         h_outer.push_back(HashVec(slices[i].outer));
668       }
669       std::set<Hash> seen;
670       for (size_t i = 0; i < slices.size(); i++) {
671         // Try to account for deduplication in the size calculation.
672         if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
673         if (seen.count(h_outer[i]) == 0) sum += slices[i].OuterSize();
674         if (seen.count(h_inner[i]) == 0) sum += slices[i].OuterSize();
675         seen.insert(h_emit[i]);
676         seen.insert(h_outer[i]);
677         seen.insert(h_inner[i]);
678       }
679       if (slice_bits != 0) sum += 3 * 64 * slices.size();
680       return sum;
681     }
BuildTableBuilder::NestedTable682     void Build(const TableBuilder* builder, int op_bits) const override {
683       Sink* const global_fns = builder->ctx_->global_fns();
684       Sink* const global_decls = builder->ctx_->global_decls();
685       Sink* const global_values = builder->ctx_->global_values();
686       const int id = builder->id_;
687       std::vector<std::string> lines;
688       const uint64_t max_inner = MaxInner();
689       const uint64_t max_outer = MaxOuter();
690       std::vector<std::unique_ptr<Array>> emit_names;
691       std::vector<std::unique_ptr<Array>> inner_names;
692       std::vector<std::unique_ptr<Array>> outer_names;
693       for (size_t i = 0; i < slices.size(); i++) {
694         emit_names.push_back(builder->GenArray(
695             slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
696             "uint8_t", slices[i].emit, true, global_decls, global_values));
697         inner_names.push_back(builder->GenArray(
698             slice_bits != 0, absl::StrCat("table", id, "_", i, "_inner"),
699             TypeForMax(max_inner), slices[i].inner, true, global_decls,
700             global_values));
701         outer_names.push_back(builder->GenArray(
702             slice_bits != 0, absl::StrCat("table", id, "_", i, "_outer"),
703             TypeForMax(max_outer), slices[i].outer, false, global_decls,
704             global_values));
705       }
706       if (slice_bits == 0) {
707         global_fns->Add(absl::StrCat(
708             "static inline uint64_t GetOp", id, "(size_t i) { return ",
709             inner_names[0]->Index(outer_names[0]->Index("i")), "; }"));
710         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
711                                      "(size_t, size_t emit) { return ",
712                                      emit_names[0]->Index("emit"), "; }"));
713       } else {
714         GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
715                     global_values);
716         GenCompound(id, inner_names, "inner", TypeForMax(max_inner),
717                     global_decls, global_values);
718         GenCompound(id, outer_names, "outer", TypeForMax(max_outer),
719                     global_decls, global_values);
720         global_fns->Add(absl::StrCat(
721             "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
722             "_inner_[i >> ", op_bits - slice_bits, "][table", id,
723             "_outer_[i >> ", op_bits - slice_bits, "][i & 0x",
724             absl::Hex((1 << (op_bits - slice_bits)) - 1), "]]; }"));
725         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
726                                      "(size_t i, size_t emit) { return table",
727                                      id, "_emit_[i >> ", op_bits - slice_bits,
728                                      "][emit]; }"));
729       }
730     }
MaxInnerTableBuilder::NestedTable731     uint64_t MaxInner() const {
732       uint64_t max_inner = 0;
733       for (size_t i = 0; i < slices.size(); i++) {
734         max_inner = std::max(
735             max_inner,
736             *std::max_element(slices[i].inner.begin(), slices[i].inner.end()));
737       }
738       return max_inner;
739     }
MaxOuterTableBuilder::NestedTable740     int MaxOuter() const {
741       int max_outer = 0;
742       for (size_t i = 0; i < slices.size(); i++) {
743         max_outer = std::max(
744             max_outer,
745             *std::max_element(slices[i].outer.begin(), slices[i].outer.end()));
746       }
747       return max_outer;
748     }
749   };
750 
751   // Encoding that uses single level lookup for each slice.
752   struct Table : public EncodeOption {
753     std::vector<Slice> slices;
754     int slice_bits;
SizeTableBuilder::Table755     size_t Size() const override {
756       size_t sum = 0;
757       std::vector<Hash> h_emit;
758       std::vector<Hash> h_ops;
759       for (size_t i = 0; i < slices.size(); i++) {
760         h_emit.push_back(HashVec(slices[i].emit));
761         h_ops.push_back(HashVec(slices[i].ops));
762       }
763       std::set<Hash> seen;
764       for (size_t i = 0; i < slices.size(); i++) {
765         if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
766         if (seen.count(h_ops[i]) == 0) sum += slices[i].OpsSize();
767         seen.insert(h_emit[i]);
768         seen.insert(h_ops[i]);
769       }
770       return sum + 3 * 64 * slices.size();
771     }
BuildTableBuilder::Table772     void Build(const TableBuilder* builder, int op_bits) const override {
773       Sink* const global_fns = builder->ctx_->global_fns();
774       Sink* const global_decls = builder->ctx_->global_decls();
775       Sink* const global_values = builder->ctx_->global_values();
776       uint64_t max_op = MaxOp();
777       const int id = builder->id_;
778       std::vector<std::unique_ptr<Array>> emit_names;
779       std::vector<std::unique_ptr<Array>> ops_names;
780       for (size_t i = 0; i < slices.size(); i++) {
781         emit_names.push_back(builder->GenArray(
782             slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
783             "uint8_t", slices[i].emit, true, global_decls, global_values));
784         ops_names.push_back(builder->GenArray(
785             slice_bits != 0, absl::StrCat("table", id, "_", i, "_ops"),
786             TypeForMax(max_op), slices[i].ops, true, global_decls,
787             global_values));
788       }
789       if (slice_bits == 0) {
790         global_fns->Add(absl::StrCat("static inline uint64_t GetOp", id,
791                                      "(size_t i) { return ",
792                                      ops_names[0]->Index("i"), "; }"));
793         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
794                                      "(size_t, size_t emit) { return ",
795                                      emit_names[0]->Index("emit"), "; }"));
796       } else {
797         GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
798                     global_values);
799         GenCompound(id, ops_names, "ops", TypeForMax(max_op), global_decls,
800                     global_values);
801         global_fns->Add(absl::StrCat(
802             "static inline uint64_t GetOp", id, "(size_t i) { return table", id,
803             "_ops_[i >> ", op_bits - slice_bits, "][i & 0x",
804             absl::Hex((1 << (op_bits - slice_bits)) - 1), "]; }"));
805         global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
806                                      "(size_t i, size_t emit) { return table",
807                                      id, "_emit_[i >> ", op_bits - slice_bits,
808                                      "][emit]; }"));
809       }
810     }
MaxOpTableBuilder::Table811     uint64_t MaxOp() const {
812       uint64_t max_op = 0;
813       for (size_t i = 0; i < slices.size(); i++) {
814         max_op = std::max(max_op, *std::max_element(slices[i].ops.begin(),
815                                                     slices[i].ops.end()));
816       }
817       return max_op;
818     }
819     // Convert to a two-level lookup
MakeNestedTableTableBuilder::Table820     std::unique_ptr<NestedTable> MakeNestedTable() {
821       std::unique_ptr<NestedTable> result(new NestedTable);
822       result->slice_bits = slice_bits;
823       for (const auto& slice : slices) {
824         result->slices.push_back(slice.MakeNestedSlice());
825       }
826       return result;
827     }
828   };
829 
830   // Given a number of slices (2**slice_bits), generate a table that uses a
831   // single level lookup for each slice based on our input.
MakeTable(size_t slice_bits) const832   std::unique_ptr<Table> MakeTable(size_t slice_bits) const {
833     std::unique_ptr<Table> table = std::make_unique<Table>();
834     int slices = 1 << slice_bits;
835     table->slices.resize(slices);
836     table->slice_bits = slice_bits;
837     const int pack_consume_bits = ConsumeBits();
838     const int pack_match_bits = MatchBits();
839     for (size_t i = 0; i < slices; i++) {
840       auto& slice = table->slices[i];
841       for (size_t j = 0; j < elems_.size() / slices; j++) {
842         const auto& elem = elems_[i * elems_.size() / slices + j];
843         slice.ops.push_back(elem.consumed_bits |
844                             (elem.match_case << pack_consume_bits) |
845                             (slice.OffsetOf(elem.emit)
846                              << (pack_consume_bits + pack_match_bits)));
847       }
848     }
849     return table;
850   }
851 
852   class Array {
853    public:
854     virtual ~Array() = default;
855     virtual std::string Index(absl::string_view value) = 0;
856     virtual std::string ArrayName() = 0;
857     virtual int Cost() = 0;
858   };
859 
860   class NamedArray : public Array {
861    public:
NamedArray(std::string name)862     explicit NamedArray(std::string name) : name_(std::move(name)) {}
Index(absl::string_view value)863     std::string Index(absl::string_view value) override {
864       return absl::StrCat(name_, "[", value, "]");
865     }
ArrayName()866     std::string ArrayName() override { return name_; }
Cost()867     int Cost() override { abort(); }
868 
869    private:
870     std::string name_;
871   };
872 
873   class IdentityArray : public Array {
874    public:
Index(absl::string_view value)875     std::string Index(absl::string_view value) override {
876       return std::string(value);
877     }
ArrayName()878     std::string ArrayName() override { abort(); }
Cost()879     int Cost() override { return 0; }
880   };
881 
882   class ConstantArray : public Array {
883    public:
ConstantArray(std::string value)884     explicit ConstantArray(std::string value) : value_(std::move(value)) {}
Index(absl::string_view index)885     std::string Index(absl::string_view index) override {
886       return absl::StrCat("((void)", index, ", ", value_, ")");
887     }
ArrayName()888     std::string ArrayName() override { abort(); }
Cost()889     int Cost() override { return 0; }
890 
891    private:
892     std::string value_;
893   };
894 
895   class OffsetArray : public Array {
896    public:
OffsetArray(int offset)897     explicit OffsetArray(int offset) : offset_(offset) {}
Index(absl::string_view value)898     std::string Index(absl::string_view value) override {
899       return absl::StrCat(value, " + ", offset_);
900     }
ArrayName()901     std::string ArrayName() override { abort(); }
Cost()902     int Cost() override { return 10; }
903 
904    private:
905     int offset_;
906   };
907 
908   class LinearDivideArray : public Array {
909    public:
LinearDivideArray(int offset,int divisor)910     LinearDivideArray(int offset, int divisor)
911         : offset_(offset), divisor_(divisor) {}
Index(absl::string_view value)912     std::string Index(absl::string_view value) override {
913       return absl::StrCat(value, "/", divisor_, " + ", offset_);
914     }
ArrayName()915     std::string ArrayName() override { abort(); }
Cost()916     int Cost() override { return 20 + (offset_ != 0 ? 10 : 0); }
917 
918    private:
919     int offset_;
920     int divisor_;
921   };
922 
923   class TwoElemArray : public Array {
924    public:
TwoElemArray(std::string value0,std::string value1)925     TwoElemArray(std::string value0, std::string value1)
926         : value0_(std::move(value0)), value1_(std::move(value1)) {}
Index(absl::string_view value)927     std::string Index(absl::string_view value) override {
928       return absl::StrCat(value, " ? ", value1_, " : ", value0_);
929     }
ArrayName()930     std::string ArrayName() override { abort(); }
Cost()931     int Cost() override { return 40; }
932 
933    private:
934     std::string value0_;
935     std::string value1_;
936   };
937 
938   class Composite2Array : public Array {
939    public:
Composite2Array(std::unique_ptr<Array> a,std::unique_ptr<Array> b,int split)940     Composite2Array(std::unique_ptr<Array> a, std::unique_ptr<Array> b,
941                     int split)
942         : a_(std::move(a)), b_(std::move(b)), split_(split) {}
Index(absl::string_view value)943     std::string Index(absl::string_view value) override {
944       return absl::StrCat(
945           "(", value, " < ", split_, " ? (", a_->Index(value), ") : (",
946           b_->Index(absl::StrCat("(", value, "-", split_, ")")), "))");
947     }
ArrayName()948     std::string ArrayName() override { abort(); }
Cost()949     int Cost() override { return 40 + a_->Cost() + b_->Cost(); }
950 
951    private:
952     std::unique_ptr<Array> a_;
953     std::unique_ptr<Array> b_;
954     int split_;
955   };
956 
957   // Helper to generate a compound table (an array of arrays)
GenCompound(int id,const std::vector<std::unique_ptr<Array>> & arrays,std::string ext,std::string type,Sink * global_decls,Sink * global_values)958   static void GenCompound(int id,
959                           const std::vector<std::unique_ptr<Array>>& arrays,
960                           std::string ext, std::string type, Sink* global_decls,
961                           Sink* global_values) {
962     global_decls->Add(absl::StrCat("static const ", type, "* const table", id,
963                                    "_", ext, "_[", arrays.size(), "];"));
964     global_values->Add(absl::StrCat("const ", type,
965                                     "* const HuffDecoderCommon::table", id, "_",
966                                     ext, "_[", arrays.size(), "] = {"));
967     for (const std::unique_ptr<Array>& array : arrays) {
968       global_values->Add(absl::StrCat("  ", array->ArrayName(), ","));
969     }
970     global_values->Add("};");
971   }
972 
973   // Try to create a simple function equivalent to a mapping implied by a set of
974   // values.
975   static const int kMaxArrayToFunctionRecursions = 1;
976   template <typename T>
ArrayToFunction(const std::vector<T> & values,int recurse=kMaxArrayToFunctionRecursions)977   static std::unique_ptr<Array> ArrayToFunction(
978       const std::vector<T>& values,
979       int recurse = kMaxArrayToFunctionRecursions) {
980     std::unique_ptr<Array> best = nullptr;
981     auto note_solution = [&best](std::unique_ptr<Array> a) {
982       if (best != nullptr && best->Cost() <= a->Cost()) return;
983       best = std::move(a);
984     };
985     // constant => k,k,k,k,...
986     bool is_constant = true;
987     for (size_t i = 1; i < values.size(); i++) {
988       if (values[i] != values[0]) {
989         is_constant = false;
990         break;
991       }
992     }
993     if (is_constant) {
994       note_solution(std::make_unique<ConstantArray>(absl::StrCat(values[0])));
995     }
996     // identity => 0,1,2,3,...
997     bool is_identity = true;
998     for (size_t i = 0; i < values.size(); i++) {
999       if (values[i] != i) {
1000         is_identity = false;
1001         break;
1002       }
1003     }
1004     if (is_identity) {
1005       note_solution(std::make_unique<IdentityArray>());
1006     }
1007     // offset => k,k+1,k+2,k+3,...
1008     bool is_offset = true;
1009     for (size_t i = 1; i < values.size(); i++) {
1010       if (values[i] - values[0] != i) {
1011         is_offset = false;
1012         break;
1013       }
1014     }
1015     if (is_offset) {
1016       note_solution(std::make_unique<OffsetArray>(values[0]));
1017     }
1018     // offset => k,k,k+1,k+1,...
1019     for (int d = 2; d < 32; d++) {
1020       bool is_linear = true;
1021       for (size_t i = 1; i < values.size(); i++) {
1022         if (values[i] - values[0] != (i / d)) {
1023           is_linear = false;
1024           break;
1025         }
1026       }
1027       if (is_linear) {
1028         note_solution(std::make_unique<LinearDivideArray>(values[0], d));
1029       }
1030     }
1031     // Two items can be resolved with a conditional
1032     if (values.size() == 2) {
1033       note_solution(std::make_unique<TwoElemArray>(absl::StrCat(values[0]),
1034                                                    absl::StrCat(values[1])));
1035     }
1036     if ((recurse > 0 && values.size() >= 6) ||
1037         (recurse == kMaxArrayToFunctionRecursions)) {
1038       for (size_t i = 1; i < values.size() - 1; i++) {
1039         std::vector<T> left(values.begin(), values.begin() + i);
1040         std::vector<T> right(values.begin() + i, values.end());
1041         std::unique_ptr<Array> left_array = ArrayToFunction(left, recurse - 1);
1042         std::unique_ptr<Array> right_array =
1043             ArrayToFunction(right, recurse - 1);
1044         if (left_array && right_array) {
1045           note_solution(std::make_unique<Composite2Array>(
1046               std::move(left_array), std::move(right_array), i));
1047         }
1048       }
1049     }
1050     return best;
1051   }
1052 
1053   // Helper to generate an array of values
1054   template <typename T>
GenArray(bool force_array,std::string name,std::string type,const std::vector<T> & values,bool hex,Sink * global_decls,Sink * global_values) const1055   std::unique_ptr<Array> GenArray(bool force_array, std::string name,
1056                                   std::string type,
1057                                   const std::vector<T>& values, bool hex,
1058                                   Sink* global_decls,
1059                                   Sink* global_values) const {
1060     if (!force_array) {
1061       auto fn = ArrayToFunction(values);
1062       if (fn != nullptr) return fn;
1063     }
1064     auto previous_name = ctx_->PreviousNameForArtifact(name, HashVec(values));
1065     if (previous_name.has_value()) {
1066       return std::make_unique<NamedArray>(absl::StrCat(*previous_name, "_"));
1067     }
1068     std::vector<std::string> elems;
1069     elems.reserve(values.size());
1070     for (const auto& elem : values) {
1071       if (hex) {
1072         if (type == "uint8_t") {
1073           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad2)));
1074         } else if (type == "uint16_t") {
1075           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad4)));
1076         } else {
1077           elems.push_back(absl::StrCat("0x", absl::Hex(elem, absl::kZeroPad8)));
1078         }
1079       } else {
1080         elems.push_back(absl::StrCat(elem));
1081       }
1082     }
1083     std::string data = absl::StrJoin(elems, ", ");
1084     global_decls->Add(absl::StrCat("static const ", type, " ", name, "_[",
1085                                    values.size(), "];"));
1086     global_values->Add(absl::StrCat("const ", type, " HuffDecoderCommon::",
1087                                     name, "_[", values.size(), "] = {"));
1088     global_values->Add(absl::StrCat("  ", data));
1089     global_values->Add("};");
1090     return std::make_unique<NamedArray>(absl::StrCat(name, "_"));
1091   }
1092 
1093   // Choose an encoding for this set of tables.
1094   // We try all available values for slice count and choose the one that gives
1095   // the smallest footprint.
Choose() const1096   std::unique_ptr<EncodeOption> Choose() const {
1097     std::unique_ptr<EncodeOption> chosen;
1098     size_t best_size = std::numeric_limits<size_t>::max();
1099     for (size_t slice_bits = 0; (1 << slice_bits) < elems_.size();
1100          slice_bits++) {
1101       auto raw = MakeTable(slice_bits);
1102       size_t raw_size = raw->Size();
1103       auto nested = raw->MakeNestedTable();
1104       size_t nested_size = nested->Size();
1105       if (raw_size < best_size) {
1106         chosen = std::move(raw);
1107         best_size = raw_size;
1108       }
1109       if (nested_size < best_size) {
1110         chosen = std::move(nested);
1111         best_size = nested_size;
1112       }
1113     }
1114     return chosen;
1115   }
1116 
1117   BuildCtx* const ctx_;
1118   std::vector<Elem> elems_;
1119   int max_consumed_bits_ = 0;
1120   int max_match_case_ = 0;
1121   const int id_;
1122 };
1123 
1124 ///////////////////////////////////////////////////////////////////////////////
1125 // FunMaker
1126 // Handles generating the code for various functions.
1127 
1128 class FunMaker {
1129  public:
FunMaker(Sink * sink)1130   explicit FunMaker(Sink* sink) : sink_(sink) {}
1131 
1132   // Generate a refill function - that ensures the incoming bitmask has enough
1133   // bits for the next step.
RefillTo(int n)1134   std::string RefillTo(int n) {
1135     if (have_refills_.count(n) == 0) {
1136       have_refills_.insert(n);
1137       auto fn = NewFun(absl::StrCat("RefillTo", n), "bool");
1138       auto s = fn->Add<Switch>("buffer_len_");
1139       for (int i = 0; i < n; i++) {
1140         auto c = s->Case(absl::StrCat(i));
1141         const int bytes_needed = (n - i + 7) / 8;
1142         c->Add(absl::StrCat("return ", ReadBytes(bytes_needed), ";"));
1143       }
1144       fn->Add("return true;");
1145     }
1146     return absl::StrCat("RefillTo", n, "()");
1147   }
1148 
1149   // At callsite, generate a call to a new function with base name
1150   // base_name (new functions get a suffix of how many instances of base_name
1151   // there have been).
1152   // Return a sink to fill in the body of the new function.
CallNewFun(std::string base_name,Sink * callsite)1153   Sink* CallNewFun(std::string base_name, Sink* callsite) {
1154     std::string name = absl::StrCat(base_name, have_funs_[base_name]++);
1155     callsite->Add(absl::StrCat(name, "();"));
1156     return NewFun(name, "void");
1157   }
1158 
1159  private:
NewFun(std::string name,std::string returns)1160   Sink* NewFun(std::string name, std::string returns) {
1161     sink_->Add(absl::StrCat(returns, " ", name, "() {"));
1162     auto fn = sink_->Add<Indent>();
1163     sink_->Add("}");
1164     return fn;
1165   }
1166 
1167   // Bring in some number of bytes from the input stream to our current read
1168   // bits.
ReadBytes(int bytes_needed)1169   std::string ReadBytes(int bytes_needed) {
1170     if (have_reads_.count(bytes_needed) == 0) {
1171       have_reads_.insert(bytes_needed);
1172       auto fn = NewFun(absl::StrCat("Read", bytes_needed), "bool");
1173       fn->Add(absl::StrCat("if (end_ - begin_ < ", bytes_needed,
1174                            ") return false;"));
1175       fn->Add(absl::StrCat("buffer_ <<= ", 8 * bytes_needed, ";"));
1176       for (int i = 0; i < bytes_needed; i++) {
1177         fn->Add(absl::StrCat("buffer_ |= static_cast<uint64_t>(*begin_++) << ",
1178                              8 * (bytes_needed - i - 1), ";"));
1179       }
1180       fn->Add(absl::StrCat("buffer_len_ += ", 8 * bytes_needed, ";"));
1181       fn->Add("return true;");
1182     }
1183     return absl::StrCat("Read", bytes_needed, "()");
1184   }
1185 
1186   std::set<int> have_refills_;
1187   std::set<int> have_reads_;
1188   std::map<std::string, int> have_funs_;
1189   Sink* sink_;
1190 };
1191 
1192 ///////////////////////////////////////////////////////////////////////////////
1193 // BuildCtx implementation
1194 
AddDone(SymSet start_syms,int num_bits,bool all_ones_so_far,Sink * out)1195 void BuildCtx::AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
1196                        Sink* out) {
1197   out->Add("done_ = true;");
1198   if (num_bits == 1) {
1199     if (!all_ones_so_far) out->Add("ok_ = false;");
1200     return;
1201   }
1202   // we must have 0 < buffer_len_ < num_bits
1203   auto s = out->Add<Switch>("buffer_len_");
1204   auto c0 = s->Case("0");
1205   if (!all_ones_so_far) c0->Add("ok_ = false;");
1206   c0->Add("return;");
1207   for (int i = 1; i < num_bits; i++) {
1208     auto c = s->Case(absl::StrCat(i));
1209     SymSet maybe;
1210     for (auto sym : start_syms) {
1211       if (sym.bits.length() > i) continue;
1212       maybe.push_back(sym);
1213     }
1214     if (maybe.empty()) {
1215       if (all_ones_so_far) {
1216         c->Add("ok_ = (buffer_ & ((1<<buffer_len_)-1)) == (1<<buffer_len_)-1;");
1217       } else {
1218         c->Add("ok_ = false;");
1219       }
1220       c->Add("return;");
1221       continue;
1222     }
1223     TableBuilder table_builder(this);
1224     enum Cases {
1225       kNoEmitOk,
1226       kFail,
1227       kEmitOk,
1228     };
1229     for (size_t n = 0; n < (1 << i); n++) {
1230       if (all_ones_so_far && n == (1 << i) - 1) {
1231         table_builder.Add(kNoEmitOk, {}, 0);
1232         goto next;
1233       }
1234       for (auto sym : maybe) {
1235         if ((n >> (i - sym.bits.length())) == sym.bits.mask()) {
1236           for (int j = 0; j < (i - sym.bits.length()); j++) {
1237             if ((n & (1 << j)) == 0) {
1238               table_builder.Add(kFail, {}, 0);
1239               goto next;
1240             }
1241           }
1242           table_builder.Add(kEmitOk, {static_cast<uint8_t>(sym.symbol)}, 0);
1243           goto next;
1244         }
1245       }
1246       table_builder.Add(kFail, {}, 0);
1247     next:;
1248     }
1249     table_builder.Build();
1250     c->Add(absl::StrCat("const auto index = buffer_ & ", (1 << i) - 1, ";"));
1251     c->Add(absl::StrCat("const auto op = ", table_builder.OpAccessor("index"),
1252                         ";"));
1253     if (table_builder.ConsumeBits() != 0) {
1254       fprintf(stderr, "consume bits = %d\n", table_builder.ConsumeBits());
1255       abort();
1256     }
1257     auto s_fin = c->Add<Switch>(
1258         absl::StrCat("op & ", (1 << table_builder.MatchBits()) - 1));
1259     auto emit_ok = s_fin->Case(absl::StrCat(kEmitOk));
1260     emit_ok->Add(absl::StrCat(
1261         "sink_(",
1262         table_builder.EmitAccessor(
1263             "index", absl::StrCat("op >>", table_builder.MatchBits())),
1264         ");"));
1265     emit_ok->Add("break;");
1266     auto fail = s_fin->Case(absl::StrCat(kFail));
1267     fail->Add("ok_ = false;");
1268     fail->Add("break;");
1269     c->Add("return;");
1270   }
1271 }
1272 
AddStep(SymSet start_syms,int num_bits,bool is_top,bool refill,int depth,Sink * out)1273 void BuildCtx::AddStep(SymSet start_syms, int num_bits, bool is_top,
1274                        bool refill, int depth, Sink* out) {
1275   TableBuilder table_builder(this);
1276   if (refill) {
1277     out->Add(absl::StrCat("if (!", fun_maker_->RefillTo(num_bits), ") {"));
1278     auto ifblk = out->Add<Indent>();
1279     if (!is_top) {
1280       Sym some = start_syms[0];
1281       auto sym = grpc_chttp2_huffsyms[some.symbol];
1282       int consumed_len = (sym.length - some.bits.length());
1283       uint32_t consumed_mask = sym.bits >> some.bits.length();
1284       bool all_ones_so_far = consumed_mask == ((1 << consumed_len) - 1);
1285       AddDone(start_syms, num_bits, all_ones_so_far,
1286               fun_maker_->CallNewFun("Done", ifblk));
1287       ifblk->Add("return;");
1288     } else {
1289       AddDone(start_syms, num_bits, true,
1290               fun_maker_->CallNewFun("Done", ifblk));
1291       ifblk->Add("break;");
1292     }
1293     out->Add("}");
1294   }
1295   out->Add(absl::StrCat("const auto index = (buffer_ >> (buffer_len_ - ",
1296                         num_bits, ")) & 0x", absl::Hex((1 << num_bits) - 1),
1297                         ";"));
1298   std::map<MatchCase, int> match_cases;
1299   for (int i = 0; i < (1 << num_bits); i++) {
1300     auto actions = ActionsFor(BitQueue(i, num_bits), start_syms, is_top);
1301     auto add_case = [&match_cases](MatchCase match_case) {
1302       if (match_cases.find(match_case) == match_cases.end()) {
1303         match_cases[match_case] = match_cases.size();
1304       }
1305       return match_cases[match_case];
1306     };
1307     if (actions.emit.size() == 1 && actions.emit[0] == 256) {
1308       table_builder.Add(add_case(End{}), {}, actions.consumed);
1309     } else if (actions.consumed == 0) {
1310       table_builder.Add(add_case(Unmatched{std::move(actions.remaining)}), {},
1311                         num_bits);
1312     } else {
1313       std::vector<uint8_t> emit;
1314       for (auto sym : actions.emit) emit.push_back(sym);
1315       table_builder.Add(
1316           add_case(Matched{static_cast<int>(actions.emit.size())}),
1317           std::move(emit), actions.consumed);
1318     }
1319   }
1320   table_builder.Build();
1321   out->Add(
1322       absl::StrCat("const auto op = ", table_builder.OpAccessor("index"), ";"));
1323   out->Add(absl::StrCat("const int consumed = op & ",
1324                         (1 << table_builder.ConsumeBits()) - 1, ";"));
1325   out->Add("buffer_len_ -= consumed;");
1326   out->Add(absl::StrCat("const auto emit_ofs = op >> ",
1327                         table_builder.ConsumeBits() + table_builder.MatchBits(),
1328                         ";"));
1329   if (match_cases.size() == 1) {
1330     AddMatchBody(&table_builder, "index", "emit_ofs",
1331                  match_cases.begin()->first, is_top, refill, depth, out);
1332   } else {
1333     auto s = out->Add<Switch>(
1334         absl::StrCat("(op >> ", table_builder.ConsumeBits(), ") & ",
1335                      (1 << table_builder.MatchBits()) - 1));
1336     for (auto kv : match_cases) {
1337       auto c = s->Case(absl::StrCat(kv.second));
1338       AddMatchBody(&table_builder, "index", "emit_ofs", kv.first, is_top,
1339                    refill, depth, c);
1340       c->Add("break;");
1341     }
1342   }
1343 }
1344 
AddMatchBody(TableBuilder * table_builder,std::string index,std::string ofs,const MatchCase & match_case,bool is_top,bool refill,int depth,Sink * out)1345 void BuildCtx::AddMatchBody(TableBuilder* table_builder, std::string index,
1346                             std::string ofs, const MatchCase& match_case,
1347                             bool is_top, bool refill, int depth, Sink* out) {
1348   if (absl::holds_alternative<End>(match_case)) {
1349     out->Add("begin_ = end_;");
1350     out->Add("buffer_len_ = 0;");
1351     return;
1352   }
1353   if (auto* p = absl::get_if<Unmatched>(&match_case)) {
1354     if (refill) {
1355       int max_bits = 0;
1356       for (auto sym : p->syms) max_bits = std::max(max_bits, sym.bits.length());
1357       AddStep(p->syms,
1358               depth + 1 >= max_bits_for_depth_.size()
1359                   ? max_bits
1360                   : std::min(max_bits, max_bits_for_depth_[depth + 1]),
1361               false, true, depth + 1,
1362               fun_maker_->CallNewFun("DecodeStep", out));
1363     }
1364     return;
1365   }
1366   const auto& matched = absl::get<Matched>(match_case);
1367   for (int i = 0; i < matched.emits; i++) {
1368     out->Add(absl::StrCat(
1369         "sink_(",
1370         table_builder->EmitAccessor(index, absl::StrCat(ofs, " + ", i)), ");"));
1371   }
1372 }
1373 
1374 ///////////////////////////////////////////////////////////////////////////////
1375 // Driver code
1376 
1377 // Generated header and source code
1378 struct BuildOutput {
1379   std::string header;
1380   std::string source;
1381 };
1382 
1383 // Given max_bits_for_depth = {n1,n2,n3,...}
1384 // Build a decoder that first considers n1 bits, then n2, then n3, ...
Build(std::vector<int> max_bits_for_depth)1385 BuildOutput Build(std::vector<int> max_bits_for_depth) {
1386   auto hdr = std::make_unique<Sink>();
1387   auto src = std::make_unique<Sink>();
1388   hdr->Add<Prelude>();
1389   src->Add<Prelude>();
1390   hdr->Add("#ifndef GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1391   hdr->Add("#define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1392   src->Add(
1393       "#include \"src/core/ext/transport/chttp2/transport/decode_huff.h\"");
1394   hdr->Add("#include <cstddef>");
1395   hdr->Add("#include <grpc/support/port_platform.h>");
1396   src->Add("#include <grpc/support/port_platform.h>");
1397   hdr->Add("#include <cstdint>");
1398   hdr->Add(
1399       absl::StrCat("// GEOMETRY: ", absl::StrJoin(max_bits_for_depth, ",")));
1400   hdr->Add("namespace grpc_core {");
1401   src->Add("namespace grpc_core {");
1402   hdr->Add("class HuffDecoderCommon {");
1403   hdr->Add(" protected:");
1404   auto global_fns = hdr->Add<Indent>();
1405   hdr->Add(" private:");
1406   auto global_decls = hdr->Add<Indent>();
1407   hdr->Add("};");
1408   hdr->Add(
1409       "template<typename F> class HuffDecoder : public HuffDecoderCommon {");
1410   hdr->Add(" public:");
1411   auto pub = hdr->Add<Indent>();
1412   hdr->Add(" private:");
1413   auto prv = hdr->Add<Indent>();
1414   FunMaker fun_maker(prv->Add<Sink>());
1415   hdr->Add("};");
1416   hdr->Add("}  // namespace grpc_core");
1417   hdr->Add("#endif  // GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_DECODE_HUFF_H");
1418   auto global_values = src->Add<Indent>();
1419   src->Add("}  // namespace grpc_core");
1420   BuildCtx ctx(std::move(max_bits_for_depth), global_fns, global_decls,
1421                global_values, &fun_maker);
1422   // constructor
1423   pub->Add(
1424       "HuffDecoder(F sink, const uint8_t* begin, const uint8_t* end) : "
1425       "sink_(sink), begin_(begin), end_(end) {}");
1426   // members
1427   prv->Add("F sink_;");
1428   prv->Add("const uint8_t* begin_;");
1429   prv->Add("const uint8_t* const end_;");
1430   prv->Add("uint64_t buffer_ = 0;");
1431   prv->Add("int buffer_len_ = 0;");
1432   prv->Add("bool ok_ = true;");
1433   prv->Add("bool done_ = false;");
1434   // main fn
1435   pub->Add("bool Run() {");
1436   auto body = pub->Add<Indent>();
1437   body->Add("while (!done_) {");
1438   ctx.AddStep(AllSyms(), ctx.MaxBitsForTop(), true, true, 0,
1439               body->Add<Indent>());
1440   body->Add("}");
1441   body->Add("return ok_;");
1442   pub->Add("}");
1443   return {hdr->ToString(), src->ToString()};
1444 }
1445 
1446 // Generate all permutations of max_bits_for_depth for the Build function,
1447 // with a minimum step size of 5 bits (needed for http2 I think) and a
1448 // configurable maximum step size.
1449 class PermutationBuilder {
1450  public:
PermutationBuilder(int max_depth)1451   explicit PermutationBuilder(int max_depth) : max_depth_(max_depth) {}
Run()1452   std::vector<std::vector<int>> Run() {
1453     Step({});
1454     return std::move(perms_);
1455   }
1456 
1457  private:
Step(std::vector<int> so_far)1458   void Step(std::vector<int> so_far) {
1459     int sum_so_far = std::accumulate(so_far.begin(), so_far.end(), 0);
1460     if (so_far.size() > max_depth_ ||
1461         (so_far.size() == max_depth_ && sum_so_far != 30)) {
1462       return;
1463     }
1464     if (sum_so_far + 5 > 30) {
1465       perms_.emplace_back(std::move(so_far));
1466       return;
1467     }
1468     for (int i = 5; i <= std::min(30 - sum_so_far, 16); i++) {
1469       auto p = so_far;
1470       p.push_back(i);
1471       Step(std::move(p));
1472     }
1473   }
1474 
1475   const int max_depth_;
1476   std::vector<std::vector<int>> perms_;
1477 };
1478 
1479 // Does what it says.
WriteFile(std::string filename,std::string content)1480 void WriteFile(std::string filename, std::string content) {
1481   std::ofstream ofs(filename);
1482   ofs << content;
1483 }
1484 
main(void)1485 int main(void) {
1486   BuildOutput best;
1487   size_t best_len = std::numeric_limits<size_t>::max();
1488   std::vector<std::unique_ptr<BuildOutput>> results;
1489   std::queue<std::thread> threads;
1490   // Generate all permutations of max_bits_for_depth for the Build function.
1491   // Then generate all variations of the code.
1492   for (const auto& perm : PermutationBuilder(30).Run()) {
1493     while (threads.size() > 200) {
1494       threads.front().join();
1495       threads.pop();
1496     }
1497     results.emplace_back(std::make_unique<BuildOutput>());
1498     threads.emplace([perm, r = results.back().get()] { *r = Build(perm); });
1499   }
1500   while (!threads.empty()) {
1501     threads.front().join();
1502     threads.pop();
1503   }
1504   // Choose the variation that generates the least code, weighted towards header
1505   // length
1506   for (auto& r : results) {
1507     size_t l = 5 * r->header.length() + r->source.length();
1508     if (l < best_len) {
1509       best_len = l;
1510       best = std::move(*r);
1511     }
1512   }
1513   WriteFile("src/core/ext/transport/chttp2/transport/decode_huff.h",
1514             best.header);
1515   WriteFile("src/core/ext/transport/chttp2/transport/decode_huff.cc",
1516             best.source);
1517   return 0;
1518 }
1519