xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/transpose_folding.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 
26 // HLO pass that folds transpose operators into Dot operators, where the Dot
27 // operator is implemented by a GEMM kernel that can transpose its inputs.
28 class TransposeFolding : public HloModulePass {
29  public:
30   using OperandIndices = std::vector<int64_t>;
31 
32   // Returns the set of foldable operands for a given HLO and some candidate
33   // operands.
34   using TransposableConvOperandsFn = std::function<OperandIndices(
35       const HloInstruction&, const OperandIndices&)>;
36 
37   using CanFoldTransposeOperand = std::function<StatusOr<bool>(
38       const HloInstruction&, int64_t /*operand_idx*/)>;
39 
40   // Helper function to explicitly not fold transposes.
NeverFoldTranspose(const HloInstruction &,const OperandIndices &)41   static OperandIndices NeverFoldTranspose(const HloInstruction&,
42                                            const OperandIndices&) {
43     return {};
44   }
45 
46   // Helper function to always fold transposes.
AlwaysFoldTranspose(const HloInstruction &,const OperandIndices & ids)47   static OperandIndices AlwaysFoldTranspose(const HloInstruction&,
48                                             const OperandIndices& ids) {
49     return ids;
50   }
51 
52   // `dot_can_fold_transpose_operand` returns whether the dot operation can fold
53   // in the given transpose operand.
54   //
55   // transposable_conv_operands returns the set of operands it wants to fold if
56   // the instruction argument is implemented as a convolution that supports
57   // transposing its arguments.
58   explicit TransposeFolding(
59       CanFoldTransposeOperand dot_can_fold_transpose_operand =
60           IsRowColumnTransposeDotOperand,
61       TransposableConvOperandsFn transposable_conv_operands =
62           AlwaysFoldTranspose);
name()63   absl::string_view name() const override { return "transpose-folding"; }
64 
65   using HloPassInterface::Run;
66   StatusOr<bool> Run(
67       HloModule* module,
68       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
69 
70   static StatusOr<bool> IsRowColumnTransposeDotOperand(
71       const HloInstruction& dot, int64_t operand_idx);
72 
73  private:
74   CanFoldTransposeOperand dot_can_fold_transpose_operand_;
75   TransposableConvOperandsFn transposable_conv_operands_;
76 };
77 
78 }  // namespace xla
79 
80 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
81