1 #pragma once 2 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/jit/ir/alias_analysis.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 #include <utility> 9 10 namespace torch::jit { 11 12 struct TORCH_API MutationRemover { 13 MutationRemover( 14 std::shared_ptr<Graph> graph, 15 std::optional<std::function<bool(Node*)>> mutation_filter = std::nullopt) mutation_filter_MutationRemover16 : mutation_filter_(std::move(mutation_filter)), 17 aliasDb_(nullptr), 18 graph_(std::move(graph)) {} 19 20 // return true if graph is modified 21 bool removeListMutation(); 22 23 // return true if graph is modified 24 bool removeTensorMutation(); 25 isSpecialMappedOpMutationRemover26 bool isSpecialMappedOp(Node* n) { 27 return n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)") || 28 n->matches( 29 "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)") || 30 n->matches( 31 "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"); 32 } 33 34 bool inplaceOpVariant(Node* n); 35 36 static bool hasSideEffectOrAlias(Value* v, AliasDb* aliasDb); 37 38 private: 39 Node* createSpecialMappedOp(Node* n); 40 bool listMutationFollowingListConstruct(Node* n); 41 bool tryMakeCreationAndMutationAtomic( 42 Value* mutated_value, 43 Node* mutating_op); 44 bool tryMakeUnaliasedIfOutputAndMutationAtomic( 45 Value* mutated_value, 46 Node* mutating_op); 47 // return true if graph is modified 48 bool RemoveListMutation(Block* block); 49 // return true if graph is modified 50 bool RemoveTensorMutation(Block* block); 51 getOrCreateAliasDbMutationRemover52 AliasDb* getOrCreateAliasDb() { 53 if (!aliasDb_) { 54 aliasDb_ = std::make_unique<AliasDb>(graph_); 55 } 56 return aliasDb_.get(); 57 } 58 59 std::optional<std::function<bool(Node*)>> mutation_filter_; 60 std::unique_ptr<AliasDb> aliasDb_ = nullptr; 61 std::shared_ptr<Graph> graph_; 62 }; 63 64 // Removes list mutation with functional equivalents 65 // return true if graph is modified 66 TORCH_API bool RemoveListMutation(const std::shared_ptr<Graph>& graph); 67 68 // Replaces in-place aten ops with their functional equivalents 69 // when it can be proven that this does not change graph semantics 70 // if `mutation_filter` is present, the pass will only attempt to 71 // remove mutation on nodes which return true for the filter 72 // return true if graph is modified 73 TORCH_API bool RemoveTensorMutation( 74 const std::shared_ptr<Graph>& graph, 75 std::optional<std::function<bool(Node*)>> mutation_filter = std::nullopt); 76 77 // Replaces in-place aten activation ops with their functional equivalence 78 TORCH_API bool InplaceToFunctionalActivation( 79 const std::shared_ptr<Graph>& graph); 80 81 } // namespace torch::jit 82