xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
17 
18 #include <algorithm>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33 
34 namespace xla {
35 namespace gpu {
36 
37 class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
HandleReduce(HloInstruction * hlo)38   Status HandleReduce(HloInstruction *hlo) override {
39     auto reduce = Cast<HloReduceInstruction>(hlo);
40     VLOG(5) << "Input: " << reduce->ToString();
41 
42     int operand_idx = -1;
43 
44     absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
45     absl::InlinedVector<Shape, 2> new_reduce_shapes;
46 
47     DimensionVector out_reduce_dimensions;
48     const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
49 
50     for (HloInstruction *operand : reduce->inputs()) {
51       operand_idx++;
52 
53       if (operand_idx != 0 &&
54           operand->shape().layout() != first_instruction_shape.layout()) {
55         HloInstruction *copy =
56             reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
57                 operand->shape(), HloOpcode::kCopy, operand));
58 
59         LayoutUtil::ClearLayout(copy->mutable_shape());
60         TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
61             first_instruction_shape, copy->mutable_shape()));
62 
63         copy->set_metadata(operand->metadata());
64         operand = copy;
65         VLOG(3) << "Copying to establish consistent inputs layout: "
66                 << copy->ToString();
67       }
68 
69       const Shape &operand_shape = operand->shape();
70       const Layout &operand_layout = operand_shape.layout();
71 
72       const Shape &reduce_shape =
73           reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
74                                     : reduce->shape();
75 
76       DimensionVector new_reduce_dimensions;
77       DimensionVector new_operand_shape_data;
78       DimensionVector new_reduce_shape_data;
79 
80       // The layout order of the reduction output can be different to the
81       // ordering of kept dimensions in the input operand, thus we need to
82       // calculate the new layout.
83       DimensionVector new_reduce_shape_layout(reduce_shape.rank());
84       std::vector<int64_t> reduce_shape_logical_to_physical =
85           LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
86 
87       auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
88         return op_logical_dim -
89                absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
90                  CHECK(dim != op_logical_dim);
91                  return dim < op_logical_dim;
92                });
93       };
94 
95       for (int i = 0; i < operand_shape.rank(); i++) {
96         // Process the dimensions in the major-to-minor order in order to
97         // enforce the default layout.
98         int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
99         int64_t logical_dim =
100             operand_layout.minor_to_major(major_to_minor_dim_idx);
101         int64_t dim_size = operand_shape.dimensions(logical_dim);
102         VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
103                 << dim_size;
104         new_operand_shape_data.push_back(dim_size);
105 
106         if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
107           new_reduce_dimensions.push_back(i);
108         } else {
109           new_reduce_shape_data.push_back(dim_size);
110           int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
111           int64_t physical_reduce_dim =
112               reduce_shape_logical_to_physical[logical_reduce_dim];
113           VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
114                   << "physical_reduce_dim = " << physical_reduce_dim;
115           new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
116                                   1] = new_reduce_shape_data.size() - 1;
117         }
118       }
119 
120       Shape new_operand_shape = ShapeUtil::MakeShape(
121           operand_shape.element_type(), new_operand_shape_data);
122       Shape new_reduce_shape = ShapeUtil::MakeShapeWithLayout(
123           reduce_shape.element_type(), new_reduce_shape_data,
124           new_reduce_shape_layout);
125 
126       if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
127         return OkStatus();
128       }
129 
130       HloInstruction *canonical_reduce_input =
131           new_operand_shape != operand_shape
132               ? reduce->parent()->AddInstruction(
133                     HloInstruction::CreateBitcast(new_operand_shape, operand))
134               : operand;
135       canonical_reduce_input->set_metadata(operand->metadata());
136       VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
137 
138       new_reduce_shapes.push_back(new_reduce_shape);
139       canonical_reduce_inputs.push_back(canonical_reduce_input);
140 
141       if (out_reduce_dimensions.empty()) {
142         out_reduce_dimensions = new_reduce_dimensions;
143       } else {
144         TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
145       }
146     }
147 
148     Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
149 
150     std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
151         new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
152         out_reduce_dimensions, reduce->to_apply());
153     VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
154     const Shape &orig_reduce_shape = reduce->shape();
155 
156     if (new_reduce_shape != orig_reduce_shape) {
157       HloInstruction *wrapped_reduce =
158           reduce->parent()->AddInstruction(std::move(new_reduce));
159 
160       if (!new_reduce_shape.IsTuple()) {
161         new_reduce =
162             HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
163       } else {
164         // Bitcast each element of the tuple.
165         absl::InlinedVector<HloInstruction *, 2> out;
166         for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
167           HloInstruction *gte = reduce->parent()->AddInstruction(
168               HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
169           out.push_back(
170               reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
171                   orig_reduce_shape.tuple_shapes(oidx), gte)));
172         }
173         new_reduce = HloInstruction::CreateTuple(out);
174       }
175     }
176 
177     VLOG(5) << "Generated output: " << new_reduce->ToString();
178     return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
179   }
180 };
181 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)182 StatusOr<bool> ReductionLayoutNormalizer::Run(
183     HloModule *module,
184     const absl::flat_hash_set<absl::string_view> &execution_threads) {
185   TF_ASSIGN_OR_RETURN(bool changed,
186                       EnforceMinorToMajorReduceOpVisitor().RunOnModule(
187                           module, execution_threads));
188   return changed;
189 }
190 
191 }  // namespace gpu
192 }  // namespace xla
193