1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ 18 19 #include <algorithm> 20 #include <type_traits> 21 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 25 #include "tensorflow/compiler/xla/status_macros.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 29 namespace xla { 30 31 // Do an HLO pass to a fix point. 32 template <typename Pass, int kIterationLimit = 25> 33 class HloPassFix : public Pass { 34 public: 35 static_assert(std::is_base_of<HloPassInterface, Pass>::value, 36 "Pass must be a subclass of HloPassInterface"); 37 using RunState = HloPassInterface::RunState; 38 template <typename... Args> HloPassFix(Args &&...args)39 explicit HloPassFix(Args&&... args) : Pass(args...) {} 40 RunOnChangedComputations(HloModule * module,RunState * outer_run_state,const absl::flat_hash_set<absl::string_view> & execution_threads)41 Status RunOnChangedComputations(HloModule* module, RunState* outer_run_state, 42 const absl::flat_hash_set<absl::string_view>& 43 execution_threads) override { 44 RunState run_state; 45 run_state.changed_last_iteration = outer_run_state->changed_last_iteration; 46 TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); 47 outer_run_state->changed_this_iteration.insert(run_state.changed.begin(), 48 run_state.changed.end()); 49 return OkStatus(); 50 } 51 52 using HloPassInterface::Run; Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)53 StatusOr<bool> Run(HloModule* module, 54 const absl::flat_hash_set<absl::string_view>& 55 execution_threads) override { 56 RunState run_state(module); 57 TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); 58 return !run_state.changed.empty(); 59 } 60 61 using HloPassInterface::RunOnModuleGroup; RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)62 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group, 63 const absl::flat_hash_set<absl::string_view>& 64 execution_threads) override { 65 bool changed = false; 66 bool changed_this_iteration = true; 67 int64_t iteration_count = 0; 68 VLOG(3) << "Running HloPassFix."; 69 while (changed_this_iteration) { 70 TF_ASSIGN_OR_RETURN( 71 changed_this_iteration, 72 Pass::RunOnModuleGroup(module_group, execution_threads)); 73 changed |= changed_this_iteration; 74 VLOG(3) << "changed_this_iteration: " << changed_this_iteration; 75 ++iteration_count; 76 if (iteration_count == kIterationLimit) { 77 VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " 78 "exiting fixed point loop."; 79 // Return false in case this is fixed point is nested. 80 return false; 81 } 82 } 83 return changed; 84 } 85 86 private: RunToFixPoint(HloModule * module,RunState * run_state,const absl::flat_hash_set<absl::string_view> & execution_threads)87 Status RunToFixPoint( 88 HloModule* module, RunState* run_state, 89 const absl::flat_hash_set<absl::string_view>& execution_threads) { 90 VLOG(3) << "Running HloPassFix on " << Pass::name(); 91 while (!run_state->changed_last_iteration.empty()) { 92 TF_RETURN_IF_ERROR( 93 RunOnChangedComputationsOnce(module, run_state, execution_threads)); 94 VLOG(3) << Pass::name() << " iteration " << run_state->iteration 95 << " changed_this_iteration: " 96 << !run_state->changed_last_iteration.empty(); 97 run_state->IncrementIteration(); 98 if (run_state->iteration == kIterationLimit) { 99 VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" 100 << Pass::name() << "' for module '" << module->name() 101 << "'. Exiting fixed point loop."; 102 // Clear changed and abort in case this is fixed point is nested. 103 run_state->changed.clear(); 104 break; 105 } 106 } 107 return OkStatus(); 108 } 109 RunOnChangedComputationsOnce(HloModule * module,RunState * run_state,const absl::flat_hash_set<absl::string_view> & execution_threads)110 Status RunOnChangedComputationsOnce( 111 HloModule* module, RunState* run_state, 112 const absl::flat_hash_set<absl::string_view>& execution_threads) { 113 // If Pass overrides RunOnChangedComputations, just forward to it. 114 if (!std::is_same<decltype(&HloPassInterface::RunOnChangedComputations), 115 decltype(&Pass::RunOnChangedComputations)>::value) { 116 return Pass::RunOnChangedComputations(module, run_state, 117 execution_threads); 118 } 119 // If Pass does not override the default 120 // HloPassInterface::RunOnChangedComputations that calls into 121 // HloPassFix<Pass>::Run, avoid infinite recursion. 122 TF_ASSIGN_OR_RETURN(bool changed, Pass::Run(module, execution_threads)); 123 if (changed) { 124 auto computations = module->computations(execution_threads); 125 run_state->changed_this_iteration.insert(computations.begin(), 126 computations.end()); 127 } 128 return OkStatus(); 129 } 130 }; 131 132 } // namespace xla 133 134 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ 135