xref: /aosp_15_r20/external/pytorch/torchgen/code_template.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import re
4from typing import Mapping, Sequence
5
6
7# match $identifier or ${identifier} and replace with value in env
8# If this identifier is at the beginning of whitespace on a line
9# and its value is a list then it is treated as
10# block substitution by indenting to that depth and putting each element
11# of the list on its own line
12# if the identifier is on a line starting with non-whitespace and a list
13# then it is comma separated ${,foo} will insert a comma before the list
14# if this list is not empty and ${foo,} will insert one after.
15
16
17class CodeTemplate:
18    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
19    substitution = re.compile(substitution_str, re.MULTILINE)
20
21    pattern: str
22    filename: str
23
24    @staticmethod
25    def from_file(filename: str) -> CodeTemplate:
26        with open(filename) as f:
27            return CodeTemplate(f.read(), filename)
28
29    def __init__(self, pattern: str, filename: str = "") -> None:
30        self.pattern = pattern
31        self.filename = filename
32
33    def substitute(
34        self, env: Mapping[str, object] | None = None, **kwargs: object
35    ) -> str:
36        if env is None:
37            env = {}
38
39        def lookup(v: str) -> object:
40            assert env is not None
41            return kwargs[v] if v in kwargs else env[v]
42
43        def indent_lines(indent: str, v: Sequence[object]) -> str:
44            return "".join(
45                [indent + l + "\n" for e in v for l in str(e).splitlines()]
46            ).rstrip()
47
48        def replace(match: re.Match[str]) -> str:
49            indent = match.group(1)
50            key = match.group(2)
51            comma_before = ""
52            comma_after = ""
53            if key[0] == "{":
54                key = key[1:-1]
55                if key[0] == ",":
56                    comma_before = ", "
57                    key = key[1:]
58                if key[-1] == ",":
59                    comma_after = ", "
60                    key = key[:-1]
61            v = lookup(key)
62            if indent is not None:
63                if not isinstance(v, list):
64                    v = [v]
65                return indent_lines(indent, v)
66            elif isinstance(v, list):
67                middle = ", ".join([str(x) for x in v])
68                if len(v) == 0:
69                    return middle
70                return comma_before + middle + comma_after
71            else:
72                return str(v)
73
74        return self.substitution.sub(replace, self.pattern)
75
76
77if __name__ == "__main__":
78    c = CodeTemplate(
79        """\
80    int foo($args) {
81
82        $bar
83            $bar
84        $a+$b
85    }
86    int commatest(int a${,stuff})
87    int notest(int a${,empty,})
88    """
89    )
90    print(
91        c.substitute(
92            args=["hi", 8],
93            bar=["what", 7],
94            a=3,
95            b=4,
96            stuff=["things...", "others"],
97            empty=[],
98        )
99    )
100