1 #pragma once 2 3 #include <ATen/core/symbol.h> 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/ir/alias_analysis.h> 7 #include <torch/csrc/jit/ir/ir.h> 8 9 namespace torch::jit { 10 11 // A map which stores if an activation operator can perform type promotion 12 const std::unordered_map<Symbol, bool> activation_type_promotion_mapping = { 13 {aten::sigmoid, true}, 14 {aten::tanh, true}, 15 {aten::celu, false}, 16 {aten::elu, false}, 17 {aten::gelu, false}, 18 {aten::glu, false}, 19 {aten::hardshrink, false}, 20 {aten::hardsigmoid, false}, 21 {aten::hardswish, false}, 22 {aten::hardtanh, false}, 23 {aten::leaky_relu, false}, 24 {aten::prelu, false}, 25 {aten::relu6, false}, 26 {aten::relu, false}, 27 {aten::rrelu, false}, 28 {aten::selu, false}, 29 {aten::silu, false}}; 30 31 class FunctionalToInplaceRewriter { 32 public: 33 FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph); 34 35 bool FunctionalToInplace(Block* block); 36 37 private: getOrCreateAliasDb()38 AliasDb* getOrCreateAliasDb() { 39 if (!aliasDb_) { 40 aliasDb_ = std::make_unique<AliasDb>(graph_); 41 } 42 return aliasDb_.get(); 43 } 44 45 bool CanBeInplace(Node* node); 46 47 std::unique_ptr<AliasDb> aliasDb_ = nullptr; 48 std::shared_ptr<Graph> graph_; 49 }; 50 51 // A common application scenario is to apply InplaceToFunctionalActivation 52 // before some JIT optimization passes, so that those passes are less 53 // constrained by in-place ops. After those passes are done, we can call 54 // FunctionalToInplaceActivation to recover in-place activation ops, 55 // so that we won't lose the performance benefit coming from memory reduction. 56 57 // Replaces functional aten activation ops with their in-place equivalents 58 TORCH_API bool FunctionalToInplaceActivation( 59 const std::shared_ptr<Graph>& graph); 60 61 } // namespace torch::jit 62