xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/tuple_simplifier.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17 
18 #include <queue>
19 
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace xla {
31 
TupleSimplifier(bool exclude_entry_computation)32 TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
33     : exclude_entry_computation_(exclude_entry_computation) {}
34 
RemoveWholeTuple(HloInstruction * tuple)35 StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
36   HloInstruction* top_tuple = nullptr;
37   for (int64_t operand_number = 0; operand_number < tuple->operand_count();
38        ++operand_number) {
39     HloInstruction* operand = tuple->mutable_operand(operand_number);
40     if (operand->opcode() != HloOpcode::kGetTupleElement ||
41         operand->tuple_index() != operand_number) {
42       return false;
43     }
44     if (top_tuple == nullptr) {
45       top_tuple = operand->mutable_operand(0);
46       if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
47         return false;
48       }
49     } else if (top_tuple != operand->operand(0)) {
50       return false;
51     }
52   }
53   if (top_tuple == nullptr) {
54     return false;
55   }
56   TF_ASSIGN_OR_RETURN(bool changed,
57                       tuple->parent()->ReplaceInstruction(
58                           tuple, top_tuple, /*preserve_sharding=*/true));
59   return changed;
60 }
61 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)62 StatusOr<bool> TupleSimplifier::Run(
63     HloModule* module,
64     const absl::flat_hash_set<absl::string_view>& execution_threads) {
65   // Initially add all GTE and Tuple instructions to the worklist.
66   bool changed = false;
67   for (auto* computation : module->computations(execution_threads)) {
68     if (exclude_entry_computation_ &&
69         computation == module->entry_computation()) {
70       continue;
71     }
72     for (auto* instruction : computation->MakeInstructionPostOrder()) {
73       if (instruction->opcode() == HloOpcode::kTuple) {
74         TF_ASSIGN_OR_RETURN(bool c, RemoveWholeTuple(instruction));
75         changed |= c;
76       } else {
77         auto ancestor = instruction->LatestNonGteAncestorAndIndex();
78         if (ancestor.first == instruction) {
79           continue;
80         }
81         // If possible replace a chain of GTE with the operation which produces
82         // the element. For example, replace uses of GTE with below with just
83         // 'Op' (assuming 'Op' is at the index of the GTE instruction):
84         //
85         //     ...  Op ...
86         //       \  |   /
87         //        Tuple
88         //          |
89         //         GTE
90         //         ...
91         //          |
92         //         GTE
93         //          |
94         //         GTE
95         //
96         // Note that this deletes the Tuple instruction altogether. In addition,
97         // if only a subset of tuple's elements are used, this transform
98         // optimizes them one at a time, and after the last use is optimized,
99         // the Tuple will also be deleted.
100         HloInstruction* replacement = nullptr;
101         if (ShapeUtil::Compatible(ancestor.first->shape(),
102                                   instruction->shape())) {
103           replacement = ancestor.first;
104         } else if (ancestor.first->opcode() == HloOpcode::kTuple) {
105           replacement = ancestor.first->mutable_operand(ancestor.second[0]);
106         }
107 
108         if (replacement) {
109           TF_ASSIGN_OR_RETURN(bool replaced, computation->ReplaceInstruction(
110                                                  instruction, replacement,
111                                                  /*preserve_sharding=*/true));
112           changed |= replaced;
113         }
114       }
115     }
116   }
117   return changed;
118 }
119 
120 }  // namespace xla
121