1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17
18 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
19
20 #include "tensorflow/cc/ops/resource_variable_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/state_ops.h"
23 #include "tensorflow/cc/saved_model/loader.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 #include "tensorflow/core/public/session.h"
30
31 namespace tensorflow {
32 namespace tensorrt {
33
34 struct TestParam {
35 TfTrtConversionParams conv_params;
36 std::vector<std::vector<int64>> input_shapes;
37 };
38
39 class TrtConverterTest
40 : public ::testing::TestWithParam<std::tuple<TestParam, bool, bool>> {
41 protected:
TrtConverterTest()42 TrtConverterTest() {
43 param_ = std::get<0>(GetParam());
44 use_variable_ = std::get<1>(GetParam());
45 use_function_ = std::get<2>(GetParam());
46 input_tensors_ = GetInputTensors();
47 }
48
49 // Returns the following graph: output = input * [42, 137] + input
GetGraphDef(PartialTensorShape input_shape)50 GraphDef GetGraphDef(PartialTensorShape input_shape) {
51 Scope root = Scope::NewRootScope();
52 Output c;
53 c = ops::Const(root.WithOpName("my_const"), {{42.0f, 137.0f}});
54 Output v;
55 if (use_variable_) {
56 Output v_handle = ops::VarHandleOp(root.WithOpName("my_var"),
57 DataType::DT_FLOAT, {1, 2});
58 v = ops::ReadVariableOp(root.WithOpName("my_var/Read/ReadVariableOp"),
59 v_handle, DataType::DT_FLOAT);
60 auto v_init =
61 ops::AssignVariableOp(root.WithOpName("my_var/init"), v_handle, c);
62 } else {
63 v = c;
64 }
65 const auto attrs = ops::Placeholder::Shape(input_shape);
66 auto x = ops::Placeholder(root.WithOpName("input"), DT_FLOAT, attrs);
67 auto y = ops::Mul(root.WithOpName("my_mul"), x, v);
68 auto z = ops::Add(root.WithOpName("my_add"), x, y);
69 auto q = ops::Identity(root.WithOpName("output"), z);
70
71 GraphDef out;
72 TF_CHECK_OK(root.ToGraphDef(&out));
73 return out;
74 }
75
GetGraphWithFunction(PartialTensorShape input_shape)76 GraphDef GetGraphWithFunction(PartialTensorShape input_shape) {
77 using ::tensorflow::test::function::GDef;
78 using ::tensorflow::test::function::NDef;
79 GraphConstructorOptions opts;
80 const Tensor kOne = test::AsScalar<float>(1.0f);
81 TensorShapeProto value_shape_proto;
82 kOne.shape().AsProto(&value_shape_proto);
83 TensorShapeProto input_shape_proto;
84 input_shape.AsProto(&input_shape_proto);
85 NodeDef value_node;
86 if (use_variable_) {
87 value_node =
88 NDef("my_value", "Identity", {"my_var:0"}, {{"T", DT_RESOURCE}});
89 } else {
90 value_node =
91 NDef("my_value", "Identity", {"my_const:0"}, {{"T", DT_FLOAT}});
92 }
93 GraphDef gdef = GDef(
94 {
95 NDef("input", "Placeholder", {},
96 {{"dtype", DT_FLOAT}, {"shape", input_shape_proto}}),
97 NDef("my_const", "Const", {},
98 {{"dtype", DT_FLOAT}, {"value", kOne}}),
99 value_node,
100 NDef("call", "StatefulPartitionedCall", {"input", "my_value"},
101 {{"Tin", DataTypeSlice{DT_FLOAT, use_variable_ ? DT_RESOURCE
102 : DT_FLOAT}},
103 {"Tout", DataTypeSlice{DT_FLOAT}},
104 {"f", FunctionDefHelper::FunctionRef("f", {})}}),
105 NDef("output", "Identity", {"call:0"}, {{"T", DT_FLOAT}}),
106 },
107 {});
108 FunctionDef fdef;
109 if (use_variable_) {
110 gdef.add_node()->CopyFrom(
111 NDef("my_var", "VarHandleOp", {},
112 {{"dtype", DT_FLOAT}, {"shape", value_shape_proto}}));
113
114 gdef.add_node()->CopyFrom(NDef("my_var/init", "AssignVariableOp",
115 {"my_var", "my_const"},
116 {{"dtype", DT_FLOAT}}));
117 gdef.add_node()->CopyFrom(NDef("my_var/Read/ReadVariableOp",
118 "ReadVariableOp", {"my_var"},
119 {{"dtype", DT_FLOAT}}));
120 // Define function f(x, v) = x * v + x, where v is a variable.
121 fdef = FunctionDefHelper::Define(
122 "f", // Name
123 {"x: float", "v: resource"}, // Args
124 {"q: float"}, // Returns
125 {}, // Attr def
126 // Nodes
127 {{{"my_var/Read/ReadVariableOp"},
128 "ReadVariableOp",
129 {"v"},
130 {{"dtype", DT_FLOAT}}},
131 {{"my_mul"},
132 "Mul",
133 {"x", "my_var/Read/ReadVariableOp"},
134 {{"T", DT_FLOAT}}},
135 {{"my_add"}, "AddV2", {"x", "my_mul"}, {{"T", DT_FLOAT}}},
136 {{"q"}, "Identity", {"my_add"}, {{"T", DT_FLOAT}}}});
137 } else {
138 // Define function f(x, v) = x * v + x, where v is const value.
139 fdef = FunctionDefHelper::Define(
140 "f", // Name
141 {"x: float", "v: float"}, // Args
142 {"q: float"}, // Returns
143 {}, // Attr def
144 // Nodes
145 {{{"my_mul"}, "Mul", {"x", "v"}, {{"T", DT_FLOAT}}},
146 {{"my_add"}, "AddV2", {"x", "my_mul"}, {{"T", DT_FLOAT}}},
147 {{"q"}, "Identity", {"my_add"}, {{"T", DT_FLOAT}}}});
148 }
149 gdef.mutable_library()->add_function()->CopyFrom(fdef);
150
151 return gdef;
152 }
153
154 // Returns the following graph: output = input * [42, 137] + input
GetModel()155 MetaGraphDef GetModel() {
156 PartialTensorShape shape({-1, 2});
157 MetaGraphDef out;
158 if (use_function_) {
159 *(out.mutable_graph_def()) = GetGraphWithFunction(shape);
160 } else {
161 *(out.mutable_graph_def()) = GetGraphDef(shape);
162 }
163 VLOG(2) << out.graph_def().DebugString();
164 TensorShapeProto shape_proto;
165 shape.AsProto(&shape_proto);
166 SignatureDef signature_def;
167 (*signature_def.mutable_inputs())["input"].set_name("input:0");
168 (*signature_def.mutable_inputs())["input"].set_dtype(DT_FLOAT);
169 (*signature_def.mutable_inputs())["input"].mutable_tensor_shape()->CopyFrom(
170 shape_proto);
171 (*signature_def.mutable_outputs())["output"].set_name("output:0");
172 (*signature_def.mutable_outputs())["output"].set_dtype(DT_FLOAT);
173 (*signature_def.mutable_outputs())["output"]
174 .mutable_tensor_shape()
175 ->CopyFrom(shape_proto);
176 (*out.mutable_signature_def())["serving_default"] = signature_def;
177
178 VLOG(2) << signature_def.DebugString();
179 return out;
180 }
181
GetSavedModelBundle(SavedModelBundle * bundle)182 Status GetSavedModelBundle(SavedModelBundle* bundle) {
183 bundle->meta_graph_def = GetModel();
184 Session* session = nullptr;
185 TF_RETURN_IF_ERROR(NewSession(tensorflow::SessionOptions(), &session));
186 TF_RETURN_IF_ERROR(session->Create(bundle->meta_graph_def.graph_def()));
187 bundle->session.reset(session);
188 TF_RETURN_IF_ERROR(session->Run(/* inputs */ {}, /*outputs*/ {},
189 /*targets*/ {"my_var/init"}, nullptr));
190 return OkStatus();
191 }
192
193 // Confirms that we have a TRT node with the correct attributes.
CheckTrtNode(const GraphDef & converted_graph_def)194 void CheckTrtNode(const GraphDef& converted_graph_def) {
195 int n_trt_ops = 0;
196 string op_name{"TRTEngineOp"};
197 for (const auto& node : converted_graph_def.node()) {
198 if (!op_name.compare(node.op())) {
199 n_trt_ops++;
200 const auto& attr = node.attr();
201 EXPECT_EQ(attr.at("static_engine").b(),
202 param_.conv_params.convert_to_static_engine);
203 if (param_.conv_params.convert_to_static_engine) {
204 VLOG(2) << "Found serialized segment with size "
205 << attr.at("serialized_segment").s().size();
206 EXPECT_GT(attr.at("serialized_segment").s().size(), 0);
207 }
208 }
209 }
210 EXPECT_EQ(n_trt_ops, 1);
211 }
212
213 // Creates a list of input tensors, they will be used to build the engines.
GetInputTensors()214 std::vector<std::vector<Tensor>> GetInputTensors() {
215 std::vector<std::vector<Tensor>> input_tensors;
216 for (const std::vector<int64>& shape : param_.input_shapes) {
217 Tensor tensor(DT_FLOAT, TensorShape(shape));
218 test::FillIota(&tensor, 1.0f);
219 input_tensors.push_back({tensor});
220 }
221 return input_tensors;
222 }
223
RunAndCompareResults(Session * session,const GraphDef & converted_graph_def)224 void RunAndCompareResults(Session* session,
225 const GraphDef& converted_graph_def) {
226 // Create a session to execute the converted graph.
227 Session* p_session = nullptr;
228 TF_EXPECT_OK(NewSession(SessionOptions(), &p_session));
229 std::unique_ptr<tensorflow::Session> trt_session(p_session);
230 TF_EXPECT_OK(trt_session->Create(converted_graph_def));
231
232 // Run models and compare the output.
233 for (const std::vector<Tensor>& input : input_tensors_) {
234 std::vector<Tensor> outputs;
235 TF_EXPECT_OK(
236 session->Run({{"input", input.at(0)}}, {"output"}, {}, &outputs));
237 std::cout << outputs.at(0).DebugString() << std::endl;
238
239 std::vector<Tensor> trt_outputs;
240 TF_EXPECT_OK(trt_session->Run({{"input", input.at(0)}}, {"output"}, {},
241 &trt_outputs));
242 std::cout << trt_outputs.at(0).DebugString() << std::endl;
243 ASSERT_EQ(outputs.size(), 1);
244 ASSERT_EQ(trt_outputs.size(), 1);
245 tensorflow::test::ExpectEqual(outputs[0], trt_outputs[0]);
246 }
247 }
248
ConvertAndRunFrozenGraph()249 void ConvertAndRunFrozenGraph() {
250 MetaGraphDef meta_graph_def = GetModel();
251
252 StatusOr<GraphDef> result = tensorrt::ConvertAndBuild(
253 meta_graph_def.graph_def(), {"input"}, {"output"}, input_tensors_,
254 param_.conv_params);
255 TF_ASSERT_OK(result.status());
256 const GraphDef& converted_graph_def = result.ValueOrDie();
257 CheckTrtNode(converted_graph_def);
258
259 // Create a session to execute the original graph.
260 Session* p_session = nullptr;
261 TF_EXPECT_OK(NewSession(SessionOptions(), &p_session));
262 std::unique_ptr<tensorflow::Session> session(p_session);
263 TF_EXPECT_OK(session->Create(meta_graph_def.graph_def()));
264
265 RunAndCompareResults(session.get(), converted_graph_def);
266 }
267
ConvertAndRunSavedModel()268 void ConvertAndRunSavedModel() {
269 SavedModelBundle bundle;
270 TF_CHECK_OK(GetSavedModelBundle(&bundle));
271
272 StatusOr<GraphDef> result = tensorrt::ConvertAndBuild(
273 &bundle, "serving_default", input_tensors_, param_.conv_params);
274 TF_ASSERT_OK(result.status());
275 const GraphDef& converted_graph_def = result.ValueOrDie();
276 CheckTrtNode(converted_graph_def);
277
278 RunAndCompareResults(bundle.GetSession(), converted_graph_def);
279 }
280
281 TestParam param_;
282 bool use_variable_;
283 bool use_function_;
284 std::vector<std::vector<Tensor>> input_tensors_;
285 };
286
287 INSTANTIATE_TEST_CASE_P(
288 TrtConverterTestInstantiation, TrtConverterTest,
289 ::testing::Combine(
290 ::testing::Values(
291 // Dynamic shape mode test with conver_to_static_engine=true.
292 TestParam{TfTrtConversionParams{
293 1 << 20, // max workspace size
294 TrtPrecisionMode::FP32,
295 3, // minimum_segment_size
296 1, // max_cached_engines
297 false, // use_calibration
298 true, // use_dynamic_shape
299 ProfileStrategy::kOptimal,
300 true, // allow_build_at_runtime
301 true // convert_to_static_engine
302 },
303 {{1, 2}, {4, 2}}},
304 // Implicit batch mode test with conver_to_static_engine=true.
305 TestParam{TfTrtConversionParams{
306 1 << 20, // max workspace size
307 TrtPrecisionMode::FP16,
308 3, // minimum_segment_size
309 1, // max_cached_engines
310 false, // use_calibration
311 false, // use_dynamic_shape
312 ProfileStrategy::kRange,
313 true, // allow_build_at_runtime
314 true // convert_to_static_engine
315 },
316 {{1, 2}}},
317 // Dynamic shape mode test convert_to_static_engine=false: we cannot
318 // save the engines, therefore we do not generate profiles. A single
319 // engine will be built during runtime, with profile that matches
320 // the first shape ({1,2}). The second shape will run as native
321 // segment.
322 TestParam{TfTrtConversionParams{
323 1 << 20, // max workspace size
324 TrtPrecisionMode::FP32,
325 3, // minimum_segment_size
326 1, // max_cached_engines
327 false, // use_calibration
328 true, // use_dynamic_shape
329 ProfileStrategy::kOptimal,
330 true, // allow_build_at_runtime
331 false // convert_to_static_engine
332 },
333 {{1, 2}, {4, 2}}},
334 // Implicit batch mode test with convert_to_static_engine=false.
335 // We will have two engines in the cache to handle the two shapes.
336 TestParam{TfTrtConversionParams{
337 1 << 20, // max workspace size
338 TrtPrecisionMode::FP16,
339 3, // minimum_segment_size
340 2, // max_cached_engines
341 false, // use_calibration
342 false, // use_dynamic_shape
343 ProfileStrategy::kRange,
344 true, // allow_build_at_runtime
345 false // convert_to_static_engine
346 },
347 {{1, 2}, {4, 2}}}),
348 ::testing::Values(false, true), // use_variables
349 ::testing::Values(false, true))); // use_function
350
TEST_P(TrtConverterTest,Basic)351 TEST_P(TrtConverterTest, Basic) {
352 if (use_variable_) {
353 ConvertAndRunSavedModel();
354 } else {
355 ConvertAndRunFrozenGraph();
356 }
357 }
358
359 } // namespace tensorrt
360 } // namespace tensorflow
361
362 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
363