1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h"
17
18 #include <utility>
19
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
23 #include "tensorflow/core/util/env_var.h"
24 #include "third_party/tensorrt/NvInfer.h"
25
26 // getAlgorithmIOInfo is deprecated in TRT >= 8, replaced by
27 // getAlgorithmIOInfoByIndex.
28 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
29 #define ALGORITHM_IO_INFO_BY_IDX(alg, idx) *(alg).getAlgorithmIOInfoByIndex(idx)
30 #else
31 #define ALGORITHM_IO_INFO_BY_IDX(alg, idx) (alg).getAlgorithmIOInfo(idx)
32 #endif
33
34 namespace nvinfer1 {
35
operator <<(std::ostream & os,const nvinfer1::IAlgorithmContext & ctx)36 std::ostream& operator<<(std::ostream& os,
37 const nvinfer1::IAlgorithmContext& ctx) {
38 os << "AlgorithmContext(name=" << ctx.getName()
39 << ",nbInputs=" << ctx.getNbInputs() << ",nbOutputs=" << ctx.getNbOutputs()
40 << ")";
41 return os;
42 }
43
operator <<(std::ostream & os,const nvinfer1::IAlgorithm & alg)44 std::ostream& operator<<(std::ostream& os, const nvinfer1::IAlgorithm& alg) {
45 const nvinfer1::IAlgorithmVariant& variant = alg.getAlgorithmVariant();
46 os << "Algorithm("
47 << "variant.implementation=" << variant.getImplementation()
48 << ",variant.tactic=" << variant.getTactic()
49 << ",timingMSec=" << alg.getTimingMSec()
50 << ",workspaceSize=" << alg.getWorkspaceSize() << ")";
51 return os;
52 }
53
operator <<(std::ostream & os,const nvinfer1::IAlgorithmIOInfo & info)54 std::ostream& operator<<(std::ostream& os,
55 const nvinfer1::IAlgorithmIOInfo& info) {
56 os << "IOTensor(format=" << info.getTensorFormat()
57 << ",dtype=" << info.getDataType() << ",strides=" << info.getStrides()
58 << ")";
59 return os;
60 }
61 } // namespace nvinfer1
62
63 namespace tensorflow {
64 namespace tensorrt {
65 namespace convert {
66
operator >=(const AlgorithmSelectorImpl::TRTVersion & lhs,const AlgorithmSelectorImpl::TRTVersion & rhs)67 bool operator>=(const AlgorithmSelectorImpl::TRTVersion& lhs,
68 const AlgorithmSelectorImpl::TRTVersion& rhs) {
69 if (lhs[0] > rhs[0]) return true;
70 if (lhs[0] == rhs[0] && lhs[1] > rhs[1]) return true;
71 if (lhs[0] == rhs[0] && lhs[1] == rhs[1] && lhs[2] > rhs[2]) return true;
72 if (lhs[0] == rhs[0] && lhs[1] == rhs[1] && lhs[2] == rhs[2] &&
73 lhs[3] >= rhs[3]) {
74 return true;
75 }
76 return false;
77 }
78
IsTrtVersionGE(const TRTVersion & version) const79 bool AlgorithmSelectorImpl::IsTrtVersionGE(const TRTVersion& version) const {
80 return version_ >= version;
81 }
82
IsShuffleLayer(ImplementationID id) const83 bool AlgorithmSelectorImpl::IsShuffleLayer(ImplementationID id) const {
84 if (IsTrtVersionGE({8, 2, 0, 0})) {
85 return id == 0x80000000 + 13;
86 }
87 if (IsTrtVersionGE({8, 0, 0, 0})) {
88 return id == 0x80000000 + 14;
89 }
90 if (IsTrtVersionGE({7, 2, 0, 0})) {
91 return id == 0x80000000 + 16;
92 }
93 return id == 18;
94 }
95
96 std::set<AlgorithmSelectorImpl::TacticID>
GetBannedTRT72TuringTactics()97 AlgorithmSelectorImpl::GetBannedTRT72TuringTactics() {
98 static const std::set<TacticID> banned_turing_72{
99 // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_medium_nhwc_gelu_tn_v1
100 -5927686925093575778,
101 // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_interior_nhwc_gelu_tn_v1
102 -3848538574386518527,
103 // turing_fp16_s1688cudnn_fp16_128x128_ldg8_relu_f2f_exp_small_nhwc_gelu_tn_v1
104 -959009792490796596};
105 return banned_turing_72;
106 }
107
IsBannedTactic(TacticID id) const108 bool AlgorithmSelectorImpl::IsBannedTactic(TacticID id) const {
109 // Disable problematic FP16-Turing tactics in TensorRT 7.2.
110 if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
111 auto banned_turing_72 = GetBannedTRT72TuringTactics();
112 return banned_turing_72.find(id) != banned_turing_72.end();
113 }
114 return false;
115 }
116
AllowShuffleAlgorithm(TacticID tactic,nvinfer1::DataType input_dtype,nvinfer1::TensorFormat input_format) const117 bool AlgorithmSelectorImpl::AllowShuffleAlgorithm(
118 TacticID tactic, nvinfer1::DataType input_dtype,
119 nvinfer1::TensorFormat input_format) const {
120 if (IsTrtVersionGE({8, 0, 0, 0}) && !IsTrtVersionGE({8, 0, 3, 0})) {
121 // Reject shuffle node when input format is linear row major INT8
122 // format in TensorRT 8.0 GA.
123 return !(input_format == nvinfer1::TensorFormat::kLINEAR &&
124 input_dtype == nvinfer1::DataType::kINT8);
125 }
126
127 if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
128 // For TRT 7.2, accept shuffle node when input format is not 32-wide
129 // channel vectorized row major FP32 format
130 return !(input_format == nvinfer1::TensorFormat::kCHW32 &&
131 input_dtype == nvinfer1::DataType::kFLOAT);
132 }
133 return true;
134 }
135
IsAlgorithmSelectorRequired() const136 bool AlgorithmSelectorImpl::IsAlgorithmSelectorRequired() const {
137 // If we are in turing for TensorRT 7.2, we need the selector for shuffle and
138 // avoiding specfic Turing tactics.
139 if (IsTrtVersionGE({7, 2, 0, 0}) && !IsTrtVersionGE({8, 0, 0, 0})) {
140 return true;
141 }
142
143 // If we are in TensorRT 8.0 GA, we want to reject certain types of shuffles.
144 if (IsTrtVersionGE({8, 0, 0, 0}) && !IsTrtVersionGE({8, 0, 3, 0})) {
145 return true;
146 }
147
148 return false;
149 }
150
151 namespace {
152
FormatAlgorithmList(const nvinfer1::IAlgorithmContext & ctx,absl::Span<const nvinfer1::IAlgorithm * const> algs)153 string FormatAlgorithmList(const nvinfer1::IAlgorithmContext& ctx,
154 absl::Span<const nvinfer1::IAlgorithm* const> algs) {
155 return absl::StrFormat(
156 "%s:\n\t%s", absl::FormatStreamed(ctx),
157 absl::StrJoin(
158 algs, "\n\t",
159 [&ctx](std::string* out, const nvinfer1::IAlgorithm* const alg) {
160 absl::StrAppendFormat(out, "%s", absl::FormatStreamed(*alg));
161 for (int i = 0; i < ctx.getNbInputs() + ctx.getNbOutputs(); i++) {
162 absl::StrAppendFormat(
163 out, "\n\t\t%s",
164 absl::FormatStreamed(ALGORITHM_IO_INFO_BY_IDX(*alg, i)));
165 }
166 }));
167 }
168
169 } // namespace
170
TftrtAlgorithmSelector()171 TftrtAlgorithmSelector::TftrtAlgorithmSelector()
172 : fixed_algorithm_idx_(GetFixedAlgorithmID()),
173 selector_(AlgorithmSelectorImpl::CompileTimeTRTVersion()) {}
174
GetFixedAlgorithmID()175 std::optional<int64_t> TftrtAlgorithmSelector::GetFixedAlgorithmID() {
176 int64_t trt_algorithm_idx = 0;
177 constexpr auto null_idx =
178 std::numeric_limits<decltype(trt_algorithm_idx)>::min();
179 Status status = tensorflow::ReadInt64FromEnvVar("TF_TRT_FIXED_ALGORITHM_ID",
180 /*default_val=*/null_idx,
181 &trt_algorithm_idx);
182 if (!status.ok()) {
183 LOG(ERROR) << status;
184 return std::nullopt;
185 }
186 if (trt_algorithm_idx != null_idx) {
187 return std::max(static_cast<int32_t>(trt_algorithm_idx), 0);
188 }
189 return std::nullopt;
190 }
191
AlgorithmPolicy(const nvinfer1::IAlgorithmContext & context,const nvinfer1::IAlgorithm & alg) const192 bool TftrtAlgorithmSelector::AlgorithmPolicy(
193 const nvinfer1::IAlgorithmContext& context,
194 const nvinfer1::IAlgorithm& alg) const {
195 const nvinfer1::IAlgorithmVariant& variant = alg.getAlgorithmVariant();
196
197 // Check if this tactic ID is banned.
198 TacticID tactic_id = variant.getTactic();
199 if (selector_.IsBannedTactic(tactic_id)) {
200 return false;
201 }
202
203 if (selector_.IsShuffleLayer(variant.getImplementation())) {
204 return selector_.AllowShuffleAlgorithm(
205 tactic_id, alg.getAlgorithmIOInfo(0).getDataType(),
206 alg.getAlgorithmIOInfo(0).getTensorFormat());
207 }
208 return true;
209 }
210
selectAlgorithms(const nvinfer1::IAlgorithmContext & algoContext,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbChoices,int32_t * selection)211 int32_t TftrtAlgorithmSelector::selectAlgorithms(
212 const nvinfer1::IAlgorithmContext& algoContext,
213 const nvinfer1::IAlgorithm* const* algoChoices, int32_t nbChoices,
214 int32_t* selection) noexcept {
215 if (fixed_algorithm_idx_) {
216 LOG(WARNING) << "Forcing TRT algorithm selection to: ID = "
217 << *fixed_algorithm_idx_;
218 selection[0] = std::min(*fixed_algorithm_idx_, nbChoices - 1);
219 return 1;
220 }
221
222 int num_selections = 0;
223
224 VLOG(1) << "Algorithm selection choices: "
225 << FormatAlgorithmList(algoContext,
226 absl::MakeSpan(algoChoices, nbChoices));
227
228 for (int i = 0; i < nbChoices; i++) {
229 const nvinfer1::IAlgorithm& alg = *algoChoices[i];
230
231 // Check layer-specific issues.
232 if (!AlgorithmPolicy(algoContext, alg)) {
233 LOG(WARNING) << absl::StrFormat("Rejecting Algorithm: %s ",
234 absl::FormatStreamed(alg));
235 continue;
236 }
237 selection[num_selections++] = i;
238 }
239 return num_selections;
240 }
241
242 // Called by TensorRT to report choices it made.
reportAlgorithms(const nvinfer1::IAlgorithmContext * const * algoContexts,const nvinfer1::IAlgorithm * const * algoChoices,int32_t nbAlgorithms)243 void TftrtAlgorithmSelector::reportAlgorithms(
244 const nvinfer1::IAlgorithmContext* const* algoContexts,
245 const nvinfer1::IAlgorithm* const* algoChoices,
246 int32_t nbAlgorithms) noexcept {
247 if (VLOG_IS_ON(1)) {
248 string selection_msg = "Algorithms selected:\n";
249 for (int i = 0; i < nbAlgorithms; i++) {
250 absl::StrAppend(&selection_msg,
251 FormatAlgorithmList(*algoContexts[i],
252 absl::MakeSpan(algoChoices + i, 1)));
253 }
254 VLOG(1) << selection_msg;
255 }
256 }
257
MaybeCreateAlgorithmSelector()258 std::unique_ptr<TftrtAlgorithmSelector> MaybeCreateAlgorithmSelector() {
259 auto selector = std::make_unique<TftrtAlgorithmSelector>();
260
261 if (selector->IsRequired()) {
262 return selector;
263 }
264
265 return nullptr;
266 }
267
268 } // namespace convert
269 } // namespace tensorrt
270 } // namespace tensorflow
271
272 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
273