xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_pass_pipeline.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/service/compilation_stats.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 
31 namespace xla {
32 
33 class PhaseOrderPipeline;
34 
35 // Pipeline of HLO passes.
36 class HloPassPipeline : public HloPassInterface {
37  public:
38   explicit HloPassPipeline(const std::string& name,
39                            CompilationStats* compilation_stats = nullptr)
name_(name)40       : name_(name), compilation_stats_(compilation_stats) {
41     if (compilation_stats == nullptr) {
42       empty_compilation_stats_ = CompilationStats::MakeNoopStats();
43       compilation_stats_ = empty_compilation_stats_.get();
44     }
45   }
name()46   absl::string_view name() const override { return name_; }
47 
48   // Add a pass to the pipeline. It should be called with the arguments for the
49   // pass constructor:
50   //
51   //   pipeline.AddPass<FooPass>(constructor_arg1, constructor_arg2);
52   //
53   // Returns a reference to the added pass.
54   template <typename T, typename... Args>
AddPass(Args &&...args)55   T& AddPass(Args&&... args) {
56     CHECK(!run_called_) << "AddPass cannot be called after Run";
57     auto pass = new T(std::forward<Args>(args)...);
58     passes_.push_back(std::unique_ptr<T>(pass));
59     return *pass;
60   }
61 
62   // Add an invariant-checking pass to the pipeline. It will be run before and
63   // after each HLO pass. The invariant checking pass must not mutate the graph
64   // (it is required to always return "false" from its Run() method).
65   template <typename T, typename... Args>
AddInvariantChecker(Args &&...args)66   T& AddInvariantChecker(Args&&... args) {
67     CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run";
68     auto pass = new T(std::forward<Args>(args)...);
69     invariant_checkers_.push_back(std::unique_ptr<T>(pass));
70     return *pass;
71   }
72 
73   // Add an invariant-checking pass to the pipeline on debug builds only.
74   template <typename T, typename... Args>
AddInvariantCheckerDebug(Args &&...args)75   void AddInvariantCheckerDebug(Args&&... args) {
76 #ifndef NDEBUG
77     AddInvariantChecker<T>(std::forward<Args>(args)...);
78 #endif  // NDEBUG
79   }
80 
81   using HloPassInterface::Run;
82   StatusOr<bool> Run(
83       HloModule* module,
84       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
85   using HloPassInterface::RunOnModuleGroup;
86   StatusOr<bool> RunOnModuleGroup(
87       HloModuleGroup* module_group,
88       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
89 
IsPassPipeline()90   bool IsPassPipeline() override { return true; }
91 
92   // Return size of passes_.
PassesSize()93   int PassesSize() { return passes_.size(); }
94   // Return reference to pass specified by index.
GetPass(int index)95   HloPassInterface& GetPass(int index) { return *passes_[index]; }
96 
97  private:
98   // Returns the set of passes which are enabled. DebugOptions can selectively
99   // disable passes via --xla_disable_hlo_passes flag.
100   std::vector<HloPassInterface*> GetEnabledPasses(
101       const DebugOptions& debug_options);
102 
103   // Maybe dumps the given module or module group depending on flag values
104   // contained in DebugOptions of module config. If it is dumped, saves the
105   // filenames of the dumps into module metadata.
106   void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group,
107                                     absl::string_view after_pass_name,
108                                     absl::string_view before_pass_name);
109   void MaybeDumpHloAndSaveFilenames(HloModule& module,
110                                     absl::string_view after_pass_name,
111                                     absl::string_view before_pass_name);
112 
113   // Runs the invariant checker on the given HLO for specified
114   // `execution_threads`. Empty `execution_threads` means all execution threads
115   // are included. HloT can be either HloModule or HloModuleGroup.
116   template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)117   Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name) {
118     return RunInvariantCheckers(hlo, after_pass_name, /*execution_threads=*/{});
119   }
120   template <typename HloT>
121   Status RunInvariantCheckers(
122       HloT* hlo, absl::string_view after_pass_name,
123       const absl::flat_hash_set<absl::string_view>& execution_threads);
124 
125   // Helper which runs the given pass on the given HLO. HloT can be either
126   // HloModule or HloModuleGroup.
127   template <typename HloT>
128   StatusOr<bool> RunPassesInternal(
129       HloT* hlo, const DebugOptions& debug_options,
130       const absl::flat_hash_set<absl::string_view>& execution_threads);
131 
132   // Helpers which run the given passes on the given HLO construct. Only
133   // computations with specified `execution_threads` are considered by the pass,
134   // empty thread list means all `execution_threads` are considered. These
135   // helpers enable templating of the core of the pipeline logic by providing
136   // HloModule and HloModuleGroup specific methods with the same name.
RunHelper(HloPassInterface * pass,HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)137   static StatusOr<bool> RunHelper(
138       HloPassInterface* pass, HloModule* module,
139       const absl::flat_hash_set<absl::string_view>& execution_threads) {
140     TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module, execution_threads));
141     module->Cleanup();
142     return changed;
143   }
RunHelper(HloPassInterface * pass,HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)144   static StatusOr<bool> RunHelper(
145       HloPassInterface* pass, HloModuleGroup* module_group,
146       const absl::flat_hash_set<absl::string_view>& execution_threads) {
147     TF_ASSIGN_OR_RETURN(
148         bool changed, pass->RunOnModuleGroup(module_group, execution_threads));
149     module_group->Cleanup();
150     return changed;
151   }
152 
153   const std::string name_;
154   std::vector<std::unique_ptr<HloPassInterface>> passes_;
155   std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
156   bool run_called_ = false;
157 
158   CompilationStats* compilation_stats_;
159   // Default stats instance for when one is not passed in the constructor.
160   // Use via compilation_stats_, not directly.
161   std::unique_ptr<CompilationStats> empty_compilation_stats_;
162 
163   // Allow PhaseOrderPipeline to modify private passes_ member in order to
164   // perform PhaseOrdering.
165   friend class ::xla::PhaseOrderPipeline;
166 };
167 
168 }  // namespace xla
169 
170 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
171