xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
18 
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
21 
22 namespace xla {
23 namespace gpu {
24 
25 // Rewrites custom-calls targeting cudnnConvolutionForward to
26 // cudnnConvolutionBiasActivationForward by fusing operations following forward
27 // convolution.  This transform must run after cudnn_conv_rewriter.
28 //
29 // Semantics of underlying cudnn ops:
30 //
31 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
32 // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
33 //
34 // ## Floating-point convs
35 //
36 // A "complete" fused floating-point conv has the form
37 //
38 //   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
39 //
40 // which we fuse to
41 //
42 //   cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
43 //
44 // You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
45 // get a fused convolution.  alpha1/2 must be broadcasts of scalar constants.
46 //
47 // f16 convs accumulate in f32.  We represent this in HLO as an f32 convolution
48 // whose inputs can be converted to f16 without loss of precision and whose
49 // output is immediately converted to f16.  A fused f16 conv must follow one of
50 // the following idioms.
51 //
52 //   1. convert_f16(conv_f32(x_f32, w_f32)) +
53 //      side_input_f16 + broadcast(bias_f16)
54 //
55 //   2. convert_f16(conv_f32(x_f32, w_f32) +
56 //                  side_input_f32 + broadcast(bias_f32))
57 //
58 // (These are not strictly mathematically equivalent, but cudnn doesn't tell us
59 // which one it does, and we deem them "close enough".)
60 //
61 // The foo_f32 HLOs must all be losslessly-convertible to f16.  Some valid
62 // examples:
63 //
64 //   - foo_f32 = convert_f32(foo_f16)
65 //   - foo_f32 = an f32 constant whose values all fit within f16
66 //   - foo_f32 = broadcast/transpose/reshape(one of the above)
67 //
68 // If you have a relu, it can appear before or after the convert_f16.
69 //
70 // Note that here `bias` must be losslessly-convertible to f16; this is
71 // different than for s8 convolutions, where bias is f32.
72 //
73 // ## Integer convs
74 //
75 // In pure HLO, a "complete" integer conv is spelled as one of the following
76 // `result`s.
77 //
78 //   base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
79 //          alpha2_f32 * side_input +
80 //          bias_f32
81 //
82 //   result_f32        = max(result_f32, 0)
83 //   result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
84 //   result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
85 //
86 // The foo_s32 HLOs must be losslessly-convertible to s8.  If the `result_s8`
87 // case, side_input should be an f32 HLO that's losslessly-convertible to s8;
88 // otherwise, it should be losslessly-convertible to f32.
89 //
90 // In the `result_s8` case where there's no bias, side-input, or alpha1, you can
91 // skip the convert_f32 on conv.
92 //
93 // If you have an integer convolution that doesn't fit one of these idioms, this
94 // pass returns an error -- cudnn will not be able to run it.
95 class CudnnFusedConvRewriter : public HloModulePass {
96  public:
name()97   absl::string_view name() const override {
98     return "cudnn-fused-convolution-rewriter";
99   }
100 
101   using HloPassInterface::Run;
102   StatusOr<bool> Run(
103       HloModule* module,
104       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
105 };
106 
107 }  // namespace gpu
108 }  // namespace xla
109 
110 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
111