xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_
17 #define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_
18 
19 #include <limits>
20 #include <string>
21 
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/xla/service/call_graph.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
26 
27 namespace xla {
28 
29 // An HLO pass that replaces while loop conditionals to execute a known constant
30 // number of iterations and remove operations that are difficult to run in
31 // standalone tests, such as infeed/outfeed and collective operations.
32 class HloControlFlowFlattening : public HloModulePass {
33  public:
34   // While execution count specifies how many times the while loops in the
35   // transformed graph will execute.
36   // If remove_comm = true, remove all communication operations.
37   // If remove_host_transfer = true, remove the host-transfer send and recv
38   // operations.
39   struct Options {
40     int while_execution_count = 1;
41     int max_outer_loop_count = std::numeric_limits<int>::max();
42     int max_loop_count = std::numeric_limits<int>::max();
43     bool remove_infeed_outfeed = true;
44     bool flatten_while_loop = true;
45     bool remove_comm = true;
46     bool remove_host_transfer = false;
47   };
HloControlFlowFlattening(const Options & options)48   explicit HloControlFlowFlattening(const Options& options)
49       : while_execution_count_(options.while_execution_count),
50         max_outer_loop_count_(options.max_outer_loop_count),
51         max_loop_count_(options.max_loop_count),
52         remove_infeed_outfeed_(options.remove_infeed_outfeed),
53         flatten_while_loop_(options.flatten_while_loop),
54         remove_host_transfer_(options.remove_host_transfer),
55         remove_comm_(options.remove_comm) {}
56   ~HloControlFlowFlattening() override = default;
name()57   absl::string_view name() const override { return "control-flow-flattening"; }
58   using HloPassInterface::Run;
59   StatusOr<bool> Run(
60       HloModule* module,
61       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
62 
63  private:
64   // Replaces an infeed with a custom call.
65   Status RemoveInfeed(HloInstruction* infeed_hlo) const;
66   // Removes outfeeds and replaces the outfeed HLO with a side-effecting custom
67   // call that ensures that XLA doesn't dead-code-eliminate the outfeeded values
68   // but lowers to a no-op.
69   Status RemoveOutfeed(HloInstruction* outfeed_hlo) const;
70   // Flattens the while loop. Precondition: while_hlo is a while instruction.
71   Status FlattenWhileLoop(HloInstruction* while_hlo,
72                           const CallGraph& call_graph) const;
73   // Replaces a partition-id or replica-id with a zero constant.
74   Status RemovePartitionOrReplicaId(HloInstruction* hlo) const;
75   // Removes send and send-done with a custom call.
76   Status RemoveSendDone(
77       HloInstruction* send_done,
78       absl::flat_hash_set<HloInstruction*>* additional_removed) const;
79   // Removes recv and recv-done with a custom call.
80   Status RemoveRecvDone(
81       HloInstruction* recv_done,
82       absl::flat_hash_set<HloInstruction*>* additional_removed) const;
83 
84   int while_execution_count_;
85   int max_outer_loop_count_;
86   int max_loop_count_;
87   bool remove_infeed_outfeed_;
88   bool flatten_while_loop_;
89   bool remove_host_transfer_;
90 
91  protected:
92   // Replaces a collective op with a custom call.
93   Status RemoveCollective(HloInstruction* hlo) const;
94 
95   bool remove_comm_;
96 };
97 
98 // Retrieves the original loop bound. If fail, return a default value. If bounds
99 // exceed a given max, returns the max. This function is more opportunistic than
100 // ComputeWhileLoopTripCount in the while loop analysis as it may return a
101 // constant found in a compare expression when it is not an actual bound.
102 int GetLoopBound(const HloInstruction& while_hlo, const int default_loop_count,
103                  const int max_loop_count);
104 
105 // Retrieves the loop bound determined by the original loop bound, the max
106 // outer loops count and max loop count.
107 int GetLoopBoundWithOuterLoopMax(const HloInstruction& while_hlo,
108                                  const CallGraph& call_graph,
109                                  const int default_loop_count,
110                                  const int max_outer_loop_count,
111                                  const int max_loop_count);
112 }  // namespace xla
113 
114 #endif  // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_CONTROL_FLOW_FLATTENING_H_
115