xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/transpose_folding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17 
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/status.h"
37 
38 namespace xla {
39 namespace {
40 
CanFoldOperandsIntoConvolution(const HloInstruction & convolution,const TransposeFolding::TransposableConvOperandsFn & transposable_conv_operands)41 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
42     const HloInstruction& convolution,
43     const TransposeFolding::TransposableConvOperandsFn&
44         transposable_conv_operands) {
45   if (HloOpcode::kConvolution != convolution.opcode()) {
46     return {};
47   }
48 
49   TransposeFolding::OperandIndices operand_set;
50   for (int64_t i = 0; i < convolution.operand_count(); ++i) {
51     auto& operand = *convolution.operand(i);
52     if (operand.opcode() == HloOpcode::kTranspose) {
53       operand_set.push_back(i);
54     }
55   }
56 
57   return transposable_conv_operands(convolution, operand_set);
58 }
59 
IsNonIdentityTranspose(const HloInstruction * instruction)60 bool IsNonIdentityTranspose(const HloInstruction* instruction) {
61   if (instruction->opcode() == HloOpcode::kTranspose) {
62     for (int dim = 0; dim < instruction->dimensions().size(); ++dim) {
63       if (dim != instruction->dimensions(dim)) {
64         return true;
65       }
66     }
67   }
68   return false;
69 }
70 
TransposeDims(tensorflow::protobuf::RepeatedField<int64_t> & dims,absl::Span<const int64_t> transpose_dims)71 void TransposeDims(tensorflow::protobuf::RepeatedField<int64_t>& dims,
72                    absl::Span<const int64_t> transpose_dims) {
73   for (auto& dim : dims) {
74     dim = transpose_dims[dim];
75   }
76 }
77 
78 using InstructionOperandsPair =
79     std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
80 
81 // Folds the operands of `dot` that are foldable transposes.
FoldTransposeIntoDot(InstructionOperandsPair & pair)82 Status FoldTransposeIntoDot(InstructionOperandsPair& pair) {
83   HloInstruction* dot = pair.first;
84 
85   DotDimensionNumbers new_dot_dims = dot->dot_dimension_numbers();
86   HloInstruction* lhs = dot->mutable_operand(0);
87   HloInstruction* rhs = dot->mutable_operand(1);
88 
89   for (int64_t operand_index : pair.second) {
90     if (operand_index == 0) {
91       TransposeDims(*new_dot_dims.mutable_lhs_contracting_dimensions(),
92                     lhs->dimensions());
93       TransposeDims(*new_dot_dims.mutable_lhs_batch_dimensions(),
94                     lhs->dimensions());
95       lhs = lhs->mutable_operand(0);
96     } else {
97       CHECK_EQ(operand_index, 1);
98       TransposeDims(*new_dot_dims.mutable_rhs_contracting_dimensions(),
99                     rhs->dimensions());
100       TransposeDims(*new_dot_dims.mutable_rhs_batch_dimensions(),
101                     rhs->dimensions());
102       rhs = rhs->mutable_operand(0);
103     }
104   }
105 
106   return dot->parent()->ReplaceWithNewInstruction(
107       dot, HloInstruction::CreateDot(dot->shape(), lhs, rhs, new_dot_dims,
108                                      dot->precision_config()));
109 }
110 
111 // Folds the operands of `convolution` that are foldable transposes.
112 // `computation` is the parent HLO computation of `convolution`.
113 //
114 // Returns whether the module is changed.
FoldTransposeIntoConvolution(InstructionOperandsPair & pair)115 bool FoldTransposeIntoConvolution(InstructionOperandsPair& pair) {
116   auto& convolution = *pair.first;
117   auto& operand_indices = pair.second;
118 
119   if (operand_indices.empty()) {
120     return false;
121   }
122 
123   const ConvolutionDimensionNumbers& dnums =
124       convolution.convolution_dimension_numbers();
125   ConvolutionDimensionNumbers new_dnums = dnums;
126 
127   HloInstruction* new_lhs;
128   const int64_t kLhsIdx = 0;
129   if (absl::c_linear_search(operand_indices, kLhsIdx)) {
130     HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
131     const auto& transpose_dimensions = transpose.dimensions();
132     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
133 
134     // Everything remains the same except for the input/output dimension
135     // numbers. We need to apply the transpose permutation to the original shape
136     // to figure out what the new logical dimensions are.
137     new_dnums.set_input_batch_dimension(
138         transpose_dimensions[dnums.input_batch_dimension()]);
139     new_dnums.set_input_feature_dimension(
140         transpose_dimensions[dnums.input_feature_dimension()]);
141     for (auto& input_spatial_dimension :
142          *new_dnums.mutable_input_spatial_dimensions()) {
143       input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
144     }
145     new_lhs = &transpose_operand;
146   } else {
147     new_lhs = convolution.mutable_operand(kLhsIdx);
148   }
149 
150   HloInstruction* new_rhs;
151   const int64_t kRhsIdx = 1;
152   if (absl::c_linear_search(operand_indices, kRhsIdx)) {
153     HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
154     const auto& transpose_dimensions = transpose.dimensions();
155     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
156 
157     // Everything remains the same except for the kernel dimension numbers. We
158     // need to apply the transpose permutation to the original shape to figure
159     // out what the new logical dimensions are.
160     new_dnums.set_kernel_input_feature_dimension(
161         transpose_dimensions[dnums.kernel_input_feature_dimension()]);
162     new_dnums.set_kernel_output_feature_dimension(
163         transpose_dimensions[dnums.kernel_output_feature_dimension()]);
164     for (auto& kernel_spatial_dimension :
165          *new_dnums.mutable_kernel_spatial_dimensions()) {
166       kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
167     }
168     new_rhs = &transpose_operand;
169   } else {
170     new_rhs = convolution.mutable_operand(kRhsIdx);
171   }
172 
173   auto new_conv = HloInstruction::CreateConvolve(
174       convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
175       convolution.batch_group_count(), convolution.window(), new_dnums,
176       convolution.precision_config());
177   TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
178       &convolution, std::move(new_conv)));
179 
180   return true;
181 }
182 
183 }  // namespace
184 
TransposeFolding(CanFoldTransposeOperand dot_can_fold_transpose_operand,TransposableConvOperandsFn transposable_conv_operands)185 TransposeFolding::TransposeFolding(
186     CanFoldTransposeOperand dot_can_fold_transpose_operand,
187     TransposableConvOperandsFn transposable_conv_operands)
188     : dot_can_fold_transpose_operand_(
189           std::move(dot_can_fold_transpose_operand)),
190       transposable_conv_operands_(std::move(transposable_conv_operands)) {}
191 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)192 StatusOr<bool> TransposeFolding::Run(
193     HloModule* module,
194     const absl::flat_hash_set<absl::string_view>& execution_threads) {
195   // Modifying the graph while traversing is dangerous, so we find all folding
196   // opportunities before actually folding them.
197   std::vector<InstructionOperandsPair> foldable_dots;
198   std::vector<InstructionOperandsPair> foldable_convolutions;
199 
200   FunctionVisitor visit_fn([this, &foldable_dots, &foldable_convolutions](
201                                HloInstruction* instruction) {
202     if (instruction->opcode() == HloOpcode::kDot) {
203       // Don't fold dots with a 1D operand.
204       if ((instruction->operand(0)->shape().rank() < 2) ||
205           (instruction->operand(1)->shape().rank() < 2)) {
206         return OkStatus();
207       }
208 
209       OperandIndices operand_indices;
210       for (int64_t i = 0; i < 2; ++i) {
211         if (!IsNonIdentityTranspose(instruction->operand(i))) {
212           continue;
213         }
214 
215         TF_ASSIGN_OR_RETURN(bool can_fold_operand,
216                             dot_can_fold_transpose_operand_(*instruction, i));
217 
218         if (can_fold_operand) {
219           operand_indices.push_back(i);
220         }
221       }
222 
223       if (!operand_indices.empty()) {
224         foldable_dots.emplace_back(instruction, operand_indices);
225       }
226     }
227 
228     {
229       OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
230           *instruction, transposable_conv_operands_);
231       if (!operand_indices.empty()) {
232         foldable_convolutions.emplace_back(instruction, operand_indices);
233       }
234     }
235     return OkStatus();
236   });
237 
238   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
239     TF_RETURN_IF_ERROR(comp->Accept(&visit_fn));
240   }
241 
242   bool changed = false;
243   for (InstructionOperandsPair& pair : foldable_dots) {
244     TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair));
245     changed = true;
246   }
247   for (InstructionOperandsPair& pair : foldable_convolutions) {
248     changed |= FoldTransposeIntoConvolution(pair);
249   }
250   return changed;
251 }
252 
IsRowColumnTransposeDotOperand(const HloInstruction & dot,int64_t operand_idx)253 /*static*/ StatusOr<bool> TransposeFolding::IsRowColumnTransposeDotOperand(
254     const HloInstruction& dot, int64_t operand_idx) {
255   TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
256   TF_RET_CHECK(dot.operand_count() > operand_idx);
257 
258   const HloInstruction& transpose = *dot.operand(operand_idx);
259   TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
260 
261   const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
262 
263   auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
264                                        : dot_dims.rhs_batch_dimensions();
265 
266   auto contracting_dims = (operand_idx == 0)
267                               ? dot_dims.lhs_contracting_dimensions()
268                               : dot_dims.rhs_contracting_dimensions();
269 
270   return (batch_dims.size() == transpose.shape().rank() - 2) &&
271          (contracting_dims.size() == 1) &&
272          absl::c_all_of(batch_dims, [&](int64_t dim) {
273            return transpose.dimensions(dim) == dim;
274          });
275 }
276 
277 }  // namespace xla
278