xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/function_schema_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2 
3 #include <ATen/core/Reduction.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/type_factory.h>
6 #include <torch/csrc/jit/frontend/lexer.h>
7 #include <torch/csrc/jit/frontend/parse_string_literal.h>
8 #include <torch/csrc/jit/frontend/schema_type_parser.h>
9 #include <optional>
10 
11 #include <memory>
12 #include <vector>
13 
14 using at::TypeKind;
15 using c10::Argument;
16 using c10::FunctionSchema;
17 using c10::IValue;
18 using c10::ListType;
19 using c10::OperatorName;
20 
21 namespace torch::jit {
22 
23 namespace {
24 struct SchemaParser {
SchemaParsertorch::jit::__anona00848e00111::SchemaParser25   explicit SchemaParser(const std::string& str, bool allow_typevars)
26       : L(std::make_shared<Source>(
27             c10::string_view(str),
28             std::nullopt,
29             0,
30             nullptr,
31             Source::DONT_COPY)),
32         type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
33 
parseDeclarationtorch::jit::__anona00848e00111::SchemaParser34   std::variant<OperatorName, FunctionSchema> parseDeclaration() {
35     OperatorName name = parseName();
36 
37     // If there is no parentheses coming, then this is just the operator name
38     // without an argument list
39     if (L.cur().kind != '(') {
40       return OperatorName(std::move(name));
41     }
42 
43     std::vector<Argument> arguments;
44     std::vector<Argument> returns;
45     bool kwarg_only = false;
46     bool is_vararg = false;
47     bool is_varret = false;
48     size_t idx = 0;
49     parseList('(', ',', ')', [&] {
50       if (is_vararg)
51         throw(
52             ErrorReport(L.cur())
53             << "... must be the last element of the argument list");
54       if (L.nextIf('*')) {
55         kwarg_only = true;
56       } else if (L.nextIf(TK_DOTS)) {
57         is_vararg = true;
58       } else {
59         arguments.push_back(parseArgument(
60             idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
61       }
62     });
63 
64     // check if all arguments are not-default for vararg schemas
65     if (is_vararg) {
66       for (const auto& arg : arguments) {
67         if (arg.default_value().has_value()) {
68           throw(
69               ErrorReport(L.cur())
70               << "schemas with vararg (...) can't have default value args");
71         }
72       }
73     }
74 
75     idx = 0;
76     L.expect(TK_ARROW);
77     if (L.nextIf(TK_DOTS)) {
78       is_varret = true;
79     } else if (L.cur().kind == '(') {
80       parseList('(', ',', ')', [&] {
81         if (is_varret) {
82           throw(
83               ErrorReport(L.cur())
84               << "... must be the last element of the return list");
85         }
86         if (L.nextIf(TK_DOTS)) {
87           is_varret = true;
88         } else {
89           returns.push_back(
90               parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
91         }
92       });
93     } else {
94       returns.push_back(
95           parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
96     }
97 
98     return FunctionSchema(
99         std::move(name.name),
100         std::move(name.overload_name),
101         std::move(arguments),
102         std::move(returns),
103         is_vararg,
104         is_varret);
105   }
106 
parseNametorch::jit::__anona00848e00111::SchemaParser107   c10::OperatorName parseName() {
108     std::string name = L.expect(TK_IDENT).text();
109     if (L.nextIf(':')) {
110       L.expect(':');
111       name = name + "::" + L.expect(TK_IDENT).text();
112     }
113     std::string overload_name = "";
114     if (L.nextIf('.')) {
115       overload_name = L.expect(TK_IDENT).text();
116     }
117     // default is used as an attribute on the `OpOverloadPacket`
118     // (obtained using `torch.ops.aten.foo`) to get the operator
119     // overload with overload name as an empty string
120     // and so shouldn't be used as an overload name
121     // also disallow dunder attribute names to be overload names
122     bool is_a_valid_overload_name =
123         !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
124     TORCH_CHECK(
125         is_a_valid_overload_name,
126         overload_name,
127         " is not a legal overload name for aten operators");
128     return {name, overload_name};
129   }
130 
parseDeclarationstorch::jit::__anona00848e00111::SchemaParser131   std::vector<std::variant<OperatorName, FunctionSchema>> parseDeclarations() {
132     std::vector<std::variant<OperatorName, FunctionSchema>> results;
133     do {
134       results.emplace_back(parseDeclaration());
135     } while (L.nextIf(TK_NEWLINE));
136     L.expect(TK_EOF);
137     return results;
138   }
139 
parseExactlyOneDeclarationtorch::jit::__anona00848e00111::SchemaParser140   std::variant<OperatorName, FunctionSchema> parseExactlyOneDeclaration() {
141     auto result = parseDeclaration();
142     L.nextIf(TK_NEWLINE);
143     L.expect(TK_EOF);
144     return result;
145   }
146 
parseArgumenttorch::jit::__anona00848e00111::SchemaParser147   Argument parseArgument(size_t /*idx*/, bool is_return, bool kwarg_only) {
148     // fake and real type coincide except for Layout/MemoryFormat/ScalarType
149     // the fake type for these is Int instead
150     auto p = type_parser.parseFakeAndRealType();
151     auto fake_type = std::move(std::get<0>(p));
152     auto real_type = std::move(std::get<1>(p));
153     auto alias_info = std::move(std::get<2>(p));
154     std::optional<int32_t> N;
155     std::optional<IValue> default_value;
156     std::optional<std::string> alias_set;
157     std::string name;
158     if (L.nextIf('[')) {
159       // note: an array with a size hint can only occur at the Argument level
160       fake_type = ListType::create(std::move(fake_type));
161       real_type = ListType::create(std::move(real_type));
162       N = std::stoll(L.expect(TK_NUMBER).text());
163       L.expect(']');
164       auto container = type_parser.parseAliasAnnotation();
165       if (alias_info) {
166         if (!container) {
167           container = std::optional<at::AliasInfo>(at::AliasInfo());
168           container->setIsWrite(alias_info->isWrite());
169         }
170         container->addContainedType(std::move(*alias_info));
171       }
172       alias_info = std::move(container);
173       if (L.nextIf('?')) {
174         fake_type =
175             c10::TypeFactory::create<c10::OptionalType>(std::move(fake_type));
176         real_type =
177             c10::TypeFactory::create<c10::OptionalType>(std::move(real_type));
178       }
179     }
180     if (is_return) {
181       // optionally field names in return values
182       if (L.cur().kind == TK_IDENT) {
183         name = L.next().text();
184       } else {
185         name = "";
186       }
187     } else {
188       name = L.expect(TK_IDENT).text();
189       if (L.nextIf('=')) {
190         // NB: this means we have to unswizzle default too
191         default_value =
192             parseDefaultValue(*fake_type, fake_type->kind(), *real_type, N);
193       }
194     }
195     return Argument(
196         std::move(name),
197         std::move(fake_type),
198         std::move(real_type),
199         N,
200         std::move(default_value),
201         !is_return && kwarg_only,
202         std::move(alias_info));
203   }
204 
isPossiblyOptionalScalarTypetorch::jit::__anona00848e00111::SchemaParser205   bool isPossiblyOptionalScalarType(const c10::Type& type) {
206     if (type.kind() == at::ScalarTypeType::Kind) {
207       return true;
208     }
209     if (type.kind() == at::OptionalType::Kind) {
210       for (const auto& inner : type.containedTypes()) {
211         if (isPossiblyOptionalScalarType(*inner))
212           return true;
213       }
214     }
215     return false;
216   }
217 
parseSingleConstanttorch::jit::__anona00848e00111::SchemaParser218   IValue parseSingleConstant(
219       const c10::Type& type,
220       TypeKind kind,
221       const c10::Type& real_type) {
222     if (kind == c10::TypeKind::DynamicType) {
223       return parseSingleConstant(
224           type, type.expectRef<c10::DynamicType>().dynamicKind(), real_type);
225     }
226     const auto& str2dtype = c10::getStringToDtypeMap();
227     switch (L.cur().kind) {
228       case TK_TRUE:
229         L.next();
230         return true;
231       case TK_FALSE:
232         L.next();
233         return false;
234       case TK_NONE:
235         L.next();
236         return IValue();
237       case TK_STRINGLITERAL: {
238         auto token = L.next();
239         return parseStringLiteral(token.range, token.text());
240       }
241       case TK_IDENT: {
242         auto tok = L.next();
243         auto text = tok.text();
244         // NB: float/complex/long are here for BC purposes. Other dtypes
245         // are handled via str2dtype.
246         // Please don't add more cases to this if-else block.
247         if ("float" == text) {
248           return static_cast<int64_t>(at::kFloat);
249         } else if ("complex" == text) {
250           return static_cast<int64_t>(at::kComplexFloat);
251         } else if ("long" == text) {
252           return static_cast<int64_t>(at::kLong);
253         } else if ("strided" == text) {
254           return static_cast<int64_t>(at::kStrided);
255         } else if ("Mean" == text) {
256           return static_cast<int64_t>(at::Reduction::Mean);
257         } else if ("contiguous_format" == text) {
258           return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
259         } else if (
260             isPossiblyOptionalScalarType(real_type) &&
261             str2dtype.count(text) > 0) {
262           return static_cast<int64_t>(str2dtype.at(text));
263         } else {
264           throw(ErrorReport(L.cur().range) << "invalid numeric default value");
265         }
266       }
267       default:
268         std::string n;
269         if (L.nextIf('-'))
270           n = "-" + L.expect(TK_NUMBER).text();
271         else
272           n = L.expect(TK_NUMBER).text();
273 
274         if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
275           auto imag = std::stod(n.substr(0, n.size() - 1));
276           return c10::complex<double>(0, imag);
277         } else if (
278             kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
279             n.find('e') != std::string::npos) {
280           return std::stod(n);
281         } else {
282           int64_t v = std::stoll(n);
283           return v;
284         }
285     }
286   }
convertToListtorch::jit::__anona00848e00111::SchemaParser287   IValue convertToList(
288       const c10::Type& type,
289       TypeKind kind,
290       const SourceRange& range,
291       const std::vector<IValue>& vs) {
292     switch (kind) {
293       case TypeKind::ComplexType:
294         return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
295       case TypeKind::FloatType:
296         return fmap(vs, [](const IValue& v) { return v.toDouble(); });
297       case TypeKind::IntType:
298         return fmap(vs, [](const IValue& v) { return v.toInt(); });
299       case TypeKind::BoolType:
300         return fmap(vs, [](const IValue& v) { return v.toBool(); });
301       case TypeKind::DynamicType:
302         return convertToList(
303             type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
304       default:
305         throw(
306             ErrorReport(range)
307             << "lists are only supported for float, int and complex types");
308     }
309   }
parseConstantListtorch::jit::__anona00848e00111::SchemaParser310   IValue parseConstantList(
311       const c10::Type& type,
312       TypeKind kind,
313       const c10::Type& real_type) {
314     auto tok = L.expect('[');
315     std::vector<IValue> vs;
316     if (L.cur().kind != ']') {
317       do {
318         vs.push_back(parseSingleConstant(type, kind, real_type));
319       } while (L.nextIf(','));
320     }
321     L.expect(']');
322     return convertToList(type, kind, tok.range, vs);
323   }
324 
parseTensorDefaulttorch::jit::__anona00848e00111::SchemaParser325   IValue parseTensorDefault(const SourceRange& /*range*/) {
326     L.expect(TK_NONE);
327     return IValue();
328   }
parseDefaultValuetorch::jit::__anona00848e00111::SchemaParser329   IValue parseDefaultValue(
330       const c10::Type& arg_type,
331       TypeKind kind,
332       const c10::Type& real_type,
333       std::optional<int32_t> arg_N) {
334     auto range = L.cur().range;
335     switch (kind) {
336       case TypeKind::TensorType:
337       case TypeKind::GeneratorType:
338       case TypeKind::QuantizerType: {
339         return parseTensorDefault(range);
340       } break;
341       case TypeKind::StringType:
342       case TypeKind::OptionalType:
343       case TypeKind::NumberType:
344       case TypeKind::IntType:
345       case TypeKind::BoolType:
346       case TypeKind::FloatType:
347       case TypeKind::ComplexType:
348         return parseSingleConstant(arg_type, kind, real_type);
349         break;
350       case TypeKind::DeviceObjType: {
351         auto device_text =
352             parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
353         return c10::Device(device_text);
354         break;
355       }
356       case TypeKind::ListType: {
357         auto elem_type = arg_type.containedType(0);
358         auto real_elem_type = real_type.containedType(0);
359         if (L.cur().kind == TK_IDENT) {
360           return parseTensorDefault(range);
361         } else if (arg_N && L.cur().kind != '[') {
362           IValue v = parseSingleConstant(
363               *elem_type, elem_type->kind(), *real_elem_type);
364           std::vector<IValue> repeated(*arg_N, v);
365           return convertToList(*elem_type, elem_type->kind(), range, repeated);
366         } else {
367           return parseConstantList(
368               *elem_type, elem_type->kind(), *real_elem_type);
369         }
370       } break;
371       case TypeKind::DynamicType:
372         return parseDefaultValue(
373             arg_type,
374             arg_type.expectRef<c10::DynamicType>().dynamicKind(),
375             real_type,
376             arg_N);
377       default:
378         throw(ErrorReport(range) << "unexpected type, file a bug report");
379     }
380     return IValue(); // silence warnings
381   }
382 
parseListtorch::jit::__anona00848e00111::SchemaParser383   void parseList(
384       int begin,
385       int sep,
386       int end,
387       c10::function_ref<void()> callback) {
388     auto r = L.cur().range;
389     if (begin != TK_NOTHING)
390       L.expect(begin);
391     if (L.cur().kind != end) {
392       do {
393         callback();
394       } while (L.nextIf(sep));
395     }
396     if (end != TK_NOTHING)
397       L.expect(end);
398   }
399   Lexer L;
400   SchemaTypeParser type_parser;
401 };
402 } // namespace
403 
parseSchemaOrName(const std::string & schemaOrName,bool allow_typevars)404 std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
405     const std::string& schemaOrName,
406     bool allow_typevars) {
407   // We're ignoring aten and prim for BC reasons
408   if (schemaOrName.rfind("aten::", 0) == 0 ||
409       schemaOrName.rfind("prim::", 0) == 0) {
410     allow_typevars = true;
411   }
412   return SchemaParser(schemaOrName, allow_typevars)
413       .parseExactlyOneDeclaration();
414 }
415 
parseSchema(const std::string & schema,bool allow_typevars)416 FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
417   auto parsed = parseSchemaOrName(schema, allow_typevars);
418   TORCH_CHECK(
419       std::holds_alternative<FunctionSchema>(parsed),
420       "Tried to parse a function schema but only the operator name was given");
421   return std::get<FunctionSchema>(std::move(parsed));
422 }
423 
parseName(const std::string & name)424 OperatorName parseName(const std::string& name) {
425   auto parsed = parseSchemaOrName(name);
426   TORCH_CHECK(
427       std::holds_alternative<OperatorName>(parsed),
428       "Tried to parse an operator name but function schema was given");
429   return std::get<OperatorName>(std::move(parsed));
430 }
431 
432 } // namespace torch::jit
433