1 // Copyright 2007 The RE2 Authors.  All Rights Reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 
5 // Compiled regular expression representation.
6 // Tested by compile_test.cc
7 
8 #include "re2/prog.h"
9 
10 #if defined(__AVX2__)
11 #include <immintrin.h>
12 #ifdef _MSC_VER
13 #include <intrin.h>
14 #endif
15 #endif
16 #include <stdint.h>
17 #include <string.h>
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 
22 #include "util/util.h"
23 #include "util/logging.h"
24 #include "util/strutil.h"
25 #include "re2/bitmap256.h"
26 #include "re2/stringpiece.h"
27 
28 namespace re2 {
29 
30 // Constructors per Inst opcode
31 
InitAlt(uint32_t out,uint32_t out1)32 void Prog::Inst::InitAlt(uint32_t out, uint32_t out1) {
33   DCHECK_EQ(out_opcode_, 0);
34   set_out_opcode(out, kInstAlt);
35   out1_ = out1;
36 }
37 
InitByteRange(int lo,int hi,int foldcase,uint32_t out)38 void Prog::Inst::InitByteRange(int lo, int hi, int foldcase, uint32_t out) {
39   DCHECK_EQ(out_opcode_, 0);
40   set_out_opcode(out, kInstByteRange);
41   lo_ = lo & 0xFF;
42   hi_ = hi & 0xFF;
43   hint_foldcase_ = foldcase&1;
44 }
45 
InitCapture(int cap,uint32_t out)46 void Prog::Inst::InitCapture(int cap, uint32_t out) {
47   DCHECK_EQ(out_opcode_, 0);
48   set_out_opcode(out, kInstCapture);
49   cap_ = cap;
50 }
51 
InitEmptyWidth(EmptyOp empty,uint32_t out)52 void Prog::Inst::InitEmptyWidth(EmptyOp empty, uint32_t out) {
53   DCHECK_EQ(out_opcode_, 0);
54   set_out_opcode(out, kInstEmptyWidth);
55   empty_ = empty;
56 }
57 
InitMatch(int32_t id)58 void Prog::Inst::InitMatch(int32_t id) {
59   DCHECK_EQ(out_opcode_, 0);
60   set_opcode(kInstMatch);
61   match_id_ = id;
62 }
63 
InitNop(uint32_t out)64 void Prog::Inst::InitNop(uint32_t out) {
65   DCHECK_EQ(out_opcode_, 0);
66   set_opcode(kInstNop);
67 }
68 
InitFail()69 void Prog::Inst::InitFail() {
70   DCHECK_EQ(out_opcode_, 0);
71   set_opcode(kInstFail);
72 }
73 
Dump()74 std::string Prog::Inst::Dump() {
75   switch (opcode()) {
76     default:
77       return StringPrintf("opcode %d", static_cast<int>(opcode()));
78 
79     case kInstAlt:
80       return StringPrintf("alt -> %d | %d", out(), out1_);
81 
82     case kInstAltMatch:
83       return StringPrintf("altmatch -> %d | %d", out(), out1_);
84 
85     case kInstByteRange:
86       return StringPrintf("byte%s [%02x-%02x] %d -> %d",
87                           foldcase() ? "/i" : "",
88                           lo_, hi_, hint(), out());
89 
90     case kInstCapture:
91       return StringPrintf("capture %d -> %d", cap_, out());
92 
93     case kInstEmptyWidth:
94       return StringPrintf("emptywidth %#x -> %d",
95                           static_cast<int>(empty_), out());
96 
97     case kInstMatch:
98       return StringPrintf("match! %d", match_id());
99 
100     case kInstNop:
101       return StringPrintf("nop -> %d", out());
102 
103     case kInstFail:
104       return StringPrintf("fail");
105   }
106 }
107 
Prog()108 Prog::Prog()
109   : anchor_start_(false),
110     anchor_end_(false),
111     reversed_(false),
112     did_flatten_(false),
113     did_onepass_(false),
114     start_(0),
115     start_unanchored_(0),
116     size_(0),
117     bytemap_range_(0),
118     prefix_foldcase_(false),
119     prefix_size_(0),
120     list_count_(0),
121     bit_state_text_max_size_(0),
122     dfa_mem_(0),
123     dfa_first_(NULL),
124     dfa_longest_(NULL) {
125 }
126 
~Prog()127 Prog::~Prog() {
128   DeleteDFA(dfa_longest_);
129   DeleteDFA(dfa_first_);
130   if (prefix_foldcase_)
131     delete[] prefix_dfa_;
132 }
133 
134 typedef SparseSet Workq;
135 
AddToQueue(Workq * q,int id)136 static inline void AddToQueue(Workq* q, int id) {
137   if (id != 0)
138     q->insert(id);
139 }
140 
ProgToString(Prog * prog,Workq * q)141 static std::string ProgToString(Prog* prog, Workq* q) {
142   std::string s;
143   for (Workq::iterator i = q->begin(); i != q->end(); ++i) {
144     int id = *i;
145     Prog::Inst* ip = prog->inst(id);
146     s += StringPrintf("%d. %s\n", id, ip->Dump().c_str());
147     AddToQueue(q, ip->out());
148     if (ip->opcode() == kInstAlt || ip->opcode() == kInstAltMatch)
149       AddToQueue(q, ip->out1());
150   }
151   return s;
152 }
153 
FlattenedProgToString(Prog * prog,int start)154 static std::string FlattenedProgToString(Prog* prog, int start) {
155   std::string s;
156   for (int id = start; id < prog->size(); id++) {
157     Prog::Inst* ip = prog->inst(id);
158     if (ip->last())
159       s += StringPrintf("%d. %s\n", id, ip->Dump().c_str());
160     else
161       s += StringPrintf("%d+ %s\n", id, ip->Dump().c_str());
162   }
163   return s;
164 }
165 
Dump()166 std::string Prog::Dump() {
167   if (did_flatten_)
168     return FlattenedProgToString(this, start_);
169 
170   Workq q(size_);
171   AddToQueue(&q, start_);
172   return ProgToString(this, &q);
173 }
174 
DumpUnanchored()175 std::string Prog::DumpUnanchored() {
176   if (did_flatten_)
177     return FlattenedProgToString(this, start_unanchored_);
178 
179   Workq q(size_);
180   AddToQueue(&q, start_unanchored_);
181   return ProgToString(this, &q);
182 }
183 
DumpByteMap()184 std::string Prog::DumpByteMap() {
185   std::string map;
186   for (int c = 0; c < 256; c++) {
187     int b = bytemap_[c];
188     int lo = c;
189     while (c < 256-1 && bytemap_[c+1] == b)
190       c++;
191     int hi = c;
192     map += StringPrintf("[%02x-%02x] -> %d\n", lo, hi, b);
193   }
194   return map;
195 }
196 
197 // Is ip a guaranteed match at end of text, perhaps after some capturing?
IsMatch(Prog * prog,Prog::Inst * ip)198 static bool IsMatch(Prog* prog, Prog::Inst* ip) {
199   for (;;) {
200     switch (ip->opcode()) {
201       default:
202         LOG(DFATAL) << "Unexpected opcode in IsMatch: " << ip->opcode();
203         return false;
204 
205       case kInstAlt:
206       case kInstAltMatch:
207       case kInstByteRange:
208       case kInstFail:
209       case kInstEmptyWidth:
210         return false;
211 
212       case kInstCapture:
213       case kInstNop:
214         ip = prog->inst(ip->out());
215         break;
216 
217       case kInstMatch:
218         return true;
219     }
220   }
221 }
222 
223 // Peep-hole optimizer.
Optimize()224 void Prog::Optimize() {
225   Workq q(size_);
226 
227   // Eliminate nops.  Most are taken out during compilation
228   // but a few are hard to avoid.
229   q.clear();
230   AddToQueue(&q, start_);
231   for (Workq::iterator i = q.begin(); i != q.end(); ++i) {
232     int id = *i;
233 
234     Inst* ip = inst(id);
235     int j = ip->out();
236     Inst* jp;
237     while (j != 0 && (jp=inst(j))->opcode() == kInstNop) {
238       j = jp->out();
239     }
240     ip->set_out(j);
241     AddToQueue(&q, ip->out());
242 
243     if (ip->opcode() == kInstAlt) {
244       j = ip->out1();
245       while (j != 0 && (jp=inst(j))->opcode() == kInstNop) {
246         j = jp->out();
247       }
248       ip->out1_ = j;
249       AddToQueue(&q, ip->out1());
250     }
251   }
252 
253   // Insert kInstAltMatch instructions
254   // Look for
255   //   ip: Alt -> j | k
256   //	  j: ByteRange [00-FF] -> ip
257   //    k: Match
258   // or the reverse (the above is the greedy one).
259   // Rewrite Alt to AltMatch.
260   q.clear();
261   AddToQueue(&q, start_);
262   for (Workq::iterator i = q.begin(); i != q.end(); ++i) {
263     int id = *i;
264     Inst* ip = inst(id);
265     AddToQueue(&q, ip->out());
266     if (ip->opcode() == kInstAlt)
267       AddToQueue(&q, ip->out1());
268 
269     if (ip->opcode() == kInstAlt) {
270       Inst* j = inst(ip->out());
271       Inst* k = inst(ip->out1());
272       if (j->opcode() == kInstByteRange && j->out() == id &&
273           j->lo() == 0x00 && j->hi() == 0xFF &&
274           IsMatch(this, k)) {
275         ip->set_opcode(kInstAltMatch);
276         continue;
277       }
278       if (IsMatch(this, j) &&
279           k->opcode() == kInstByteRange && k->out() == id &&
280           k->lo() == 0x00 && k->hi() == 0xFF) {
281         ip->set_opcode(kInstAltMatch);
282       }
283     }
284   }
285 }
286 
EmptyFlags(const StringPiece & text,const char * p)287 uint32_t Prog::EmptyFlags(const StringPiece& text, const char* p) {
288   int flags = 0;
289 
290   // ^ and \A
291   if (p == text.data())
292     flags |= kEmptyBeginText | kEmptyBeginLine;
293   else if (p[-1] == '\n')
294     flags |= kEmptyBeginLine;
295 
296   // $ and \z
297   if (p == text.data() + text.size())
298     flags |= kEmptyEndText | kEmptyEndLine;
299   else if (p < text.data() + text.size() && p[0] == '\n')
300     flags |= kEmptyEndLine;
301 
302   // \b and \B
303   if (p == text.data() && p == text.data() + text.size()) {
304     // no word boundary here
305   } else if (p == text.data()) {
306     if (IsWordChar(p[0]))
307       flags |= kEmptyWordBoundary;
308   } else if (p == text.data() + text.size()) {
309     if (IsWordChar(p[-1]))
310       flags |= kEmptyWordBoundary;
311   } else {
312     if (IsWordChar(p[-1]) != IsWordChar(p[0]))
313       flags |= kEmptyWordBoundary;
314   }
315   if (!(flags & kEmptyWordBoundary))
316     flags |= kEmptyNonWordBoundary;
317 
318   return flags;
319 }
320 
321 // ByteMapBuilder implements a coloring algorithm.
322 //
323 // The first phase is a series of "mark and merge" batches: we mark one or more
324 // [lo-hi] ranges, then merge them into our internal state. Batching is not for
325 // performance; rather, it means that the ranges are treated indistinguishably.
326 //
327 // Internally, the ranges are represented using a bitmap that stores the splits
328 // and a vector that stores the colors; both of them are indexed by the ranges'
329 // last bytes. Thus, in order to merge a [lo-hi] range, we split at lo-1 and at
330 // hi (if not already split), then recolor each range in between. The color map
331 // (i.e. from the old color to the new color) is maintained for the lifetime of
332 // the batch and so underpins this somewhat obscure approach to set operations.
333 //
334 // The second phase builds the bytemap from our internal state: we recolor each
335 // range, then store the new color (which is now the byte class) in each of the
336 // corresponding array elements. Finally, we output the number of byte classes.
337 class ByteMapBuilder {
338  public:
ByteMapBuilder()339   ByteMapBuilder() {
340     // Initial state: the [0-255] range has color 256.
341     // This will avoid problems during the second phase,
342     // in which we assign byte classes numbered from 0.
343     splits_.Set(255);
344     colors_[255] = 256;
345     nextcolor_ = 257;
346   }
347 
348   void Mark(int lo, int hi);
349   void Merge();
350   void Build(uint8_t* bytemap, int* bytemap_range);
351 
352  private:
353   int Recolor(int oldcolor);
354 
355   Bitmap256 splits_;
356   int colors_[256];
357   int nextcolor_;
358   std::vector<std::pair<int, int>> colormap_;
359   std::vector<std::pair<int, int>> ranges_;
360 
361   ByteMapBuilder(const ByteMapBuilder&) = delete;
362   ByteMapBuilder& operator=(const ByteMapBuilder&) = delete;
363 };
364 
Mark(int lo,int hi)365 void ByteMapBuilder::Mark(int lo, int hi) {
366   DCHECK_GE(lo, 0);
367   DCHECK_GE(hi, 0);
368   DCHECK_LE(lo, 255);
369   DCHECK_LE(hi, 255);
370   DCHECK_LE(lo, hi);
371 
372   // Ignore any [0-255] ranges. They cause us to recolor every range, which
373   // has no effect on the eventual result and is therefore a waste of time.
374   if (lo == 0 && hi == 255)
375     return;
376 
377   ranges_.emplace_back(lo, hi);
378 }
379 
Merge()380 void ByteMapBuilder::Merge() {
381   for (std::vector<std::pair<int, int>>::const_iterator it = ranges_.begin();
382        it != ranges_.end();
383        ++it) {
384     int lo = it->first-1;
385     int hi = it->second;
386 
387     if (0 <= lo && !splits_.Test(lo)) {
388       splits_.Set(lo);
389       int next = splits_.FindNextSetBit(lo+1);
390       colors_[lo] = colors_[next];
391     }
392     if (!splits_.Test(hi)) {
393       splits_.Set(hi);
394       int next = splits_.FindNextSetBit(hi+1);
395       colors_[hi] = colors_[next];
396     }
397 
398     int c = lo+1;
399     while (c < 256) {
400       int next = splits_.FindNextSetBit(c);
401       colors_[next] = Recolor(colors_[next]);
402       if (next == hi)
403         break;
404       c = next+1;
405     }
406   }
407   colormap_.clear();
408   ranges_.clear();
409 }
410 
Build(uint8_t * bytemap,int * bytemap_range)411 void ByteMapBuilder::Build(uint8_t* bytemap, int* bytemap_range) {
412   // Assign byte classes numbered from 0.
413   nextcolor_ = 0;
414 
415   int c = 0;
416   while (c < 256) {
417     int next = splits_.FindNextSetBit(c);
418     uint8_t b = static_cast<uint8_t>(Recolor(colors_[next]));
419     while (c <= next) {
420       bytemap[c] = b;
421       c++;
422     }
423   }
424 
425   *bytemap_range = nextcolor_;
426 }
427 
Recolor(int oldcolor)428 int ByteMapBuilder::Recolor(int oldcolor) {
429   // Yes, this is a linear search. There can be at most 256
430   // colors and there will typically be far fewer than that.
431   // Also, we need to consider keys *and* values in order to
432   // avoid recoloring a given range more than once per batch.
433   std::vector<std::pair<int, int>>::const_iterator it =
434       std::find_if(colormap_.begin(), colormap_.end(),
435                    [=](const std::pair<int, int>& kv) -> bool {
436                      return kv.first == oldcolor || kv.second == oldcolor;
437                    });
438   if (it != colormap_.end())
439     return it->second;
440   int newcolor = nextcolor_;
441   nextcolor_++;
442   colormap_.emplace_back(oldcolor, newcolor);
443   return newcolor;
444 }
445 
ComputeByteMap()446 void Prog::ComputeByteMap() {
447   // Fill in bytemap with byte classes for the program.
448   // Ranges of bytes that are treated indistinguishably
449   // will be mapped to a single byte class.
450   ByteMapBuilder builder;
451 
452   // Don't repeat the work for ^ and $.
453   bool marked_line_boundaries = false;
454   // Don't repeat the work for \b and \B.
455   bool marked_word_boundaries = false;
456 
457   for (int id = 0; id < size(); id++) {
458     Inst* ip = inst(id);
459     if (ip->opcode() == kInstByteRange) {
460       int lo = ip->lo();
461       int hi = ip->hi();
462       builder.Mark(lo, hi);
463       if (ip->foldcase() && lo <= 'z' && hi >= 'a') {
464         int foldlo = lo;
465         int foldhi = hi;
466         if (foldlo < 'a')
467           foldlo = 'a';
468         if (foldhi > 'z')
469           foldhi = 'z';
470         if (foldlo <= foldhi) {
471           foldlo += 'A' - 'a';
472           foldhi += 'A' - 'a';
473           builder.Mark(foldlo, foldhi);
474         }
475       }
476       // If this Inst is not the last Inst in its list AND the next Inst is
477       // also a ByteRange AND the Insts have the same out, defer the merge.
478       if (!ip->last() &&
479           inst(id+1)->opcode() == kInstByteRange &&
480           ip->out() == inst(id+1)->out())
481         continue;
482       builder.Merge();
483     } else if (ip->opcode() == kInstEmptyWidth) {
484       if (ip->empty() & (kEmptyBeginLine|kEmptyEndLine) &&
485           !marked_line_boundaries) {
486         builder.Mark('\n', '\n');
487         builder.Merge();
488         marked_line_boundaries = true;
489       }
490       if (ip->empty() & (kEmptyWordBoundary|kEmptyNonWordBoundary) &&
491           !marked_word_boundaries) {
492         // We require two batches here: the first for ranges that are word
493         // characters, the second for ranges that are not word characters.
494         for (bool isword : {true, false}) {
495           int j;
496           for (int i = 0; i < 256; i = j) {
497             for (j = i + 1; j < 256 &&
498                             Prog::IsWordChar(static_cast<uint8_t>(i)) ==
499                                 Prog::IsWordChar(static_cast<uint8_t>(j));
500                  j++)
501               ;
502             if (Prog::IsWordChar(static_cast<uint8_t>(i)) == isword)
503               builder.Mark(i, j - 1);
504           }
505           builder.Merge();
506         }
507         marked_word_boundaries = true;
508       }
509     }
510   }
511 
512   builder.Build(bytemap_, &bytemap_range_);
513 
514   if (0) {  // For debugging, use trivial bytemap.
515     LOG(ERROR) << "Using trivial bytemap.";
516     for (int i = 0; i < 256; i++)
517       bytemap_[i] = static_cast<uint8_t>(i);
518     bytemap_range_ = 256;
519   }
520 }
521 
522 // Prog::Flatten() implements a graph rewriting algorithm.
523 //
524 // The overall process is similar to epsilon removal, but retains some epsilon
525 // transitions: those from Capture and EmptyWidth instructions; and those from
526 // nullable subexpressions. (The latter avoids quadratic blowup in transitions
527 // in the worst case.) It might be best thought of as Alt instruction elision.
528 //
529 // In conceptual terms, it divides the Prog into "trees" of instructions, then
530 // traverses the "trees" in order to produce "lists" of instructions. A "tree"
531 // is one or more instructions that grow from one "root" instruction to one or
532 // more "leaf" instructions; if a "tree" has exactly one instruction, then the
533 // "root" is also the "leaf". In most cases, a "root" is the successor of some
534 // "leaf" (i.e. the "leaf" instruction's out() returns the "root" instruction)
535 // and is considered a "successor root". A "leaf" can be a ByteRange, Capture,
536 // EmptyWidth or Match instruction. However, this is insufficient for handling
537 // nested nullable subexpressions correctly, so in some cases, a "root" is the
538 // dominator of the instructions reachable from some "successor root" (i.e. it
539 // has an unreachable predecessor) and is considered a "dominator root". Since
540 // only Alt instructions can be "dominator roots" (other instructions would be
541 // "leaves"), only Alt instructions are required to be marked as predecessors.
542 //
543 // Dividing the Prog into "trees" comprises two passes: marking the "successor
544 // roots" and the predecessors; and marking the "dominator roots". Sorting the
545 // "successor roots" by their bytecode offsets enables iteration in order from
546 // greatest to least during the second pass; by working backwards in this case
547 // and flooding the graph no further than "leaves" and already marked "roots",
548 // it becomes possible to mark "dominator roots" without doing excessive work.
549 //
550 // Traversing the "trees" is just iterating over the "roots" in order of their
551 // marking and flooding the graph no further than "leaves" and "roots". When a
552 // "leaf" is reached, the instruction is copied with its successor remapped to
553 // its "root" number. When a "root" is reached, a Nop instruction is generated
554 // with its successor remapped similarly. As each "list" is produced, its last
555 // instruction is marked as such. After all of the "lists" have been produced,
556 // a pass over their instructions remaps their successors to bytecode offsets.
Flatten()557 void Prog::Flatten() {
558   if (did_flatten_)
559     return;
560   did_flatten_ = true;
561 
562   // Scratch structures. It's important that these are reused by functions
563   // that we call in loops because they would thrash the heap otherwise.
564   SparseSet reachable(size());
565   std::vector<int> stk;
566   stk.reserve(size());
567 
568   // First pass: Marks "successor roots" and predecessors.
569   // Builds the mapping from inst-ids to root-ids.
570   SparseArray<int> rootmap(size());
571   SparseArray<int> predmap(size());
572   std::vector<std::vector<int>> predvec;
573   MarkSuccessors(&rootmap, &predmap, &predvec, &reachable, &stk);
574 
575   // Second pass: Marks "dominator roots".
576   SparseArray<int> sorted(rootmap);
577   std::sort(sorted.begin(), sorted.end(), sorted.less);
578   for (SparseArray<int>::const_iterator i = sorted.end() - 1;
579        i != sorted.begin();
580        --i) {
581     if (i->index() != start_unanchored() && i->index() != start())
582       MarkDominator(i->index(), &rootmap, &predmap, &predvec, &reachable, &stk);
583   }
584 
585   // Third pass: Emits "lists". Remaps outs to root-ids.
586   // Builds the mapping from root-ids to flat-ids.
587   std::vector<int> flatmap(rootmap.size());
588   std::vector<Inst> flat;
589   flat.reserve(size());
590   for (SparseArray<int>::const_iterator i = rootmap.begin();
591        i != rootmap.end();
592        ++i) {
593     flatmap[i->value()] = static_cast<int>(flat.size());
594     EmitList(i->index(), &rootmap, &flat, &reachable, &stk);
595     flat.back().set_last();
596     // We have the bounds of the "list", so this is the
597     // most convenient point at which to compute hints.
598     ComputeHints(&flat, flatmap[i->value()], static_cast<int>(flat.size()));
599   }
600 
601   list_count_ = static_cast<int>(flatmap.size());
602   for (int i = 0; i < kNumInst; i++)
603     inst_count_[i] = 0;
604 
605   // Fourth pass: Remaps outs to flat-ids.
606   // Counts instructions by opcode.
607   for (int id = 0; id < static_cast<int>(flat.size()); id++) {
608     Inst* ip = &flat[id];
609     if (ip->opcode() != kInstAltMatch)  // handled in EmitList()
610       ip->set_out(flatmap[ip->out()]);
611     inst_count_[ip->opcode()]++;
612   }
613 
614 #if !defined(NDEBUG)
615   // Address a `-Wunused-but-set-variable' warning from Clang 13.x.
616   size_t total = 0;
617   for (int i = 0; i < kNumInst; i++)
618     total += inst_count_[i];
619   CHECK_EQ(total, flat.size());
620 #endif
621 
622   // Remap start_unanchored and start.
623   if (start_unanchored() == 0) {
624     DCHECK_EQ(start(), 0);
625   } else if (start_unanchored() == start()) {
626     set_start_unanchored(flatmap[1]);
627     set_start(flatmap[1]);
628   } else {
629     set_start_unanchored(flatmap[1]);
630     set_start(flatmap[2]);
631   }
632 
633   // Finally, replace the old instructions with the new instructions.
634   size_ = static_cast<int>(flat.size());
635   inst_ = PODArray<Inst>(size_);
636   memmove(inst_.data(), flat.data(), size_*sizeof inst_[0]);
637 
638   // Populate the list heads for BitState.
639   // 512 instructions limits the memory footprint to 1KiB.
640   if (size_ <= 512) {
641     list_heads_ = PODArray<uint16_t>(size_);
642     // 0xFF makes it more obvious if we try to look up a non-head.
643     memset(list_heads_.data(), 0xFF, size_*sizeof list_heads_[0]);
644     for (int i = 0; i < list_count_; ++i)
645       list_heads_[flatmap[i]] = i;
646   }
647 
648   // BitState allocates a bitmap of size list_count_ * (text.size()+1)
649   // for tracking pairs of possibilities that it has already explored.
650   const size_t kBitStateBitmapMaxSize = 256*1024;  // max size in bits
651   bit_state_text_max_size_ = kBitStateBitmapMaxSize / list_count_ - 1;
652 }
653 
MarkSuccessors(SparseArray<int> * rootmap,SparseArray<int> * predmap,std::vector<std::vector<int>> * predvec,SparseSet * reachable,std::vector<int> * stk)654 void Prog::MarkSuccessors(SparseArray<int>* rootmap,
655                           SparseArray<int>* predmap,
656                           std::vector<std::vector<int>>* predvec,
657                           SparseSet* reachable, std::vector<int>* stk) {
658   // Mark the kInstFail instruction.
659   rootmap->set_new(0, rootmap->size());
660 
661   // Mark the start_unanchored and start instructions.
662   if (!rootmap->has_index(start_unanchored()))
663     rootmap->set_new(start_unanchored(), rootmap->size());
664   if (!rootmap->has_index(start()))
665     rootmap->set_new(start(), rootmap->size());
666 
667   reachable->clear();
668   stk->clear();
669   stk->push_back(start_unanchored());
670   while (!stk->empty()) {
671     int id = stk->back();
672     stk->pop_back();
673   Loop:
674     if (reachable->contains(id))
675       continue;
676     reachable->insert_new(id);
677 
678     Inst* ip = inst(id);
679     switch (ip->opcode()) {
680       default:
681         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
682         break;
683 
684       case kInstAltMatch:
685       case kInstAlt:
686         // Mark this instruction as a predecessor of each out.
687         for (int out : {ip->out(), ip->out1()}) {
688           if (!predmap->has_index(out)) {
689             predmap->set_new(out, static_cast<int>(predvec->size()));
690             predvec->emplace_back();
691           }
692           (*predvec)[predmap->get_existing(out)].emplace_back(id);
693         }
694         stk->push_back(ip->out1());
695         id = ip->out();
696         goto Loop;
697 
698       case kInstByteRange:
699       case kInstCapture:
700       case kInstEmptyWidth:
701         // Mark the out of this instruction as a "root".
702         if (!rootmap->has_index(ip->out()))
703           rootmap->set_new(ip->out(), rootmap->size());
704         id = ip->out();
705         goto Loop;
706 
707       case kInstNop:
708         id = ip->out();
709         goto Loop;
710 
711       case kInstMatch:
712       case kInstFail:
713         break;
714     }
715   }
716 }
717 
MarkDominator(int root,SparseArray<int> * rootmap,SparseArray<int> * predmap,std::vector<std::vector<int>> * predvec,SparseSet * reachable,std::vector<int> * stk)718 void Prog::MarkDominator(int root, SparseArray<int>* rootmap,
719                          SparseArray<int>* predmap,
720                          std::vector<std::vector<int>>* predvec,
721                          SparseSet* reachable, std::vector<int>* stk) {
722   reachable->clear();
723   stk->clear();
724   stk->push_back(root);
725   while (!stk->empty()) {
726     int id = stk->back();
727     stk->pop_back();
728   Loop:
729     if (reachable->contains(id))
730       continue;
731     reachable->insert_new(id);
732 
733     if (id != root && rootmap->has_index(id)) {
734       // We reached another "tree" via epsilon transition.
735       continue;
736     }
737 
738     Inst* ip = inst(id);
739     switch (ip->opcode()) {
740       default:
741         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
742         break;
743 
744       case kInstAltMatch:
745       case kInstAlt:
746         stk->push_back(ip->out1());
747         id = ip->out();
748         goto Loop;
749 
750       case kInstByteRange:
751       case kInstCapture:
752       case kInstEmptyWidth:
753         break;
754 
755       case kInstNop:
756         id = ip->out();
757         goto Loop;
758 
759       case kInstMatch:
760       case kInstFail:
761         break;
762     }
763   }
764 
765   for (SparseSet::const_iterator i = reachable->begin();
766        i != reachable->end();
767        ++i) {
768     int id = *i;
769     if (predmap->has_index(id)) {
770       for (int pred : (*predvec)[predmap->get_existing(id)]) {
771         if (!reachable->contains(pred)) {
772           // id has a predecessor that cannot be reached from root!
773           // Therefore, id must be a "root" too - mark it as such.
774           if (!rootmap->has_index(id))
775             rootmap->set_new(id, rootmap->size());
776         }
777       }
778     }
779   }
780 }
781 
EmitList(int root,SparseArray<int> * rootmap,std::vector<Inst> * flat,SparseSet * reachable,std::vector<int> * stk)782 void Prog::EmitList(int root, SparseArray<int>* rootmap,
783                     std::vector<Inst>* flat,
784                     SparseSet* reachable, std::vector<int>* stk) {
785   reachable->clear();
786   stk->clear();
787   stk->push_back(root);
788   while (!stk->empty()) {
789     int id = stk->back();
790     stk->pop_back();
791   Loop:
792     if (reachable->contains(id))
793       continue;
794     reachable->insert_new(id);
795 
796     if (id != root && rootmap->has_index(id)) {
797       // We reached another "tree" via epsilon transition. Emit a kInstNop
798       // instruction so that the Prog does not become quadratically larger.
799       flat->emplace_back();
800       flat->back().set_opcode(kInstNop);
801       flat->back().set_out(rootmap->get_existing(id));
802       continue;
803     }
804 
805     Inst* ip = inst(id);
806     switch (ip->opcode()) {
807       default:
808         LOG(DFATAL) << "unhandled opcode: " << ip->opcode();
809         break;
810 
811       case kInstAltMatch:
812         flat->emplace_back();
813         flat->back().set_opcode(kInstAltMatch);
814         flat->back().set_out(static_cast<int>(flat->size()));
815         flat->back().out1_ = static_cast<uint32_t>(flat->size())+1;
816         FALLTHROUGH_INTENDED;
817 
818       case kInstAlt:
819         stk->push_back(ip->out1());
820         id = ip->out();
821         goto Loop;
822 
823       case kInstByteRange:
824       case kInstCapture:
825       case kInstEmptyWidth:
826         flat->emplace_back();
827         memmove(&flat->back(), ip, sizeof *ip);
828         flat->back().set_out(rootmap->get_existing(ip->out()));
829         break;
830 
831       case kInstNop:
832         id = ip->out();
833         goto Loop;
834 
835       case kInstMatch:
836       case kInstFail:
837         flat->emplace_back();
838         memmove(&flat->back(), ip, sizeof *ip);
839         break;
840     }
841   }
842 }
843 
844 // For each ByteRange instruction in [begin, end), computes a hint to execution
845 // engines: the delta to the next instruction (in flat) worth exploring iff the
846 // current instruction matched.
847 //
848 // Implements a coloring algorithm related to ByteMapBuilder, but in this case,
849 // colors are instructions and recoloring ranges precisely identifies conflicts
850 // between instructions. Iterating backwards over [begin, end) is guaranteed to
851 // identify the nearest conflict (if any) with only linear complexity.
ComputeHints(std::vector<Inst> * flat,int begin,int end)852 void Prog::ComputeHints(std::vector<Inst>* flat, int begin, int end) {
853   Bitmap256 splits;
854   int colors[256];
855 
856   bool dirty = false;
857   for (int id = end; id >= begin; --id) {
858     if (id == end ||
859         (*flat)[id].opcode() != kInstByteRange) {
860       if (dirty) {
861         dirty = false;
862         splits.Clear();
863       }
864       splits.Set(255);
865       colors[255] = id;
866       // At this point, the [0-255] range is colored with id.
867       // Thus, hints cannot point beyond id; and if id == end,
868       // hints that would have pointed to id will be 0 instead.
869       continue;
870     }
871     dirty = true;
872 
873     // We recolor the [lo-hi] range with id. Note that first ratchets backwards
874     // from end to the nearest conflict (if any) during recoloring.
875     int first = end;
876     auto Recolor = [&](int lo, int hi) {
877       // Like ByteMapBuilder, we split at lo-1 and at hi.
878       --lo;
879 
880       if (0 <= lo && !splits.Test(lo)) {
881         splits.Set(lo);
882         int next = splits.FindNextSetBit(lo+1);
883         colors[lo] = colors[next];
884       }
885       if (!splits.Test(hi)) {
886         splits.Set(hi);
887         int next = splits.FindNextSetBit(hi+1);
888         colors[hi] = colors[next];
889       }
890 
891       int c = lo+1;
892       while (c < 256) {
893         int next = splits.FindNextSetBit(c);
894         // Ratchet backwards...
895         first = std::min(first, colors[next]);
896         // Recolor with id - because it's the new nearest conflict!
897         colors[next] = id;
898         if (next == hi)
899           break;
900         c = next+1;
901       }
902     };
903 
904     Inst* ip = &(*flat)[id];
905     int lo = ip->lo();
906     int hi = ip->hi();
907     Recolor(lo, hi);
908     if (ip->foldcase() && lo <= 'z' && hi >= 'a') {
909       int foldlo = lo;
910       int foldhi = hi;
911       if (foldlo < 'a')
912         foldlo = 'a';
913       if (foldhi > 'z')
914         foldhi = 'z';
915       if (foldlo <= foldhi) {
916         foldlo += 'A' - 'a';
917         foldhi += 'A' - 'a';
918         Recolor(foldlo, foldhi);
919       }
920     }
921 
922     if (first != end) {
923       uint16_t hint = static_cast<uint16_t>(std::min(first - id, 32767));
924       ip->hint_foldcase_ |= hint<<1;
925     }
926   }
927 }
928 
929 // The final state will always be this, which frees up a register for the hot
930 // loop and thus avoids the spilling that can occur when building with Clang.
931 static const size_t kShiftDFAFinal = 9;
932 
933 // This function takes the prefix as std::string (i.e. not const std::string&
934 // as normal) because it's going to clobber it, so a temporary is convenient.
BuildShiftDFA(std::string prefix)935 static uint64_t* BuildShiftDFA(std::string prefix) {
936   // This constant is for convenience now and also for correctness later when
937   // we clobber the prefix, but still need to know how long it was initially.
938   const size_t size = prefix.size();
939 
940   // Construct the NFA.
941   // The table is indexed by input byte; each element is a bitfield of states
942   // reachable by the input byte. Given a bitfield of the current states, the
943   // bitfield of states reachable from those is - for this specific purpose -
944   // always ((ncurr << 1) | 1). Intersecting the reachability bitfields gives
945   // the bitfield of the next states reached by stepping over the input byte.
946   // Credits for this technique: the Hyperscan paper by Geoff Langdale et al.
947   uint16_t nfa[256]{};
948   for (size_t i = 0; i < size; ++i) {
949     uint8_t b = prefix[i];
950     nfa[b] |= 1 << (i+1);
951   }
952   // This is the `\C*?` for unanchored search.
953   for (int b = 0; b < 256; ++b)
954     nfa[b] |= 1;
955 
956   // This maps from DFA state to NFA states; the reverse mapping is used when
957   // recording transitions and gets implemented with plain old linear search.
958   // The "Shift DFA" technique limits this to ten states when using uint64_t;
959   // to allow for the initial state, we use at most nine bytes of the prefix.
960   // That same limit is also why uint16_t is sufficient for the NFA bitfield.
961   uint16_t states[kShiftDFAFinal+1]{};
962   states[0] = 1;
963   for (size_t dcurr = 0; dcurr < size; ++dcurr) {
964     uint8_t b = prefix[dcurr];
965     uint16_t ncurr = states[dcurr];
966     uint16_t nnext = nfa[b] & ((ncurr << 1) | 1);
967     size_t dnext = dcurr+1;
968     if (dnext == size)
969       dnext = kShiftDFAFinal;
970     states[dnext] = nnext;
971   }
972 
973   // Sort and unique the bytes of the prefix to avoid repeating work while we
974   // record transitions. This clobbers the prefix, but it's no longer needed.
975   std::sort(prefix.begin(), prefix.end());
976   prefix.erase(std::unique(prefix.begin(), prefix.end()), prefix.end());
977 
978   // Construct the DFA.
979   // The table is indexed by input byte; each element is effectively a packed
980   // array of uint6_t; each array value will be multiplied by six in order to
981   // avoid having to do so later in the hot loop as well as masking/shifting.
982   // Credits for this technique: "Shift-based DFAs" on GitHub by Per Vognsen.
983   uint64_t* dfa = new uint64_t[256]{};
984   // Record a transition from each state for each of the bytes of the prefix.
985   // Note that all other input bytes go back to the initial state by default.
986   for (size_t dcurr = 0; dcurr < size; ++dcurr) {
987     for (uint8_t b : prefix) {
988       uint16_t ncurr = states[dcurr];
989       uint16_t nnext = nfa[b] & ((ncurr << 1) | 1);
990       size_t dnext = 0;
991       while (states[dnext] != nnext)
992         ++dnext;
993       dfa[b] |= static_cast<uint64_t>(dnext * 6) << (dcurr * 6);
994       // Convert ASCII letters to uppercase and record the extra transitions.
995       // Note that ASCII letters are guaranteed to be lowercase at this point
996       // because that's how the parser normalises them. #FunFact: 'k' and 's'
997       // match U+212A and U+017F, respectively, so they won't occur here when
998       // using UTF-8 encoding because the parser will emit character classes.
999       if ('a' <= b && b <= 'z') {
1000         b -= 'a' - 'A';
1001         dfa[b] |= static_cast<uint64_t>(dnext * 6) << (dcurr * 6);
1002       }
1003     }
1004   }
1005   // This lets the final state "saturate", which will matter for performance:
1006   // in the hot loop, we check for a match only at the end of each iteration,
1007   // so we must keep signalling the match until we get around to checking it.
1008   for (int b = 0; b < 256; ++b)
1009     dfa[b] |= static_cast<uint64_t>(kShiftDFAFinal * 6) << (kShiftDFAFinal * 6);
1010 
1011   return dfa;
1012 }
1013 
ConfigurePrefixAccel(const std::string & prefix,bool prefix_foldcase)1014 void Prog::ConfigurePrefixAccel(const std::string& prefix,
1015                                 bool prefix_foldcase) {
1016   prefix_foldcase_ = prefix_foldcase;
1017   prefix_size_ = prefix.size();
1018   if (prefix_foldcase_) {
1019     // Use PrefixAccel_ShiftDFA().
1020     // ... and no more than nine bytes of the prefix. (See above for details.)
1021     prefix_size_ = std::min(prefix_size_, kShiftDFAFinal);
1022     prefix_dfa_ = BuildShiftDFA(prefix.substr(0, prefix_size_));
1023   } else if (prefix_size_ != 1) {
1024     // Use PrefixAccel_FrontAndBack().
1025     prefix_front_ = prefix.front();
1026     prefix_back_ = prefix.back();
1027   } else {
1028     // Use memchr(3).
1029     prefix_front_ = prefix.front();
1030   }
1031 }
1032 
PrefixAccel_ShiftDFA(const void * data,size_t size)1033 const void* Prog::PrefixAccel_ShiftDFA(const void* data, size_t size) {
1034   if (size < prefix_size_)
1035     return NULL;
1036 
1037   uint64_t curr = 0;
1038 
1039   // At the time of writing, rough benchmarks on a Broadwell machine showed
1040   // that this unroll factor (i.e. eight) achieves a speedup factor of two.
1041   if (size >= 8) {
1042     const uint8_t* p = reinterpret_cast<const uint8_t*>(data);
1043     const uint8_t* endp = p + (size&~7);
1044     do {
1045       uint8_t b0 = p[0];
1046       uint8_t b1 = p[1];
1047       uint8_t b2 = p[2];
1048       uint8_t b3 = p[3];
1049       uint8_t b4 = p[4];
1050       uint8_t b5 = p[5];
1051       uint8_t b6 = p[6];
1052       uint8_t b7 = p[7];
1053 
1054       uint64_t next0 = prefix_dfa_[b0];
1055       uint64_t next1 = prefix_dfa_[b1];
1056       uint64_t next2 = prefix_dfa_[b2];
1057       uint64_t next3 = prefix_dfa_[b3];
1058       uint64_t next4 = prefix_dfa_[b4];
1059       uint64_t next5 = prefix_dfa_[b5];
1060       uint64_t next6 = prefix_dfa_[b6];
1061       uint64_t next7 = prefix_dfa_[b7];
1062 
1063       uint64_t curr0 = next0 >> (curr  & 63);
1064       uint64_t curr1 = next1 >> (curr0 & 63);
1065       uint64_t curr2 = next2 >> (curr1 & 63);
1066       uint64_t curr3 = next3 >> (curr2 & 63);
1067       uint64_t curr4 = next4 >> (curr3 & 63);
1068       uint64_t curr5 = next5 >> (curr4 & 63);
1069       uint64_t curr6 = next6 >> (curr5 & 63);
1070       uint64_t curr7 = next7 >> (curr6 & 63);
1071 
1072       if ((curr7 & 63) == kShiftDFAFinal * 6) {
1073         // At the time of writing, using the same masking subexpressions from
1074         // the preceding lines caused Clang to clutter the hot loop computing
1075         // them - even though they aren't actually needed for shifting! Hence
1076         // these rewritten conditions, which achieve a speedup factor of two.
1077         if (((curr7-curr0) & 63) == 0) return p+1-prefix_size_;
1078         if (((curr7-curr1) & 63) == 0) return p+2-prefix_size_;
1079         if (((curr7-curr2) & 63) == 0) return p+3-prefix_size_;
1080         if (((curr7-curr3) & 63) == 0) return p+4-prefix_size_;
1081         if (((curr7-curr4) & 63) == 0) return p+5-prefix_size_;
1082         if (((curr7-curr5) & 63) == 0) return p+6-prefix_size_;
1083         if (((curr7-curr6) & 63) == 0) return p+7-prefix_size_;
1084         if (((curr7-curr7) & 63) == 0) return p+8-prefix_size_;
1085       }
1086 
1087       curr = curr7;
1088       p += 8;
1089     } while (p != endp);
1090     data = p;
1091     size = size&7;
1092   }
1093 
1094   const uint8_t* p = reinterpret_cast<const uint8_t*>(data);
1095   const uint8_t* endp = p + size;
1096   while (p != endp) {
1097     uint8_t b = *p++;
1098     uint64_t next = prefix_dfa_[b];
1099     curr = next >> (curr & 63);
1100     if ((curr & 63) == kShiftDFAFinal * 6)
1101       return p-prefix_size_;
1102   }
1103   return NULL;
1104 }
1105 
1106 #if defined(__AVX2__)
1107 // Finds the least significant non-zero bit in n.
FindLSBSet(uint32_t n)1108 static int FindLSBSet(uint32_t n) {
1109   DCHECK_NE(n, 0);
1110 #if defined(__GNUC__)
1111   return __builtin_ctz(n);
1112 #elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
1113   unsigned long c;
1114   _BitScanForward(&c, n);
1115   return static_cast<int>(c);
1116 #else
1117   int c = 31;
1118   for (int shift = 1 << 4; shift != 0; shift >>= 1) {
1119     uint32_t word = n << shift;
1120     if (word != 0) {
1121       n = word;
1122       c -= shift;
1123     }
1124   }
1125   return c;
1126 #endif
1127 }
1128 #endif
1129 
PrefixAccel_FrontAndBack(const void * data,size_t size)1130 const void* Prog::PrefixAccel_FrontAndBack(const void* data, size_t size) {
1131   DCHECK_GE(prefix_size_, 2);
1132   if (size < prefix_size_)
1133     return NULL;
1134   // Don't bother searching the last prefix_size_-1 bytes for prefix_front_.
1135   // This also means that probing for prefix_back_ doesn't go out of bounds.
1136   size -= prefix_size_-1;
1137 
1138 #if defined(__AVX2__)
1139   // Use AVX2 to look for prefix_front_ and prefix_back_ 32 bytes at a time.
1140   if (size >= sizeof(__m256i)) {
1141     const __m256i* fp = reinterpret_cast<const __m256i*>(
1142         reinterpret_cast<const char*>(data));
1143     const __m256i* bp = reinterpret_cast<const __m256i*>(
1144         reinterpret_cast<const char*>(data) + prefix_size_-1);
1145     const __m256i* endfp = fp + size/sizeof(__m256i);
1146     const __m256i f_set1 = _mm256_set1_epi8(prefix_front_);
1147     const __m256i b_set1 = _mm256_set1_epi8(prefix_back_);
1148     do {
1149       const __m256i f_loadu = _mm256_loadu_si256(fp++);
1150       const __m256i b_loadu = _mm256_loadu_si256(bp++);
1151       const __m256i f_cmpeq = _mm256_cmpeq_epi8(f_set1, f_loadu);
1152       const __m256i b_cmpeq = _mm256_cmpeq_epi8(b_set1, b_loadu);
1153       const int fb_testz = _mm256_testz_si256(f_cmpeq, b_cmpeq);
1154       if (fb_testz == 0) {  // ZF: 1 means zero, 0 means non-zero.
1155         const __m256i fb_and = _mm256_and_si256(f_cmpeq, b_cmpeq);
1156         const int fb_movemask = _mm256_movemask_epi8(fb_and);
1157         const int fb_ctz = FindLSBSet(fb_movemask);
1158         return reinterpret_cast<const char*>(fp-1) + fb_ctz;
1159       }
1160     } while (fp != endfp);
1161     data = fp;
1162     size = size%sizeof(__m256i);
1163   }
1164 #endif
1165 
1166   const char* p0 = reinterpret_cast<const char*>(data);
1167   for (const char* p = p0;; p++) {
1168     DCHECK_GE(size, static_cast<size_t>(p-p0));
1169     p = reinterpret_cast<const char*>(memchr(p, prefix_front_, size - (p-p0)));
1170     if (p == NULL || p[prefix_size_-1] == prefix_back_)
1171       return p;
1172   }
1173 }
1174 
1175 }  // namespace re2
1176