xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
17 
18 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
19 
20 namespace tensorflow {
21 
22 // Eager op rewrites should inherit from this class and
23 // implement the Run method.
24 class EagerOpRewrite {
25  public:
EagerOpRewrite(string name,string file,string line)26   EagerOpRewrite(string name, string file, string line) {
27     debug_info_.name = name;
28     debug_info_.file = file;
29     debug_info_.line = line;
30   }
31 
~EagerOpRewrite()32   virtual ~EagerOpRewrite() {}
33 
34   // To be implemented by an Eager op rewrite pass.
35   virtual Status Run(EagerOperation* orig_op,
36                      std::unique_ptr<tensorflow::EagerOperation>* out_op) = 0;
37 
38   // Holds information about the rewrite registration.
39   struct DebugInfo {
40     string name, file, line;
41   };
42 
43   // Returns information about the registered Eager op rewrite.
GetDebugInfo()44   DebugInfo GetDebugInfo() const { return debug_info_; }
45 
46  private:
47   DebugInfo debug_info_;
48 };
49 
50 class EagerOpRewriteRegistry {
51  public:
52   // Phases at which the Eager op rewrite pass should run.
53   enum Phase {
54     PRE_EXECUTION = 0,  // right before executing an eager op
55     POST_PLACEMENT = 1  // after device placement
56   };
57 
58   // Add a rewrite pass to the registry.
59   void Register(Phase phase, int32_t ordinal,
60                 std::unique_ptr<EagerOpRewrite> pass);
61 
62   // Run the rewrite pass registered for a given phase.
63   Status RunRewrite(Phase phase, EagerOperation* orig_op,
64                     std::unique_ptr<tensorflow::EagerOperation>* out_op);
65 
66   // Returns the global registry of rewrite passes.
67   static EagerOpRewriteRegistry* Global();
68 
69  private:
70   static constexpr int32_t kNumPhases = 2;
71   // Holds all the registered Eager op rewrites and their ordinal numbers.
72   std::array<std::list<std::pair<std::unique_ptr<EagerOpRewrite>, int32>>,
73              kNumPhases>
74       rewrites_;
75 };
76 
77 namespace eager_rewrite_registration {
78 
79 // This class is used to register a new Eager Op rewrite.
80 class EagerRewriteRegistration {
81  public:
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,int32_t ordinal,std::unique_ptr<EagerOpRewrite> pass)82   EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase, int32_t ordinal,
83                            std::unique_ptr<EagerOpRewrite> pass) {
84     EagerOpRewriteRegistry::Global()->Register(phase, ordinal, std::move(pass));
85   }
86 };
87 
88 }  // namespace eager_rewrite_registration
89 
90 #define REGISTER_REWRITE(phase, ordinal, rewrite)                      \
91   REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, \
92                                ordinal, rewrite)
93 
94 #define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, ordinal, rewrite) \
95   REGISTER_REWRITE_UNIQ(ctr, file, line, phase, ordinal, rewrite)
96 
97 #define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, ordinal, rewrite)       \
98   static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration   \
99       register_rewrite_##ctr(phase, ordinal,                                  \
100                              ::std::unique_ptr<::tensorflow::EagerOpRewrite>( \
101                                  new rewrite(#rewrite, file, #line)))
102 
103 }  // namespace tensorflow
104 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
105