1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ 18 19 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 21 22 namespace xla { 23 namespace gpu { 24 25 // Rewrites custom-calls targeting cudnnConvolutionForward to 26 // cudnnConvolutionBiasActivationForward by fusing operations following forward 27 // convolution. This transform must run after cudnn_conv_rewriter. 28 // 29 // Semantics of underlying cudnn ops: 30 // 31 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward 32 // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters 33 // 34 // ## Floating-point convs 35 // 36 // A "complete" fused floating-point conv has the form 37 // 38 // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)), 39 // 40 // which we fuse to 41 // 42 // cudnnConvolutionBiasActivationForward(x, w, bias, side_input). 43 // 44 // You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still 45 // get a fused convolution. alpha1/2 must be broadcasts of scalar constants. 46 // 47 // f16 convs accumulate in f32. We represent this in HLO as an f32 convolution 48 // whose inputs can be converted to f16 without loss of precision and whose 49 // output is immediately converted to f16. A fused f16 conv must follow one of 50 // the following idioms. 51 // 52 // 1. convert_f16(conv_f32(x_f32, w_f32)) + 53 // side_input_f16 + broadcast(bias_f16) 54 // 55 // 2. convert_f16(conv_f32(x_f32, w_f32) + 56 // side_input_f32 + broadcast(bias_f32)) 57 // 58 // (These are not strictly mathematically equivalent, but cudnn doesn't tell us 59 // which one it does, and we deem them "close enough".) 60 // 61 // The foo_f32 HLOs must all be losslessly-convertible to f16. Some valid 62 // examples: 63 // 64 // - foo_f32 = convert_f32(foo_f16) 65 // - foo_f32 = an f32 constant whose values all fit within f16 66 // - foo_f32 = broadcast/transpose/reshape(one of the above) 67 // 68 // If you have a relu, it can appear before or after the convert_f16. 69 // 70 // Note that here `bias` must be losslessly-convertible to f16; this is 71 // different than for s8 convolutions, where bias is f32. 72 // 73 // ## Integer convs 74 // 75 // In pure HLO, a "complete" integer conv is spelled as one of the following 76 // `result`s. 77 // 78 // base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) + 79 // alpha2_f32 * side_input + 80 // bias_f32 81 // 82 // result_f32 = max(result_f32, 0) 83 // result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0) 84 // result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127)) 85 // 86 // The foo_s32 HLOs must be losslessly-convertible to s8. If the `result_s8` 87 // case, side_input should be an f32 HLO that's losslessly-convertible to s8; 88 // otherwise, it should be losslessly-convertible to f32. 89 // 90 // In the `result_s8` case where there's no bias, side-input, or alpha1, you can 91 // skip the convert_f32 on conv. 92 // 93 // If you have an integer convolution that doesn't fit one of these idioms, this 94 // pass returns an error -- cudnn will not be able to run it. 95 class CudnnFusedConvRewriter : public HloModulePass { 96 public: name()97 absl::string_view name() const override { 98 return "cudnn-fused-convolution-rewriter"; 99 } 100 101 using HloPassInterface::Run; 102 StatusOr<bool> Run( 103 HloModule* module, 104 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 105 }; 106 107 } // namespace gpu 108 } // namespace xla 109 110 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ 111