xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_
16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_
17 #if GOOGLE_CUDA && GOOGLE_TENSORRT
18 #include <array>
19 #include <memory>
20 #include <set>
21 
22 #include "absl/types/optional.h"
23 #include "third_party/tensorrt/NvInfer.h"
24 
25 namespace tensorflow {
26 namespace tensorrt {
27 namespace convert {
28 
29 // Implements core algorithm selection logic in a testable manner. The policy
30 // implemented depends on the given TRT version. We have this class because TRT
31 // interfaces make it difficult to directly test an IAlgorithmSelector
32 // implementation.
33 class AlgorithmSelectorImpl {
34  public:
35   using TRTVersion = std::array<int, 4>;
36   using ImplementationID = int64_t;
37   using TacticID = int64_t;
38 
CompileTimeTRTVersion()39   static constexpr TRTVersion CompileTimeTRTVersion() {
40     return TRTVersion{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH,
41                       NV_TENSORRT_BUILD};
42   }
43 
44   explicit AlgorithmSelectorImpl(
45       const TRTVersion& version = CompileTimeTRTVersion())
version_(version)46       : version_(version) {}
47 
48   bool IsShuffleLayer(ImplementationID id) const;
49 
50   bool IsBannedTactic(TacticID id) const;
51 
52   // Returns true if the algorithm implementing the IShuffleLayer is acceptable.
53   bool AllowShuffleAlgorithm(TacticID tactic, nvinfer1::DataType input_dtype,
54                              nvinfer1::TensorFormat input_format) const;
55 
56   bool IsTrtVersionGE(const TRTVersion& version) const;
57 
58   // Returns true if we know at compile time that the algorithm selector
59   // should be required. This is a conservative estimate.
60   bool IsAlgorithmSelectorRequired() const;
61 
62   static std::set<TacticID> GetBannedTRT72TuringTactics();
63 
64  private:
65   TRTVersion version_;
66 };
67 
68 // Impelements the TRT IAlgorithmSelector interface. The method
69 // "selectAlgorithms" selects allowable algorithms for each layer, and
70 // "reportAlgorithms" summarizes the algorithms selected by TensorRT.
71 class TftrtAlgorithmSelector : public nvinfer1::IAlgorithmSelector {
72  private:
73   using TacticID = AlgorithmSelectorImpl::TacticID;
74 
75   // An index we should choose for all algorithms. Used for debugging.
76   std::optional<int32_t> fixed_algorithm_idx_;
77 
78   AlgorithmSelectorImpl selector_;
79 
80  public:
81   TftrtAlgorithmSelector();
82 
83   // If the environment variable TF_TRT_FIXED_ALGORITHM_ID is empty, this
84   // function returns nullopt. Otherwise, it returns the specified number.
85   static std::optional<int64_t> GetFixedAlgorithmID();
86 
87   // Returns true if the algorithm associated with context is acceptable.
88   bool AlgorithmPolicy(const nvinfer1::IAlgorithmContext& context,
89                        const nvinfer1::IAlgorithm& alg) const;
90 
91   // This function fills the array "selection" with the indices of selected
92   // algorithm candidates from "algoChoices", each of which is an implementation
93   // for the kernel described by the given IAlgorithmContext. It should return a
94   // number in [0, nbChoices] indicating the number of selected indices. If 0 is
95   // returned, TensorRT will use its default selection mechanism.
96   int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext,
97                            const nvinfer1::IAlgorithm* const* algoChoices,
98                            int32_t nbChoices,
99                            int32_t* selection) noexcept override;
100 
101   // Called by TensorRT to report choices it made.
102   void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts,
103                         const nvinfer1::IAlgorithm* const* algoChoices,
104                         int32_t nbAlgorithms) noexcept override;
105 
IsRequired()106   bool IsRequired() const {
107     return selector_.IsAlgorithmSelectorRequired() ||
108            fixed_algorithm_idx_ != std::nullopt;
109   }
110 };
111 
112 // Returns an initialized AlgorithmSelector if an algorithm selector is required
113 // for the current TRT version. Otherwise, returns nullptr.
114 std::unique_ptr<TftrtAlgorithmSelector> MaybeCreateAlgorithmSelector();
115 
116 }  // namespace convert
117 }  // namespace tensorrt
118 }  // namespace tensorflow
119 
120 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
121 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_
122