1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/status.h"
37
38 namespace xla {
39 namespace {
40
CanFoldOperandsIntoConvolution(const HloInstruction & convolution,const TransposeFolding::TransposableConvOperandsFn & transposable_conv_operands)41 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
42 const HloInstruction& convolution,
43 const TransposeFolding::TransposableConvOperandsFn&
44 transposable_conv_operands) {
45 if (HloOpcode::kConvolution != convolution.opcode()) {
46 return {};
47 }
48
49 TransposeFolding::OperandIndices operand_set;
50 for (int64_t i = 0; i < convolution.operand_count(); ++i) {
51 auto& operand = *convolution.operand(i);
52 if (operand.opcode() == HloOpcode::kTranspose) {
53 operand_set.push_back(i);
54 }
55 }
56
57 return transposable_conv_operands(convolution, operand_set);
58 }
59
IsNonIdentityTranspose(const HloInstruction * instruction)60 bool IsNonIdentityTranspose(const HloInstruction* instruction) {
61 if (instruction->opcode() == HloOpcode::kTranspose) {
62 for (int dim = 0; dim < instruction->dimensions().size(); ++dim) {
63 if (dim != instruction->dimensions(dim)) {
64 return true;
65 }
66 }
67 }
68 return false;
69 }
70
TransposeDims(tensorflow::protobuf::RepeatedField<int64_t> & dims,absl::Span<const int64_t> transpose_dims)71 void TransposeDims(tensorflow::protobuf::RepeatedField<int64_t>& dims,
72 absl::Span<const int64_t> transpose_dims) {
73 for (auto& dim : dims) {
74 dim = transpose_dims[dim];
75 }
76 }
77
78 using InstructionOperandsPair =
79 std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
80
81 // Folds the operands of `dot` that are foldable transposes.
FoldTransposeIntoDot(InstructionOperandsPair & pair)82 Status FoldTransposeIntoDot(InstructionOperandsPair& pair) {
83 HloInstruction* dot = pair.first;
84
85 DotDimensionNumbers new_dot_dims = dot->dot_dimension_numbers();
86 HloInstruction* lhs = dot->mutable_operand(0);
87 HloInstruction* rhs = dot->mutable_operand(1);
88
89 for (int64_t operand_index : pair.second) {
90 if (operand_index == 0) {
91 TransposeDims(*new_dot_dims.mutable_lhs_contracting_dimensions(),
92 lhs->dimensions());
93 TransposeDims(*new_dot_dims.mutable_lhs_batch_dimensions(),
94 lhs->dimensions());
95 lhs = lhs->mutable_operand(0);
96 } else {
97 CHECK_EQ(operand_index, 1);
98 TransposeDims(*new_dot_dims.mutable_rhs_contracting_dimensions(),
99 rhs->dimensions());
100 TransposeDims(*new_dot_dims.mutable_rhs_batch_dimensions(),
101 rhs->dimensions());
102 rhs = rhs->mutable_operand(0);
103 }
104 }
105
106 return dot->parent()->ReplaceWithNewInstruction(
107 dot, HloInstruction::CreateDot(dot->shape(), lhs, rhs, new_dot_dims,
108 dot->precision_config()));
109 }
110
111 // Folds the operands of `convolution` that are foldable transposes.
112 // `computation` is the parent HLO computation of `convolution`.
113 //
114 // Returns whether the module is changed.
FoldTransposeIntoConvolution(InstructionOperandsPair & pair)115 bool FoldTransposeIntoConvolution(InstructionOperandsPair& pair) {
116 auto& convolution = *pair.first;
117 auto& operand_indices = pair.second;
118
119 if (operand_indices.empty()) {
120 return false;
121 }
122
123 const ConvolutionDimensionNumbers& dnums =
124 convolution.convolution_dimension_numbers();
125 ConvolutionDimensionNumbers new_dnums = dnums;
126
127 HloInstruction* new_lhs;
128 const int64_t kLhsIdx = 0;
129 if (absl::c_linear_search(operand_indices, kLhsIdx)) {
130 HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
131 const auto& transpose_dimensions = transpose.dimensions();
132 HloInstruction& transpose_operand = *transpose.mutable_operand(0);
133
134 // Everything remains the same except for the input/output dimension
135 // numbers. We need to apply the transpose permutation to the original shape
136 // to figure out what the new logical dimensions are.
137 new_dnums.set_input_batch_dimension(
138 transpose_dimensions[dnums.input_batch_dimension()]);
139 new_dnums.set_input_feature_dimension(
140 transpose_dimensions[dnums.input_feature_dimension()]);
141 for (auto& input_spatial_dimension :
142 *new_dnums.mutable_input_spatial_dimensions()) {
143 input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
144 }
145 new_lhs = &transpose_operand;
146 } else {
147 new_lhs = convolution.mutable_operand(kLhsIdx);
148 }
149
150 HloInstruction* new_rhs;
151 const int64_t kRhsIdx = 1;
152 if (absl::c_linear_search(operand_indices, kRhsIdx)) {
153 HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
154 const auto& transpose_dimensions = transpose.dimensions();
155 HloInstruction& transpose_operand = *transpose.mutable_operand(0);
156
157 // Everything remains the same except for the kernel dimension numbers. We
158 // need to apply the transpose permutation to the original shape to figure
159 // out what the new logical dimensions are.
160 new_dnums.set_kernel_input_feature_dimension(
161 transpose_dimensions[dnums.kernel_input_feature_dimension()]);
162 new_dnums.set_kernel_output_feature_dimension(
163 transpose_dimensions[dnums.kernel_output_feature_dimension()]);
164 for (auto& kernel_spatial_dimension :
165 *new_dnums.mutable_kernel_spatial_dimensions()) {
166 kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
167 }
168 new_rhs = &transpose_operand;
169 } else {
170 new_rhs = convolution.mutable_operand(kRhsIdx);
171 }
172
173 auto new_conv = HloInstruction::CreateConvolve(
174 convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
175 convolution.batch_group_count(), convolution.window(), new_dnums,
176 convolution.precision_config());
177 TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
178 &convolution, std::move(new_conv)));
179
180 return true;
181 }
182
183 } // namespace
184
TransposeFolding(CanFoldTransposeOperand dot_can_fold_transpose_operand,TransposableConvOperandsFn transposable_conv_operands)185 TransposeFolding::TransposeFolding(
186 CanFoldTransposeOperand dot_can_fold_transpose_operand,
187 TransposableConvOperandsFn transposable_conv_operands)
188 : dot_can_fold_transpose_operand_(
189 std::move(dot_can_fold_transpose_operand)),
190 transposable_conv_operands_(std::move(transposable_conv_operands)) {}
191
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)192 StatusOr<bool> TransposeFolding::Run(
193 HloModule* module,
194 const absl::flat_hash_set<absl::string_view>& execution_threads) {
195 // Modifying the graph while traversing is dangerous, so we find all folding
196 // opportunities before actually folding them.
197 std::vector<InstructionOperandsPair> foldable_dots;
198 std::vector<InstructionOperandsPair> foldable_convolutions;
199
200 FunctionVisitor visit_fn([this, &foldable_dots, &foldable_convolutions](
201 HloInstruction* instruction) {
202 if (instruction->opcode() == HloOpcode::kDot) {
203 // Don't fold dots with a 1D operand.
204 if ((instruction->operand(0)->shape().rank() < 2) ||
205 (instruction->operand(1)->shape().rank() < 2)) {
206 return OkStatus();
207 }
208
209 OperandIndices operand_indices;
210 for (int64_t i = 0; i < 2; ++i) {
211 if (!IsNonIdentityTranspose(instruction->operand(i))) {
212 continue;
213 }
214
215 TF_ASSIGN_OR_RETURN(bool can_fold_operand,
216 dot_can_fold_transpose_operand_(*instruction, i));
217
218 if (can_fold_operand) {
219 operand_indices.push_back(i);
220 }
221 }
222
223 if (!operand_indices.empty()) {
224 foldable_dots.emplace_back(instruction, operand_indices);
225 }
226 }
227
228 {
229 OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
230 *instruction, transposable_conv_operands_);
231 if (!operand_indices.empty()) {
232 foldable_convolutions.emplace_back(instruction, operand_indices);
233 }
234 }
235 return OkStatus();
236 });
237
238 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
239 TF_RETURN_IF_ERROR(comp->Accept(&visit_fn));
240 }
241
242 bool changed = false;
243 for (InstructionOperandsPair& pair : foldable_dots) {
244 TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair));
245 changed = true;
246 }
247 for (InstructionOperandsPair& pair : foldable_convolutions) {
248 changed |= FoldTransposeIntoConvolution(pair);
249 }
250 return changed;
251 }
252
IsRowColumnTransposeDotOperand(const HloInstruction & dot,int64_t operand_idx)253 /*static*/ StatusOr<bool> TransposeFolding::IsRowColumnTransposeDotOperand(
254 const HloInstruction& dot, int64_t operand_idx) {
255 TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
256 TF_RET_CHECK(dot.operand_count() > operand_idx);
257
258 const HloInstruction& transpose = *dot.operand(operand_idx);
259 TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
260
261 const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
262
263 auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
264 : dot_dims.rhs_batch_dimensions();
265
266 auto contracting_dims = (operand_idx == 0)
267 ? dot_dims.lhs_contracting_dimensions()
268 : dot_dims.rhs_contracting_dimensions();
269
270 return (batch_dims.size() == transpose.shape().rank() - 2) &&
271 (contracting_dims.size() == 1) &&
272 absl::c_all_of(batch_dims, [&](int64_t dim) {
273 return transpose.dimensions(dim) == dim;
274 });
275 }
276
277 } // namespace xla
278