1 #pragma once 2 3 #include <torch/csrc/jit/frontend/function_schema_parser.h> 4 #include <unordered_set> 5 6 namespace torch::utils { 7 8 using SchemaSpecialCasePair = 9 std::pair<c10::FunctionSchema, std::unordered_set<std::string>>; 10 /** 11 * class SchemaInfo 12 * 13 * FunctionSchema wrapper that publicizes argument value specific operator 14 * behavior (mutation, aliasing, special cases, etc...) 15 */ 16 17 struct TORCH_API SchemaInfo { 18 public: SchemaInfoSchemaInfo19 explicit SchemaInfo(c10::FunctionSchema schema) 20 : schema_(std::move(schema)), 21 alias_maps_current_(false), 22 has_init_(false) {} SchemaInfoSchemaInfo23 explicit SchemaInfo(const char* signature) 24 : schema_(torch::jit::parseSchema(signature)), 25 alias_maps_current_(false), 26 has_init_(false) {} 27 28 bool is_mutable(); 29 30 bool is_mutable(const c10::SchemaArgument& argument); 31 32 bool is_mutable(c10::string_view name); 33 34 bool has_argument(c10::string_view name); 35 36 bool is_nondeterministic() const; 37 38 // Returns whether lhs and rhs may alias directly. 39 // This does not account for cases where lhs or rhs are a container that 40 // may contain elements that alias the other argument. 41 // Besides the checks already included in FunctionSchema::may_alias, this 42 // method also accounts special aliasing cases causes by aliasing argument 43 // values supplied from addArgumentValue. 44 bool may_alias( 45 const c10::SchemaArgument& lhs, 46 const c10::SchemaArgument& rhs); 47 48 // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a 49 // container that may contain elements that alias the other argument. Besides 50 // the checks already included in FunctionSchema::may_contain_alias, this 51 // method also accounts for special aliasing cases causes by aliasing argument 52 // values supplied from addArgumentValue. bidirectional = false only returns 53 // whether lhs may contain an alias of rhs while bidirectional = true returns 54 // both directions. 55 bool may_contain_alias( 56 const c10::SchemaArgument& lhs, 57 const c10::SchemaArgument& rhs, 58 bool bidirectional = true); 59 60 void addArgumentValue(const std::string& name, const at::IValue& value); 61 62 void addArgumentValues( 63 const std::vector<std::optional<at::IValue>>& value_list); 64 65 void addArgumentValues( 66 const std::unordered_map<std::string, at::IValue>& values); 67 68 bool hasInputArgumentNamed(const std::string& name) const; 69 70 private: 71 // This function enforces more conservative results when the TORCH_WARN is 72 // triggered from above due to duplicates in an argument list 73 void ensureConservativity( 74 const std::unordered_set<at::Symbol>& duplicates, 75 const std::vector<c10::Argument>& arguments_list, 76 c10::SchemaArgType type); 77 78 void initSchemaInfo(); 79 80 void generateAliasMaps(); 81 82 bool mayContainAliasImpl( 83 const c10::SchemaArgument& lhs, 84 const c10::SchemaArgument& rhs); 85 86 static std::vector<c10::FunctionSchema> getNonDeterministicOps(); 87 88 static std::vector<SchemaSpecialCasePair> getTrainingOps(); 89 90 const std::unordered_set<c10::SchemaArgument>& wildcardSet(); 91 92 const std::unordered_set<c10::SchemaArgument>& containerSet(); 93 94 // Set of all wildcard arguments 95 std::unordered_set<c10::SchemaArgument> wildcard_set_; 96 97 // Set of all container arguments 98 std::unordered_set<c10::SchemaArgument> container_set_; 99 100 // Map of argument IValues 101 std::unordered_map<std::string, at::IValue> value_map_; 102 103 // Alias map of inputs with each other 104 std::vector<std::unordered_set<size_t>> input_alias_map_; 105 106 // Alias map of outputs to inputs 107 std::vector<std::unordered_set<size_t>> output_alias_map_; 108 109 const c10::FunctionSchema schema_; 110 111 bool alias_maps_current_; 112 113 bool has_init_; 114 }; 115 } // namespace torch::utils 116