1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ 16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ 17 #if GOOGLE_CUDA && GOOGLE_TENSORRT 18 #include <array> 19 #include <memory> 20 #include <set> 21 22 #include "absl/types/optional.h" 23 #include "third_party/tensorrt/NvInfer.h" 24 25 namespace tensorflow { 26 namespace tensorrt { 27 namespace convert { 28 29 // Implements core algorithm selection logic in a testable manner. The policy 30 // implemented depends on the given TRT version. We have this class because TRT 31 // interfaces make it difficult to directly test an IAlgorithmSelector 32 // implementation. 33 class AlgorithmSelectorImpl { 34 public: 35 using TRTVersion = std::array<int, 4>; 36 using ImplementationID = int64_t; 37 using TacticID = int64_t; 38 CompileTimeTRTVersion()39 static constexpr TRTVersion CompileTimeTRTVersion() { 40 return TRTVersion{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH, 41 NV_TENSORRT_BUILD}; 42 } 43 44 explicit AlgorithmSelectorImpl( 45 const TRTVersion& version = CompileTimeTRTVersion()) version_(version)46 : version_(version) {} 47 48 bool IsShuffleLayer(ImplementationID id) const; 49 50 bool IsBannedTactic(TacticID id) const; 51 52 // Returns true if the algorithm implementing the IShuffleLayer is acceptable. 53 bool AllowShuffleAlgorithm(TacticID tactic, nvinfer1::DataType input_dtype, 54 nvinfer1::TensorFormat input_format) const; 55 56 bool IsTrtVersionGE(const TRTVersion& version) const; 57 58 // Returns true if we know at compile time that the algorithm selector 59 // should be required. This is a conservative estimate. 60 bool IsAlgorithmSelectorRequired() const; 61 62 static std::set<TacticID> GetBannedTRT72TuringTactics(); 63 64 private: 65 TRTVersion version_; 66 }; 67 68 // Impelements the TRT IAlgorithmSelector interface. The method 69 // "selectAlgorithms" selects allowable algorithms for each layer, and 70 // "reportAlgorithms" summarizes the algorithms selected by TensorRT. 71 class TftrtAlgorithmSelector : public nvinfer1::IAlgorithmSelector { 72 private: 73 using TacticID = AlgorithmSelectorImpl::TacticID; 74 75 // An index we should choose for all algorithms. Used for debugging. 76 std::optional<int32_t> fixed_algorithm_idx_; 77 78 AlgorithmSelectorImpl selector_; 79 80 public: 81 TftrtAlgorithmSelector(); 82 83 // If the environment variable TF_TRT_FIXED_ALGORITHM_ID is empty, this 84 // function returns nullopt. Otherwise, it returns the specified number. 85 static std::optional<int64_t> GetFixedAlgorithmID(); 86 87 // Returns true if the algorithm associated with context is acceptable. 88 bool AlgorithmPolicy(const nvinfer1::IAlgorithmContext& context, 89 const nvinfer1::IAlgorithm& alg) const; 90 91 // This function fills the array "selection" with the indices of selected 92 // algorithm candidates from "algoChoices", each of which is an implementation 93 // for the kernel described by the given IAlgorithmContext. It should return a 94 // number in [0, nbChoices] indicating the number of selected indices. If 0 is 95 // returned, TensorRT will use its default selection mechanism. 96 int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext, 97 const nvinfer1::IAlgorithm* const* algoChoices, 98 int32_t nbChoices, 99 int32_t* selection) noexcept override; 100 101 // Called by TensorRT to report choices it made. 102 void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts, 103 const nvinfer1::IAlgorithm* const* algoChoices, 104 int32_t nbAlgorithms) noexcept override; 105 IsRequired()106 bool IsRequired() const { 107 return selector_.IsAlgorithmSelectorRequired() || 108 fixed_algorithm_idx_ != std::nullopt; 109 } 110 }; 111 112 // Returns an initialized AlgorithmSelector if an algorithm selector is required 113 // for the current TRT version. Otherwise, returns nullptr. 114 std::unique_ptr<TftrtAlgorithmSelector> MaybeCreateAlgorithmSelector(); 115 116 } // namespace convert 117 } // namespace tensorrt 118 } // namespace tensorflow 119 120 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 121 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ 122