xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/trt_convert_api_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include "tensorflow/compiler/tf2tensorrt/trt_convert_api.h"
19 
20 #include "tensorflow/cc/ops/resource_variable_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/state_ops.h"
23 #include "tensorflow/cc/saved_model/loader.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 #include "tensorflow/core/public/session.h"
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 
34 struct TestParam {
35   TfTrtConversionParams conv_params;
36   std::vector<std::vector<int64>> input_shapes;
37 };
38 
39 class TrtConverterTest
40     : public ::testing::TestWithParam<std::tuple<TestParam, bool, bool>> {
41  protected:
TrtConverterTest()42   TrtConverterTest() {
43     param_ = std::get<0>(GetParam());
44     use_variable_ = std::get<1>(GetParam());
45     use_function_ = std::get<2>(GetParam());
46     input_tensors_ = GetInputTensors();
47   }
48 
49   // Returns the following graph: output = input * [42, 137] + input
GetGraphDef(PartialTensorShape input_shape)50   GraphDef GetGraphDef(PartialTensorShape input_shape) {
51     Scope root = Scope::NewRootScope();
52     Output c;
53     c = ops::Const(root.WithOpName("my_const"), {{42.0f, 137.0f}});
54     Output v;
55     if (use_variable_) {
56       Output v_handle = ops::VarHandleOp(root.WithOpName("my_var"),
57                                          DataType::DT_FLOAT, {1, 2});
58       v = ops::ReadVariableOp(root.WithOpName("my_var/Read/ReadVariableOp"),
59                               v_handle, DataType::DT_FLOAT);
60       auto v_init =
61           ops::AssignVariableOp(root.WithOpName("my_var/init"), v_handle, c);
62     } else {
63       v = c;
64     }
65     const auto attrs = ops::Placeholder::Shape(input_shape);
66     auto x = ops::Placeholder(root.WithOpName("input"), DT_FLOAT, attrs);
67     auto y = ops::Mul(root.WithOpName("my_mul"), x, v);
68     auto z = ops::Add(root.WithOpName("my_add"), x, y);
69     auto q = ops::Identity(root.WithOpName("output"), z);
70 
71     GraphDef out;
72     TF_CHECK_OK(root.ToGraphDef(&out));
73     return out;
74   }
75 
GetGraphWithFunction(PartialTensorShape input_shape)76   GraphDef GetGraphWithFunction(PartialTensorShape input_shape) {
77     using ::tensorflow::test::function::GDef;
78     using ::tensorflow::test::function::NDef;
79     GraphConstructorOptions opts;
80     const Tensor kOne = test::AsScalar<float>(1.0f);
81     TensorShapeProto value_shape_proto;
82     kOne.shape().AsProto(&value_shape_proto);
83     TensorShapeProto input_shape_proto;
84     input_shape.AsProto(&input_shape_proto);
85     NodeDef value_node;
86     if (use_variable_) {
87       value_node =
88           NDef("my_value", "Identity", {"my_var:0"}, {{"T", DT_RESOURCE}});
89     } else {
90       value_node =
91           NDef("my_value", "Identity", {"my_const:0"}, {{"T", DT_FLOAT}});
92     }
93     GraphDef gdef = GDef(
94         {
95             NDef("input", "Placeholder", {},
96                  {{"dtype", DT_FLOAT}, {"shape", input_shape_proto}}),
97             NDef("my_const", "Const", {},
98                  {{"dtype", DT_FLOAT}, {"value", kOne}}),
99             value_node,
100             NDef("call", "StatefulPartitionedCall", {"input", "my_value"},
101                  {{"Tin", DataTypeSlice{DT_FLOAT, use_variable_ ? DT_RESOURCE
102                                                                 : DT_FLOAT}},
103                   {"Tout", DataTypeSlice{DT_FLOAT}},
104                   {"f", FunctionDefHelper::FunctionRef("f", {})}}),
105             NDef("output", "Identity", {"call:0"}, {{"T", DT_FLOAT}}),
106         },
107         {});
108     FunctionDef fdef;
109     if (use_variable_) {
110       gdef.add_node()->CopyFrom(
111           NDef("my_var", "VarHandleOp", {},
112                {{"dtype", DT_FLOAT}, {"shape", value_shape_proto}}));
113 
114       gdef.add_node()->CopyFrom(NDef("my_var/init", "AssignVariableOp",
115                                      {"my_var", "my_const"},
116                                      {{"dtype", DT_FLOAT}}));
117       gdef.add_node()->CopyFrom(NDef("my_var/Read/ReadVariableOp",
118                                      "ReadVariableOp", {"my_var"},
119                                      {{"dtype", DT_FLOAT}}));
120       // Define function f(x, v) = x * v + x, where v is a variable.
121       fdef = FunctionDefHelper::Define(
122           "f",                          // Name
123           {"x: float", "v: resource"},  // Args
124           {"q: float"},                 // Returns
125           {},                           // Attr def
126           // Nodes
127           {{{"my_var/Read/ReadVariableOp"},
128             "ReadVariableOp",
129             {"v"},
130             {{"dtype", DT_FLOAT}}},
131            {{"my_mul"},
132             "Mul",
133             {"x", "my_var/Read/ReadVariableOp"},
134             {{"T", DT_FLOAT}}},
135            {{"my_add"}, "AddV2", {"x", "my_mul"}, {{"T", DT_FLOAT}}},
136            {{"q"}, "Identity", {"my_add"}, {{"T", DT_FLOAT}}}});
137     } else {
138       // Define function f(x, v) = x * v + x, where v is const value.
139       fdef = FunctionDefHelper::Define(
140           "f",                       // Name
141           {"x: float", "v: float"},  // Args
142           {"q: float"},              // Returns
143           {},                        // Attr def
144           // Nodes
145           {{{"my_mul"}, "Mul", {"x", "v"}, {{"T", DT_FLOAT}}},
146            {{"my_add"}, "AddV2", {"x", "my_mul"}, {{"T", DT_FLOAT}}},
147            {{"q"}, "Identity", {"my_add"}, {{"T", DT_FLOAT}}}});
148     }
149     gdef.mutable_library()->add_function()->CopyFrom(fdef);
150 
151     return gdef;
152   }
153 
154   // Returns the following graph: output = input * [42, 137] + input
GetModel()155   MetaGraphDef GetModel() {
156     PartialTensorShape shape({-1, 2});
157     MetaGraphDef out;
158     if (use_function_) {
159       *(out.mutable_graph_def()) = GetGraphWithFunction(shape);
160     } else {
161       *(out.mutable_graph_def()) = GetGraphDef(shape);
162     }
163     VLOG(2) << out.graph_def().DebugString();
164     TensorShapeProto shape_proto;
165     shape.AsProto(&shape_proto);
166     SignatureDef signature_def;
167     (*signature_def.mutable_inputs())["input"].set_name("input:0");
168     (*signature_def.mutable_inputs())["input"].set_dtype(DT_FLOAT);
169     (*signature_def.mutable_inputs())["input"].mutable_tensor_shape()->CopyFrom(
170         shape_proto);
171     (*signature_def.mutable_outputs())["output"].set_name("output:0");
172     (*signature_def.mutable_outputs())["output"].set_dtype(DT_FLOAT);
173     (*signature_def.mutable_outputs())["output"]
174         .mutable_tensor_shape()
175         ->CopyFrom(shape_proto);
176     (*out.mutable_signature_def())["serving_default"] = signature_def;
177 
178     VLOG(2) << signature_def.DebugString();
179     return out;
180   }
181 
GetSavedModelBundle(SavedModelBundle * bundle)182   Status GetSavedModelBundle(SavedModelBundle* bundle) {
183     bundle->meta_graph_def = GetModel();
184     Session* session = nullptr;
185     TF_RETURN_IF_ERROR(NewSession(tensorflow::SessionOptions(), &session));
186     TF_RETURN_IF_ERROR(session->Create(bundle->meta_graph_def.graph_def()));
187     bundle->session.reset(session);
188     TF_RETURN_IF_ERROR(session->Run(/* inputs */ {}, /*outputs*/ {},
189                                     /*targets*/ {"my_var/init"}, nullptr));
190     return OkStatus();
191   }
192 
193   // Confirms that we have a TRT node with the correct attributes.
CheckTrtNode(const GraphDef & converted_graph_def)194   void CheckTrtNode(const GraphDef& converted_graph_def) {
195     int n_trt_ops = 0;
196     string op_name{"TRTEngineOp"};
197     for (const auto& node : converted_graph_def.node()) {
198       if (!op_name.compare(node.op())) {
199         n_trt_ops++;
200         const auto& attr = node.attr();
201         EXPECT_EQ(attr.at("static_engine").b(),
202                   param_.conv_params.convert_to_static_engine);
203         if (param_.conv_params.convert_to_static_engine) {
204           VLOG(2) << "Found serialized segment with size "
205                   << attr.at("serialized_segment").s().size();
206           EXPECT_GT(attr.at("serialized_segment").s().size(), 0);
207         }
208       }
209     }
210     EXPECT_EQ(n_trt_ops, 1);
211   }
212 
213   // Creates a list of input tensors, they will be used to build the engines.
GetInputTensors()214   std::vector<std::vector<Tensor>> GetInputTensors() {
215     std::vector<std::vector<Tensor>> input_tensors;
216     for (const std::vector<int64>& shape : param_.input_shapes) {
217       Tensor tensor(DT_FLOAT, TensorShape(shape));
218       test::FillIota(&tensor, 1.0f);
219       input_tensors.push_back({tensor});
220     }
221     return input_tensors;
222   }
223 
RunAndCompareResults(Session * session,const GraphDef & converted_graph_def)224   void RunAndCompareResults(Session* session,
225                             const GraphDef& converted_graph_def) {
226     // Create a session to execute the converted graph.
227     Session* p_session = nullptr;
228     TF_EXPECT_OK(NewSession(SessionOptions(), &p_session));
229     std::unique_ptr<tensorflow::Session> trt_session(p_session);
230     TF_EXPECT_OK(trt_session->Create(converted_graph_def));
231 
232     // Run models and compare the output.
233     for (const std::vector<Tensor>& input : input_tensors_) {
234       std::vector<Tensor> outputs;
235       TF_EXPECT_OK(
236           session->Run({{"input", input.at(0)}}, {"output"}, {}, &outputs));
237       std::cout << outputs.at(0).DebugString() << std::endl;
238 
239       std::vector<Tensor> trt_outputs;
240       TF_EXPECT_OK(trt_session->Run({{"input", input.at(0)}}, {"output"}, {},
241                                     &trt_outputs));
242       std::cout << trt_outputs.at(0).DebugString() << std::endl;
243       ASSERT_EQ(outputs.size(), 1);
244       ASSERT_EQ(trt_outputs.size(), 1);
245       tensorflow::test::ExpectEqual(outputs[0], trt_outputs[0]);
246     }
247   }
248 
ConvertAndRunFrozenGraph()249   void ConvertAndRunFrozenGraph() {
250     MetaGraphDef meta_graph_def = GetModel();
251 
252     StatusOr<GraphDef> result = tensorrt::ConvertAndBuild(
253         meta_graph_def.graph_def(), {"input"}, {"output"}, input_tensors_,
254         param_.conv_params);
255     TF_ASSERT_OK(result.status());
256     const GraphDef& converted_graph_def = result.ValueOrDie();
257     CheckTrtNode(converted_graph_def);
258 
259     // Create a session to execute the original graph.
260     Session* p_session = nullptr;
261     TF_EXPECT_OK(NewSession(SessionOptions(), &p_session));
262     std::unique_ptr<tensorflow::Session> session(p_session);
263     TF_EXPECT_OK(session->Create(meta_graph_def.graph_def()));
264 
265     RunAndCompareResults(session.get(), converted_graph_def);
266   }
267 
ConvertAndRunSavedModel()268   void ConvertAndRunSavedModel() {
269     SavedModelBundle bundle;
270     TF_CHECK_OK(GetSavedModelBundle(&bundle));
271 
272     StatusOr<GraphDef> result = tensorrt::ConvertAndBuild(
273         &bundle, "serving_default", input_tensors_, param_.conv_params);
274     TF_ASSERT_OK(result.status());
275     const GraphDef& converted_graph_def = result.ValueOrDie();
276     CheckTrtNode(converted_graph_def);
277 
278     RunAndCompareResults(bundle.GetSession(), converted_graph_def);
279   }
280 
281   TestParam param_;
282   bool use_variable_;
283   bool use_function_;
284   std::vector<std::vector<Tensor>> input_tensors_;
285 };
286 
287 INSTANTIATE_TEST_CASE_P(
288     TrtConverterTestInstantiation, TrtConverterTest,
289     ::testing::Combine(
290         ::testing::Values(
291             // Dynamic shape mode test with conver_to_static_engine=true.
292             TestParam{TfTrtConversionParams{
293                           1 << 20,  // max workspace size
294                           TrtPrecisionMode::FP32,
295                           3,      // minimum_segment_size
296                           1,      // max_cached_engines
297                           false,  // use_calibration
298                           true,   // use_dynamic_shape
299                           ProfileStrategy::kOptimal,
300                           true,  // allow_build_at_runtime
301                           true   // convert_to_static_engine
302                       },
303                       {{1, 2}, {4, 2}}},
304             // Implicit batch mode test with conver_to_static_engine=true.
305             TestParam{TfTrtConversionParams{
306                           1 << 20,  // max workspace size
307                           TrtPrecisionMode::FP16,
308                           3,      // minimum_segment_size
309                           1,      // max_cached_engines
310                           false,  // use_calibration
311                           false,  // use_dynamic_shape
312                           ProfileStrategy::kRange,
313                           true,  // allow_build_at_runtime
314                           true   // convert_to_static_engine
315                       },
316                       {{1, 2}}},
317             // Dynamic shape mode test convert_to_static_engine=false: we cannot
318             // save the engines, therefore we do not generate profiles. A single
319             // engine will be built during runtime, with profile that matches
320             // the first shape ({1,2}). The second shape will run as native
321             // segment.
322             TestParam{TfTrtConversionParams{
323                           1 << 20,  // max workspace size
324                           TrtPrecisionMode::FP32,
325                           3,      // minimum_segment_size
326                           1,      // max_cached_engines
327                           false,  // use_calibration
328                           true,   // use_dynamic_shape
329                           ProfileStrategy::kOptimal,
330                           true,  // allow_build_at_runtime
331                           false  // convert_to_static_engine
332                       },
333                       {{1, 2}, {4, 2}}},
334             // Implicit batch mode test with convert_to_static_engine=false.
335             // We will have two engines in the cache to handle the two shapes.
336             TestParam{TfTrtConversionParams{
337                           1 << 20,  // max workspace size
338                           TrtPrecisionMode::FP16,
339                           3,      // minimum_segment_size
340                           2,      // max_cached_engines
341                           false,  // use_calibration
342                           false,  // use_dynamic_shape
343                           ProfileStrategy::kRange,
344                           true,  // allow_build_at_runtime
345                           false  // convert_to_static_engine
346                       },
347                       {{1, 2}, {4, 2}}}),
348         ::testing::Values(false, true),    // use_variables
349         ::testing::Values(false, true)));  // use_function
350 
TEST_P(TrtConverterTest,Basic)351 TEST_P(TrtConverterTest, Basic) {
352   if (use_variable_) {
353     ConvertAndRunSavedModel();
354   } else {
355     ConvertAndRunFrozenGraph();
356   }
357 }
358 
359 }  // namespace tensorrt
360 }  // namespace tensorflow
361 
362 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
363