1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
17
18 #include <algorithm>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33
34 namespace xla {
35 namespace gpu {
36
37 class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
HandleReduce(HloInstruction * hlo)38 Status HandleReduce(HloInstruction *hlo) override {
39 auto reduce = Cast<HloReduceInstruction>(hlo);
40 VLOG(5) << "Input: " << reduce->ToString();
41
42 int operand_idx = -1;
43
44 absl::InlinedVector<HloInstruction *, 2> canonical_reduce_inputs;
45 absl::InlinedVector<Shape, 2> new_reduce_shapes;
46
47 DimensionVector out_reduce_dimensions;
48 const Shape &first_instruction_shape = reduce->inputs()[0]->shape();
49
50 for (HloInstruction *operand : reduce->inputs()) {
51 operand_idx++;
52
53 if (operand_idx != 0 &&
54 operand->shape().layout() != first_instruction_shape.layout()) {
55 HloInstruction *copy =
56 reduce->parent()->AddInstruction(HloInstruction::CreateUnary(
57 operand->shape(), HloOpcode::kCopy, operand));
58
59 LayoutUtil::ClearLayout(copy->mutable_shape());
60 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
61 first_instruction_shape, copy->mutable_shape()));
62
63 copy->set_metadata(operand->metadata());
64 operand = copy;
65 VLOG(3) << "Copying to establish consistent inputs layout: "
66 << copy->ToString();
67 }
68
69 const Shape &operand_shape = operand->shape();
70 const Layout &operand_layout = operand_shape.layout();
71
72 const Shape &reduce_shape =
73 reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(operand_idx)
74 : reduce->shape();
75
76 DimensionVector new_reduce_dimensions;
77 DimensionVector new_operand_shape_data;
78 DimensionVector new_reduce_shape_data;
79
80 // The layout order of the reduction output can be different to the
81 // ordering of kept dimensions in the input operand, thus we need to
82 // calculate the new layout.
83 DimensionVector new_reduce_shape_layout(reduce_shape.rank());
84 std::vector<int64_t> reduce_shape_logical_to_physical =
85 LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout());
86
87 auto to_reduce_logical_dim = [&](int64_t op_logical_dim) {
88 return op_logical_dim -
89 absl::c_count_if(reduce->dimensions(), [&](int64_t dim) {
90 CHECK(dim != op_logical_dim);
91 return dim < op_logical_dim;
92 });
93 };
94
95 for (int i = 0; i < operand_shape.rank(); i++) {
96 // Process the dimensions in the major-to-minor order in order to
97 // enforce the default layout.
98 int64_t major_to_minor_dim_idx = operand_shape.rank() - i - 1;
99 int64_t logical_dim =
100 operand_layout.minor_to_major(major_to_minor_dim_idx);
101 int64_t dim_size = operand_shape.dimensions(logical_dim);
102 VLOG(5) << "Processing logical dimension " << logical_dim << " of size "
103 << dim_size;
104 new_operand_shape_data.push_back(dim_size);
105
106 if (absl::c_linear_search(reduce->dimensions(), logical_dim)) {
107 new_reduce_dimensions.push_back(i);
108 } else {
109 new_reduce_shape_data.push_back(dim_size);
110 int64_t logical_reduce_dim = to_reduce_logical_dim(logical_dim);
111 int64_t physical_reduce_dim =
112 reduce_shape_logical_to_physical[logical_reduce_dim];
113 VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", "
114 << "physical_reduce_dim = " << physical_reduce_dim;
115 new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim -
116 1] = new_reduce_shape_data.size() - 1;
117 }
118 }
119
120 Shape new_operand_shape = ShapeUtil::MakeShape(
121 operand_shape.element_type(), new_operand_shape_data);
122 Shape new_reduce_shape = ShapeUtil::MakeShapeWithLayout(
123 reduce_shape.element_type(), new_reduce_shape_data,
124 new_reduce_shape_layout);
125
126 if (new_operand_shape == operand_shape && reduce->inputs().size() == 1) {
127 return OkStatus();
128 }
129
130 HloInstruction *canonical_reduce_input =
131 new_operand_shape != operand_shape
132 ? reduce->parent()->AddInstruction(
133 HloInstruction::CreateBitcast(new_operand_shape, operand))
134 : operand;
135 canonical_reduce_input->set_metadata(operand->metadata());
136 VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
137
138 new_reduce_shapes.push_back(new_reduce_shape);
139 canonical_reduce_inputs.push_back(canonical_reduce_input);
140
141 if (out_reduce_dimensions.empty()) {
142 out_reduce_dimensions = new_reduce_dimensions;
143 } else {
144 TF_RET_CHECK(out_reduce_dimensions == new_reduce_dimensions);
145 }
146 }
147
148 Shape new_reduce_shape = ShapeUtil::MakeMaybeTupleShape(new_reduce_shapes);
149
150 std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
151 new_reduce_shape, canonical_reduce_inputs, reduce->init_values(),
152 out_reduce_dimensions, reduce->to_apply());
153 VLOG(5) << "Generated new reduction: " << new_reduce->ToString();
154 const Shape &orig_reduce_shape = reduce->shape();
155
156 if (new_reduce_shape != orig_reduce_shape) {
157 HloInstruction *wrapped_reduce =
158 reduce->parent()->AddInstruction(std::move(new_reduce));
159
160 if (!new_reduce_shape.IsTuple()) {
161 new_reduce =
162 HloInstruction::CreateBitcast(reduce->shape(), wrapped_reduce);
163 } else {
164 // Bitcast each element of the tuple.
165 absl::InlinedVector<HloInstruction *, 2> out;
166 for (int oidx = 0; oidx < reduce->input_count(); oidx++) {
167 HloInstruction *gte = reduce->parent()->AddInstruction(
168 HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
169 out.push_back(
170 reduce->parent()->AddInstruction(HloInstruction::CreateBitcast(
171 orig_reduce_shape.tuple_shapes(oidx), gte)));
172 }
173 new_reduce = HloInstruction::CreateTuple(out);
174 }
175 }
176
177 VLOG(5) << "Generated output: " << new_reduce->ToString();
178 return ReplaceWithNewInstruction(reduce, std::move(new_reduce));
179 }
180 };
181
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)182 StatusOr<bool> ReductionLayoutNormalizer::Run(
183 HloModule *module,
184 const absl::flat_hash_set<absl::string_view> &execution_threads) {
185 TF_ASSIGN_OR_RETURN(bool changed,
186 EnforceMinorToMajorReduceOpVisitor().RunOnModule(
187 module, execution_threads));
188 return changed;
189 }
190
191 } // namespace gpu
192 } // namespace xla
193