1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/strings/str_cat.h" 25 #include "tensorflow/compiler/xla/service/compilation_stats.h" 26 #include "tensorflow/compiler/xla/service/hlo_module.h" 27 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 31 namespace xla { 32 33 class PhaseOrderPipeline; 34 35 // Pipeline of HLO passes. 36 class HloPassPipeline : public HloPassInterface { 37 public: 38 explicit HloPassPipeline(const std::string& name, 39 CompilationStats* compilation_stats = nullptr) name_(name)40 : name_(name), compilation_stats_(compilation_stats) { 41 if (compilation_stats == nullptr) { 42 empty_compilation_stats_ = CompilationStats::MakeNoopStats(); 43 compilation_stats_ = empty_compilation_stats_.get(); 44 } 45 } name()46 absl::string_view name() const override { return name_; } 47 48 // Add a pass to the pipeline. It should be called with the arguments for the 49 // pass constructor: 50 // 51 // pipeline.AddPass<FooPass>(constructor_arg1, constructor_arg2); 52 // 53 // Returns a reference to the added pass. 54 template <typename T, typename... Args> AddPass(Args &&...args)55 T& AddPass(Args&&... args) { 56 CHECK(!run_called_) << "AddPass cannot be called after Run"; 57 auto pass = new T(std::forward<Args>(args)...); 58 passes_.push_back(std::unique_ptr<T>(pass)); 59 return *pass; 60 } 61 62 // Add an invariant-checking pass to the pipeline. It will be run before and 63 // after each HLO pass. The invariant checking pass must not mutate the graph 64 // (it is required to always return "false" from its Run() method). 65 template <typename T, typename... Args> AddInvariantChecker(Args &&...args)66 T& AddInvariantChecker(Args&&... args) { 67 CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; 68 auto pass = new T(std::forward<Args>(args)...); 69 invariant_checkers_.push_back(std::unique_ptr<T>(pass)); 70 return *pass; 71 } 72 73 // Add an invariant-checking pass to the pipeline on debug builds only. 74 template <typename T, typename... Args> AddInvariantCheckerDebug(Args &&...args)75 void AddInvariantCheckerDebug(Args&&... args) { 76 #ifndef NDEBUG 77 AddInvariantChecker<T>(std::forward<Args>(args)...); 78 #endif // NDEBUG 79 } 80 81 using HloPassInterface::Run; 82 StatusOr<bool> Run( 83 HloModule* module, 84 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 85 using HloPassInterface::RunOnModuleGroup; 86 StatusOr<bool> RunOnModuleGroup( 87 HloModuleGroup* module_group, 88 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 89 IsPassPipeline()90 bool IsPassPipeline() override { return true; } 91 92 // Return size of passes_. PassesSize()93 int PassesSize() { return passes_.size(); } 94 // Return reference to pass specified by index. GetPass(int index)95 HloPassInterface& GetPass(int index) { return *passes_[index]; } 96 97 private: 98 // Returns the set of passes which are enabled. DebugOptions can selectively 99 // disable passes via --xla_disable_hlo_passes flag. 100 std::vector<HloPassInterface*> GetEnabledPasses( 101 const DebugOptions& debug_options); 102 103 // Maybe dumps the given module or module group depending on flag values 104 // contained in DebugOptions of module config. If it is dumped, saves the 105 // filenames of the dumps into module metadata. 106 void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group, 107 absl::string_view after_pass_name, 108 absl::string_view before_pass_name); 109 void MaybeDumpHloAndSaveFilenames(HloModule& module, 110 absl::string_view after_pass_name, 111 absl::string_view before_pass_name); 112 113 // Runs the invariant checker on the given HLO for specified 114 // `execution_threads`. Empty `execution_threads` means all execution threads 115 // are included. HloT can be either HloModule or HloModuleGroup. 116 template <typename HloT> RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)117 Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name) { 118 return RunInvariantCheckers(hlo, after_pass_name, /*execution_threads=*/{}); 119 } 120 template <typename HloT> 121 Status RunInvariantCheckers( 122 HloT* hlo, absl::string_view after_pass_name, 123 const absl::flat_hash_set<absl::string_view>& execution_threads); 124 125 // Helper which runs the given pass on the given HLO. HloT can be either 126 // HloModule or HloModuleGroup. 127 template <typename HloT> 128 StatusOr<bool> RunPassesInternal( 129 HloT* hlo, const DebugOptions& debug_options, 130 const absl::flat_hash_set<absl::string_view>& execution_threads); 131 132 // Helpers which run the given passes on the given HLO construct. Only 133 // computations with specified `execution_threads` are considered by the pass, 134 // empty thread list means all `execution_threads` are considered. These 135 // helpers enable templating of the core of the pipeline logic by providing 136 // HloModule and HloModuleGroup specific methods with the same name. RunHelper(HloPassInterface * pass,HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)137 static StatusOr<bool> RunHelper( 138 HloPassInterface* pass, HloModule* module, 139 const absl::flat_hash_set<absl::string_view>& execution_threads) { 140 TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module, execution_threads)); 141 module->Cleanup(); 142 return changed; 143 } RunHelper(HloPassInterface * pass,HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)144 static StatusOr<bool> RunHelper( 145 HloPassInterface* pass, HloModuleGroup* module_group, 146 const absl::flat_hash_set<absl::string_view>& execution_threads) { 147 TF_ASSIGN_OR_RETURN( 148 bool changed, pass->RunOnModuleGroup(module_group, execution_threads)); 149 module_group->Cleanup(); 150 return changed; 151 } 152 153 const std::string name_; 154 std::vector<std::unique_ptr<HloPassInterface>> passes_; 155 std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_; 156 bool run_called_ = false; 157 158 CompilationStats* compilation_stats_; 159 // Default stats instance for when one is not passed in the constructor. 160 // Use via compilation_stats_, not directly. 161 std::unique_ptr<CompilationStats> empty_compilation_stats_; 162 163 // Allow PhaseOrderPipeline to modify private passes_ member in order to 164 // perform PhaseOrdering. 165 friend class ::xla::PhaseOrderPipeline; 166 }; 167 168 } // namespace xla 169 170 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 171