1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_ 17 #define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_ 18 19 #include <limits> 20 #include <string> 21 22 #include "absl/strings/string_view.h" 23 #include "tensorflow/compiler/xla/service/call_graph.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 26 27 namespace xla { 28 29 // An HLO pass that replaces while loop conditionals to execute a known constant 30 // number of iterations and remove operations that are difficult to run in 31 // standalone tests, such as infeed/outfeed and collective operations. 32 class HloControlFlowFlattening : public HloModulePass { 33 public: 34 // While execution count specifies how many times the while loops in the 35 // transformed graph will execute. 36 // If remove_comm = true, remove all communication operations. 37 // If remove_host_transfer = true, remove the host-transfer send and recv 38 // operations. 39 struct Options { 40 int while_execution_count = 1; 41 int max_outer_loop_count = std::numeric_limits<int>::max(); 42 int max_loop_count = std::numeric_limits<int>::max(); 43 bool remove_infeed_outfeed = true; 44 bool flatten_while_loop = true; 45 bool remove_comm = true; 46 bool remove_host_transfer = false; 47 }; HloControlFlowFlattening(const Options & options)48 explicit HloControlFlowFlattening(const Options& options) 49 : while_execution_count_(options.while_execution_count), 50 max_outer_loop_count_(options.max_outer_loop_count), 51 max_loop_count_(options.max_loop_count), 52 remove_infeed_outfeed_(options.remove_infeed_outfeed), 53 flatten_while_loop_(options.flatten_while_loop), 54 remove_host_transfer_(options.remove_host_transfer), 55 remove_comm_(options.remove_comm) {} 56 ~HloControlFlowFlattening() override = default; name()57 absl::string_view name() const override { return "control-flow-flattening"; } 58 using HloPassInterface::Run; 59 StatusOr<bool> Run( 60 HloModule* module, 61 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 62 63 private: 64 // Replaces an infeed with a custom call. 65 Status RemoveInfeed(HloInstruction* infeed_hlo) const; 66 // Removes outfeeds and replaces the outfeed HLO with a side-effecting custom 67 // call that ensures that XLA doesn't dead-code-eliminate the outfeeded values 68 // but lowers to a no-op. 69 Status RemoveOutfeed(HloInstruction* outfeed_hlo) const; 70 // Flattens the while loop. Precondition: while_hlo is a while instruction. 71 Status FlattenWhileLoop(HloInstruction* while_hlo, 72 const CallGraph& call_graph) const; 73 // Replaces a partition-id or replica-id with a zero constant. 74 Status RemovePartitionOrReplicaId(HloInstruction* hlo) const; 75 // Removes send and send-done with a custom call. 76 Status RemoveSendDone( 77 HloInstruction* send_done, 78 absl::flat_hash_set<HloInstruction*>* additional_removed) const; 79 // Removes recv and recv-done with a custom call. 80 Status RemoveRecvDone( 81 HloInstruction* recv_done, 82 absl::flat_hash_set<HloInstruction*>* additional_removed) const; 83 84 int while_execution_count_; 85 int max_outer_loop_count_; 86 int max_loop_count_; 87 bool remove_infeed_outfeed_; 88 bool flatten_while_loop_; 89 bool remove_host_transfer_; 90 91 protected: 92 // Replaces a collective op with a custom call. 93 Status RemoveCollective(HloInstruction* hlo) const; 94 95 bool remove_comm_; 96 }; 97 98 // Retrieves the original loop bound. If fail, return a default value. If bounds 99 // exceed a given max, returns the max. This function is more opportunistic than 100 // ComputeWhileLoopTripCount in the while loop analysis as it may return a 101 // constant found in a compare expression when it is not an actual bound. 102 int GetLoopBound(const HloInstruction& while_hlo, const int default_loop_count, 103 const int max_loop_count); 104 105 // Retrieves the loop bound determined by the original loop bound, the max 106 // outer loops count and max loop count. 107 int GetLoopBoundWithOuterLoopMax(const HloInstruction& while_hlo, 108 const CallGraph& call_graph, 109 const int default_loop_count, 110 const int max_outer_loop_count, 111 const int max_loop_count); 112 } // namespace xla 113 114 #endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_ 115