xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/schema_info.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/frontend/function_schema_parser.h>
4 #include <unordered_set>
5 
6 namespace torch::utils {
7 
8 using SchemaSpecialCasePair =
9     std::pair<c10::FunctionSchema, std::unordered_set<std::string>>;
10 /**
11  * class SchemaInfo
12  *
13  * FunctionSchema wrapper that publicizes argument value specific operator
14  * behavior (mutation, aliasing, special cases, etc...)
15  */
16 
17 struct TORCH_API SchemaInfo {
18  public:
SchemaInfoSchemaInfo19   explicit SchemaInfo(c10::FunctionSchema schema)
20       : schema_(std::move(schema)),
21         alias_maps_current_(false),
22         has_init_(false) {}
SchemaInfoSchemaInfo23   explicit SchemaInfo(const char* signature)
24       : schema_(torch::jit::parseSchema(signature)),
25         alias_maps_current_(false),
26         has_init_(false) {}
27 
28   bool is_mutable();
29 
30   bool is_mutable(const c10::SchemaArgument& argument);
31 
32   bool is_mutable(c10::string_view name);
33 
34   bool has_argument(c10::string_view name);
35 
36   bool is_nondeterministic() const;
37 
38   // Returns whether lhs and rhs may alias directly.
39   // This does not account for cases where lhs or rhs are a container that
40   // may contain elements that alias the other argument.
41   // Besides the checks already included in FunctionSchema::may_alias, this
42   // method also accounts special aliasing cases causes by aliasing argument
43   // values supplied from addArgumentValue.
44   bool may_alias(
45       const c10::SchemaArgument& lhs,
46       const c10::SchemaArgument& rhs);
47 
48   // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a
49   // container that may contain elements that alias the other argument. Besides
50   // the checks already included in FunctionSchema::may_contain_alias, this
51   // method also accounts for special aliasing cases causes by aliasing argument
52   // values supplied from addArgumentValue. bidirectional = false only returns
53   // whether lhs may contain an alias of rhs while bidirectional = true returns
54   // both directions.
55   bool may_contain_alias(
56       const c10::SchemaArgument& lhs,
57       const c10::SchemaArgument& rhs,
58       bool bidirectional = true);
59 
60   void addArgumentValue(const std::string& name, const at::IValue& value);
61 
62   void addArgumentValues(
63       const std::vector<std::optional<at::IValue>>& value_list);
64 
65   void addArgumentValues(
66       const std::unordered_map<std::string, at::IValue>& values);
67 
68   bool hasInputArgumentNamed(const std::string& name) const;
69 
70  private:
71   // This function enforces more conservative results when the TORCH_WARN is
72   // triggered from above due to duplicates in an argument list
73   void ensureConservativity(
74       const std::unordered_set<at::Symbol>& duplicates,
75       const std::vector<c10::Argument>& arguments_list,
76       c10::SchemaArgType type);
77 
78   void initSchemaInfo();
79 
80   void generateAliasMaps();
81 
82   bool mayContainAliasImpl(
83       const c10::SchemaArgument& lhs,
84       const c10::SchemaArgument& rhs);
85 
86   static std::vector<c10::FunctionSchema> getNonDeterministicOps();
87 
88   static std::vector<SchemaSpecialCasePair> getTrainingOps();
89 
90   const std::unordered_set<c10::SchemaArgument>& wildcardSet();
91 
92   const std::unordered_set<c10::SchemaArgument>& containerSet();
93 
94   // Set of all wildcard arguments
95   std::unordered_set<c10::SchemaArgument> wildcard_set_;
96 
97   // Set of all container arguments
98   std::unordered_set<c10::SchemaArgument> container_set_;
99 
100   // Map of argument IValues
101   std::unordered_map<std::string, at::IValue> value_map_;
102 
103   // Alias map of inputs with each other
104   std::vector<std::unordered_set<size_t>> input_alias_map_;
105 
106   // Alias map of outputs to inputs
107   std::vector<std::unordered_set<size_t>> output_alias_map_;
108 
109   const c10::FunctionSchema schema_;
110 
111   bool alias_maps_current_;
112 
113   bool has_init_;
114 };
115 } // namespace torch::utils
116