xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h"
17 
18 #include <utility>
19 
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
23 #include "tensorflow/core/util/env_var.h"
24 #include "third_party/tensorrt/NvInfer.h"
25 
26 // getAlgorithmIOInfo is deprecated in TRT >= 8, replaced by
27 // getAlgorithmIOInfoByIndex.
28 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
29 #define ALGORITHM_IO_INFO_BY_IDX(alg, idx) *(alg).getAlgorithmIOInfoByIndex(idx)
30 #else
31 #define ALGORITHM_IO_INFO_BY_IDX(alg, idx) (alg).getAlgorithmIOInfo(idx)
32 #endif
33 
34 namespace nvinfer1 {
35 
operator <<(std::ostream & os,const nvinfer1::IAlgorithmContext & ctx)36 std::ostream& operator<<(std::ostream& os,
37                          const nvinfer1::IAlgorithmContext& ctx) {
38   os << "AlgorithmContext(name=" << ctx.getName()
39      << ",nbInputs=" << ctx.getNbInputs() << ",nbOutputs=" << ctx.getNbOutputs()
40      << ")";
41   return os;
42 }
43 
operator <<(std::ostream & os,const nvinfer1::IAlgorithm & alg)44 std::ostream& operator<<(std::ostream& os, const nvinfer1::IAlgorithm& alg) {
45   const nvinfer1::IAlgorithmVariant& variant = alg.getAlgorithmVariant();
46   os << "Algorithm("
47      << "variant.implementation=" << variant.getImplementation()
48      << ",variant.tactic=" << variant.getTactic()
49      << ",timingMSec=" << alg.getTimingMSec()
50      << ",workspaceSize=" << alg.getWorkspaceSize() << ")";
51   return os;
52 }
53 
operator <<(std::ostream & os,const nvinfer1::IAlgorithmIOInfo & info)54 std::ostream& operator<<(std::ostream& os,
55                          const nvinfer1::IAlgorithmIOInfo& info) {
56   os << "IOTensor(format=" << info.getTensorFormat()
57      << ",dtype=" << info.getDataType() << ",strides=" << info.getStrides()
58      << ")";
59   return os;
60 }
61 }  // namespace nvinfer1
62 
63 namespace tensorflow {
64 namespace tensorrt {
65 namespace convert {
66 
operator >=(const AlgorithmSelectorImpl::TRTVersion & lhs,const AlgorithmSelectorImpl::TRTVersion & rhs)67 bool operator>=(const AlgorithmSelectorImpl::TRTVersion& lhs,
68                 const AlgorithmSelectorImpl::TRTVersion& rhs) {
69   if (lhs[0] > rhs[0]) return true;
70   if (lhs[0] == rhs[0] && lhs[1] > rhs[1]) return true;
71   if (lhs[0] == rhs[0] && lhs[1] == rhs[1] && lhs[2] > rhs[2]) return true;
72   if (lhs[0] == rhs[0] && lhs[1] == rhs[1] && lhs[2] == rhs[2] &&
73       lhs[3] >= rhs[3]) {
74     return true;
75   }
76   return false;
77 }
78 
IsTrtVersionGE(const TRTVersion & version) const79 bool AlgorithmSelectorImpl::IsTrtVersionGE(const TRTVersion& version) const {
80   return version_ >= version;
81 }
82 
IsShuffleLayer(ImplementationID id) const83 bool AlgorithmSelectorImpl::IsShuffleLayer(ImplementationID id) const {
84   if (IsTrtVersionGE({8, 2, 0, 0})) {
85     return id == 0x80000000 + 13;
86   }
87   if (IsTrtVersionGE({8, 0, 0, 0})) {
88     return id == 0x80000000 + 14;
89   }
90   if (IsTrtVersionGE({7, 2, 0, 0})) {
91     return id == 0x80000000 + 16;
92   }
93   return id == 18;
94 }
95 
96 std::set<AlgorithmSelectorImpl::TacticID>
GetBannedTRT72TuringTactics()97 AlgorithmSelectorImpl::GetBannedTRT72TuringTactics() {
98   static const std::set<TacticID> banned_turing_72{
99       // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_medium_nhwc_gelu_tn_v1
100       -5927686925093575778,
101       // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_interior_nhwc_gelu_tn_v1
102       -3848538574386518527,
103       // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_small_nhwc_gelu_tn_v1
104       -959009792490796596};
105   return banned_turing_72;
106 }
107 
IsBannedTactic(TacticID id) const108 bool AlgorithmSelectorImpl::IsBannedTactic(TacticID id) const {
109   // Disable problematic FP16-Turing tactics in TensorRT 7.2.
110   if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
111     auto banned_turing_72 = GetBannedTRT72TuringTactics();
112     return banned_turing_72.find(id) != banned_turing_72.end();
113   }
114   return false;
115 }
116 
AllowShuffleAlgorithm(TacticID tactic,nvinfer1::DataType input_dtype,nvinfer1::TensorFormat input_format) const117 bool AlgorithmSelectorImpl::AllowShuffleAlgorithm(
118     TacticID tactic, nvinfer1::DataType input_dtype,
119     nvinfer1::TensorFormat input_format) const {
120   if (IsTrtVersionGE({8, 0, 0, 0}) && !IsTrtVersionGE({8, 0, 3, 0})) {
121     // Reject shuffle node when input format is linear row major INT8
122     // format in TensorRT 8.0 GA.
123     return !(input_format == nvinfer1::TensorFormat::kLINEAR &&
124              input_dtype == nvinfer1::DataType::kINT8);
125   }
126 
127   if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
128     // For TRT 7.2, accept shuffle node when input format is not 32-wide
129     // channel vectorized row major FP32 format
130     return !(input_format == nvinfer1::TensorFormat::kCHW32 &&
131              input_dtype == nvinfer1::DataType::kFLOAT);
132   }
133   return true;
134 }
135 
IsAlgorithmSelectorRequired() const136 bool AlgorithmSelectorImpl::IsAlgorithmSelectorRequired() const {
137   // If we are in turing for TensorRT 7.2, we need the  selector for shuffle and
138   // avoiding specfic Turing tactics.
139   if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
140     return true;
141   }
142 
143   // If we are in TensorRT 8.0 GA, we want to reject certain types of shuffles.
144   if (IsTrtVersionGE({8, 0, 0, 0}) && !IsTrtVersionGE({8, 0, 3, 0})) {
145     return true;
146   }
147 
148   return false;
149 }
150 
151 namespace {
152 
FormatAlgorithmList(const nvinfer1::IAlgorithmContext & ctx,absl::Span<const nvinfer1::IAlgorithm * const> algs)153 string FormatAlgorithmList(const nvinfer1::IAlgorithmContext& ctx,
154                            absl::Span<const nvinfer1::IAlgorithm* const> algs) {
155   return absl::StrFormat(
156       "%s:\n\t%s", absl::FormatStreamed(ctx),
157       absl::StrJoin(
158           algs, "\n\t",
159           [&ctx](std::string* out, const nvinfer1::IAlgorithm* const alg) {
160             absl::StrAppendFormat(out, "%s", absl::FormatStreamed(*alg));
161             for (int i = 0; i < ctx.getNbInputs() + ctx.getNbOutputs(); i++) {
162               absl::StrAppendFormat(
163                   out, "\n\t\t%s",
164                   absl::FormatStreamed(ALGORITHM_IO_INFO_BY_IDX(*alg, i)));
165             }
166           }));
167 }
168 
169 }  // namespace
170 
TftrtAlgorithmSelector()171 TftrtAlgorithmSelector::TftrtAlgorithmSelector()
172     : fixed_algorithm_idx_(GetFixedAlgorithmID()),
173       selector_(AlgorithmSelectorImpl::CompileTimeTRTVersion()) {}
174 
GetFixedAlgorithmID()175 std::optional<int64_t> TftrtAlgorithmSelector::GetFixedAlgorithmID() {
176   int64_t trt_algorithm_idx = 0;
177   constexpr auto null_idx =
178       std::numeric_limits<decltype(trt_algorithm_idx)>::min();
179   Status status = tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
180                                                   /*default_val=*/null_idx,
181                                                   &trt_algorithm_idx);
182   if (!status.ok()) {
183     LOG(ERROR) << status;
184     return std::nullopt;
185   }
186   if (trt_algorithm_idx != null_idx) {
187     return std::max(static_cast<int32_t>(trt_algorithm_idx), 0);
188   }
189   return std::nullopt;
190 }
191 
AlgorithmPolicy(const nvinfer1::IAlgorithmContext & context,const nvinfer1::IAlgorithm & alg) const192 bool TftrtAlgorithmSelector::AlgorithmPolicy(
193     const nvinfer1::IAlgorithmContext& context,
194     const nvinfer1::IAlgorithm& alg) const {
195   const nvinfer1::IAlgorithmVariant& variant = alg.getAlgorithmVariant();
196 
197   // Check if this tactic ID is banned.
198   TacticID tactic_id = variant.getTactic();
199   if (selector_.IsBannedTactic(tactic_id)) {
200     return false;
201   }
202 
203   if (selector_.IsShuffleLayer(variant.getImplementation())) {
204     return selector_.AllowShuffleAlgorithm(
205         tactic_id, alg.getAlgorithmIOInfo(0).getDataType(),
206         alg.getAlgorithmIOInfo(0).getTensorFormat());
207   }
208   return true;
209 }
210 
selectAlgorithms(const nvinfer1::IAlgorithmContext & algoContext,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbChoices,int32_t * selection)211 int32_t TftrtAlgorithmSelector::selectAlgorithms(
212     const nvinfer1::IAlgorithmContext& algoContext,
213     const nvinfer1::IAlgorithm* const* algoChoices, int32_t nbChoices,
214     int32_t* selection) noexcept {
215   if (fixed_algorithm_idx_) {
216     LOG(WARNING) << "Forcing TRT algorithm selection to: ID = "
217                  << *fixed_algorithm_idx_;
218     selection[0] = std::min(*fixed_algorithm_idx_, nbChoices - 1);
219     return 1;
220   }
221 
222   int num_selections = 0;
223 
224   VLOG(1) << "Algorithm selection choices: "
225           << FormatAlgorithmList(algoContext,
226                                  absl::MakeSpan(algoChoices, nbChoices));
227 
228   for (int i = 0; i < nbChoices; i++) {
229     const nvinfer1::IAlgorithm& alg = *algoChoices[i];
230 
231     // Check layer-specific issues.
232     if (!AlgorithmPolicy(algoContext, alg)) {
233       LOG(WARNING) << absl::StrFormat("Rejecting Algorithm: %s ",
234                                       absl::FormatStreamed(alg));
235       continue;
236     }
237     selection[num_selections++] = i;
238   }
239   return num_selections;
240 }
241 
242 // Called by TensorRT to report choices it made.
reportAlgorithms(const nvinfer1::IAlgorithmContext * const * algoContexts,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbAlgorithms)243 void TftrtAlgorithmSelector::reportAlgorithms(
244     const nvinfer1::IAlgorithmContext* const* algoContexts,
245     const nvinfer1::IAlgorithm* const* algoChoices,
246     int32_t nbAlgorithms) noexcept {
247   if (VLOG_IS_ON(1)) {
248     string selection_msg = "Algorithms selected:\n";
249     for (int i = 0; i < nbAlgorithms; i++) {
250       absl::StrAppend(&selection_msg,
251                       FormatAlgorithmList(*algoContexts[i],
252                                           absl::MakeSpan(algoChoices + i, 1)));
253     }
254     VLOG(1) << selection_msg;
255   }
256 }
257 
MaybeCreateAlgorithmSelector()258 std::unique_ptr<TftrtAlgorithmSelector> MaybeCreateAlgorithmSelector() {
259   auto selector = std::make_unique<TftrtAlgorithmSelector>();
260 
261   if (selector->IsRequired()) {
262     return selector;
263   }
264 
265   return nullptr;
266 }
267 
268 }  // namespace convert
269 }  // namespace tensorrt
270 }  // namespace tensorflow
271 
272 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
273