xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/re/_parser.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1#
2# Secret Labs' Regular Expression Engine
3#
4# convert re-style regular expression to sre pattern
5#
6# Copyright (c) 1998-2001 by Secret Labs AB.  All rights reserved.
7#
8# See the __init__.py file for information on usage and redistribution.
9#
10
11"""Internal support module for sre"""
12
13# XXX: show string offset and offending character for all errors
14
15from ._constants import *
16
17SPECIAL_CHARS = ".\\[{()*+?^$|"
18REPEAT_CHARS = "*+?{"
19
20DIGITS = frozenset("0123456789")
21
22OCTDIGITS = frozenset("01234567")
23HEXDIGITS = frozenset("0123456789abcdefABCDEF")
24ASCIILETTERS = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
25
26WHITESPACE = frozenset(" \t\n\r\v\f")
27
28_REPEATCODES = frozenset({MIN_REPEAT, MAX_REPEAT, POSSESSIVE_REPEAT})
29_UNITCODES = frozenset({ANY, RANGE, IN, LITERAL, NOT_LITERAL, CATEGORY})
30
31ESCAPES = {
32    r"\a": (LITERAL, ord("\a")),
33    r"\b": (LITERAL, ord("\b")),
34    r"\f": (LITERAL, ord("\f")),
35    r"\n": (LITERAL, ord("\n")),
36    r"\r": (LITERAL, ord("\r")),
37    r"\t": (LITERAL, ord("\t")),
38    r"\v": (LITERAL, ord("\v")),
39    r"\\": (LITERAL, ord("\\"))
40}
41
42CATEGORIES = {
43    r"\A": (AT, AT_BEGINNING_STRING), # start of string
44    r"\b": (AT, AT_BOUNDARY),
45    r"\B": (AT, AT_NON_BOUNDARY),
46    r"\d": (IN, [(CATEGORY, CATEGORY_DIGIT)]),
47    r"\D": (IN, [(CATEGORY, CATEGORY_NOT_DIGIT)]),
48    r"\s": (IN, [(CATEGORY, CATEGORY_SPACE)]),
49    r"\S": (IN, [(CATEGORY, CATEGORY_NOT_SPACE)]),
50    r"\w": (IN, [(CATEGORY, CATEGORY_WORD)]),
51    r"\W": (IN, [(CATEGORY, CATEGORY_NOT_WORD)]),
52    r"\Z": (AT, AT_END_STRING), # end of string
53}
54
55FLAGS = {
56    # standard flags
57    "i": SRE_FLAG_IGNORECASE,
58    "L": SRE_FLAG_LOCALE,
59    "m": SRE_FLAG_MULTILINE,
60    "s": SRE_FLAG_DOTALL,
61    "x": SRE_FLAG_VERBOSE,
62    # extensions
63    "a": SRE_FLAG_ASCII,
64    "t": SRE_FLAG_TEMPLATE,
65    "u": SRE_FLAG_UNICODE,
66}
67
68TYPE_FLAGS = SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE
69GLOBAL_FLAGS = SRE_FLAG_DEBUG | SRE_FLAG_TEMPLATE
70
71class State:
72    # keeps track of state for parsing
73    def __init__(self):
74        self.flags = 0
75        self.groupdict = {}
76        self.groupwidths = [None]  # group 0
77        self.lookbehindgroups = None
78        self.grouprefpos = {}
79    @property
80    def groups(self):
81        return len(self.groupwidths)
82    def opengroup(self, name=None):
83        gid = self.groups
84        self.groupwidths.append(None)
85        if self.groups > MAXGROUPS:
86            raise error("too many groups")
87        if name is not None:
88            ogid = self.groupdict.get(name, None)
89            if ogid is not None:
90                raise error("redefinition of group name %r as group %d; "
91                            "was group %d" % (name, gid,  ogid))
92            self.groupdict[name] = gid
93        return gid
94    def closegroup(self, gid, p):
95        self.groupwidths[gid] = p.getwidth()
96    def checkgroup(self, gid):
97        return gid < self.groups and self.groupwidths[gid] is not None
98
99    def checklookbehindgroup(self, gid, source):
100        if self.lookbehindgroups is not None:
101            if not self.checkgroup(gid):
102                raise source.error('cannot refer to an open group')
103            if gid >= self.lookbehindgroups:
104                raise source.error('cannot refer to group defined in the same '
105                                   'lookbehind subpattern')
106
107class SubPattern:
108    # a subpattern, in intermediate form
109    def __init__(self, state, data=None):
110        self.state = state
111        if data is None:
112            data = []
113        self.data = data
114        self.width = None
115
116    def dump(self, level=0):
117        nl = True
118        seqtypes = (tuple, list)
119        for op, av in self.data:
120            print(level*"  " + str(op), end='')
121            if op is IN:
122                # member sublanguage
123                print()
124                for op, a in av:
125                    print((level+1)*"  " + str(op), a)
126            elif op is BRANCH:
127                print()
128                for i, a in enumerate(av[1]):
129                    if i:
130                        print(level*"  " + "OR")
131                    a.dump(level+1)
132            elif op is GROUPREF_EXISTS:
133                condgroup, item_yes, item_no = av
134                print('', condgroup)
135                item_yes.dump(level+1)
136                if item_no:
137                    print(level*"  " + "ELSE")
138                    item_no.dump(level+1)
139            elif isinstance(av, seqtypes):
140                nl = False
141                for a in av:
142                    if isinstance(a, SubPattern):
143                        if not nl:
144                            print()
145                        a.dump(level+1)
146                        nl = True
147                    else:
148                        if not nl:
149                            print(' ', end='')
150                        print(a, end='')
151                        nl = False
152                if not nl:
153                    print()
154            else:
155                print('', av)
156    def __repr__(self):
157        return repr(self.data)
158    def __len__(self):
159        return len(self.data)
160    def __delitem__(self, index):
161        del self.data[index]
162    def __getitem__(self, index):
163        if isinstance(index, slice):
164            return SubPattern(self.state, self.data[index])
165        return self.data[index]
166    def __setitem__(self, index, code):
167        self.data[index] = code
168    def insert(self, index, code):
169        self.data.insert(index, code)
170    def append(self, code):
171        self.data.append(code)
172    def getwidth(self):
173        # determine the width (min, max) for this subpattern
174        if self.width is not None:
175            return self.width
176        lo = hi = 0
177        for op, av in self.data:
178            if op is BRANCH:
179                i = MAXREPEAT - 1
180                j = 0
181                for av in av[1]:
182                    l, h = av.getwidth()
183                    i = min(i, l)
184                    j = max(j, h)
185                lo = lo + i
186                hi = hi + j
187            elif op is ATOMIC_GROUP:
188                i, j = av.getwidth()
189                lo = lo + i
190                hi = hi + j
191            elif op is SUBPATTERN:
192                i, j = av[-1].getwidth()
193                lo = lo + i
194                hi = hi + j
195            elif op in _REPEATCODES:
196                i, j = av[2].getwidth()
197                lo = lo + i * av[0]
198                hi = hi + j * av[1]
199            elif op in _UNITCODES:
200                lo = lo + 1
201                hi = hi + 1
202            elif op is GROUPREF:
203                i, j = self.state.groupwidths[av]
204                lo = lo + i
205                hi = hi + j
206            elif op is GROUPREF_EXISTS:
207                i, j = av[1].getwidth()
208                if av[2] is not None:
209                    l, h = av[2].getwidth()
210                    i = min(i, l)
211                    j = max(j, h)
212                else:
213                    i = 0
214                lo = lo + i
215                hi = hi + j
216            elif op is SUCCESS:
217                break
218        self.width = min(lo, MAXREPEAT - 1), min(hi, MAXREPEAT)
219        return self.width
220
221class Tokenizer:
222    def __init__(self, string):
223        self.istext = isinstance(string, str)
224        self.string = string
225        if not self.istext:
226            string = str(string, 'latin1')
227        self.decoded_string = string
228        self.index = 0
229        self.next = None
230        self.__next()
231    def __next(self):
232        index = self.index
233        try:
234            char = self.decoded_string[index]
235        except IndexError:
236            self.next = None
237            return
238        if char == "\\":
239            index += 1
240            try:
241                char += self.decoded_string[index]
242            except IndexError:
243                raise error("bad escape (end of pattern)",
244                            self.string, len(self.string) - 1) from None
245        self.index = index + 1
246        self.next = char
247    def match(self, char):
248        if char == self.next:
249            self.__next()
250            return True
251        return False
252    def get(self):
253        this = self.next
254        self.__next()
255        return this
256    def getwhile(self, n, charset):
257        result = ''
258        for _ in range(n):
259            c = self.next
260            if c not in charset:
261                break
262            result += c
263            self.__next()
264        return result
265    def getuntil(self, terminator, name):
266        result = ''
267        while True:
268            c = self.next
269            self.__next()
270            if c is None:
271                if not result:
272                    raise self.error("missing " + name)
273                raise self.error("missing %s, unterminated name" % terminator,
274                                 len(result))
275            if c == terminator:
276                if not result:
277                    raise self.error("missing " + name, 1)
278                break
279            result += c
280        return result
281    @property
282    def pos(self):
283        return self.index - len(self.next or '')
284    def tell(self):
285        return self.index - len(self.next or '')
286    def seek(self, index):
287        self.index = index
288        self.__next()
289
290    def error(self, msg, offset=0):
291        if not self.istext:
292            msg = msg.encode('ascii', 'backslashreplace').decode('ascii')
293        return error(msg, self.string, self.tell() - offset)
294
295    def checkgroupname(self, name, offset, nested):
296        if not name.isidentifier():
297            msg = "bad character in group name %r" % name
298            raise self.error(msg, len(name) + offset)
299        if not (self.istext or name.isascii()):
300            import warnings
301            warnings.warn(
302                "bad character in group name %a at position %d" %
303                (name, self.tell() - len(name) - offset),
304                DeprecationWarning, stacklevel=nested + 7
305            )
306
307def _class_escape(source, escape):
308    # handle escape code inside character class
309    code = ESCAPES.get(escape)
310    if code:
311        return code
312    code = CATEGORIES.get(escape)
313    if code and code[0] is IN:
314        return code
315    try:
316        c = escape[1:2]
317        if c == "x":
318            # hexadecimal escape (exactly two digits)
319            escape += source.getwhile(2, HEXDIGITS)
320            if len(escape) != 4:
321                raise source.error("incomplete escape %s" % escape, len(escape))
322            return LITERAL, int(escape[2:], 16)
323        elif c == "u" and source.istext:
324            # unicode escape (exactly four digits)
325            escape += source.getwhile(4, HEXDIGITS)
326            if len(escape) != 6:
327                raise source.error("incomplete escape %s" % escape, len(escape))
328            return LITERAL, int(escape[2:], 16)
329        elif c == "U" and source.istext:
330            # unicode escape (exactly eight digits)
331            escape += source.getwhile(8, HEXDIGITS)
332            if len(escape) != 10:
333                raise source.error("incomplete escape %s" % escape, len(escape))
334            c = int(escape[2:], 16)
335            chr(c) # raise ValueError for invalid code
336            return LITERAL, c
337        elif c == "N" and source.istext:
338            import unicodedata
339            # named unicode escape e.g. \N{EM DASH}
340            if not source.match('{'):
341                raise source.error("missing {")
342            charname = source.getuntil('}', 'character name')
343            try:
344                c = ord(unicodedata.lookup(charname))
345            except (KeyError, TypeError):
346                raise source.error("undefined character name %r" % charname,
347                                   len(charname) + len(r'\N{}')) from None
348            return LITERAL, c
349        elif c in OCTDIGITS:
350            # octal escape (up to three digits)
351            escape += source.getwhile(2, OCTDIGITS)
352            c = int(escape[1:], 8)
353            if c > 0o377:
354                raise source.error('octal escape value %s outside of '
355                                   'range 0-0o377' % escape, len(escape))
356            return LITERAL, c
357        elif c in DIGITS:
358            raise ValueError
359        if len(escape) == 2:
360            if c in ASCIILETTERS:
361                raise source.error('bad escape %s' % escape, len(escape))
362            return LITERAL, ord(escape[1])
363    except ValueError:
364        pass
365    raise source.error("bad escape %s" % escape, len(escape))
366
367def _escape(source, escape, state):
368    # handle escape code in expression
369    code = CATEGORIES.get(escape)
370    if code:
371        return code
372    code = ESCAPES.get(escape)
373    if code:
374        return code
375    try:
376        c = escape[1:2]
377        if c == "x":
378            # hexadecimal escape
379            escape += source.getwhile(2, HEXDIGITS)
380            if len(escape) != 4:
381                raise source.error("incomplete escape %s" % escape, len(escape))
382            return LITERAL, int(escape[2:], 16)
383        elif c == "u" and source.istext:
384            # unicode escape (exactly four digits)
385            escape += source.getwhile(4, HEXDIGITS)
386            if len(escape) != 6:
387                raise source.error("incomplete escape %s" % escape, len(escape))
388            return LITERAL, int(escape[2:], 16)
389        elif c == "U" and source.istext:
390            # unicode escape (exactly eight digits)
391            escape += source.getwhile(8, HEXDIGITS)
392            if len(escape) != 10:
393                raise source.error("incomplete escape %s" % escape, len(escape))
394            c = int(escape[2:], 16)
395            chr(c) # raise ValueError for invalid code
396            return LITERAL, c
397        elif c == "N" and source.istext:
398            import unicodedata
399            # named unicode escape e.g. \N{EM DASH}
400            if not source.match('{'):
401                raise source.error("missing {")
402            charname = source.getuntil('}', 'character name')
403            try:
404                c = ord(unicodedata.lookup(charname))
405            except (KeyError, TypeError):
406                raise source.error("undefined character name %r" % charname,
407                                   len(charname) + len(r'\N{}')) from None
408            return LITERAL, c
409        elif c == "0":
410            # octal escape
411            escape += source.getwhile(2, OCTDIGITS)
412            return LITERAL, int(escape[1:], 8)
413        elif c in DIGITS:
414            # octal escape *or* decimal group reference (sigh)
415            if source.next in DIGITS:
416                escape += source.get()
417                if (escape[1] in OCTDIGITS and escape[2] in OCTDIGITS and
418                    source.next in OCTDIGITS):
419                    # got three octal digits; this is an octal escape
420                    escape += source.get()
421                    c = int(escape[1:], 8)
422                    if c > 0o377:
423                        raise source.error('octal escape value %s outside of '
424                                           'range 0-0o377' % escape,
425                                           len(escape))
426                    return LITERAL, c
427            # not an octal escape, so this is a group reference
428            group = int(escape[1:])
429            if group < state.groups:
430                if not state.checkgroup(group):
431                    raise source.error("cannot refer to an open group",
432                                       len(escape))
433                state.checklookbehindgroup(group, source)
434                return GROUPREF, group
435            raise source.error("invalid group reference %d" % group, len(escape) - 1)
436        if len(escape) == 2:
437            if c in ASCIILETTERS:
438                raise source.error("bad escape %s" % escape, len(escape))
439            return LITERAL, ord(escape[1])
440    except ValueError:
441        pass
442    raise source.error("bad escape %s" % escape, len(escape))
443
444def _uniq(items):
445    return list(dict.fromkeys(items))
446
447def _parse_sub(source, state, verbose, nested):
448    # parse an alternation: a|b|c
449
450    items = []
451    itemsappend = items.append
452    sourcematch = source.match
453    start = source.tell()
454    while True:
455        itemsappend(_parse(source, state, verbose, nested + 1,
456                           not nested and not items))
457        if not sourcematch("|"):
458            break
459        if not nested:
460            verbose = state.flags & SRE_FLAG_VERBOSE
461
462    if len(items) == 1:
463        return items[0]
464
465    subpattern = SubPattern(state)
466
467    # check if all items share a common prefix
468    while True:
469        prefix = None
470        for item in items:
471            if not item:
472                break
473            if prefix is None:
474                prefix = item[0]
475            elif item[0] != prefix:
476                break
477        else:
478            # all subitems start with a common "prefix".
479            # move it out of the branch
480            for item in items:
481                del item[0]
482            subpattern.append(prefix)
483            continue # check next one
484        break
485
486    # check if the branch can be replaced by a character set
487    set = []
488    for item in items:
489        if len(item) != 1:
490            break
491        op, av = item[0]
492        if op is LITERAL:
493            set.append((op, av))
494        elif op is IN and av[0][0] is not NEGATE:
495            set.extend(av)
496        else:
497            break
498    else:
499        # we can store this as a character set instead of a
500        # branch (the compiler may optimize this even more)
501        subpattern.append((IN, _uniq(set)))
502        return subpattern
503
504    subpattern.append((BRANCH, (None, items)))
505    return subpattern
506
507def _parse(source, state, verbose, nested, first=False):
508    # parse a simple pattern
509    subpattern = SubPattern(state)
510
511    # precompute constants into local variables
512    subpatternappend = subpattern.append
513    sourceget = source.get
514    sourcematch = source.match
515    _len = len
516    _ord = ord
517
518    while True:
519
520        this = source.next
521        if this is None:
522            break # end of pattern
523        if this in "|)":
524            break # end of subpattern
525        sourceget()
526
527        if verbose:
528            # skip whitespace and comments
529            if this in WHITESPACE:
530                continue
531            if this == "#":
532                while True:
533                    this = sourceget()
534                    if this is None or this == "\n":
535                        break
536                continue
537
538        if this[0] == "\\":
539            code = _escape(source, this, state)
540            subpatternappend(code)
541
542        elif this not in SPECIAL_CHARS:
543            subpatternappend((LITERAL, _ord(this)))
544
545        elif this == "[":
546            here = source.tell() - 1
547            # character set
548            set = []
549            setappend = set.append
550##          if sourcematch(":"):
551##              pass # handle character classes
552            if source.next == '[':
553                import warnings
554                warnings.warn(
555                    'Possible nested set at position %d' % source.tell(),
556                    FutureWarning, stacklevel=nested + 6
557                )
558            negate = sourcematch("^")
559            # check remaining characters
560            while True:
561                this = sourceget()
562                if this is None:
563                    raise source.error("unterminated character set",
564                                       source.tell() - here)
565                if this == "]" and set:
566                    break
567                elif this[0] == "\\":
568                    code1 = _class_escape(source, this)
569                else:
570                    if set and this in '-&~|' and source.next == this:
571                        import warnings
572                        warnings.warn(
573                            'Possible set %s at position %d' % (
574                                'difference' if this == '-' else
575                                'intersection' if this == '&' else
576                                'symmetric difference' if this == '~' else
577                                'union',
578                                source.tell() - 1),
579                            FutureWarning, stacklevel=nested + 6
580                        )
581                    code1 = LITERAL, _ord(this)
582                if sourcematch("-"):
583                    # potential range
584                    that = sourceget()
585                    if that is None:
586                        raise source.error("unterminated character set",
587                                           source.tell() - here)
588                    if that == "]":
589                        if code1[0] is IN:
590                            code1 = code1[1][0]
591                        setappend(code1)
592                        setappend((LITERAL, _ord("-")))
593                        break
594                    if that[0] == "\\":
595                        code2 = _class_escape(source, that)
596                    else:
597                        if that == '-':
598                            import warnings
599                            warnings.warn(
600                                'Possible set difference at position %d' % (
601                                    source.tell() - 2),
602                                FutureWarning, stacklevel=nested + 6
603                            )
604                        code2 = LITERAL, _ord(that)
605                    if code1[0] != LITERAL or code2[0] != LITERAL:
606                        msg = "bad character range %s-%s" % (this, that)
607                        raise source.error(msg, len(this) + 1 + len(that))
608                    lo = code1[1]
609                    hi = code2[1]
610                    if hi < lo:
611                        msg = "bad character range %s-%s" % (this, that)
612                        raise source.error(msg, len(this) + 1 + len(that))
613                    setappend((RANGE, (lo, hi)))
614                else:
615                    if code1[0] is IN:
616                        code1 = code1[1][0]
617                    setappend(code1)
618
619            set = _uniq(set)
620            # XXX: <fl> should move set optimization to compiler!
621            if _len(set) == 1 and set[0][0] is LITERAL:
622                # optimization
623                if negate:
624                    subpatternappend((NOT_LITERAL, set[0][1]))
625                else:
626                    subpatternappend(set[0])
627            else:
628                if negate:
629                    set.insert(0, (NEGATE, None))
630                # charmap optimization can't be added here because
631                # global flags still are not known
632                subpatternappend((IN, set))
633
634        elif this in REPEAT_CHARS:
635            # repeat previous item
636            here = source.tell()
637            if this == "?":
638                min, max = 0, 1
639            elif this == "*":
640                min, max = 0, MAXREPEAT
641
642            elif this == "+":
643                min, max = 1, MAXREPEAT
644            elif this == "{":
645                if source.next == "}":
646                    subpatternappend((LITERAL, _ord(this)))
647                    continue
648
649                min, max = 0, MAXREPEAT
650                lo = hi = ""
651                while source.next in DIGITS:
652                    lo += sourceget()
653                if sourcematch(","):
654                    while source.next in DIGITS:
655                        hi += sourceget()
656                else:
657                    hi = lo
658                if not sourcematch("}"):
659                    subpatternappend((LITERAL, _ord(this)))
660                    source.seek(here)
661                    continue
662
663                if lo:
664                    min = int(lo)
665                    if min >= MAXREPEAT:
666                        raise OverflowError("the repetition number is too large")
667                if hi:
668                    max = int(hi)
669                    if max >= MAXREPEAT:
670                        raise OverflowError("the repetition number is too large")
671                    if max < min:
672                        raise source.error("min repeat greater than max repeat",
673                                           source.tell() - here)
674            else:
675                raise AssertionError("unsupported quantifier %r" % (char,))
676            # figure out which item to repeat
677            if subpattern:
678                item = subpattern[-1:]
679            else:
680                item = None
681            if not item or item[0][0] is AT:
682                raise source.error("nothing to repeat",
683                                   source.tell() - here + len(this))
684            if item[0][0] in _REPEATCODES:
685                raise source.error("multiple repeat",
686                                   source.tell() - here + len(this))
687            if item[0][0] is SUBPATTERN:
688                group, add_flags, del_flags, p = item[0][1]
689                if group is None and not add_flags and not del_flags:
690                    item = p
691            if sourcematch("?"):
692                # Non-Greedy Match
693                subpattern[-1] = (MIN_REPEAT, (min, max, item))
694            elif sourcematch("+"):
695                # Possessive Match (Always Greedy)
696                subpattern[-1] = (POSSESSIVE_REPEAT, (min, max, item))
697            else:
698                # Greedy Match
699                subpattern[-1] = (MAX_REPEAT, (min, max, item))
700
701        elif this == ".":
702            subpatternappend((ANY, None))
703
704        elif this == "(":
705            start = source.tell() - 1
706            capture = True
707            atomic = False
708            name = None
709            add_flags = 0
710            del_flags = 0
711            if sourcematch("?"):
712                # options
713                char = sourceget()
714                if char is None:
715                    raise source.error("unexpected end of pattern")
716                if char == "P":
717                    # python extensions
718                    if sourcematch("<"):
719                        # named group: skip forward to end of name
720                        name = source.getuntil(">", "group name")
721                        source.checkgroupname(name, 1, nested)
722                    elif sourcematch("="):
723                        # named backreference
724                        name = source.getuntil(")", "group name")
725                        source.checkgroupname(name, 1, nested)
726                        gid = state.groupdict.get(name)
727                        if gid is None:
728                            msg = "unknown group name %r" % name
729                            raise source.error(msg, len(name) + 1)
730                        if not state.checkgroup(gid):
731                            raise source.error("cannot refer to an open group",
732                                               len(name) + 1)
733                        state.checklookbehindgroup(gid, source)
734                        subpatternappend((GROUPREF, gid))
735                        continue
736
737                    else:
738                        char = sourceget()
739                        if char is None:
740                            raise source.error("unexpected end of pattern")
741                        raise source.error("unknown extension ?P" + char,
742                                           len(char) + 2)
743                elif char == ":":
744                    # non-capturing group
745                    capture = False
746                elif char == "#":
747                    # comment
748                    while True:
749                        if source.next is None:
750                            raise source.error("missing ), unterminated comment",
751                                               source.tell() - start)
752                        if sourceget() == ")":
753                            break
754                    continue
755
756                elif char in "=!<":
757                    # lookahead assertions
758                    dir = 1
759                    if char == "<":
760                        char = sourceget()
761                        if char is None:
762                            raise source.error("unexpected end of pattern")
763                        if char not in "=!":
764                            raise source.error("unknown extension ?<" + char,
765                                               len(char) + 2)
766                        dir = -1 # lookbehind
767                        lookbehindgroups = state.lookbehindgroups
768                        if lookbehindgroups is None:
769                            state.lookbehindgroups = state.groups
770                    p = _parse_sub(source, state, verbose, nested + 1)
771                    if dir < 0:
772                        if lookbehindgroups is None:
773                            state.lookbehindgroups = None
774                    if not sourcematch(")"):
775                        raise source.error("missing ), unterminated subpattern",
776                                           source.tell() - start)
777                    if char == "=":
778                        subpatternappend((ASSERT, (dir, p)))
779                    else:
780                        subpatternappend((ASSERT_NOT, (dir, p)))
781                    continue
782
783                elif char == "(":
784                    # conditional backreference group
785                    condname = source.getuntil(")", "group name")
786                    if condname.isidentifier():
787                        source.checkgroupname(condname, 1, nested)
788                        condgroup = state.groupdict.get(condname)
789                        if condgroup is None:
790                            msg = "unknown group name %r" % condname
791                            raise source.error(msg, len(condname) + 1)
792                    else:
793                        try:
794                            condgroup = int(condname)
795                            if condgroup < 0:
796                                raise ValueError
797                        except ValueError:
798                            msg = "bad character in group name %r" % condname
799                            raise source.error(msg, len(condname) + 1) from None
800                        if not condgroup:
801                            raise source.error("bad group number",
802                                               len(condname) + 1)
803                        if condgroup >= MAXGROUPS:
804                            msg = "invalid group reference %d" % condgroup
805                            raise source.error(msg, len(condname) + 1)
806                        if condgroup not in state.grouprefpos:
807                            state.grouprefpos[condgroup] = (
808                                source.tell() - len(condname) - 1
809                            )
810                        if not (condname.isdecimal() and condname.isascii()):
811                            import warnings
812                            warnings.warn(
813                                "bad character in group name %s at position %d" %
814                                (repr(condname) if source.istext else ascii(condname),
815                                 source.tell() - len(condname) - 1),
816                                DeprecationWarning, stacklevel=nested + 6
817                            )
818                    state.checklookbehindgroup(condgroup, source)
819                    item_yes = _parse(source, state, verbose, nested + 1)
820                    if source.match("|"):
821                        item_no = _parse(source, state, verbose, nested + 1)
822                        if source.next == "|":
823                            raise source.error("conditional backref with more than two branches")
824                    else:
825                        item_no = None
826                    if not source.match(")"):
827                        raise source.error("missing ), unterminated subpattern",
828                                           source.tell() - start)
829                    subpatternappend((GROUPREF_EXISTS, (condgroup, item_yes, item_no)))
830                    continue
831
832                elif char == ">":
833                    # non-capturing, atomic group
834                    capture = False
835                    atomic = True
836                elif char in FLAGS or char == "-":
837                    # flags
838                    flags = _parse_flags(source, state, char)
839                    if flags is None:  # global flags
840                        if not first or subpattern:
841                            raise source.error('global flags not at the start '
842                                               'of the expression',
843                                               source.tell() - start)
844                        verbose = state.flags & SRE_FLAG_VERBOSE
845                        continue
846
847                    add_flags, del_flags = flags
848                    capture = False
849                else:
850                    raise source.error("unknown extension ?" + char,
851                                       len(char) + 1)
852
853            # parse group contents
854            if capture:
855                try:
856                    group = state.opengroup(name)
857                except error as err:
858                    raise source.error(err.msg, len(name) + 1) from None
859            else:
860                group = None
861            sub_verbose = ((verbose or (add_flags & SRE_FLAG_VERBOSE)) and
862                           not (del_flags & SRE_FLAG_VERBOSE))
863            p = _parse_sub(source, state, sub_verbose, nested + 1)
864            if not source.match(")"):
865                raise source.error("missing ), unterminated subpattern",
866                                   source.tell() - start)
867            if group is not None:
868                state.closegroup(group, p)
869            if atomic:
870                assert group is None
871                subpatternappend((ATOMIC_GROUP, p))
872            else:
873                subpatternappend((SUBPATTERN, (group, add_flags, del_flags, p)))
874
875        elif this == "^":
876            subpatternappend((AT, AT_BEGINNING))
877
878        elif this == "$":
879            subpatternappend((AT, AT_END))
880
881        else:
882            raise AssertionError("unsupported special character %r" % (char,))
883
884    # unpack non-capturing groups
885    for i in range(len(subpattern))[::-1]:
886        op, av = subpattern[i]
887        if op is SUBPATTERN:
888            group, add_flags, del_flags, p = av
889            if group is None and not add_flags and not del_flags:
890                subpattern[i: i+1] = p
891
892    return subpattern
893
894def _parse_flags(source, state, char):
895    sourceget = source.get
896    add_flags = 0
897    del_flags = 0
898    if char != "-":
899        while True:
900            flag = FLAGS[char]
901            if source.istext:
902                if char == 'L':
903                    msg = "bad inline flags: cannot use 'L' flag with a str pattern"
904                    raise source.error(msg)
905            else:
906                if char == 'u':
907                    msg = "bad inline flags: cannot use 'u' flag with a bytes pattern"
908                    raise source.error(msg)
909            add_flags |= flag
910            if (flag & TYPE_FLAGS) and (add_flags & TYPE_FLAGS) != flag:
911                msg = "bad inline flags: flags 'a', 'u' and 'L' are incompatible"
912                raise source.error(msg)
913            char = sourceget()
914            if char is None:
915                raise source.error("missing -, : or )")
916            if char in ")-:":
917                break
918            if char not in FLAGS:
919                msg = "unknown flag" if char.isalpha() else "missing -, : or )"
920                raise source.error(msg, len(char))
921    if char == ")":
922        state.flags |= add_flags
923        return None
924    if add_flags & GLOBAL_FLAGS:
925        raise source.error("bad inline flags: cannot turn on global flag", 1)
926    if char == "-":
927        char = sourceget()
928        if char is None:
929            raise source.error("missing flag")
930        if char not in FLAGS:
931            msg = "unknown flag" if char.isalpha() else "missing flag"
932            raise source.error(msg, len(char))
933        while True:
934            flag = FLAGS[char]
935            if flag & TYPE_FLAGS:
936                msg = "bad inline flags: cannot turn off flags 'a', 'u' and 'L'"
937                raise source.error(msg)
938            del_flags |= flag
939            char = sourceget()
940            if char is None:
941                raise source.error("missing :")
942            if char == ":":
943                break
944            if char not in FLAGS:
945                msg = "unknown flag" if char.isalpha() else "missing :"
946                raise source.error(msg, len(char))
947    assert char == ":"
948    if del_flags & GLOBAL_FLAGS:
949        raise source.error("bad inline flags: cannot turn off global flag", 1)
950    if add_flags & del_flags:
951        raise source.error("bad inline flags: flag turned on and off", 1)
952    return add_flags, del_flags
953
954def fix_flags(src, flags):
955    # Check and fix flags according to the type of pattern (str or bytes)
956    if isinstance(src, str):
957        if flags & SRE_FLAG_LOCALE:
958            raise ValueError("cannot use LOCALE flag with a str pattern")
959        if not flags & SRE_FLAG_ASCII:
960            flags |= SRE_FLAG_UNICODE
961        elif flags & SRE_FLAG_UNICODE:
962            raise ValueError("ASCII and UNICODE flags are incompatible")
963    else:
964        if flags & SRE_FLAG_UNICODE:
965            raise ValueError("cannot use UNICODE flag with a bytes pattern")
966        if flags & SRE_FLAG_LOCALE and flags & SRE_FLAG_ASCII:
967            raise ValueError("ASCII and LOCALE flags are incompatible")
968    return flags
969
970def parse(str, flags=0, state=None):
971    # parse 're' pattern into list of (opcode, argument) tuples
972
973    source = Tokenizer(str)
974
975    if state is None:
976        state = State()
977    state.flags = flags
978    state.str = str
979
980    p = _parse_sub(source, state, flags & SRE_FLAG_VERBOSE, 0)
981    p.state.flags = fix_flags(str, p.state.flags)
982
983    if source.next is not None:
984        assert source.next == ")"
985        raise source.error("unbalanced parenthesis")
986
987    for g in p.state.grouprefpos:
988        if g >= p.state.groups:
989            msg = "invalid group reference %d" % g
990            raise error(msg, str, p.state.grouprefpos[g])
991
992    if flags & SRE_FLAG_DEBUG:
993        p.dump()
994
995    return p
996
997def parse_template(source, state):
998    # parse 're' replacement string into list of literals and
999    # group references
1000    s = Tokenizer(source)
1001    sget = s.get
1002    groups = []
1003    literals = []
1004    literal = []
1005    lappend = literal.append
1006    def addgroup(index, pos):
1007        if index > state.groups:
1008            raise s.error("invalid group reference %d" % index, pos)
1009        if literal:
1010            literals.append(''.join(literal))
1011            del literal[:]
1012        groups.append((len(literals), index))
1013        literals.append(None)
1014    groupindex = state.groupindex
1015    while True:
1016        this = sget()
1017        if this is None:
1018            break # end of replacement string
1019        if this[0] == "\\":
1020            # group
1021            c = this[1]
1022            if c == "g":
1023                if not s.match("<"):
1024                    raise s.error("missing <")
1025                name = s.getuntil(">", "group name")
1026                if name.isidentifier():
1027                    s.checkgroupname(name, 1, -1)
1028                    try:
1029                        index = groupindex[name]
1030                    except KeyError:
1031                        raise IndexError("unknown group name %r" % name) from None
1032                else:
1033                    try:
1034                        index = int(name)
1035                        if index < 0:
1036                            raise ValueError
1037                    except ValueError:
1038                        raise s.error("bad character in group name %r" % name,
1039                                      len(name) + 1) from None
1040                    if index >= MAXGROUPS:
1041                        raise s.error("invalid group reference %d" % index,
1042                                      len(name) + 1)
1043                    if not (name.isdecimal() and name.isascii()):
1044                        import warnings
1045                        warnings.warn(
1046                            "bad character in group name %s at position %d" %
1047                            (repr(name) if s.istext else ascii(name),
1048                             s.tell() - len(name) - 1),
1049                            DeprecationWarning, stacklevel=5
1050                        )
1051                addgroup(index, len(name) + 1)
1052            elif c == "0":
1053                if s.next in OCTDIGITS:
1054                    this += sget()
1055                    if s.next in OCTDIGITS:
1056                        this += sget()
1057                lappend(chr(int(this[1:], 8) & 0xff))
1058            elif c in DIGITS:
1059                isoctal = False
1060                if s.next in DIGITS:
1061                    this += sget()
1062                    if (c in OCTDIGITS and this[2] in OCTDIGITS and
1063                        s.next in OCTDIGITS):
1064                        this += sget()
1065                        isoctal = True
1066                        c = int(this[1:], 8)
1067                        if c > 0o377:
1068                            raise s.error('octal escape value %s outside of '
1069                                          'range 0-0o377' % this, len(this))
1070                        lappend(chr(c))
1071                if not isoctal:
1072                    addgroup(int(this[1:]), len(this) - 1)
1073            else:
1074                try:
1075                    this = chr(ESCAPES[this][1])
1076                except KeyError:
1077                    if c in ASCIILETTERS:
1078                        raise s.error('bad escape %s' % this, len(this)) from None
1079                lappend(this)
1080        else:
1081            lappend(this)
1082    if literal:
1083        literals.append(''.join(literal))
1084    if not isinstance(source, str):
1085        # The tokenizer implicitly decodes bytes objects as latin-1, we must
1086        # therefore re-encode the final representation.
1087        literals = [None if s is None else s.encode('latin-1') for s in literals]
1088    return groups, literals
1089
1090def expand_template(template, match):
1091    g = match.group
1092    empty = match.string[:0]
1093    groups, literals = template
1094    literals = literals[:]
1095    try:
1096        for index, group in groups:
1097            literals[index] = g(group) or empty
1098    except IndexError:
1099        raise error("invalid group reference %d" % index) from None
1100    return empty.join(literals)
1101