xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gemm_rewriter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_REWRITER_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_REWRITER_H_
17 
18 #include <optional>
19 
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 namespace gpu {
26 
27 // cuBLAS GEMM in the most general form can run the following operation:
28 //
29 // (kAdd
30 //    (kMultiply (kDot A B) alpha)
31 //    (kMultiply C beta))
32 //
33 // where A, B, C are matrixes and `alpha` and `beta` are host constants.
34 // The additional requirement is that C has no other users (otherwise,
35 // it does not make sense to fuse it inside the custom call).
36 //
37 // Both multiplication and addition can be avoided (equivalent to setting
38 // `alpha` to one and `beta` to zero).
39 //
40 // This pass pattern-matches the most general form of this instruction
41 // (we assume transposes are already folded), and rewrites it into a custom call
42 // where (A, B, C) are three operands respectively, and `alpha` and `beta` are
43 // stored in the backend config.
44 class GemmRewriter : public HloModulePass {
45  public:
name()46   absl::string_view name() const override { return "cublas-gemm-rewriter"; }
47 
48   using HloPassInterface::Run;
49   StatusOr<bool> Run(
50       HloModule* module,
51       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
52 };
53 
54 }  // namespace gpu
55 }  // namespace xla
56 
57 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_REWRITER_H_
58