1 #pragma once
2
3 #include <c10/util/irange.h>
4
5 #include <sstream>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9
10 namespace at::jit {
11
12 // A template environment is a mapping from template variable names, e.g.,
13 // identifier (corresponding to $identifier) to their expansions.
14 //
15 // This template environment supports storing strings, numbers and lists
16 // of strings, and can be chained together (so that lookup proceeds in
17 // in the top level environment, and then recurses into a parent
18 // environment if the key is not found.)
19 struct TemplateEnv {
20 TemplateEnv() = default;
TemplateEnvTemplateEnv21 TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
22 TemplateEnv& operator=(const TemplateEnv& parent) = delete;
23
24 using string_list = std::vector<std::string>;
25
26 // Add a string 'v' to the map at key 'k'.
sTemplateEnv27 void s(const std::string& k, const std::string& v) {
28 strings_[k] = v;
29 lists_.erase(k);
30 }
31
32 // Add a number 'v' to the map at key 'k'
33 template <typename T>
dTemplateEnv34 void d(const std::string& k, const T& v) {
35 strings_[k] = std::to_string(v);
36 lists_.erase(k);
37 }
38
39 // Retrieve the string representation of the value stored at 'k' from the map.
40 // Raises an exception if the key is not found.
sTemplateEnv41 const std::string& s(const std::string& k) const {
42 if (strings_.count(k) == 0) {
43 if (parent) {
44 return parent->s(k);
45 }
46 notFound(k);
47 }
48 return strings_.at(k);
49 }
50
51 // Store a list of strings 'v' in the map at 'k'.
vTemplateEnv52 void v(const std::string& k, const string_list& v) {
53 lists_[k] = v;
54 strings_.erase(k);
55 }
56
57 // Retrieve a list of strings stored at 'k' from the map.
58 // Raises an exception if the key is not found.
vTemplateEnv59 const string_list& v(const std::string& k) const {
60 if (lists_.count(k) == 0) {
61 if (parent) {
62 return parent->v(k);
63 }
64 notFound(k);
65 }
66 return lists_.at(k);
67 }
68
69 // Test if a string 'k' is a string (as opposed to a list.)
keyIsStringTemplateEnv70 bool keyIsString(const std::string& k) const {
71 if (strings_.count(k) > 0)
72 return true;
73 if (lists_.count(k) > 0)
74 return false;
75 if (parent)
76 return parent->keyIsString(k);
77 notFound(k);
78 }
79
80 private:
notFoundTemplateEnv81 [[noreturn]] void notFound(const std::string& k) const {
82 std::stringstream ss;
83 ss << "key not found: " << k;
84 throw std::logic_error(ss.str());
85 }
86
87 std::unordered_map<std::string, std::string> strings_;
88 std::unordered_map<std::string, string_list> lists_;
89 TemplateEnv* parent{nullptr};
90 };
91
92 /*
93 # Match $identifier or ${identifier} and replace with the value in env.
94 # If this identifier is at the beginning of whitespace on a line
95 # and its value is a list then it is treated as
96 # block substitution by indenting all lines of all elements.
97 # If the identifier is on a line starting with non-whitespace and a list
98 # then it is comma separated. ${,foo} will insert a comma before the list
99 # if this list is not empty and ${foo,} will insert one after.
100 */
101 struct CodeTemplate {
CodeTemplateCodeTemplate102 /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
103
formatCodeTemplate104 std::string format(const TemplateEnv& env) const {
105 std::stringstream out;
106 size_t pos = 0;
107 size_t indent = 0;
108 bool all_whitespace = true;
109 while (pos < template_text.size()) {
110 char c = template_text[pos];
111 if (c == '$') {
112 std::stringstream kss;
113 bool comma_before = false;
114 bool comma_after = false;
115 size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
116 std::string k = kss.str();
117 bool is_string = env.keyIsString(k);
118 if (all_whitespace) {
119 if (is_string)
120 emitStringWithIndents(out, indent, env.s(k));
121 else
122 emitLinesIndented(out, indent, env.v(k));
123 } else {
124 if (is_string)
125 out << env.s(k);
126 else
127 emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
128 }
129 all_whitespace = false;
130 pos = new_pos;
131 } else {
132 out << c;
133 if (!isspace(c))
134 all_whitespace = false;
135 indent++;
136 if (c == '\n') {
137 indent = 0;
138 all_whitespace = true;
139 }
140 pos++;
141 }
142 }
143 return out.str();
144 }
145
146 private:
147 using string_list = std::vector<std::string>;
charAtCodeTemplate148 char charAt(size_t p) const {
149 if (p >= template_text.size())
150 throw std::logic_error("EOS found in key");
151 return template_text[p];
152 }
parseKeyCodeTemplate153 size_t parseKey(
154 size_t pos,
155 std::ostream& k,
156 bool& comma_before,
157 bool& comma_after) const {
158 comma_before = false;
159 comma_after = false;
160 pos++;
161 if (charAt(pos) == '{') {
162 pos++;
163 if (charAt(pos) == ',') {
164 comma_before = true;
165 pos++;
166 }
167 pos = parseIdent(pos, k);
168 if (charAt(pos) == ',') {
169 comma_after = true;
170 pos++;
171 }
172 if (charAt(pos) != '}')
173 throw std::logic_error("missing terminating '}'");
174 pos++;
175 return pos;
176 } else {
177 return parseIdent(pos, k);
178 }
179 }
parseIdentCodeTemplate180 size_t parseIdent(size_t pos, std::ostream& k) const {
181 while (pos < template_text.size() &&
182 (isalnum(template_text[pos]) || template_text[pos] == '_')) {
183 k << template_text[pos];
184 pos++;
185 }
186 return pos;
187 }
emitCommaSeparatedListCodeTemplate188 void emitCommaSeparatedList(
189 std::ostream& out,
190 const string_list& strings,
191 bool comma_before,
192 bool comma_after) const {
193 if (comma_before && !strings.empty())
194 out << ", ";
195 for (const auto i : c10::irange(strings.size())) {
196 if (i > 0)
197 out << ", ";
198 out << strings[i];
199 }
200 if (comma_after && !strings.empty())
201 out << ", ";
202 }
203 // These indentation functions follow the convention that they never emit
204 // leading or trailing newlines when the input string does not have leading
205 // or trailing newlines. It's the responsibility of the calling function
206 // to indent correctly in the context.
emitIndentCodeTemplate207 void emitIndent(std::ostream& out, size_t indent) const {
208 for (C10_UNUSED const auto i : c10::irange(indent)) {
209 out << " ";
210 }
211 }
emitStringWithIndentsCodeTemplate212 void emitStringWithIndents(
213 std::ostream& out,
214 size_t indent,
215 const std::string& str) const {
216 for (auto c : str) {
217 out << c;
218 if (c == '\n') {
219 emitIndent(out, indent);
220 }
221 }
222 }
emitLinesIndentedCodeTemplate223 void emitLinesIndented(
224 std::stringstream& out,
225 size_t indent,
226 const string_list& strings) const {
227 for (const auto i : c10::irange(strings.size())) {
228 if (i > 0)
229 emitIndent(out, indent);
230 emitStringWithIndents(out, indent, strings[i]);
231 if (i + 1 != strings.size())
232 out << "\n";
233 }
234 }
235 std::string template_text;
236 };
237
format(const std::string & fmt,TemplateEnv & env)238 static inline std::string format(const std::string& fmt, TemplateEnv& env) {
239 return CodeTemplate(fmt).format(env);
240 }
241
242 } // namespace at::jit
243