1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2
3 #include <ATen/core/Reduction.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/type_factory.h>
6 #include <torch/csrc/jit/frontend/lexer.h>
7 #include <torch/csrc/jit/frontend/parse_string_literal.h>
8 #include <torch/csrc/jit/frontend/schema_type_parser.h>
9 #include <optional>
10
11 #include <memory>
12 #include <vector>
13
14 using at::TypeKind;
15 using c10::Argument;
16 using c10::FunctionSchema;
17 using c10::IValue;
18 using c10::ListType;
19 using c10::OperatorName;
20
21 namespace torch::jit {
22
23 namespace {
24 struct SchemaParser {
SchemaParsertorch::jit::__anona00848e00111::SchemaParser25 explicit SchemaParser(const std::string& str, bool allow_typevars)
26 : L(std::make_shared<Source>(
27 c10::string_view(str),
28 std::nullopt,
29 0,
30 nullptr,
31 Source::DONT_COPY)),
32 type_parser(L, /*parse_complete_tensor_types*/ false, allow_typevars) {}
33
parseDeclarationtorch::jit::__anona00848e00111::SchemaParser34 std::variant<OperatorName, FunctionSchema> parseDeclaration() {
35 OperatorName name = parseName();
36
37 // If there is no parentheses coming, then this is just the operator name
38 // without an argument list
39 if (L.cur().kind != '(') {
40 return OperatorName(std::move(name));
41 }
42
43 std::vector<Argument> arguments;
44 std::vector<Argument> returns;
45 bool kwarg_only = false;
46 bool is_vararg = false;
47 bool is_varret = false;
48 size_t idx = 0;
49 parseList('(', ',', ')', [&] {
50 if (is_vararg)
51 throw(
52 ErrorReport(L.cur())
53 << "... must be the last element of the argument list");
54 if (L.nextIf('*')) {
55 kwarg_only = true;
56 } else if (L.nextIf(TK_DOTS)) {
57 is_vararg = true;
58 } else {
59 arguments.push_back(parseArgument(
60 idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
61 }
62 });
63
64 // check if all arguments are not-default for vararg schemas
65 if (is_vararg) {
66 for (const auto& arg : arguments) {
67 if (arg.default_value().has_value()) {
68 throw(
69 ErrorReport(L.cur())
70 << "schemas with vararg (...) can't have default value args");
71 }
72 }
73 }
74
75 idx = 0;
76 L.expect(TK_ARROW);
77 if (L.nextIf(TK_DOTS)) {
78 is_varret = true;
79 } else if (L.cur().kind == '(') {
80 parseList('(', ',', ')', [&] {
81 if (is_varret) {
82 throw(
83 ErrorReport(L.cur())
84 << "... must be the last element of the return list");
85 }
86 if (L.nextIf(TK_DOTS)) {
87 is_varret = true;
88 } else {
89 returns.push_back(
90 parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
91 }
92 });
93 } else {
94 returns.push_back(
95 parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
96 }
97
98 return FunctionSchema(
99 std::move(name.name),
100 std::move(name.overload_name),
101 std::move(arguments),
102 std::move(returns),
103 is_vararg,
104 is_varret);
105 }
106
parseNametorch::jit::__anona00848e00111::SchemaParser107 c10::OperatorName parseName() {
108 std::string name = L.expect(TK_IDENT).text();
109 if (L.nextIf(':')) {
110 L.expect(':');
111 name = name + "::" + L.expect(TK_IDENT).text();
112 }
113 std::string overload_name = "";
114 if (L.nextIf('.')) {
115 overload_name = L.expect(TK_IDENT).text();
116 }
117 // default is used as an attribute on the `OpOverloadPacket`
118 // (obtained using `torch.ops.aten.foo`) to get the operator
119 // overload with overload name as an empty string
120 // and so shouldn't be used as an overload name
121 // also disallow dunder attribute names to be overload names
122 bool is_a_valid_overload_name =
123 !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
124 TORCH_CHECK(
125 is_a_valid_overload_name,
126 overload_name,
127 " is not a legal overload name for aten operators");
128 return {name, overload_name};
129 }
130
parseDeclarationstorch::jit::__anona00848e00111::SchemaParser131 std::vector<std::variant<OperatorName, FunctionSchema>> parseDeclarations() {
132 std::vector<std::variant<OperatorName, FunctionSchema>> results;
133 do {
134 results.emplace_back(parseDeclaration());
135 } while (L.nextIf(TK_NEWLINE));
136 L.expect(TK_EOF);
137 return results;
138 }
139
parseExactlyOneDeclarationtorch::jit::__anona00848e00111::SchemaParser140 std::variant<OperatorName, FunctionSchema> parseExactlyOneDeclaration() {
141 auto result = parseDeclaration();
142 L.nextIf(TK_NEWLINE);
143 L.expect(TK_EOF);
144 return result;
145 }
146
parseArgumenttorch::jit::__anona00848e00111::SchemaParser147 Argument parseArgument(size_t /*idx*/, bool is_return, bool kwarg_only) {
148 // fake and real type coincide except for Layout/MemoryFormat/ScalarType
149 // the fake type for these is Int instead
150 auto p = type_parser.parseFakeAndRealType();
151 auto fake_type = std::move(std::get<0>(p));
152 auto real_type = std::move(std::get<1>(p));
153 auto alias_info = std::move(std::get<2>(p));
154 std::optional<int32_t> N;
155 std::optional<IValue> default_value;
156 std::optional<std::string> alias_set;
157 std::string name;
158 if (L.nextIf('[')) {
159 // note: an array with a size hint can only occur at the Argument level
160 fake_type = ListType::create(std::move(fake_type));
161 real_type = ListType::create(std::move(real_type));
162 N = std::stoll(L.expect(TK_NUMBER).text());
163 L.expect(']');
164 auto container = type_parser.parseAliasAnnotation();
165 if (alias_info) {
166 if (!container) {
167 container = std::optional<at::AliasInfo>(at::AliasInfo());
168 container->setIsWrite(alias_info->isWrite());
169 }
170 container->addContainedType(std::move(*alias_info));
171 }
172 alias_info = std::move(container);
173 if (L.nextIf('?')) {
174 fake_type =
175 c10::TypeFactory::create<c10::OptionalType>(std::move(fake_type));
176 real_type =
177 c10::TypeFactory::create<c10::OptionalType>(std::move(real_type));
178 }
179 }
180 if (is_return) {
181 // optionally field names in return values
182 if (L.cur().kind == TK_IDENT) {
183 name = L.next().text();
184 } else {
185 name = "";
186 }
187 } else {
188 name = L.expect(TK_IDENT).text();
189 if (L.nextIf('=')) {
190 // NB: this means we have to unswizzle default too
191 default_value =
192 parseDefaultValue(*fake_type, fake_type->kind(), *real_type, N);
193 }
194 }
195 return Argument(
196 std::move(name),
197 std::move(fake_type),
198 std::move(real_type),
199 N,
200 std::move(default_value),
201 !is_return && kwarg_only,
202 std::move(alias_info));
203 }
204
isPossiblyOptionalScalarTypetorch::jit::__anona00848e00111::SchemaParser205 bool isPossiblyOptionalScalarType(const c10::Type& type) {
206 if (type.kind() == at::ScalarTypeType::Kind) {
207 return true;
208 }
209 if (type.kind() == at::OptionalType::Kind) {
210 for (const auto& inner : type.containedTypes()) {
211 if (isPossiblyOptionalScalarType(*inner))
212 return true;
213 }
214 }
215 return false;
216 }
217
parseSingleConstanttorch::jit::__anona00848e00111::SchemaParser218 IValue parseSingleConstant(
219 const c10::Type& type,
220 TypeKind kind,
221 const c10::Type& real_type) {
222 if (kind == c10::TypeKind::DynamicType) {
223 return parseSingleConstant(
224 type, type.expectRef<c10::DynamicType>().dynamicKind(), real_type);
225 }
226 const auto& str2dtype = c10::getStringToDtypeMap();
227 switch (L.cur().kind) {
228 case TK_TRUE:
229 L.next();
230 return true;
231 case TK_FALSE:
232 L.next();
233 return false;
234 case TK_NONE:
235 L.next();
236 return IValue();
237 case TK_STRINGLITERAL: {
238 auto token = L.next();
239 return parseStringLiteral(token.range, token.text());
240 }
241 case TK_IDENT: {
242 auto tok = L.next();
243 auto text = tok.text();
244 // NB: float/complex/long are here for BC purposes. Other dtypes
245 // are handled via str2dtype.
246 // Please don't add more cases to this if-else block.
247 if ("float" == text) {
248 return static_cast<int64_t>(at::kFloat);
249 } else if ("complex" == text) {
250 return static_cast<int64_t>(at::kComplexFloat);
251 } else if ("long" == text) {
252 return static_cast<int64_t>(at::kLong);
253 } else if ("strided" == text) {
254 return static_cast<int64_t>(at::kStrided);
255 } else if ("Mean" == text) {
256 return static_cast<int64_t>(at::Reduction::Mean);
257 } else if ("contiguous_format" == text) {
258 return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
259 } else if (
260 isPossiblyOptionalScalarType(real_type) &&
261 str2dtype.count(text) > 0) {
262 return static_cast<int64_t>(str2dtype.at(text));
263 } else {
264 throw(ErrorReport(L.cur().range) << "invalid numeric default value");
265 }
266 }
267 default:
268 std::string n;
269 if (L.nextIf('-'))
270 n = "-" + L.expect(TK_NUMBER).text();
271 else
272 n = L.expect(TK_NUMBER).text();
273
274 if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
275 auto imag = std::stod(n.substr(0, n.size() - 1));
276 return c10::complex<double>(0, imag);
277 } else if (
278 kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
279 n.find('e') != std::string::npos) {
280 return std::stod(n);
281 } else {
282 int64_t v = std::stoll(n);
283 return v;
284 }
285 }
286 }
convertToListtorch::jit::__anona00848e00111::SchemaParser287 IValue convertToList(
288 const c10::Type& type,
289 TypeKind kind,
290 const SourceRange& range,
291 const std::vector<IValue>& vs) {
292 switch (kind) {
293 case TypeKind::ComplexType:
294 return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
295 case TypeKind::FloatType:
296 return fmap(vs, [](const IValue& v) { return v.toDouble(); });
297 case TypeKind::IntType:
298 return fmap(vs, [](const IValue& v) { return v.toInt(); });
299 case TypeKind::BoolType:
300 return fmap(vs, [](const IValue& v) { return v.toBool(); });
301 case TypeKind::DynamicType:
302 return convertToList(
303 type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
304 default:
305 throw(
306 ErrorReport(range)
307 << "lists are only supported for float, int and complex types");
308 }
309 }
parseConstantListtorch::jit::__anona00848e00111::SchemaParser310 IValue parseConstantList(
311 const c10::Type& type,
312 TypeKind kind,
313 const c10::Type& real_type) {
314 auto tok = L.expect('[');
315 std::vector<IValue> vs;
316 if (L.cur().kind != ']') {
317 do {
318 vs.push_back(parseSingleConstant(type, kind, real_type));
319 } while (L.nextIf(','));
320 }
321 L.expect(']');
322 return convertToList(type, kind, tok.range, vs);
323 }
324
parseTensorDefaulttorch::jit::__anona00848e00111::SchemaParser325 IValue parseTensorDefault(const SourceRange& /*range*/) {
326 L.expect(TK_NONE);
327 return IValue();
328 }
parseDefaultValuetorch::jit::__anona00848e00111::SchemaParser329 IValue parseDefaultValue(
330 const c10::Type& arg_type,
331 TypeKind kind,
332 const c10::Type& real_type,
333 std::optional<int32_t> arg_N) {
334 auto range = L.cur().range;
335 switch (kind) {
336 case TypeKind::TensorType:
337 case TypeKind::GeneratorType:
338 case TypeKind::QuantizerType: {
339 return parseTensorDefault(range);
340 } break;
341 case TypeKind::StringType:
342 case TypeKind::OptionalType:
343 case TypeKind::NumberType:
344 case TypeKind::IntType:
345 case TypeKind::BoolType:
346 case TypeKind::FloatType:
347 case TypeKind::ComplexType:
348 return parseSingleConstant(arg_type, kind, real_type);
349 break;
350 case TypeKind::DeviceObjType: {
351 auto device_text =
352 parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
353 return c10::Device(device_text);
354 break;
355 }
356 case TypeKind::ListType: {
357 auto elem_type = arg_type.containedType(0);
358 auto real_elem_type = real_type.containedType(0);
359 if (L.cur().kind == TK_IDENT) {
360 return parseTensorDefault(range);
361 } else if (arg_N && L.cur().kind != '[') {
362 IValue v = parseSingleConstant(
363 *elem_type, elem_type->kind(), *real_elem_type);
364 std::vector<IValue> repeated(*arg_N, v);
365 return convertToList(*elem_type, elem_type->kind(), range, repeated);
366 } else {
367 return parseConstantList(
368 *elem_type, elem_type->kind(), *real_elem_type);
369 }
370 } break;
371 case TypeKind::DynamicType:
372 return parseDefaultValue(
373 arg_type,
374 arg_type.expectRef<c10::DynamicType>().dynamicKind(),
375 real_type,
376 arg_N);
377 default:
378 throw(ErrorReport(range) << "unexpected type, file a bug report");
379 }
380 return IValue(); // silence warnings
381 }
382
parseListtorch::jit::__anona00848e00111::SchemaParser383 void parseList(
384 int begin,
385 int sep,
386 int end,
387 c10::function_ref<void()> callback) {
388 auto r = L.cur().range;
389 if (begin != TK_NOTHING)
390 L.expect(begin);
391 if (L.cur().kind != end) {
392 do {
393 callback();
394 } while (L.nextIf(sep));
395 }
396 if (end != TK_NOTHING)
397 L.expect(end);
398 }
399 Lexer L;
400 SchemaTypeParser type_parser;
401 };
402 } // namespace
403
parseSchemaOrName(const std::string & schemaOrName,bool allow_typevars)404 std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
405 const std::string& schemaOrName,
406 bool allow_typevars) {
407 // We're ignoring aten and prim for BC reasons
408 if (schemaOrName.rfind("aten::", 0) == 0 ||
409 schemaOrName.rfind("prim::", 0) == 0) {
410 allow_typevars = true;
411 }
412 return SchemaParser(schemaOrName, allow_typevars)
413 .parseExactlyOneDeclaration();
414 }
415
parseSchema(const std::string & schema,bool allow_typevars)416 FunctionSchema parseSchema(const std::string& schema, bool allow_typevars) {
417 auto parsed = parseSchemaOrName(schema, allow_typevars);
418 TORCH_CHECK(
419 std::holds_alternative<FunctionSchema>(parsed),
420 "Tried to parse a function schema but only the operator name was given");
421 return std::get<FunctionSchema>(std::move(parsed));
422 }
423
parseName(const std::string & name)424 OperatorName parseName(const std::string& name) {
425 auto parsed = parseSchemaOrName(name);
426 TORCH_CHECK(
427 std::holds_alternative<OperatorName>(parsed),
428 "Tried to parse an operator name but function schema was given");
429 return std::get<OperatorName>(std::move(parsed));
430 }
431
432 } // namespace torch::jit
433