xref: /aosp_15_r20/external/pytorch/aten/src/ATen/code_template.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 
5 #include <sstream>
6 #include <string>
7 #include <unordered_map>
8 #include <vector>
9 
10 namespace at::jit {
11 
12 // A template environment is a mapping from template variable names, e.g.,
13 // identifier (corresponding to $identifier) to their expansions.
14 //
15 // This template environment supports storing strings, numbers and lists
16 // of strings, and can be chained together (so that lookup proceeds in
17 // in the top level environment, and then recurses into a parent
18 // environment if the key is not found.)
19 struct TemplateEnv {
20   TemplateEnv() = default;
TemplateEnvTemplateEnv21   TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
22   TemplateEnv& operator=(const TemplateEnv& parent) = delete;
23 
24   using string_list = std::vector<std::string>;
25 
26   // Add a string 'v' to the map at key 'k'.
sTemplateEnv27   void s(const std::string& k, const std::string& v) {
28     strings_[k] = v;
29     lists_.erase(k);
30   }
31 
32   // Add a number 'v' to the map at key 'k'
33   template <typename T>
dTemplateEnv34   void d(const std::string& k, const T& v) {
35     strings_[k] = std::to_string(v);
36     lists_.erase(k);
37   }
38 
39   // Retrieve the string representation of the value stored at 'k' from the map.
40   // Raises an exception if the key is not found.
sTemplateEnv41   const std::string& s(const std::string& k) const {
42     if (strings_.count(k) == 0) {
43       if (parent) {
44         return parent->s(k);
45       }
46       notFound(k);
47     }
48     return strings_.at(k);
49   }
50 
51   // Store a list of strings 'v' in the map at 'k'.
vTemplateEnv52   void v(const std::string& k, const string_list& v) {
53     lists_[k] = v;
54     strings_.erase(k);
55   }
56 
57   // Retrieve a list of strings stored at 'k' from the map.
58   // Raises an exception if the key is not found.
vTemplateEnv59   const string_list& v(const std::string& k) const {
60     if (lists_.count(k) == 0) {
61       if (parent) {
62         return parent->v(k);
63       }
64       notFound(k);
65     }
66     return lists_.at(k);
67   }
68 
69   // Test if a string 'k' is a string (as opposed to a list.)
keyIsStringTemplateEnv70   bool keyIsString(const std::string& k) const {
71     if (strings_.count(k) > 0)
72       return true;
73     if (lists_.count(k) > 0)
74       return false;
75     if (parent)
76       return parent->keyIsString(k);
77     notFound(k);
78   }
79 
80  private:
notFoundTemplateEnv81   [[noreturn]] void notFound(const std::string& k) const {
82     std::stringstream ss;
83     ss << "key not found: " << k;
84     throw std::logic_error(ss.str());
85   }
86 
87   std::unordered_map<std::string, std::string> strings_;
88   std::unordered_map<std::string, string_list> lists_;
89   TemplateEnv* parent{nullptr};
90 };
91 
92 /*
93 # Match $identifier or ${identifier} and replace with the value in env.
94 # If this identifier is at the beginning of whitespace on a line
95 # and its value is a list then it is treated as
96 # block substitution by indenting all lines of all elements.
97 # If the identifier is on a line starting with non-whitespace and a list
98 # then it is comma separated. ${,foo} will insert a comma before the list
99 # if this list is not empty and ${foo,} will insert one after.
100 */
101 struct CodeTemplate {
CodeTemplateCodeTemplate102   /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
103 
formatCodeTemplate104   std::string format(const TemplateEnv& env) const {
105     std::stringstream out;
106     size_t pos = 0;
107     size_t indent = 0;
108     bool all_whitespace = true;
109     while (pos < template_text.size()) {
110       char c = template_text[pos];
111       if (c == '$') {
112         std::stringstream kss;
113         bool comma_before = false;
114         bool comma_after = false;
115         size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
116         std::string k = kss.str();
117         bool is_string = env.keyIsString(k);
118         if (all_whitespace) {
119           if (is_string)
120             emitStringWithIndents(out, indent, env.s(k));
121           else
122             emitLinesIndented(out, indent, env.v(k));
123         } else {
124           if (is_string)
125             out << env.s(k);
126           else
127             emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
128         }
129         all_whitespace = false;
130         pos = new_pos;
131       } else {
132         out << c;
133         if (!isspace(c))
134           all_whitespace = false;
135         indent++;
136         if (c == '\n') {
137           indent = 0;
138           all_whitespace = true;
139         }
140         pos++;
141       }
142     }
143     return out.str();
144   }
145 
146  private:
147   using string_list = std::vector<std::string>;
charAtCodeTemplate148   char charAt(size_t p) const {
149     if (p >= template_text.size())
150       throw std::logic_error("EOS found in key");
151     return template_text[p];
152   }
parseKeyCodeTemplate153   size_t parseKey(
154       size_t pos,
155       std::ostream& k,
156       bool& comma_before,
157       bool& comma_after) const {
158     comma_before = false;
159     comma_after = false;
160     pos++;
161     if (charAt(pos) == '{') {
162       pos++;
163       if (charAt(pos) == ',') {
164         comma_before = true;
165         pos++;
166       }
167       pos = parseIdent(pos, k);
168       if (charAt(pos) == ',') {
169         comma_after = true;
170         pos++;
171       }
172       if (charAt(pos) != '}')
173         throw std::logic_error("missing terminating '}'");
174       pos++;
175       return pos;
176     } else {
177       return parseIdent(pos, k);
178     }
179   }
parseIdentCodeTemplate180   size_t parseIdent(size_t pos, std::ostream& k) const {
181     while (pos < template_text.size() &&
182            (isalnum(template_text[pos]) || template_text[pos] == '_')) {
183       k << template_text[pos];
184       pos++;
185     }
186     return pos;
187   }
emitCommaSeparatedListCodeTemplate188   void emitCommaSeparatedList(
189       std::ostream& out,
190       const string_list& strings,
191       bool comma_before,
192       bool comma_after) const {
193     if (comma_before && !strings.empty())
194       out << ", ";
195     for (const auto i : c10::irange(strings.size())) {
196       if (i > 0)
197         out << ", ";
198       out << strings[i];
199     }
200     if (comma_after && !strings.empty())
201       out << ", ";
202   }
203   // These indentation functions follow the convention that they never emit
204   // leading or trailing newlines when the input string does not have leading
205   // or trailing newlines. It's the responsibility of the calling function
206   // to indent correctly in the context.
emitIndentCodeTemplate207   void emitIndent(std::ostream& out, size_t indent) const {
208     for (C10_UNUSED const auto i : c10::irange(indent)) {
209       out << " ";
210     }
211   }
emitStringWithIndentsCodeTemplate212   void emitStringWithIndents(
213       std::ostream& out,
214       size_t indent,
215       const std::string& str) const {
216     for (auto c : str) {
217       out << c;
218       if (c == '\n') {
219         emitIndent(out, indent);
220       }
221     }
222   }
emitLinesIndentedCodeTemplate223   void emitLinesIndented(
224       std::stringstream& out,
225       size_t indent,
226       const string_list& strings) const {
227     for (const auto i : c10::irange(strings.size())) {
228       if (i > 0)
229         emitIndent(out, indent);
230       emitStringWithIndents(out, indent, strings[i]);
231       if (i + 1 != strings.size())
232         out << "\n";
233     }
234   }
235   std::string template_text;
236 };
237 
format(const std::string & fmt,TemplateEnv & env)238 static inline std::string format(const std::string& fmt, TemplateEnv& env) {
239   return CodeTemplate(fmt).format(env);
240 }
241 
242 } // namespace at::jit
243