1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ 18 19 #include <optional> 20 21 #include "absl/time/time.h" 22 #include "tensorflow/compiler/xla/service/compiler.h" 23 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" 24 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 27 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 28 #include "tensorflow/core/protobuf/autotuning.pb.h" 29 #include "tensorflow/stream_executor/device_memory_allocator.h" 30 31 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) 32 #include "tensorflow/stream_executor/gpu/redzone_allocator.h" 33 #endif 34 35 namespace xla { 36 namespace gpu { 37 38 // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for 39 // each and adding explicit scratch space to the CustomCalls. 40 class GpuConvAlgorithmPicker : public HloModulePass { 41 public: 42 // If the `allocator` parameter is not null, we will use it to allocate temp 43 // memory while timing the various convolution algorithms. If it's null, 44 // we'll use the default allocator on the StreamExecutor. GpuConvAlgorithmPicker(se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * allocator)45 GpuConvAlgorithmPicker(se::StreamExecutor* stream_exec, 46 se::DeviceMemoryAllocator* allocator) 47 : stream_exec_(stream_exec), allocator_(allocator) {} 48 name()49 absl::string_view name() const override { 50 return "gpu-conv-algorithm-picker"; 51 } 52 53 using HloPassInterface::Run; 54 StatusOr<bool> Run( 55 HloModule* module, 56 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 57 58 private: 59 StatusOr<bool> RunOnComputation(HloComputation* computation); 60 StatusOr<bool> RunOnInstruction(HloInstruction* instr); 61 StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm( 62 const HloCustomCallInstruction* instr); 63 64 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) 65 // Simple bundle of an algorithm and its output, for comparing results across 66 // autotuned algorithms. 67 struct ReferenceResult { 68 stream_executor::dnn::AlgorithmDesc algorithm; 69 stream_executor::DeviceMemoryBase buffer; 70 }; 71 72 StatusOr<tensorflow::AutotuneResult> AutotuneOneConvRunner( 73 const GpuConvConfig& config, const HloCustomCallInstruction* instr, 74 se::DeviceMemoryAllocator* allocator, 75 se::RedzoneAllocator* input_output_allocator, se::Stream* stream, 76 MaybeFusedConvRunner* const runner, 77 absl::Span<const stream_executor::DeviceMemoryBase> operand_buffers, 78 stream_executor::DeviceMemoryBase result_buffer, 79 std::optional<ReferenceResult>* reference_result, 80 absl::Span<const stream_executor::dnn::AlgorithmDesc> disabled_algos); 81 StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheCuda( 82 const HloCustomCallInstruction* instr, 83 se::DeviceMemoryAllocator* allocator, se::Stream* stream); 84 #endif 85 86 StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheRocm( 87 const HloCustomCallInstruction* instr, 88 se::DeviceMemoryAllocator* allocator, se::Stream* stream); 89 90 se::StreamExecutor* stream_exec_; // never null 91 se::DeviceMemoryAllocator* allocator_; // may be null 92 }; 93 94 } // namespace gpu 95 } // namespace xla 96 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ 97