xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/schema_matching.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/Export.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/named_value.h>
5 
6 #include <ATen/core/function_schema.h>
7 
8 namespace torch::jit {
9 
10 // Try to match a list of inputs and keyword 'attributes' to this
11 // schema. Return the flat list of positional inputs to the call or
12 // `std::nullopt` on failure (`failure_messages` contains a good error
13 // report in this case)
14 
15 struct MatchedSchema {
16   std::vector<Value*> inputs;
17   std::vector<TypePtr> return_types;
18   c10::OptNameList return_field_names;
19   std::string schema_name;
20 };
21 
22 TORCH_API bool isBlockListedSchema(const FunctionSchema& schema);
23 
24 TORCH_API MatchedSchema matchSchema(
25     const ::c10::FunctionSchema& schema,
26     const SourceRange& loc,
27     Graph& graph,
28     at::ArrayRef<NamedValue> args,
29     at::ArrayRef<NamedValue> kwargs,
30     const std::optional<NamedValue>& self = std::nullopt);
31 
32 TORCH_API std::pair<size_t, MatchedSchema> matchSchemas(
33     const std::vector<const ::c10::FunctionSchema*>& schemas,
34     const SourceRange& loc,
35     Graph& graph,
36     at::ArrayRef<NamedValue> args,
37     at::ArrayRef<NamedValue> kwargs,
38     const std::optional<NamedValue>& self = std::nullopt,
39     bool render_errors = false);
40 
41 TORCH_API bool convertibleToList(
42     const TypePtr& type,
43     const TypePtr& list_type_);
44 
45 TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema);
46 
47 TORCH_API Value* emitBuiltinCall(
48     const SourceRange& loc,
49     Graph& graph,
50     Symbol name,
51     at::ArrayRef<NamedValue> args,
52     at::ArrayRef<NamedValue> kwargs,
53     const std::optional<NamedValue>& self = std::nullopt);
54 
55 TORCH_API std::optional<size_t> findInputWithName(
56     const std::string& name,
57     at::ArrayRef<NamedValue> kwargs,
58     bool is_aten = false);
59 
60 // applies implicit conversion from value trying to turn it into type
61 // concrete_type it succeeds if the return_value->isSubtypeOf(concrete_type)
62 TORCH_API Value* tryConvertToType(
63     const SourceRange& loc,
64     Graph& graph,
65     const TypePtr& concrete_type,
66     Value* value,
67     bool allow_conversions);
68 } // namespace torch::jit
69