xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/restore_mutation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/symbol.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/ir/alias_analysis.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 
9 namespace torch::jit {
10 
11 // A map which stores if an activation operator can perform type promotion
12 const std::unordered_map<Symbol, bool> activation_type_promotion_mapping = {
13     {aten::sigmoid, true},
14     {aten::tanh, true},
15     {aten::celu, false},
16     {aten::elu, false},
17     {aten::gelu, false},
18     {aten::glu, false},
19     {aten::hardshrink, false},
20     {aten::hardsigmoid, false},
21     {aten::hardswish, false},
22     {aten::hardtanh, false},
23     {aten::leaky_relu, false},
24     {aten::prelu, false},
25     {aten::relu6, false},
26     {aten::relu, false},
27     {aten::rrelu, false},
28     {aten::selu, false},
29     {aten::silu, false}};
30 
31 class FunctionalToInplaceRewriter {
32  public:
33   FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph);
34 
35   bool FunctionalToInplace(Block* block);
36 
37  private:
getOrCreateAliasDb()38   AliasDb* getOrCreateAliasDb() {
39     if (!aliasDb_) {
40       aliasDb_ = std::make_unique<AliasDb>(graph_);
41     }
42     return aliasDb_.get();
43   }
44 
45   bool CanBeInplace(Node* node);
46 
47   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
48   std::shared_ptr<Graph> graph_;
49 };
50 
51 // A common application scenario is to apply InplaceToFunctionalActivation
52 // before some JIT optimization passes, so that those passes are less
53 // constrained by in-place ops. After those passes are done, we can call
54 // FunctionalToInplaceActivation to recover in-place activation ops,
55 // so that we won't lose the performance benefit coming from memory reduction.
56 
57 // Replaces functional aten activation ops with their in-place equivalents
58 TORCH_API bool FunctionalToInplaceActivation(
59     const std::shared_ptr<Graph>& graph);
60 
61 } // namespace torch::jit
62