xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <type_traits>
25 #include <unordered_map>
26 #include <vector>
27 
28 #if GOOGLE_CUDA && GOOGLE_TENSORRT
29 
30 #include <gmock/gmock.h>
31 #include <gtest/gtest.h>
32 #include "absl/algorithm/container.h"
33 #include "absl/base/call_once.h"
34 #include "absl/container/inlined_vector.h"
35 #include "absl/strings/match.h"
36 #include "absl/strings/numbers.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/str_format.h"
39 #include "absl/strings/string_view.h"
40 #include "absl/types/span.h"
41 #include "third_party/eigen3/Eigen/Core"
42 #include "third_party/gpus/cuda/include/cuda.h"
43 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
44 #include "tensorflow/cc/framework/ops.h"
45 #include "tensorflow/cc/framework/scope.h"
46 #include "tensorflow/cc/ops/nn_ops_internal.h"
47 #include "tensorflow/cc/ops/standard_ops.h"
48 #include "tensorflow/compiler/tf2tensorrt/common/datavec.h"
49 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
50 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
51 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
52 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
53 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
54 #include "tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h"
55 #include "tensorflow/core/common_runtime/device_mgr.h"
56 #include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
57 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
58 #include "tensorflow/core/framework/allocator.h"
59 #include "tensorflow/core/framework/device_factory.h"
60 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
61 #include "tensorflow/core/framework/resource_var.h"
62 #include "tensorflow/core/framework/tensor.h"
63 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
64 #include "tensorflow/core/framework/tensor_shape.h"
65 #include "tensorflow/core/framework/tensor_testutil.h"
66 #include "tensorflow/core/framework/types.h"
67 #include "tensorflow/core/grappler/costs/graph_properties.h"
68 #include "tensorflow/core/kernels/variable_ops.h"
69 #include "tensorflow/core/lib/core/status.h"
70 #include "tensorflow/core/lib/core/status_test_util.h"
71 #include "tensorflow/core/lib/strings/str_util.h"
72 #include "tensorflow/core/lib/strings/strcat.h"
73 #include "tensorflow/core/platform/protobuf.h"
74 #include "tensorflow/core/platform/status_matchers.h"
75 #include "tensorflow/core/platform/test.h"
76 #include "tensorflow/core/platform/threadpool.h"
77 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
78 #include "tensorflow/core/public/session.h"
79 #include "tensorflow/core/public/version.h"
80 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
81 #include "third_party/tensorrt/NvInfer.h"
82 
83 namespace tensorflow {
84 namespace tensorrt {
85 
86 // TensorRT modes for testing. We define the following three modes:
87 // 1. Implicit batch mode: The tensors have static (known) input shape and the
88 //    the batch dimension (first dim) is removed from the TRT tensor shape. In
89 //    a loose notation: trt_shape = tf_shape[1:].
90 // 2. Explicit batch mode: static (known) input shape, but the batch dimension
91 //    is part of the trt tensor shape. (trt_shape = tf_shape)
92 // 3. Dynamic shape mode allows unknown input shapes, and requires explicit
93 //    batch size definition (trt_shape = tf_shape).
94 //
95 // Note that the Converter only distinguishes between two modes:
96 // - use_implicit_batch == true, this corresponds to kImplicitBatch,
97 // - use_implicit_batch == false which includes both kExplicitBatch and
98 //   kDynamicShape.
99 //
100 // For the converter, the distinction between explicit batch or dynamic shape
101 // mode follows from the input tensors of the network: dynamic shape input
102 // implies dynamic shape mode, while static shape input tensors imply explicit
103 // batch mode. We want to test all these modes, therefore we define the
104 // TrtTestMode with the following three options.
105 enum class TrtTestMode {
106   kImplicitBatch = 0,
107   kExplicitBatch = 1,
108   kDynamicShape = 2
109 };
110 
DebugString(const TrtTestMode mode)111 string DebugString(const TrtTestMode mode) {
112   switch (mode) {
113     case TrtTestMode::kImplicitBatch:
114       return "kImplicitBatch";
115     case TrtTestMode::kExplicitBatch:
116       return "kExplicitBatch";
117     case TrtTestMode::kDynamicShape:
118       return "kDynamicShape";
119     default:
120       return "Invalid TrtTestMode";
121   }
122 }
123 
124 namespace convert {
125 
126 using absl::StrCat;
127 using ::testing::ElementsAre;
128 using ::testing::ElementsAreArray;
129 using ::testing::HasSubstr;
130 using ::testing::Matcher;
131 using ::testing::PrintToString;
132 
133 using ::tensorflow::testing::IsOk;
134 using ::tensorflow::testing::StatusIs;
135 
136 constexpr std::array<TrtTestMode, 3> ValidTrtModes = {
137     TrtTestMode::kImplicitBatch, TrtTestMode::kExplicitBatch,
138     TrtTestMode::kDynamicShape};
139 
TrtShapedWeightsEquals(const TRT_ShapedWeights & lhs,const TRT_ShapedWeights & rhs)140 bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
141                             const TRT_ShapedWeights& rhs) {
142   return lhs.Shape() == rhs.Shape() && lhs.TrtDType() == rhs.TrtDType() &&
143          lhs.GetPointer<int8>() == rhs.GetPointer<int8>();
144 }
145 
146 template <typename T>
ValidateWeights(const TRT_ShapedWeights & weights,const std::vector<int> & expected_dims,const std::vector<T> & expected_value)147 void ValidateWeights(const TRT_ShapedWeights& weights,
148                      const std::vector<int>& expected_dims,
149                      const std::vector<T>& expected_value) {
150   EXPECT_EQ(weights.Shape(), DimsAdapter(expected_dims));
151   ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString();
152   const T* actual_values = weights.GetPointer<T>();
153   for (int i = 0; i < expected_value.size(); ++i) {
154     EXPECT_EQ(expected_value[i], actual_values[i]);
155   }
156 }
157 
158 // TRT >= 8.2 optimizes memory management in the builder. When all builders
159 // are destroyed, it unloads many resources. This test fixture will create and
160 // destroy hundreds of builders when run sequentially for parameterized
161 // tests. We can hold open an IBuilder in order to prevent TRT from unloading
162 // shared resources between engine builds when using TRT shared library. This
163 // greatly speeds up unit tests and is safe to do.
PreventUnloadBuilderResources()164 void PreventUnloadBuilderResources() {
165 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
166   static thread_local absl::once_flag once;
167   static TrtUniquePtrType<nvinfer1::IBuilder> hold_builder = nullptr;
168   absl::call_once(
169       once,
170       [](TrtUniquePtrType<nvinfer1::IBuilder>& builder) {
171         if (!builder) {
172           builder.reset(nvinfer1::createInferBuilder(*Logger::GetLogger()));
173         }
174       },
175       hold_builder);
176 #endif
177 }
178 
TEST(TRT_ShapedWeights_Test,Basic)179 TEST(TRT_ShapedWeights_Test, Basic) {
180   // Test constructor with no arguments.
181   {
182     TRT_ShapedWeights weights;
183     TRT_ShapedWeights copy(weights);
184     for (auto ptr : {&weights, &copy}) {
185       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
186       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
187       EXPECT_EQ(nullptr, trt_weights.values);
188       EXPECT_EQ(0, trt_weights.count);
189 
190       EXPECT_EQ(nullptr, ptr->GetPointer<int8>());
191       EXPECT_EQ(0, ptr->count());
192       EXPECT_EQ(0, ptr->size_bytes());
193     }
194   }
195   // Test constructor with DataType argument.
196   {
197     TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT);
198     TRT_ShapedWeights copy(weights);
199     for (auto ptr : {&weights, &copy}) {
200       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
201       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
202       EXPECT_EQ(nullptr, trt_weights.values);
203       EXPECT_EQ(0, trt_weights.count);
204 
205       EXPECT_EQ(nullptr, ptr->GetPointer<int8>());
206       EXPECT_EQ(0, ptr->count());
207       EXPECT_EQ(0, ptr->size_bytes());
208     }
209   }
210   // Test constructor with DataType and nvinfer1::Dims arguments.
211   {
212     TrtWeightStore store;
213     TRT_ShapedWeights weights =
214         store.GetTempWeights(nvinfer1::DataType::kFLOAT, CreateDims({2, 5}))
215             .ValueOrDie();
216     TRT_ShapedWeights copy(weights);
217     for (auto ptr : {&weights, &copy}) {
218       nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
219       EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
220       EXPECT_NE(nullptr, trt_weights.values);
221       EXPECT_EQ(10, trt_weights.count);
222 
223       EXPECT_EQ(trt_weights.values, ptr->GetPointer<int8>());
224       EXPECT_EQ(10, ptr->count());
225       EXPECT_EQ(40, ptr->size_bytes());
226     }
227     // Test that it doesn't copy the underlying buffer.
228     EXPECT_EQ(weights.GetPointer<int8>(), copy.GetPointer<int8>());
229   }
230 }
231 
TEST(TRT_TensorOrWeights_Test,Basic)232 TEST(TRT_TensorOrWeights_Test, Basic) {
233   // Test constructor with no arguments.
234   {
235     TRT_TensorOrWeights tw;
236     TRT_TensorOrWeights copy(tw);
237     TRT_TensorOrWeights assigned;
238     assigned = tw;
239     for (auto ptr : {&tw, &copy, &assigned}) {
240       EXPECT_EQ(false, ptr->is_tensor());
241       EXPECT_EQ(false, ptr->is_weights());
242       EXPECT_EQ(-1, ptr->batch_size());
243     }
244   }
245 
246   // Test constructor with ITensor and batch size argument.
247   {
248     nvinfer1::Dims dims;
249     dims.nbDims = 1;
250     dims.d[0] = 1;
251     ITensorProxyPtr itensor(dims);
252     TRT_TensorOrWeights tw(itensor);
253     TRT_TensorOrWeights tw1(itensor, /*batch_size=*/1);
254 
255     for (auto original_ptr : {&tw, &tw1}) {
256       TRT_TensorOrWeights copy(*original_ptr);
257       TRT_TensorOrWeights assigned;
258       assigned = *original_ptr;
259 
260       for (auto ptr : {original_ptr, &copy, &assigned}) {
261         ASSERT_TRUE(ptr->is_tensor());
262         EXPECT_EQ(false, ptr->is_weights());
263         if (original_ptr == &tw) {
264           EXPECT_EQ(-1, ptr->batch_size());
265         } else {
266           EXPECT_EQ(1, ptr->batch_size());
267         }
268         EXPECT_EQ(itensor->simple_tensor(), ptr->tensor()->simple_tensor());
269         EXPECT_THAT(ptr->GetTrtDims(), DimsAreArray({1}));
270       }
271     }
272   }
273   // Test constructor which creates and owns an ITensor.
274   {
275     nvinfer1::Dims dims;
276     dims.nbDims = 1;
277     dims.d[0] = 1;
278     TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1);
279     TRT_TensorOrWeights copy(tw);
280     TRT_TensorOrWeights assigned;
281     assigned = tw;
282 
283     for (auto ptr : {&tw, &copy, &assigned}) {
284       ASSERT_TRUE(ptr->is_tensor());
285       EXPECT_EQ(false, ptr->is_weights());
286       EXPECT_EQ(1, ptr->batch_size());
287       EXPECT_NE(nullptr, ptr->tensor()->simple_tensor());
288       EXPECT_THAT(ptr->GetTrtDims(), DimsAreArray({1}));
289     }
290   }
291   // Test constructor with TRT_ShapedWeights argument.
292   {
293     TRT_ShapedWeights weights;
294     TRT_TensorOrWeights tw(weights);
295     TRT_TensorOrWeights copy(tw);
296     TRT_TensorOrWeights assigned;
297     assigned = tw;
298     for (auto ptr : {&tw, &copy, &assigned}) {
299       EXPECT_EQ(false, ptr->is_tensor());
300       EXPECT_EQ(true, ptr->is_weights());
301       EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights()));
302       std::vector<int> empty_dims;
303       EXPECT_THAT(ptr->GetTrtDims(), DimsAreArray(empty_dims));
304     }
305   }
306 }
307 
308 class ValidatorTest : public ::testing::Test {
309  public:
ValidatorTest()310   ValidatorTest() { PreventUnloadBuilderResources(); }
ConvertToTensorOrWeights(const Scope & scope,const Node * node,int output_port,TRT_TensorOrWeights * tensor_or_weights)311   Status ConvertToTensorOrWeights(const Scope& scope, const Node* node,
312                                   int output_port,
313                                   TRT_TensorOrWeights* tensor_or_weights) {
314     grappler::GrapplerItem item;
315     TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
316     grappler::GraphProperties graph_properties(item);
317     TF_EXPECT_OK(graph_properties.InferStatically(true));
318 
319     TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32,
320                                /*use_calibration=*/false,
321                                /*use_implicit_batch=*/true,
322                                /*use_explicit_precision=*/false);
323     return validator.ConvertToTensorOrWeights(node->def(), output_port,
324                                               tensor_or_weights);
325   }
326 };
327 
TEST_F(ValidatorTest,ConvertToTensorOrWeights)328 TEST_F(ValidatorTest, ConvertToTensorOrWeights) {
329   // Convert Const.
330   {
331     Scope s = Scope::NewRootScope();
332     auto node =
333         ops::Const(s.WithOpName("my_const"), {1.0f, 2.0f}, TensorShape({2}));
334     TRT_TensorOrWeights output;
335     EXPECT_THAT(ConvertToTensorOrWeights(s, node.op().node(),
336                                          /*output_port=*/0, &output),
337                 IsOk());
338     ValidateWeights<float>(output.weights(), {2}, {1.0, 2.0});
339   }
340 
341   // Helper method to run ConvertToTensorOrWeights() with predefined parameters.
342   auto convert_to_tensor_or_weights = [this](const std::vector<int64_t>& dims,
343                                              TRT_TensorOrWeights* output) {
344     Scope s = Scope::NewRootScope();
345     const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims});
346     auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs);
347     auto add = ops::Add(s.WithOpName("add"), feed, feed);
348     return this->ConvertToTensorOrWeights(s, add.operation.node(),
349                                           /*output_port=*/0, output);
350   };
351   // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1.
352   {
353     TRT_TensorOrWeights output;
354     EXPECT_THAT(
355         convert_to_tensor_or_weights(
356             std::vector<int64_t>(nvinfer1::Dims::MAX_DIMS + 2, 1), &output),
357         StatusIs(error::OUT_OF_RANGE,
358                  HasSubstr("Input tensor rank is greater than 9")));
359   }
360   // Convert non-Const with #dims < 1.
361   {
362     TRT_TensorOrWeights output;
363     EXPECT_THAT(convert_to_tensor_or_weights({}, &output),
364                 StatusIs(error::INVALID_ARGUMENT,
365                          HasSubstr("Scalar input tensor is not supported since "
366                                    "the first dimension "
367                                    "is treated as batch dimension by TRT")));
368   }
369   // Convert non-Const. We test the case where the non-batch dimension is
370   // unknown as well, to make sure the validator allows that.
371   for (const int32 non_batch_dim : {-1, 2}) {
372     const int32 batch_size = 12;
373     TRT_TensorOrWeights output;
374     EXPECT_THAT(
375         convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output),
376         IsOk());
377     ASSERT_TRUE(output.is_tensor());
378     EXPECT_EQ(batch_size, output.batch_size());
379     EXPECT_NE(nullptr, output.tensor()->simple_tensor());
380     EXPECT_THAT(output.GetTrtDims(), DimsAreArray({non_batch_dim}));
381   }
382 }
383 
TEST_F(ValidatorTest,IsTensorRTCandidate_Basics)384 TEST_F(ValidatorTest, IsTensorRTCandidate_Basics) {
385   Scope s = Scope::NewRootScope();
386   auto input =
387       ops::Const(s.WithOpName("const"), {1.0f, 2.0f}, TensorShape({2}));
388   auto add = ops::Add(s.WithOpName("add"), input, input);
389   const Node* add_node = add.operation.node();
390 
391   grappler::GrapplerItem item;
392   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
393   grappler::GraphProperties graph_properties(item);
394   TF_EXPECT_OK(graph_properties.InferStatically(true));
395   TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32,
396                              /*use_calibration=*/false,
397                              /*use_implicit_batch=*/true,
398                              /*use_explicit_precision=*/false);
399 
400   // Override the Add converter.
401   bool start_conversion = false;
402   bool should_fail = false;
403   auto op_converter = [&start_conversion,
404                        &should_fail](OpConverterParams* params) -> Status {
405     if (should_fail) return errors::InvalidArgument("");
406     if (!params->validation_only) start_conversion = true;
407     return Status::OK();
408   };
409 
410   // Validator not registered.
411   auto original_op_converter = GetOpConverterRegistry()->LookUp("Add");
412   ASSERT_TRUE(original_op_converter.ok());
413   GetOpConverterRegistry()->Clear("Add");
414   EXPECT_THAT(validator.IsTensorRTCandidate(add_node),
415               StatusIs(error::UNIMPLEMENTED,
416                        HasSubstr("Op type Add is not supported.")));
417   GetOpConverterRegistry()->Register("Add", kDefaultConverterPriority + 1,
418                                      op_converter);
419   TF_EXPECT_OK(validator.IsTensorRTCandidate(add_node));
420   EXPECT_EQ(false, start_conversion);
421 
422   // Let the converter return error.
423   should_fail = true;
424   EXPECT_THAT(validator.IsTensorRTCandidate(add_node),
425               StatusIs(error::INVALID_ARGUMENT));
426   GetOpConverterRegistry()->Clear("Add");
427   GetOpConverterRegistry()->Register("Add", kDefaultConverterPriority,
428                                      *original_op_converter);
429 }
430 
TEST(TrtNodeValidator,IsTensorRTCandidate)431 TEST(TrtNodeValidator, IsTensorRTCandidate) {
432   // Create a graph containing both TRT-compatible and TRT-incompatible nodes
433   // and use it to test TrtNodeValidator::IsTensorRTCandidate().
434   const std::vector<int32> input_shape_array{2, 2};
435   TensorShape input_shape;
436   TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape));
437 
438   Scope s = Scope::NewRootScope();
439   ops::Placeholder::Attrs feed_attrs;
440   TF_EXPECT_OK(
441       TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_));
442 
443   // Compatible input.
444   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs);
445   auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape);
446 
447   // Compatible MatMul.
448   auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1);
449 
450   // Incompatible MatMul.
451   ops::MatMul::Attrs matmul_attrs;
452   matmul_attrs.transpose_a_ = true;
453   auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"),
454                                          feed, const_1, matmul_attrs);
455 
456   // Unsupported op.
457   auto unsupported_op = ops::Erfc(s.WithOpName("sin"), feed);
458 
459   // Incompatible input.
460   auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE);
461   auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape);
462   // Compatible op with incompatible input.
463   auto matmul_with_incompatible_input =
464       ops::MatMul(s.WithOpName("matmul_with_incompatible_input"),
465                   incompatible_feed, const_2);
466 
467   // Quantize ops.
468   auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
469   auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed,
470                                                quantize_attrs);
471 
472   // Get GrapplerItem and GraphProperties.
473   grappler::GrapplerItem item;
474   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
475   Tensor feed_tensor(DT_FLOAT, input_shape);
476   item.feed.push_back(std::make_pair("feed", feed_tensor));
477   grappler::GraphProperties graph_properties(item);
478   TF_EXPECT_OK(graph_properties.InferStatically(true));
479 
480   for (const TrtPrecisionMode precision_mode :
481        {TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) {
482     TrtNodeValidator validator(graph_properties, precision_mode,
483                                /*use_calibration=*/false,
484                                /*use_implicit_batch=*/true,
485                                /*use_explicit_precision=*/false);
486     TF_EXPECT_OK(validator.IsTensorRTCandidate(matmul.operation.node()));
487     EXPECT_THAT(
488         validator.IsTensorRTCandidate(incompatible_matmul.operation.node()),
489         StatusIs(error::INVALID_ARGUMENT,
490                  HasSubstr("MatMul with 2D tensors requires explicit batch "
491                            "mode, or that tensor A "
492                            "is not transposed and B is a constant tensor.")));
493     EXPECT_THAT(validator.IsTensorRTCandidate(unsupported_op.operation.node()),
494                 StatusIs(error::UNIMPLEMENTED,
495                          HasSubstr("Op type Erfc is not supported")));
496     EXPECT_THAT(validator.IsTensorRTCandidate(
497                     matmul_with_incompatible_input.operation.node()),
498                 StatusIs(error::INTERNAL,
499                          HasSubstr("Failed to convert at least one input to a "
500                                    "TRT_TensorOrWeights:")));
501     if (precision_mode == TrtPrecisionMode::INT8) {
502       TF_EXPECT_OK(validator.IsTensorRTCandidate(quantize.operation.node()));
503     } else {
504       EXPECT_THAT(
505           validator.IsTensorRTCandidate(quantize.operation.node()),
506           StatusIs(
507               error::UNIMPLEMENTED,
508               HasSubstr("Op type FakeQuantWithMinMaxArgs is not supported")));
509     }
510   }
511 }
512 
513 class ConverterTest : public ::testing::Test {
514  public:
ConverterTest()515   ConverterTest() {
516     PreventUnloadBuilderResources();
517     Reset();
518   }
519 
Reset()520   void Reset() {
521     GetOpConverterRegistry()->Clear("MyOp");
522     GetOpConverterRegistry()->Clear("DummyOp");
523     converter_ =
524         std::move(Converter::Create(TrtPrecisionMode::FP32,
525                                     /*use_calibration=*/false, &logger_,
526                                     /*use_implicit_batch=*/true,
527                                     /*engine_name=*/"TRTEngineOp_000_000",
528                                     /*use_explicit_precision=*/false)
529                       .ValueOrDie());
530     weight_store_ = &converter_->weight_store_;
531   }
532 
533   // TODO(cbate): These should be removed or changed to public per black-box
534   // testing principle.
535   // Below we expose private methods of Converter for testing.
MaybeUpdateBatchSize(int batch_size)536   Status MaybeUpdateBatchSize(int batch_size) {
537     return converter_->MaybeUpdateBatchSize(batch_size);
538   }
539 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)540   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input) {
541     return converter_->AddTensorOrWeights(name, input);
542   }
543 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)544   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
545     return converter_->GetTensorOrWeights(name, output);
546   }
547 
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const548   Status GetInputs(const NodeDef& node_def,
549                    std::vector<TRT_TensorOrWeights>* inputs) const {
550     return converter_->GetInputs(node_def, inputs);
551   }
552 
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const553   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
554                         float* out_max) const {
555     return converter_->GetWeightRange(weights, out_min, out_max);
556   }
557 
batch_size() const558   int batch_size() const { return converter_->batch_size_; }
559 
quantization_ranges_proxy()560   std::unordered_map<ITensorProxyPtr*, float>& quantization_ranges_proxy() {
561     return converter_->quantization_ranges_proxy_;
562   }
563 
quantization_ranges()564   std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
565     return converter_->quantization_ranges_;
566   }
567 
568  private:
569   Logger& logger_ = *Logger::GetLogger();
570 
571  protected:
572   std::unique_ptr<Converter> converter_;
573   TrtWeightStore* weight_store_;
574 };
575 
TEST_F(ConverterTest,ConvertNode)576 TEST_F(ConverterTest, ConvertNode) {
577   ITensorProxyPtr output_tensors[2];
578   auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
579     nvinfer1::Dims dims = params->inputs[0].tensor()->getDimensions();
580     for (int i = 0; i < 2; ++i) {
581       dims.d[0] += 1;
582       output_tensors[i]->setDimensions(dims);
583       params->outputs->push_back(TRT_TensorOrWeights(output_tensors[i]));
584     }
585     return Status::OK();
586   };
587   NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
588 
589   TF_ASSERT_OK(converter_->AddInputTensor(
590       "my_input", nvinfer1::DataType::kFLOAT, CreateDims({123}), 1));
591 
592   // Converter not registered.
593   EXPECT_THAT(
594       converter_->ConvertNode(node_def),
595       StatusIs(error::NOT_FOUND, HasSubstr("No converter for op MyOp")));
596 
597   // Register the converter and retry.
598   GetOpConverterRegistry()->Register("MyOp", kDefaultConverterPriority,
599                                      op_converter);
600   TF_ASSERT_OK(converter_->ConvertNode(node_def));
601 
602   TRT_TensorOrWeights actual_output_1;
603   TF_EXPECT_OK(GetTensorOrWeights("my_op", &actual_output_1));
604   EXPECT_EQ(output_tensors[0]->simple_tensor(),
605             actual_output_1.tensor()->simple_tensor());
606   EXPECT_EQ(124, actual_output_1.tensor()->getDimensions().d[0]);
607 
608   TRT_TensorOrWeights actual_output_2;
609   TF_EXPECT_OK(GetTensorOrWeights("my_op:1", &actual_output_2));
610   EXPECT_EQ(output_tensors[1]->simple_tensor(),
611             actual_output_2.tensor()->simple_tensor());
612   EXPECT_EQ(125, actual_output_2.tensor()->getDimensions().d[0]);
613 
614   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
615 }
616 
TEST_F(ConverterTest,AddAndGetInputs)617 TEST_F(ConverterTest, AddAndGetInputs) {
618   NodeDef node_def;
619   node_def.add_input("^control_input");
620   node_def.add_input("input");
621   node_def.add_input("input:0");
622   node_def.add_input("input:1");
623   node_def.add_input("weird_input:2:3:4:0");
624 
625   TF_EXPECT_OK(converter_->AddInputTensor("input", nvinfer1::DataType::kFLOAT,
626                                           CreateDims({1}), 1));
627   TF_EXPECT_OK(converter_->AddInputTensor("input:1", nvinfer1::DataType::kINT32,
628                                           CreateDims({2, 3}), 1));
629   TF_EXPECT_OK(converter_->AddInputTensor(
630       "weird_input:2:3:4", nvinfer1::DataType::kHALF, CreateDims({5, 3}), 1));
631 
632   std::vector<TRT_TensorOrWeights> inputs;
633   TF_EXPECT_OK(GetInputs(node_def, &inputs));
634 
635   EXPECT_EQ(4, inputs.size());
636   EXPECT_EQ(inputs[0].tensor()->trt_tensor(), inputs[1].tensor()->trt_tensor());
637 
638   EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType());
639   EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType());
640   EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType());
641   EXPECT_THAT(inputs[0].tensor()->getDimensions(), DimsAreArray({1}));
642   EXPECT_THAT(inputs[2].tensor()->getDimensions(), DimsAreArray({2, 3}));
643   EXPECT_THAT(inputs[3].tensor()->getDimensions(), DimsAreArray({5, 3}));
644 
645   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
646 }
647 
TEST_F(ConverterTest,RenameAndMarkOutputTensors)648 TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
649   // Test that the tensor are actually named and marked as output after
650   // Converter::RenameAndMarkOutputTensors() is called.
651 
652   // Register a custom converter which shuffles the input. We use it to build a
653   // TRT network whose output will be later marked.
654   std::vector<ITensorProxyPtr> output_tensors;
655   auto op_converter = [&output_tensors](OpConverterParams* params) -> Status {
656     nvinfer1::Permutation perm;
657     perm.order[0] = 1;
658     perm.order[1] = 0;
659     for (int i = 0; i < 2; ++i) {
660       ITensorProxyPtr input_tensor = params->inputs[0].tensor();
661       nvinfer1::IShuffleLayer* layer =
662           params->converter->network()->addShuffle(*input_tensor->trt_tensor());
663       layer->setFirstTranspose(perm);
664       ITensorProxyPtr output_tensor = layer->getOutput(0);
665       params->outputs->emplace_back(output_tensor);
666       output_tensors.push_back(output_tensor);
667     }
668     TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT);
669     params->outputs->emplace_back(output_weights);
670     return Status::OK();
671   };
672   GetOpConverterRegistry()->Register("MyOp", kDefaultConverterPriority,
673                                      op_converter);
674 
675   // Run the conversion.
676   NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
677   TF_EXPECT_OK(converter_->AddInputTensor(
678       "my_input", nvinfer1::DataType::kFLOAT, CreateDims({1, 2}), 1));
679   TF_EXPECT_OK(converter_->ConvertNode(node_def));
680 
681   // Mark a weight as output, should fail.
682   EXPECT_THAT(
683       converter_->RenameAndMarkOutputTensors({{"my_op:2", "my_output"}}),
684       StatusIs(error::INVALID_ARGUMENT,
685                HasSubstr("Output my_op:2 is weights not tensor")));
686 
687   // Mark tensors as output, should pass.
688   TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(
689       {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}}));
690   EXPECT_EQ(2, output_tensors.size());
691   for (auto output_tensor : output_tensors) {
692     EXPECT_THAT(output_tensor->getDimensions(), DimsAreArray({2, 1}));
693   }
694   EXPECT_EQ("my_output", string(output_tensors[0]->getName()));
695   EXPECT_EQ("my_output_1", string(output_tensors[1]->getName()));
696 
697   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
698 }
699 
TEST_F(ConverterTest,TransposeTensor)700 TEST_F(ConverterTest, TransposeTensor) {
701   ITensorProxyPtr input_tensor = converter_->network()->addInput(
702       "", nvinfer1::DataType::kFLOAT, CreateDims({2, 3, 5}));
703   ITensorProxyPtr output_tensor = nullptr;
704   NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {});
705   // Rank doesn't match.
706   EXPECT_THAT(converter_->TransposeTensor(input_tensor, {0, 1}, &output_tensor,
707                                           dummy_node_def, "sub1"),
708               StatusIs(error::INVALID_ARGUMENT,
709                        HasSubstr("Rank of perm for transpose does not match "
710                                  "with that of the input")));
711 
712   // Transpose at batch dimension.
713   EXPECT_THAT(
714       converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor,
715                                   dummy_node_def, "sub2"),
716       StatusIs(error::UNIMPLEMENTED,
717                HasSubstr("Transpose at batch dimension is not supported.")));
718 
719   // OK.
720   TF_EXPECT_OK(converter_->TransposeTensor(
721       input_tensor, {0, 3, 1, 2}, &output_tensor, dummy_node_def, "sub3"));
722   EXPECT_THAT(output_tensor->getDimensions(), DimsAreArray({5, 2, 3}));
723   EXPECT_THAT(
724       converter_->network(),
725       LayerNamesAreArray({"TRTEngineOp_000_000/dummy_op-sub3:SHUFFLE"}));
726 }
727 
TestPrepareTensorForShape(const std::vector<int> & input_dims,const std::vector<int> & reshape_dims,const std::vector<int> & expected_tensor_dims,bool input_is_tensor,Converter * converter,TrtWeightStore * weight_store,error::Code expected_code=error::OK,const char * expected_error_msg_substr=nullptr)728 void TestPrepareTensorForShape(
729     const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
730     const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
731     Converter* converter, TrtWeightStore* weight_store,
732     error::Code expected_code = error::OK,
733     const char* expected_error_msg_substr = nullptr) {
734   TRT_TensorOrWeights input;
735   if (input_is_tensor) {
736     input = TRT_TensorOrWeights(converter->network()->addInput(
737         "", nvinfer1::DataType::kFLOAT, CreateDims(input_dims)));
738   } else {
739     input = TRT_TensorOrWeights(
740         weight_store
741             ->GetTempWeights(nvinfer1::DataType::kFLOAT, CreateDims(input_dims))
742             .ValueOrDie());
743   }
744   ITensorProxyPtr output_tensor = nullptr;
745 
746   NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {});
747   for (bool validation_only : {false, true}) {
748     const Status status =
749         PrepareTensorForShape(converter, input, DimsAdapter(reshape_dims),
750                               validation_only, &output_tensor, dummy_node_def);
751     if (expected_code == error::OK) {
752       TF_EXPECT_OK(status);
753       if (validation_only) {
754         EXPECT_EQ(nullptr, *output_tensor);
755       } else {
756         EXPECT_THAT(output_tensor->getDimensions(),
757                     DimsAreArray(expected_tensor_dims));
758       }
759     } else {
760       EXPECT_THAT(status, StatusIs(expected_code,
761                                    HasSubstr(expected_error_msg_substr)));
762     }
763   }
764 }
765 
TEST_F(ConverterTest,PrepareTensorForShape)766 TEST_F(ConverterTest, PrepareTensorForShape) {
767   for (bool input_is_tensor : {true, false}) {
768     // Shape size doesn't match.
769     Reset();
770     TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
771                               converter_.get(), weight_store_,
772                               error::INVALID_ARGUMENT, "Incompatible shapes");
773 
774     // Regular shape.
775     Reset();
776     TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
777                               converter_.get(), weight_store_);
778 
779     // Reshape to zero rank.
780     Reset();
781     TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
782                               weight_store_);
783   }
784 
785   // Tensor input with zero rank.
786   Reset();
787   TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
788                             converter_.get(), weight_store_);
789 
790   // TODO(aaroey): we should check the case where uninferred dimensions are
791   // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
792 
793   // Infer tensor shape, ok.
794   Reset();
795   TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
796                             /*input_is_tensor=*/true, converter_.get(),
797                             weight_store_);
798 
799   // Infer weight shape, should fail.
800   Reset();
801   TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
802                             /*input_is_tensor=*/false, converter_.get(),
803                             weight_store_, error::INVALID_ARGUMENT,
804                             "Shape is not fully defined");
805 
806   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
807 }
808 
TEST_F(ConverterTest,MaybeUpdateBatchSize)809 TEST_F(ConverterTest, MaybeUpdateBatchSize) {
810   EXPECT_EQ(-1, batch_size());
811 
812   TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
813   EXPECT_EQ(-1, batch_size());
814 
815   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
816   EXPECT_EQ(123, batch_size());
817 
818   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
819   EXPECT_EQ(123, batch_size());
820 
821   TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
822   EXPECT_EQ(123, batch_size());
823 
824   EXPECT_THAT(
825       MaybeUpdateBatchSize(124),
826       StatusIs(error::INVALID_ARGUMENT,
827                HasSubstr(
828                    "Provided batch size does not match converter batch size")));
829 }
830 
TEST_F(ConverterTest,AddAndGetTensorOrWeights)831 TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
832   // Add a tensor.
833   ITensorProxyPtr simple_tensor;
834   TRT_TensorOrWeights tensor(simple_tensor);
835   EXPECT_EQ(-1, tensor.batch_size());
836   TF_EXPECT_OK(MaybeUpdateBatchSize(123));
837   TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor));
838 
839   // Get the added tensor.
840   TRT_TensorOrWeights added_tensor;
841   TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor));
842   EXPECT_EQ(123, added_tensor.batch_size());
843 
844   // Add the same tensor again.
845   EXPECT_THAT(AddTensorOrWeights("my_tensor", tensor),
846               StatusIs(error::ALREADY_EXISTS,
847                        HasSubstr("tensor/weights my_tensor already exist")));
848 }
849 
850 template <typename T>
TestGetWeightRange(ConverterTest * test,TrtWeightStore * weight_store)851 void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
852   nvinfer1::DataType trt_type;
853   TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type));
854   TRT_ShapedWeights weights =
855       weight_store->GetTempWeights(trt_type, CreateDims({2, 3})).ValueOrDie();
856   const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
857   absl::c_copy(values, weights.GetPointer<T>());
858   float out_min = 0.0f;
859   float out_max = 0.0f;
860   TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max));
861   EXPECT_EQ(1.0f, out_min);
862   EXPECT_EQ(6.0f, out_max);
863 }
864 
TEST_F(ConverterTest,GetWeightRange)865 TEST_F(ConverterTest, GetWeightRange) {
866   TestGetWeightRange<float>(this, weight_store_);
867   TestGetWeightRange<Eigen::half>(this, weight_store_);
868   TestGetWeightRange<int32>(this, weight_store_);
869 }
870 
TEST_F(ConverterTest,ProvideQuantizationRange)871 TEST_F(ConverterTest, ProvideQuantizationRange) {
872   ITensorProxyPtr simple_tensor;
873   // Asymmetric range
874   converter_->ProvideQuantizationRange(&simple_tensor, 0.0f, 6.0f);
875   EXPECT_EQ(6.0f, quantization_ranges_proxy()[&simple_tensor]);
876   converter_->ProvideQuantizationRange(&simple_tensor, 1.0f, 6.0f);
877   EXPECT_EQ(6.0f, quantization_ranges_proxy()[&simple_tensor]);
878   converter_->ProvideQuantizationRange(&simple_tensor, -8.0f, 6.0f);
879   EXPECT_EQ(8.0f, quantization_ranges_proxy()[&simple_tensor]);
880   converter_->ProvideQuantizationRange(&simple_tensor, -8.123f, -6.123f);
881   EXPECT_EQ(8.123f, quantization_ranges_proxy()[&simple_tensor]);
882   // Symmetric range
883   converter_->ProvideQuantizationRange(&simple_tensor, -6.123f, 6.123f);
884   EXPECT_EQ(6.123f, quantization_ranges_proxy()[&simple_tensor]);
885 
886   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
887 }
888 
TEST_F(ConverterTest,MaybeApplyQuantizationRanges)889 TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
890   ITensorProxyPtr input;
891   ITensorProxyPtr not_infer;
892   Logger& logger = *Logger::GetLogger();
893   auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
894                                           /*use_calibration=*/true, &logger,
895                                           /*use_implicit_batch=*/true,
896                                           /*engine_name=*/"")
897                             .ValueOrDie();
898   int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
899   int8_converter->ProvideQuantizationRange(&not_infer, -100.0f, 100.0f);
900 
901   int8_converter->MaybeApplyQuantizationRanges();
902   EXPECT_EQ(input->getDynamicRangeMax(), 5.0f);
903   EXPECT_EQ(not_infer->getDynamicRangeMax(), 100.0f);
904 
905   EXPECT_THAT(int8_converter->network(), LayerNamesNonEmpty());
906 }
907 
TEST_F(ConverterTest,GetTrtBroadcastShape)908 TEST_F(ConverterTest, GetTrtBroadcastShape) {
909   const bool kIsTensor = true;
910   const bool kIsNotTensor = false;
911   auto symmetric_test = [this](const std::vector<int>& operand_1_shape,
912                                const std::vector<int>& operand_2_shape,
913                                const bool operand_1_is_tensor,
914                                const bool operand_2_is_tensor,
915                                const std::vector<int>& expected_operand_1_shape,
916                                const std::vector<int>& expected_operand_2_shape,
917                                error::Code expected_code = error::OK,
918                                const char* expected_error_msg_substr = "",
919                                const int operand_1_batch_size = -1,
920                                const int operand_2_batch_size = -1) {
921     auto create_tensor_or_weights = [](const std::vector<int>& shape,
922                                        bool is_tensor, int batch_size = -1) {
923       if (is_tensor) {
924         return TRT_TensorOrWeights(nvinfer1::DataType::kFLOAT,
925                                    CreateDims(shape), batch_size);
926       }
927       TRT_ShapedWeights weights;
928       weights.Shape() = CreateDims(shape);
929       return TRT_TensorOrWeights(weights);
930     };
931 
932     nvinfer1::Dims operand_1_new_dims, operand_2_new_dims;
933     TRT_TensorOrWeights operand_1 = create_tensor_or_weights(
934         operand_1_shape, operand_1_is_tensor, operand_1_batch_size);
935     TRT_TensorOrWeights operand_2 = create_tensor_or_weights(
936         operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
937 
938     // operand_1 broadcast operand_2
939     EXPECT_THAT(
940         GetTrtBroadcastShape(operand_1, operand_2, /*check_feasibility=*/true,
941                              /*use_implicit_batch=*/true, &operand_1_new_dims,
942                              &operand_2_new_dims),
943         StatusIs(expected_code, HasSubstr(expected_error_msg_substr)));
944     if (expected_code == error::OK) {
945       EXPECT_THAT(operand_1_new_dims, DimsAreArray(expected_operand_1_shape));
946       EXPECT_THAT(operand_2_new_dims, DimsAreArray(expected_operand_2_shape));
947     }
948     // operand_2 broadcast operand_1
949     EXPECT_THAT(
950         GetTrtBroadcastShape(operand_2, operand_1, /*check_feasibility=*/true,
951                              /*use_implicit_batch=*/true, &operand_2_new_dims,
952                              &operand_1_new_dims),
953         StatusIs(expected_code, HasSubstr(expected_error_msg_substr)));
954     if (expected_code == error::OK) {
955       EXPECT_THAT(operand_1_new_dims, DimsAreArray(expected_operand_1_shape));
956       EXPECT_THAT(operand_2_new_dims, DimsAreArray(expected_operand_2_shape));
957     }
958   };
959 
960   // Both inputs are weights.
961   symmetric_test(
962       {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT,
963       "Broadcasting requires at least one of the operands be tensors");
964 
965   // One tensor and one weights.
966   symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2});
967   symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2});
968   symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1});
969   symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1},
970                  {1, 2, 3});
971   symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
972                  {2, 3, 4});
973   symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
974                  {2, 3, 4});
975   symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4},
976                  {2, 1, 4});
977   symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
978                  error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
979   symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
980                  error::INVALID_ARGUMENT, "Infeasible broadcast scheme",
981                  /*operand_1_batch_size=*/2);
982   symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
983                  error::INVALID_ARGUMENT,
984                  "Broadcasting beyond batch dimension is not supported "
985                  "(tensor #dims 4 vs broadcast #dims 5)");
986   symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {},
987                  error::INVALID_ARGUMENT,
988                  "Broadcasting beyond batch dimension is not supported "
989                  "(tensor #dims 2 vs broadcast #dims 3)",
990                  /*operand_1_batch_size=*/2);
991 
992   // Both inputs are tensors.
993   symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
994                  error::INVALID_ARGUMENT,
995                  "Broadcasting beyond batch dimension is not supported "
996                  "(tensor #dims 3 vs broadcast #dims 4)");
997   symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {},
998                  error::INVALID_ARGUMENT,
999                  "Broadcasting beyond batch dimension is not supported "
1000                  "(tensor #dims 2 vs broadcast #dims 3)");
1001   symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
1002                  {2, 1, 4});
1003   symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
1004                  error::INVALID_ARGUMENT,
1005                  "Broadcasting beyond batch dimension is not supported "
1006                  "(tensor #dims 4 vs broadcast #dims 5)");
1007   symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {},
1008                  error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
1009 
1010   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
1011 }
1012 
TEST_F(ConverterTest,CreateConstantLayer)1013 TEST_F(ConverterTest, CreateConstantLayer) {
1014   for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) {
1015     TRT_ShapedWeights weights =
1016         weight_store_->GetTempWeights(dtype, CreateDims({2, 3, 5}))
1017             .ValueOrDie();
1018     ITensorProxyPtr tensor =
1019         converter_->CreateConstantLayer(weights, CreateDims({3, 10}));
1020     ASSERT_NE(nullptr, tensor->trt_tensor());
1021     EXPECT_EQ(dtype, tensor->getType())
1022         << "Expected " << DebugString(dtype) << " vs. actual "
1023         << DebugString(tensor->getType());
1024     EXPECT_THAT(tensor->getDimensions(), DimsAreArray({3, 10}));
1025   }
1026 
1027   EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
1028 }
1029 
1030 class ConvertGraphDefToEngineTest : public ::testing::Test {
1031  public:
RunConvertGraphDefToEngine(Scope * s)1032   Status RunConvertGraphDefToEngine(Scope* s) {
1033     GraphDef gdef;
1034     TF_EXPECT_OK(s->ToGraphDef(&gdef));
1035     std::vector<PartialTensorShape> input_shapes;
1036     int batch_size = -1;
1037     for (const NodeDef& node : gdef.node()) {
1038       absl::string_view node_name(node.name());
1039       if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) {
1040         int port = -1;
1041         EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
1042         if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
1043         input_shapes[port] =
1044             PartialTensorShape(node.attr().at("shape").shape());
1045         if (batch_size == -1) {
1046           batch_size = input_shapes[port].dim_size(0);
1047         } else {
1048           EXPECT_EQ(batch_size, input_shapes[port].dim_size(0));
1049         }
1050       }
1051     }
1052     // TODO(laigd): execute the engine and get outputs.
1053     return ConvertGraphDefToEngine(
1054         gdef, /*ctx=*/nullptr, TrtPrecisionMode::FP32, /*max_batch_size=*/1,
1055         /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
1056         /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
1057         /*use_calibration=*/false, /*use_implicit_batch=*/true,
1058         /*convert_successfully=*/nullptr, /*profiles=*/nullptr,
1059         "TRTEngineOp_000_000", /*use_explicit_precision=*/false);
1060   }
1061 
1062  protected:
1063   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1064 
1065  private:
1066   Logger& logger_ = *Logger::GetLogger();
1067 };
1068 
TEST_F(ConvertGraphDefToEngineTest,IdentityGraph)1069 TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
1070   Scope s = Scope::NewRootScope();
1071   auto input =
1072       ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)),
1073                        DT_FLOAT, ops::Placeholder::Shape({1, 1}));
1074   auto output = ops::Identity(s.WithOpName("identity1"), input);
1075   output = ops::Identity(s.WithOpName("identity2"), output);
1076   output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)),
1077                          output);
1078   // If the converter marks the input tensor as output tensor, the conversion
1079   // below will fail with:
1080   // > TensorRTOutputPH_0 cannot be both input and output
1081   // > Network must have at least one output
1082   TF_EXPECT_OK(RunConvertGraphDefToEngine(&s));
1083 }
1084 
1085 // Returns a vector of shapes from a vector of input tensors. This can be used
1086 // to create optimization profiles.
GetShapeFromDataVec(DataVec input_data,std::vector<TensorShape> * shape_vec)1087 Status GetShapeFromDataVec(DataVec input_data,
1088                            std::vector<TensorShape>* shape_vec) {
1089   shape_vec->reserve(input_data.size());
1090   std::transform(input_data.begin(), input_data.end(),
1091                  std::back_inserter(*shape_vec),
1092                  [](InputOutputData x) { return x.tensor.shape(); });
1093   return Status::OK();
1094 }
1095 
1096 template <typename T>
GetSpanForData(const InputOutputData & data)1097 inline absl::Span<const T> GetSpanForData(const InputOutputData& data) {
1098   const auto& tensor_map = data.tensor.flat<T>();
1099   return absl::Span<const T>(tensor_map.data(), tensor_map.size());
1100 }
1101 
GetDataAsFloat(InputOutputData & data)1102 std::vector<float> GetDataAsFloat(InputOutputData& data) {
1103   const auto dType = data.tensor.dtype();
1104   if (dType == DT_FLOAT) {
1105     auto span = GetSpanForData<float>(data);
1106     return std::vector<float>(span.begin(), span.end());
1107   }
1108   if (dType == DT_HALF) {
1109     return CastVector<Eigen::half, float>(GetSpanForData<Eigen::half>(data));
1110   }
1111   if (dType == DT_INT32) {
1112     return CastVector<int32, float>(GetSpanForData<int32>(data));
1113   }
1114 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
1115   if (dType == DT_BOOL) {
1116     return CastVector<bool, float>(GetSpanForData<bool>(data));
1117   }
1118 #endif
1119   LOG(FATAL) << "DataType not supported for testing " << DataTypeString(dType);
1120   return {};
1121 }
1122 
1123 // Class to test various op converters, using both a TrtNodeValidator and
1124 // Converter.
1125 class OpConverterTest : public ::testing::Test {
1126  public:
OpConverterTest()1127   OpConverterTest()
1128       : tensor_buffer_allocator_(new GpuManagedAllocator()),
1129         scope_(Scope::NewRootScope()) {
1130     PreventUnloadBuilderResources();
1131     QCHECK_EQ(0, cudaStreamCreate(&stream_));
1132     Reset();
1133   }
1134 
~OpConverterTest()1135   ~OpConverterTest() noexcept override {
1136     QCHECK_EQ(0, cudaStreamDestroy(stream_));
1137   }
1138 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1139   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output) {
1140     return converter_->GetTensorOrWeights(name, output);
1141   }
1142 
Reset(TrtPrecisionMode precision_mode_to_test=TrtPrecisionMode::FP32,TrtTestMode trt_mode=TrtTestMode::kImplicitBatch,OpKernelContext * ctx=nullptr)1143   void Reset(TrtPrecisionMode precision_mode_to_test = TrtPrecisionMode::FP32,
1144              TrtTestMode trt_mode = TrtTestMode::kImplicitBatch,
1145              OpKernelContext* ctx = nullptr) {
1146     // Destroy existing TRT objects in a proper order.
1147     converter_.reset(nullptr);
1148     engine_.reset(nullptr);
1149 
1150     // Re-create them in proper order.
1151     converter_ =
1152         std::move(Converter::Create(precision_mode_to_test,
1153                                     /*use_calibration=*/false, &logger_,
1154                                     /*use_implicit_batch=*/trt_mode ==
1155                                         TrtTestMode::kImplicitBatch,
1156                                     /*engine_name=*/"",
1157                                     /*use_explicit_precision=*/false, ctx)
1158                       .ValueOrDie());
1159 
1160     // Reset other related artifacts.
1161     scope_ = Scope::NewRootScope();
1162   }
1163 
1164   // Constructs a flat tensor with 'vals' in Unified Memory.
1165   template <typename T>
AsTensor(gtl::ArraySlice<T> vals)1166   Tensor AsTensor(gtl::ArraySlice<T> vals) {  // non-absl ok
1167     Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
1168                {static_cast<int64_t>(vals.size())});
1169     std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
1170     return ret;
1171   }
1172 
1173   // Constructs a tensor of "shape" with values "vals" in Unified Memory.
1174   template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)1175   Tensor AsTensor(gtl::ArraySlice<T> vals,  // non-absl ok
1176                   const TensorShape& shape) {
1177     Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
1178                {static_cast<int64_t>(vals.size())});
1179     CHECK(ret.CopyFrom(AsTensor(vals), shape));
1180     return ret;
1181   }
1182 
1183   template <typename T, typename S>
transformTensor(const std::vector<T> & vals,Tensor & ret)1184   void transformTensor(const std::vector<T>& vals, Tensor& ret) {
1185     std::transform(vals.begin(), vals.end(), ret.flat<S>().data(),
1186                    [](const T in_val) -> S { return static_cast<S>(in_val); });
1187   }
1188 
1189   template <typename T, typename S>
transformWeights(const std::vector<T> & vals,TRT_ShapedWeights & weights)1190   void transformWeights(const std::vector<T>& vals,
1191                         TRT_ShapedWeights& weights) {
1192     std::transform(vals.begin(), vals.end(), weights.GetPointer<S>(),
1193                    [](const T in_val) -> S { return static_cast<S>(in_val); });
1194   }
1195 
1196   // Constructs a tensor with given values (vals). The tensor type is defined by
1197   // the tf_type argument, its shape is given by input_dims. The tensor is
1198   // constructed using the allocator of OpConverterTest in Unified Memory.
1199   template <typename T>
AsTensor(const std::vector<T> & vals,const std::vector<int> & input_dims,DataType tf_type)1200   Tensor AsTensor(const std::vector<T>& vals,
1201                   const std::vector<int>& input_dims, DataType tf_type) {
1202     Tensor ret(tensor_buffer_allocator_.get(), tf_type,
1203                {static_cast<int64_t>(vals.size())});
1204     if (tf_type == DT_FLOAT) {
1205       transformTensor<T, float>(vals, ret);
1206     } else if (tf_type == DT_HALF) {
1207       transformTensor<T, Eigen::half>(vals, ret);
1208     } else if (tf_type == DT_INT32) {
1209       transformTensor<T, int32>(vals, ret);
1210 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
1211     } else if (tf_type == DT_BOOL) {
1212       transformTensor<T, bool>(vals, ret);
1213 #endif
1214     } else {
1215       LOG(FATAL) << "Cannot create tensor with type "
1216                  << DataTypeString(tf_type);
1217     }
1218     TensorShape shape;
1219     TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_dims, &shape));
1220     CHECK(ret.CopyFrom(ret, shape));
1221     return ret;
1222   }
1223 
1224   template <typename T>
AsTensor(const std::vector<int> & vals,const std::vector<int> & input_dims,DataType tf_type)1225   Tensor AsTensor(const std::vector<int>& vals,
1226                   const std::vector<int>& input_dims, DataType tf_type) {
1227     const auto& conv_vals = CastVector<int, T>(vals);
1228     return AsTensor(conv_vals, input_dims, tf_type);
1229   }
1230 
1231   // Constructs a flat tensor in Unified Memory.
1232   template <typename T>
ConstructTensor(int data_size,const T & value=T ())1233   Tensor ConstructTensor(int data_size, const T& value = T()) {
1234     std::vector<T> values(data_size, value);
1235     return AsTensor<T>(values);
1236   }
1237 
1238   // Constructs a flat tensor in Unified Memory.
1239   template <typename T>
ConstructTensor(int data_size,const T & value,DataType tf_type)1240   Tensor ConstructTensor(int data_size, const T& value, DataType tf_type) {
1241     std::vector<T> values(data_size, value);
1242     return AsTensor<T>(values, {data_size}, tf_type);
1243   }
1244 
CheckDataTypeMatches(const DataVec & datas)1245   void CheckDataTypeMatches(const DataVec& datas) {
1246     if (VLOG_IS_ON(2)) {
1247       int nbBindings = engine_->getNbBindings();
1248       VLOG(2) << "Number of engine bindings: " << nbBindings;
1249       for (int i = 0; i < nbBindings; i++) {
1250         VLOG(2) << "Binding " << i << " name: " << engine_->getBindingName(i);
1251       }
1252     }
1253     for (const auto& data : datas) {
1254       VLOG(2) << "Checking if data type matches for tensor " << data.name;
1255       const int input_index = engine_->getBindingIndex(data.name.c_str());
1256       ASSERT_NE(-1, input_index);
1257       const nvinfer1::DataType trt_dtype =
1258           engine_->getBindingDataType(input_index);
1259       DataType tf_type;
1260       TF_ASSERT_OK(TrtTypeToTfType(trt_dtype, &tf_type));
1261       ASSERT_EQ(data.tensor.dtype(), tf_type)
1262           << DataTypeString(data.tensor.dtype()) << " vs. "
1263           << DataTypeString(tf_type);
1264     }
1265   }
1266 
BuildAndRun(const DataVec & input_data,DataVec * output_data,const int batch_size=1)1267   Status BuildAndRun(const DataVec& input_data, DataVec* output_data,
1268                      const int batch_size = 1) {
1269     // Mark the output tensor as TRT engine output.
1270     std::vector<Converter::EngineOutputInfo> output_info;
1271     for (const auto& data : *output_data) {
1272       nvinfer1::DataType trt_type;
1273       TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type));
1274       output_info.push_back({data.name, data.name, trt_type});
1275     }
1276     TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info));
1277 
1278     // Build the TRT engine.
1279     if (engine_.get() != nullptr) {
1280       return errors::Internal("Engine already exists");
1281     }
1282     TrtShapeOptimizationProfile profiles;
1283     if (!converter_->use_implicit_batch()) {
1284       std::vector<bool> input_mask(input_data.size());
1285       for (int i = 0; i < input_data.size(); i++) {
1286         input_mask[i] = (input_data[i].tensor.dtype() != DataType::DT_RESOURCE);
1287       }
1288       profiles.SetInputMask(input_mask);
1289       profiles.SetShapeTensorMask(converter_->network());
1290       TF_RETURN_IF_ERROR(profiles.CollectShapeValues(input_data));
1291       // Create a single optimization profile for explicit batch mode
1292       std::vector<TensorShape> input_shapes;
1293       TF_RETURN_IF_ERROR(GetShapeFromDataVec(input_data, &input_shapes));
1294       profiles.AddShape(input_shapes);
1295       std::vector<PartialTensorShape> input_partial_shapes;
1296       TF_RETURN_IF_ERROR(
1297           GetNetworkInputShapes(converter_->network(), &input_partial_shapes));
1298       profiles.InitProfiles(input_partial_shapes,
1299                             ProfileStrategy::kImplicitBatchModeCompatible);
1300     }
1301     TF_RETURN_IF_ERROR(
1302         converter_->BuildCudaEngine(&engine_,
1303                                     /*max_batch_size=*/batch_size,
1304                                     /*max_workspace_size_bytes=*/1 << 26,
1305                                     /*allocator=*/nullptr,
1306                                     /*calibrator=*/nullptr,
1307                                     /*profiles=*/&profiles));
1308     CHECK_NOTNULL(engine_.get());
1309     CheckDataTypeMatches(input_data);
1310     CheckDataTypeMatches(*output_data);
1311 
1312     const int num_bindings = input_data.size() + output_data->size();
1313     std::vector<void*> buffers(num_bindings);
1314 
1315     if (engine_->getNbBindings() != num_bindings) {
1316       return errors::Internal("Number of bindings do not match");
1317     }
1318     // Since we have only 1 optimization profile (which is enabled by default)
1319     // it is fine to create execution context directly, instead of calling
1320     // profiles.CreateExecutionContexts()
1321     TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
1322         engine_->createExecutionContext());
1323 
1324     // Prepare input bindings.
1325     TF_RETURN_IF_ERROR(
1326         SetTrtEngineInputs(engine_.get(), execution_context.get(), 0, buffers,
1327                            converter_->use_implicit_batch(), batch_size,
1328                            profiles, nullptr, &input_data));
1329     // Prepare output bindings.
1330     TF_RETURN_IF_ERROR(SetTrtEngineOutputs(
1331         engine_.get(), execution_context.get(), 0, buffers,
1332         converter_->use_implicit_batch(), batch_size, nullptr, output_data));
1333     // Execute the TRT engine.
1334     TF_RETURN_IF_ERROR(TrtEnqueue(execution_context.get(), buffers, stream_,
1335                                   converter_->use_implicit_batch(),
1336                                   batch_size));
1337     cudaStreamSynchronize(stream_);
1338     return Status::OK();
1339   }
1340 
1341   // Adds ITensor for both validation and conversion, assuming explicit batch
1342   // dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}).
AddTestTensorWithTFDims(const string & name,const std::vector<int32> & dims,nvinfer1::DataType trt_type=nvinfer1::DataType::kFLOAT,Status add_input_status=Status::OK ())1343   void AddTestTensorWithTFDims(
1344       const string& name, const std::vector<int32>& dims,
1345       nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
1346       Status add_input_status = Status::OK()) {
1347     DataType tf_type;
1348     TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type));
1349     ops::Placeholder::Attrs attrs;
1350     TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
1351 
1352     auto input = ops::Placeholder(scope_.WithOpName(name), tf_type, attrs);
1353     node_inputs_[name] = input.output;
1354 
1355     // Add a real ITensor for conversion conditionally.
1356 
1357     auto dims_adap =
1358         DimsAdapter::Create(attrs.shape_, converter_->use_implicit_batch());
1359     if (converter_->use_implicit_batch() && !dims_adap.ok()) {
1360       ASSERT_EQ(add_input_status, dims_adap.status());
1361       return;
1362     } else {
1363       TF_EXPECT_OK(dims_adap.status());
1364     }
1365     if (!converter_->use_implicit_batch() || dims_adap->IsStatic()) {
1366       int batch_size = dims.size() > 0 ? dims[0] : 0;
1367       Status status = converter_->AddInputTensor(
1368           name, trt_type, dims_adap->AsTrtDims(), batch_size);
1369       ASSERT_EQ(add_input_status, status);
1370     }
1371   }
1372 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)1373   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input) {
1374     return converter_->AddTensorOrWeights(name, input);
1375   }
1376 
1377   // Adds ITensor for both validation and conversion. The difference compared to
1378   // AddTestTensorWithTFDims is in the meaning of the dims parameter. To define
1379   // a tensor with NCHW shape, here we set dims = {C,H,W} and batch_size = N.
1380   // TODO(tfeher) remove this function once all test are updated to use the
1381   // other version of AddTestTensor (defined by
1382   // ParameterizedOpConverterTestBase).
AddTestTensor(const string & name,const std::vector<int32> & dims,int batch_size=1,nvinfer1::DataType trt_dtype=nvinfer1::DataType::kFLOAT)1383   void AddTestTensor(
1384       const string& name, const std::vector<int32>& dims, int batch_size = 1,
1385       nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
1386     DimsAdapter adap(dims);
1387     std::vector<int32_t> dims_vec;
1388     TF_CHECK_OK(adap.Prepend(batch_size).Vector(&dims_vec));
1389     AddTestTensorWithTFDims(name, dims_vec, trt_dtype);
1390     if (adap.IsStatic()) {
1391       ASSERT_EQ(batch_size, converter_->batch_size_);
1392     }
1393   }
1394 
1395   // Adds weights for both validation and conversion. The type of the weight is
1396   // determined by tf_type. The initial value vector (values) can have any
1397   // type (T) that can be statically casted to tf_type.
1398   template <typename T = int32>
AddTestWeights(const string & name,const std::vector<int> & dims,const std::vector<T> & values_inp,DataType tf_type,bool fix_values=true)1399   void AddTestWeights(const string& name, const std::vector<int>& dims,
1400                       const std::vector<T>& values_inp, DataType tf_type,
1401                       bool fix_values = true) {
1402     const DimsAdapter dims_adap(dims);
1403     const int64_t num_elements = dims_adap.Volume();
1404 
1405     std::vector<T> values(values_inp);
1406     if (num_elements != values.size()) {
1407       if (fix_values) {
1408         AdjustVectorByDims<T>(values, num_elements, name, "AddTestWeights");
1409       } else {
1410         FAIL() << "Unable to create test weights: "
1411                << (num_elements > values.size() ? "not enough" : "to many")
1412                << " values specified: " << values.size() << " vs. "
1413                << num_elements << " defined by dims";
1414       }
1415     }
1416     // Add weights for validation.
1417     Tensor t = AsTensor<T>(values, dims, tf_type);
1418     node_inputs_[name] = ops::Const(scope_.WithOpName(name), t);
1419 
1420     // Add weights for conversion.
1421     nvinfer1::DataType dtype;
1422     TF_ASSERT_OK(TfTypeToTrtType(tf_type, &dtype));
1423     QCHECK_EQ(num_elements, values.size())
1424         << num_elements << " vs " << values.size();
1425     TRT_ShapedWeights weights(dtype);
1426     if (num_elements) {
1427       weights =
1428           converter_->weight_store_.GetTempWeights(dtype, dims_adap.AsTrtDims())
1429               .value();
1430 
1431       if (tf_type == DT_FLOAT) {
1432         transformWeights<T, float>(values, weights);
1433       } else if (tf_type == DT_HALF) {
1434         transformWeights<T, Eigen::half>(values, weights);
1435       } else if (tf_type == DT_INT32) {
1436         transformWeights<T, int32>(values, weights);
1437 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
1438       } else if (tf_type == DT_BOOL) {
1439         transformWeights<T, bool>(values, weights);
1440 #endif
1441       } else {
1442         LOG(FATAL) << "Cannot create tensor with type "
1443                    << DataTypeString(tf_type);
1444       }
1445     }
1446     TF_EXPECT_OK(
1447         converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights}));
1448   }
1449 
1450   // Adds test weight without specifying tf_type arg. In this case the initial
1451   // value type (T) will determine the type of the weights.
1452   template <typename T = int32>
AddTestWeights(const string & name,const std::vector<int> & dims,const std::vector<T> & value,bool fix_values=true)1453   void AddTestWeights(const string& name, const std::vector<int>& dims,
1454                       const std::vector<T>& value, bool fix_values = true) {
1455     AddTestWeights(name, dims, value, DataTypeToEnum<T>::value, fix_values);
1456   }
1457 
1458   // Test validation in validation-only mode.
RunValidation(const Node * node)1459   Status RunValidation(const Node* node) {
1460     grappler::GrapplerItem item;
1461     TF_EXPECT_OK(scope_.ToGraphDef(&item.graph));
1462     grappler::GraphProperties graph_properties(item);
1463     TF_EXPECT_OK(graph_properties.InferStatically(true));
1464 
1465     TrtNodeValidator validator(
1466         graph_properties, converter_->precision_mode(),
1467         /*use_calibration=*/false,
1468         /*use_implicit_batch=*/converter_->use_implicit_batch(),
1469         /*use_explicit_precision=*/false);
1470     return validator.IsTensorRTCandidate(node);
1471   }
1472 
RunConversion(const Node * node,error::Code expected_code=error::OK,const std::string & expected_msg_substr="")1473   void RunConversion(const Node* node, error::Code expected_code = error::OK,
1474                      const std::string& expected_msg_substr = "") {
1475     EXPECT_THAT(converter_->ConvertNode(node->def()),
1476                 StatusIs(expected_code, HasSubstr(expected_msg_substr)));
1477     if (expected_code == error::OK) {
1478       EXPECT_THAT(converter_->network(), LayerNamesNonEmpty());
1479     }
1480   }
1481 
1482   // Helper method to run both validation and conversion, when the expected
1483   // output are same.
RunValidationAndConversion(const NodeDef & node_def,error::Code expected_code=error::OK,const std::string & expected_msg_substr="",bool should_run_conversion=true)1484   void RunValidationAndConversion(const NodeDef& node_def,
1485                                   error::Code expected_code = error::OK,
1486                                   const std::string& expected_msg_substr = "",
1487                                   bool should_run_conversion = true) {
1488     // Add the node to the graph.
1489     // TODO(laigd): we should accept a function that adds the node using
1490     // `scope_`, so individual test case can reuse the scope object and we don't
1491     // need to add the edges here by ourselves.
1492     Graph* graph = scope_.graph();
1493     Status status;
1494     Node* node = graph->AddNode(std::move(node_def), &status);
1495     TF_EXPECT_OK(status);
1496     for (int i = 0; i < node_def.input().size(); ++i) {
1497       const string& input_name = node_def.input(i);
1498       const auto& itr = node_inputs_.find(input_name);
1499       QCHECK(itr != node_inputs_.end());
1500       const Output& input = itr->second;
1501       graph->AddEdge(input.node(), input.index(), node, i);
1502     }
1503 
1504     status = RunValidation(node);
1505     if (should_run_conversion && status.ok()) {
1506       RunConversion(node, expected_code, expected_msg_substr);
1507     } else {
1508       EXPECT_THAT(status,
1509                   StatusIs(expected_code, HasSubstr(expected_msg_substr)));
1510     }
1511   }
1512 
1513   // Helper method to run both validation and conversion, and check the output
1514   // shapes.
RunValidationAndConversion(const NodeDef & node_def,const Status & status,const std::string & output_name,const std::vector<std::vector<int>> & exp_out_dims)1515   void RunValidationAndConversion(
1516       const NodeDef& node_def, const Status& status,
1517       const std::string& output_name,
1518       const std::vector<std::vector<int>>& exp_out_dims) {
1519     RunValidationAndConversion(node_def, status.code(), status.error_message(),
1520                                true);
1521 
1522     if (status.ok()) {
1523       // TODO(tfeher): Enable this check in explicit_batch_mode.
1524       // In dynamic shape mode the output dims cannot be tested here. In that
1525       // case we need to wait for the concrate input shapes to be defined (by
1526       // setBindingDimensions before enqueue) before we can check the output
1527       // dims.
1528       if (converter_->use_implicit_batch()) {
1529         for (int i = 0; i < exp_out_dims.size(); i++) {
1530           TRT_TensorOrWeights output;
1531           string name = i == 0 ? output_name : StrCat(output_name, ":", i);
1532           TF_EXPECT_OK(GetTensorOrWeights(name.c_str(), &output));
1533           ASSERT_TRUE(output.is_tensor());
1534           if (!exp_out_dims[i].empty()) {
1535             // Removing batch dim.
1536             auto out_dims = std::vector<int>(exp_out_dims[i].begin() + 1,
1537                                              exp_out_dims[i].end());
1538             VLOG(2) << "Testing output shape for tensor " << name;
1539             EXPECT_THAT(output.tensor()->getDimensions(),
1540                         DimsAreArray(out_dims));
1541           }
1542         }
1543       }
1544     }
1545   }
1546 
1547   // Expose quantization_ranges_ for tests
quantization_ranges_proxy()1548   std::unordered_map<ITensorProxyPtr*, float>& quantization_ranges_proxy() {
1549     return converter_->quantization_ranges_proxy_;
1550   }
1551 
1552   // Expose quantization_ranges_ for tests
quantization_ranges()1553   std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
1554     return converter_->quantization_ranges_;
1555   }
1556 
1557  protected:
1558   template <typename T>
AdjustVectorByDims(std::vector<T> & values,size_t num_elements,const string & name,const char * callingFunc)1559   void AdjustVectorByDims(std::vector<T>& values, size_t num_elements,
1560                           const string& name, const char* callingFunc) {
1561     const auto old_size = values.size();
1562     if (num_elements > old_size) {
1563       // Expending vector with 0's.
1564       const std::vector<T> zeros(num_elements - old_size, 0);
1565       values.reserve(num_elements);
1566       values.insert(values.end(), zeros.begin(), zeros.end());
1567       VLOG(2) << "In function " << callingFunc << " the vector '" << name
1568               << "' was extended by " << num_elements - old_size << " zeros";
1569     } else {
1570       // Removing unnecessary elements.
1571       values.resize(num_elements);
1572       VLOG(2) << "Only first " << num_elements << " out of " << old_size
1573               << " elements of the vector '" << name
1574               << "' will be used in function" << callingFunc;
1575     }
1576   }
1577 
1578  public:
1579   std::unique_ptr<Converter> converter_;
1580 
1581  protected:
1582   Logger& logger_ = *Logger::GetLogger();
1583 
1584  private:
1585   TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
1586   cudaStream_t stream_;
1587   std::unique_ptr<Allocator> tensor_buffer_allocator_;
1588 
1589  public:
1590   // The scope that contains the graph being converted. Because
1591   // tensor_buffer_allocator_ provides the storage for tensor contents that are
1592   // represented as attributes for graph nodes within scope_,
1593   // tensor_buffer_allocator_ needs to be available when destructing scope_.
1594   // Therefore, scope_ comes after tensor_buffer_allocator_ in the class member
1595   // field list.
1596   Scope scope_;
1597 
1598  protected:
1599   std::unordered_map<string, Output> node_inputs_;
1600 };
1601 
1602 // Extends the OpConverterTest for variable converters which require a properly
1603 // setup context.
1604 class VariableOpConverterTest : public OpConverterTest {
1605  public:
Reset(TrtPrecisionMode precision_mode_to_test=TrtPrecisionMode::FP32,TrtTestMode trt_mode=TrtTestMode::kImplicitBatch)1606   void Reset(TrtPrecisionMode precision_mode_to_test = TrtPrecisionMode::FP32,
1607              TrtTestMode trt_mode = TrtTestMode::kImplicitBatch) {
1608     OpConverterTest::Reset(precision_mode_to_test, trt_mode, context_.get());
1609   }
1610 
CreateContext(const NodeDef & node_def,OpKernel ** kernel,OpKernelContext ** context)1611   void CreateContext(const NodeDef& node_def, OpKernel** kernel,
1612                      OpKernelContext** context) {
1613     std::unique_ptr<Device> device_(
1614         DeviceFactory::NewDevice("GPU", {}, "/job:a/replica:0/task:0"));
1615     Device* device_ptr = device_.get();
1616 
1617     device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(device_));
1618 
1619     managed_allocator_ = std::make_unique<GpuManagedAllocator>();
1620     Allocator* allocator = managed_allocator_.get();
1621     step_container_ =
1622         std::make_unique<ScopedStepContainer>(0, [](const string&) {});
1623     slice_reader_cache_wrapper_ =
1624         std::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
1625 
1626     flib_def_ = std::make_unique<FunctionLibraryDefinition>(
1627         OpRegistry::Global(), FunctionDefLibrary{});
1628 
1629     thread_pool_ =
1630         std::make_unique<thread::ThreadPool>(Env::Default(), "default",
1631                                              /*num_threads=*/1);
1632     pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
1633         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
1634         TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(),
1635         thread_pool_.get());
1636 
1637     FunctionLibraryRuntime* flib = pflr_->GetFLR(device_ptr->name());
1638     ResourceMgr* resource_mgr = device_ptr->resource_manager();
1639 
1640     TF_CHECK_OK(NodeProperties::CreateFromNodeDef(
1641         node_def, OpRegistry::Global(), &props_));
1642 
1643     OpKernel* kernel_ptr = nullptr;
1644     TF_CHECK_OK(CreateOpKernel(DEVICE_GPU, device_ptr, allocator, flib,
1645                                resource_mgr, props_, TF_GRAPH_DEF_VERSION,
1646                                &kernel_ptr));
1647     op_kernel_ = std::unique_ptr<OpKernel>(kernel_ptr);
1648 
1649     auto* dev_info = device_ptr->tensorflow_accelerator_device_info();
1650     CHECK_NOTNULL(dev_info);
1651     DeviceContext* device_context = dev_info->default_context;
1652 
1653     // Note: this setup is not exhaustive.
1654     params_.device = device_ptr;
1655     params_.op_kernel = op_kernel_.get();
1656     params_.resource_manager = resource_mgr;
1657     params_.frame_iter = FrameAndIter(0, 0);
1658     params_.inputs = inputs_;
1659     params_.step_container = step_container_.get();
1660     params_.function_library = flib;
1661     params_.slice_reader_cache = slice_reader_cache_wrapper_.get();
1662     params_.op_device_context = device_context;
1663 
1664     context_ = std::make_unique<OpKernelContext>(&params_);
1665 
1666     // Outputs.
1667     *kernel = op_kernel_.get();
1668     *context = context_.get();
1669   }
1670 
1671   // Adds resource for resource variable op converters.
AddTestResource(const string & name,const ResourceHandle & resource)1672   void AddTestResource(const string& name, const ResourceHandle& resource) {
1673     // Add resource for validation.
1674     node_inputs_[name] =
1675         ops::Placeholder(scope_.WithOpName("my_handle"), DT_RESOURCE);
1676 
1677     // Add resource for conversion.
1678     TF_EXPECT_OK(AddTensorOrWeights(name, TRT_TensorOrWeights{resource}));
1679   }
1680 
1681  private:
1682   // The following pointers manage the kernel context.
1683   std::unique_ptr<DeviceMgr> device_mgr_;
1684   std::unique_ptr<Allocator> managed_allocator_;
1685   std::unique_ptr<ScopedStepContainer> step_container_;
1686   std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
1687       slice_reader_cache_wrapper_;
1688   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
1689   std::unique_ptr<thread::ThreadPool> thread_pool_;
1690   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
1691   OpKernelContext::Params params_;
1692   std::unique_ptr<OpKernel> op_kernel_;
1693   std::unique_ptr<OpKernelContext> context_;
1694   std::shared_ptr<const NodeProperties> props_;
1695   absl::InlinedVector<TensorValue, 4> inputs_;
1696 };
1697 
1698 // General test parameters to be used with ops that take a single input tensor.
1699 struct TestParamBase {
1700   // Concrete input dimensions for the test (including the batch dim)
1701   std::vector<int> input_dims;
1702 
1703   // Dimensions to define an input with PartialTensorShape. This can be used to
1704   // define networks with dynamic input shape. It can be left empty, in that
1705   // case AddTestTensor sets partial shapes that are appropriate to TrtTestMode.
1706   std::vector<int> partial_input_dims;
1707 
1708   // Concrete (static) output dimensions, including batch size as first dim
1709   std::vector<int> expected_output_dims;
1710 
1711   // Parameter vector, has converter specific meaning.
1712   std::vector<int> param;
1713 
1714   // Expected status of conversion (with concrete error message)
1715   Status status;
1716 
1717   // Expected status of BuildAndRun
1718   Status runtime_status;
1719 };
1720 
operator <<(std::ostream & os,const TestParamBase & p)1721 std::ostream& operator<<(std::ostream& os, const TestParamBase& p) {
1722   os << "input_dims" << PrintToString(p.input_dims);
1723   if (!p.partial_input_dims.empty()) {
1724     os << ", partial_input_dims" << PrintToString(p.partial_input_dims);
1725   }
1726   if (!p.expected_output_dims.empty()) {
1727     os << ", exp_out_dims" << PrintToString(p.expected_output_dims);
1728   }
1729   if (!p.param.empty()) {
1730     os << ", param" << PrintToString(p.param);
1731   }
1732   os << ", " << p.status;
1733   return os;
1734 }
1735 
1736 // Printing vector with the numbers of type T which defines tensor or shape.
1737 template <typename T>
get_debug_string_for_vector(const std::vector<T> & vector,absl::string_view pComment,absl::string_view name,absl::string_view type="")1738 const std::string get_debug_string_for_vector(const std::vector<T>& vector,
1739                                               absl::string_view pComment,
1740                                               absl::string_view name,
1741                                               absl::string_view type = "") {
1742   const std::string t1 = absl::StrCat(pComment, name, " Dims(nbDims=");
1743   const std::string t2 = absl::StrJoin(vector, ",");
1744   const std::string t3 = type != "" ? absl::StrCat(") of type ", type) : ")";
1745   std::stringstream stream;
1746   stream << t1 << vector.size() << ", d=" << t2 << t3;
1747   return stream.str();
1748 }
1749 
1750 // Parameterized version of OpConverterTest. We have the following parameters:
1751 // 1. TrtTestMode: implicit batch, explicit batch, dynamic shape modes
1752 // 2. DataType of the input TF tensors: DT_FLOAT, DT_HALF, DT_INT32
1753 // 3. TrtPrecisionMode argument for the Converter: FP32, FP16, INT8
1754 // We will introduce subclasses that will be instantiated using different
1755 // combinations of the DataType and TrtPrecisionMode parameters.
1756 class ParameterizedOpConverterTestBase
1757     : public OpConverterTest,
1758       public ::testing::WithParamInterface<
1759           std::tuple<TrtTestMode, DataType, TrtPrecisionMode>> {
1760  public:
ParameterizedOpConverterTestBase()1761   ParameterizedOpConverterTestBase()
1762       : trt_mode_(std::get<0>(GetParam())),
1763         tf_type_(std::get<1>(GetParam())),
1764         converter_precision_(std::get<2>(GetParam())) {
1765     LOG(INFO) << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%";
1766     LOG(INFO) << "tf_type_: " << DebugString(tf_type_);
1767     LOG(INFO) << "trt_mode_: " << DebugString(trt_mode_);
1768     LOG(INFO) << "converter_precision_: " << DebugString(converter_precision_);
1769     LOG(INFO) << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%";
1770   }
1771 
Reset()1772   void Reset() {
1773     OpConverterTest::Reset(converter_precision_, trt_mode_);
1774     input_data_.clear();
1775   }
1776 
Reset(TrtPrecisionMode precision)1777   void Reset(TrtPrecisionMode precision) {
1778     OpConverterTest::Reset(precision, trt_mode_);
1779     input_data_.clear();
1780   }
1781 
1782   // Getters of protected attributes
get_tf_type()1783   DataType get_tf_type() { return tf_type_; }
get_trt_mode()1784   TrtTestMode get_trt_mode() { return trt_mode_; }
get_converter_precision()1785   TrtPrecisionMode get_converter_precision() { return converter_precision_; }
1786 
1787   // Adds an input ITensor for TRT network. Also creates the corresponding TF
1788   // tensor, and stores it in the list of inputs (input_data_).
1789   //
1790   // The TF tensor is always created with concrete static input shape given by
1791   // dims. The ITensor can have static or dynamic shape based on the trt_mode
1792   // attribute. The ITensor shape is set automatically according to the trt_mode
1793   // parameter, unless the user overrides it with an explicit
1794   // partial_input_shape_dims argument.
1795   //
1796   // Parameters:
1797   // - name of the input node
1798   // - dims actual dimensions of the tensor that we will use during the test
1799   //   (including explicit batch dim)
1800   // - values initial values for the TF tensor
1801   // - dtype data type of the tensor
1802   // - partial_input_shape dimensions which can include unknown shapes. This can
1803   //   be empty, in that case the partial_input_shape will be set automatically
1804   //   depending on the trt_mode argument. (This argument also includes explicit
1805   //   batch dim).
1806   // - add_input_status adding ITensor to the network can fail in implicit batch
1807   //   mode if the batch size is inconsistent. Using the add_input_status arg we
1808   //   can test such errors.
1809   //
1810   template <typename T = int>
AddTestTensor(const string & name,const std::vector<int32> & dims,DataType tf_type,const std::vector<T> & values_inp,const std::vector<int32> & partial_input_shape_dims={},Status add_input_status=Status::OK (),bool fix_values=true)1811   void AddTestTensor(const string& name, const std::vector<int32>& dims,
1812                      DataType tf_type, const std::vector<T>& values_inp,
1813                      const std::vector<int32>& partial_input_shape_dims = {},
1814                      Status add_input_status = Status::OK(),
1815                      bool fix_values = true) {
1816     std::vector<T> values(values_inp);
1817     VLOG(2) << "**** AddTestTensor for " << name
1818             << " ***** dims empty() = " << dims.empty()
1819             << "  tf_type = " << DebugString(tf_type);
1820     if (!dims.empty()) {
1821       const auto num_elements = std::accumulate(
1822           std::begin(dims), std::end(dims), 1, std::multiplies<double>());
1823       if (!values.empty() && num_elements != values.size()) {
1824         if (fix_values) {
1825           AdjustVectorByDims(values, num_elements, name, "AddTestTensor");
1826         } else {
1827           // Note: for conversion only tests, it is valid to have empty values,
1828           // otherwise the number of elements should match.
1829           LOG(WARNING) << "Expected Test Tensor Shape: " << DebugString(dims)
1830                        << ", Received Input Tensor: " << DebugString(values);
1831         }
1832       }
1833     }
1834 
1835     std::vector<int32> partial_shape;
1836     if (!partial_input_shape_dims.empty()) {
1837       partial_shape = partial_input_shape_dims;
1838     } else {
1839       if (trt_mode_ == TrtTestMode::kDynamicShape) {
1840         // In dynamic shape mode we make all dims unknown.
1841         partial_shape = std::vector<int32>(dims.size(), -1);
1842       } else {
1843         // Use static (known) input shapes.
1844         partial_shape = dims;
1845       }
1846       if (VLOG_IS_ON(2)) {
1847         VLOG(2) << get_debug_string_for_vector(
1848             partial_shape, "Using partial_shape: for ", name);
1849       }
1850     }
1851     nvinfer1::DataType trt_type;
1852     TF_ASSERT_OK(TfTypeToTrtType(tf_type, &trt_type));
1853     AddTestTensorWithTFDims(name, partial_shape, trt_type, add_input_status);
1854     if (!values.empty()) {
1855       if (VLOG_IS_ON(2)) {
1856         VLOG(2) << get_debug_string_for_vector(
1857             values, "Adding test tensor: for ", name, DataTypeString(tf_type));
1858       }
1859       InputOutputData data{name, AsTensor(values, dims, tf_type)};
1860       VLOG(2) << "Added tensor: " << data.name << " with dtype "
1861               << DataTypeString(data.tensor.dtype());
1862       input_data_.push_back(data);
1863     }
1864   }
1865 
1866   // Adds test tensor (same as above) but with the default tf_type defined by
1867   // the test params.
1868   template <typename T = int>
AddTestTensor(const string & name,const std::vector<int32> & dims,const std::vector<T> & values={},const std::vector<int32> & partial_input_shape_dims={})1869   void AddTestTensor(const string& name, const std::vector<int32>& dims,
1870                      const std::vector<T>& values = {},
1871                      const std::vector<int32>& partial_input_shape_dims = {}) {
1872     AddTestTensor<T>(name, dims, tf_type_, values, partial_input_shape_dims);
1873   }
1874 
1875   // Builds and runs the converted network. Checks output tensor shape. Tests
1876   // output values using a matcher. The network can have multiple input and
1877   // output tensors. The inputs are defined by the input_data_ member variable.
BuildAndRun(const string & name,const std::vector<std::vector<int>> & expected_output_dims,const Status & expected_runtime_status,const std::vector<Matcher<std::vector<float>>> & matcher,const std::vector<DataType> & out_tf_types={})1878   void BuildAndRun(const string& name,
1879                    const std::vector<std::vector<int>>& expected_output_dims,
1880                    const Status& expected_runtime_status,
1881                    const std::vector<Matcher<std::vector<float>>>& matcher,
1882                    const std::vector<DataType>& out_tf_types = {}) {
1883     TensorShape shape;
1884     const int n_output = expected_output_dims.size();
1885     ASSERT_EQ(n_output, matcher.size());
1886     DataVec output_data;
1887     for (int i = 0; i < n_output; i++) {
1888       TF_EXPECT_OK(
1889           TensorShapeUtils::MakeShape(expected_output_dims[i], &shape));
1890       string out_name = (i == 0) ? name : StrCat(name, ":", i);
1891       DataType out_tf_type =
1892           out_tf_types.size() > i ? out_tf_types[i] : tf_type_;
1893       InputOutputData data{
1894           out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)};
1895       output_data.push_back(data);
1896     }
1897     const int batch_size =
1898         input_data_.empty() ||
1899                 TensorShapeUtils::IsScalar(input_data_[0].tensor.shape())
1900             ? 1
1901             : input_data_[0].tensor.shape().dim_size(0);
1902     Status stat =
1903         OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size);
1904     ASSERT_EQ(expected_runtime_status.ok(), stat.ok())
1905         << "expected status: " << expected_runtime_status
1906         << ", actual status: " << stat;
1907     if (expected_runtime_status.ok() && stat.ok()) {
1908       for (int i = 0; i < n_output; i++) {
1909         // Check the shape of the actual output tensors
1910         TF_EXPECT_OK(
1911             TensorShapeUtils::MakeShape(expected_output_dims[i], &shape));
1912         EXPECT_TRUE(output_data[i].tensor.shape() == shape)
1913             << "Expected shape: " << shape.DebugString() << ", actual shape: "
1914             << output_data[i].tensor.shape().DebugString();
1915         EXPECT_THAT(GetDataAsFloat(output_data[i]), matcher[i]);
1916       }
1917     }
1918   }
1919 
1920   // Runs validation and conversion. If conversion is successfull then builds
1921   // the TRT network, executes it and checks the output. Handles multiple output
1922   // tensors.
TestOpConverterMultiOut(const string & name,const NodeDef node_def,const std::vector<std::vector<int>> & expected_output_dims,const Status & expected_conversion_status,const Status & expected_runtime_status,const std::vector<Matcher<std::vector<float>>> & matcher,const std::vector<DataType> & out_tf_type={})1923   void TestOpConverterMultiOut(
1924       const string& name, const NodeDef node_def,
1925       const std::vector<std::vector<int>>& expected_output_dims,
1926       const Status& expected_conversion_status,
1927       const Status& expected_runtime_status,
1928       const std::vector<Matcher<std::vector<float>>>& matcher,
1929       const std::vector<DataType>& out_tf_type = {}) {
1930     RunValidationAndConversion(node_def, expected_conversion_status, name,
1931                                expected_output_dims);
1932     if (expected_conversion_status.ok()) {
1933       BuildAndRun(name, expected_output_dims, expected_runtime_status, matcher,
1934                   out_tf_type);
1935     }
1936   }
1937 
1938   // Runs validation and conversion. If conversion is successfull then builds
1939   // the TRT network, executes it and checks the output.
TestOpConverter(const string & name,const NodeDef node_def,const std::vector<int> & expected_output_dims,const Status & expected_conversion_status,const Status & expected_runtime_status,const Matcher<std::vector<float>> & matcher,const std::vector<DataType> & out_tf_types={})1940   void TestOpConverter(const string& name, const NodeDef node_def,
1941                        const std::vector<int>& expected_output_dims,
1942                        const Status& expected_conversion_status,
1943                        const Status& expected_runtime_status,
1944                        const Matcher<std::vector<float>>& matcher,
1945                        const std::vector<DataType>& out_tf_types = {}) {
1946     TestOpConverterMultiOut(
1947         name, node_def, std::vector<std::vector<int>>({expected_output_dims}),
1948         expected_conversion_status, expected_runtime_status,
1949         std::vector<Matcher<std::vector<float>>>({matcher}), out_tf_types);
1950   }
1951 
1952  protected:
1953   const TrtTestMode trt_mode_;
1954   const DataType tf_type_;
1955   const TrtPrecisionMode converter_precision_;
1956   DataVec input_data_;
1957 };
1958 
1959 template <typename T>
1960 class OpConverter_UnaryTest : public ParameterizedOpConverterTestBase {
1961  public:
1962   template <typename S>
RunTests(const string & testName,const OperationMap<S> & map,std::map<std::string,std::pair<std::function<NodeDef (DataType)>,T (*)(T)>> & op_map,const std::vector<T> input_values,const std::string input_name="input",float max_abs_error=0.0001,bool nan_sensitive=true)1963   void RunTests(
1964       const string& testName, const OperationMap<S>& map,
1965       std::map<std::string,
1966                std::pair<std::function<NodeDef(DataType)>, T (*)(T)>>& op_map,
1967       const std::vector<T> input_values, const std::string input_name = "input",
1968       float max_abs_error = 0.0001, bool nan_sensitive = true) {
1969     // Prepare test parameters.
1970     auto p = TestParamBase{
1971         {1, 1, 2, 3},  // input dims
1972         {},            // input partial dims
1973         {1, 1, 2, 3},  // expected output dims
1974     };
1975 
1976     // Get list of ops to test.
1977     std::vector<string> ops_to_test;
1978     for (auto& pair : map) {
1979       ops_to_test.push_back(pair.first);
1980     }
1981 
1982     for (const string& op_name : ops_to_test) {
1983       SCOPED_TRACE(op_name);
1984       if (!op_map.count(op_name)) {
1985         FAIL() << testName << " op test map does not contain op " << op_name;
1986       }
1987 
1988       const DataType tf_type = get_tf_type();
1989       const NodeDef& node_def = op_map[op_name].first(tf_type);
1990       runExpectedToFailTest(node_def, input_name, input_values, op_name);
1991 
1992       Status conv_status = Status::OK();
1993       if (trt_mode_ == TrtTestMode::kImplicitBatch &&
1994           (op_name == "Sign" || op_name == "Round" ||
1995            op_name == "LogicalNot")) {
1996         conv_status =
1997             errors::Unimplemented("Unary op: '", op_name,
1998                                   "' is not supported in implicit batch mode");
1999       }
2000 
2001       Reset();
2002       const DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type;
2003       const DataType output_tf_type = op_name == "Cast" ? DT_FLOAT : tf_type;
2004 
2005       AddTestTensor("input", p.input_dims, input_tf_type, input_values);
2006 
2007       std::vector<float> output;
2008       std::transform(input_values.begin(), input_values.end(),
2009                      std::back_inserter(output), op_map[op_name].second);
2010 
2011       TestOpConverter("my_unary", node_def, p.expected_output_dims, conv_status,
2012                       Status::OK(),
2013                       ArrayFloatNear(output, max_abs_error, nan_sensitive),
2014                       {output_tf_type});
2015     }
2016   }
runExpectedToFailTest(const NodeDef & node_def,const std::string & input_name,const std::vector<T> & input_values,const std::string & op_name)2017   void runExpectedToFailTest(const NodeDef& node_def,
2018                              const std::string& input_name,
2019                              const std::vector<T>& input_values,
2020                              const std::string& op_name) {
2021     // Input is weights, should fail.
2022     Reset();
2023     std::string error =
2024         "The input \"" + input_name + "\" for " + op_name + " must be a tensor";
2025     AddTestWeights("input", {1, 2, 3}, input_values, get_tf_type());
2026     RunValidationAndConversion(node_def, error::UNIMPLEMENTED, error);
2027 
2028     // Input has 0 dimensions, should fail.
2029     Reset();
2030     std::vector<int32> dims = {};
2031     if (trt_mode_ == TrtTestMode::kImplicitBatch) {
2032       dims = {1};
2033     }
2034     error = "At least 1 dimension is required for UNARY operation '" + op_name +
2035             "'";
2036     AddTestTensor("input", dims);
2037     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, error);
2038   }
2039 };
2040 
2041 template <typename T>
2042 class OpConverter_BinaryTest : public ParameterizedOpConverterTestBase {
2043  public:
2044   template <typename S>
RunTests(const OperationMap<S> & map,std::map<std::string,std::pair<std::function<NodeDef (DataType)>,std::vector<T>>> & op_test_info,const std::vector<std::vector<T>> & data)2045   void RunTests(
2046       const OperationMap<S>& map,
2047       std::map<std::string,
2048                std::pair<std::function<NodeDef(DataType)>, std::vector<T>>>&
2049           op_test_info,
2050       const std::vector<std::vector<T>>& data) {
2051     const std::vector<DataType> bool_types{DT_BOOL}, default_types{};
2052     std::vector<string> logical_ops{"Greater", "Less", "Equal"};
2053     std::vector<string> combined_ops{"GreaterEqual", "LessEqual"};
2054     const DataType tf_type = get_tf_type();
2055     AttrValue dtype;
2056     dtype.set_type(tf_type);
2057     std::map<std::string, NodeDef> nodes;
2058     for (const auto op_name : combined_ops) {
2059       nodes[op_name] = MakeNodeDef("my_binary", op_name, {"input1", "input2"},
2060                                    {{"T", dtype}});
2061     }
2062 
2063     for (auto& iter : map) {
2064       const string& op_name = iter.first;
2065       if (!op_test_info.count(op_name)) {
2066         FAIL() << "Binary op test map does not contain op " << op_name;
2067       }
2068       const auto comb_op = find_name(op_name, combined_ops);
2069       const auto& node_def =
2070           comb_op ? nodes[op_name] : op_test_info[op_name].first(tf_type);
2071 
2072       for (const bool operand_1_is_tensor : {true, false}) {
2073         for (const bool operand_2_is_tensor : {true, false}) {
2074           SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
2075                               operand_2_is_tensor ? "T" : "W"));
2076           Reset();
2077           if (!operand_1_is_tensor && !operand_2_is_tensor) {
2078             // In that case the only test which should be launched is in
2079             // runExpectedToFailTest
2080             runExpectedToFailTest(op_name, node_def);
2081             continue;
2082           }
2083 
2084           const bool logical_op = comb_op || find_name(op_name, logical_ops);
2085           auto conv_status = Status::OK();
2086           if (tf_type == DT_BOOL || logical_op) {
2087             if (trt_mode_ == TrtTestMode::kImplicitBatch) {
2088               conv_status = errors::Unimplemented(
2089                   "Binary op: '", op_name,
2090                   "' is not supported in implicit batch mode");
2091             } else if (!logical_op &&
2092                        (!operand_1_is_tensor || !operand_2_is_tensor)) {
2093               conv_status = errors::InvalidArgument(
2094                   "Both inputs  of '", op_name, "' are expected to be tensors");
2095             }
2096           }
2097 
2098           if (operand_1_is_tensor) {
2099             AddTestTensor("input1", {2, 1, 2}, data[0]);
2100           } else {
2101             AddTestWeights("input1", {1, 2}, data[1], tf_type);
2102           }
2103           if (operand_2_is_tensor) {
2104             AddTestTensor("input2", {2, 2, 1}, data[2]);
2105           } else {
2106             AddTestWeights("input2", {2, 1}, data[3], tf_type);
2107           }
2108 
2109           TestOpConverter("my_binary", node_def, {2, 2, 2}, conv_status,
2110                           Status::OK(),
2111                           ElementsAreArray(op_test_info[op_name].second),
2112                           logical_op ? bool_types : default_types);
2113         }
2114       }
2115     }
2116   }
2117 
runExpectedToFailTest(const std::string & op_name,const NodeDef & node)2118   void runExpectedToFailTest(const std::string& op_name, const NodeDef& node) {
2119     AddTestWeights("input1", {1}, {1}, tf_type_);
2120     AddTestWeights("input2", {1}, {1}, tf_type_);
2121     const string error =
2122         "Constant folding is falled back to TensorFlow, "
2123         "binary op '" +
2124         op_name + "' received both input as constant";
2125     RunValidationAndConversion(node, error::UNIMPLEMENTED, error);
2126   }
2127 };
2128 
2129 // Op converter test in FP32 mode. While for debugging purposes it might make
2130 // sense to run over all possible combinations, normally a subset of them
2131 // would be sufficient:
2132 // - All valid options to TrtTestMode (implicit, explicit, dynamic shape)
2133 // - DataType: is the TF data type of the input tensors. This usually only
2134 //   influences the data type added by Converter::AddInputTensor. We test the
2135 //   valid combinations of input data types in AddAndGetInputs, therefore
2136 //   for most of the OpConverterTest its is sufficient to test for DT_FLOAT.
2137 // - TrtPrecisionMode: valid options are FP32, FP16 and INT8. This influences
2138 //   how TRT handles the precision inside the TRT network, but should not matter
2139 //   for the TF -> TRT conversion. Therefore it should be sufficient to test
2140 //   for FP32.
2141 typedef ParameterizedOpConverterTestBase OpConverter_FP32_Test;
2142 // Base class for tests that need to be tested for both FP32 and FP16.
2143 typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_Test;
2144 // Base class for Binary tests that need to be tested
2145 typedef OpConverter_BinaryTest<float> OpConverter_FP32_FP16_BinaryTest;
2146 typedef OpConverter_BinaryTest<int> OpConverter_BOOL_BinaryTest;
2147 // Base class for tests that need to be tested for FP32, FP16, and INT32
2148 typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_INT32_Test;
2149 // Base class for tests that need to be tested for INT32
2150 typedef ParameterizedOpConverterTestBase OpConverter_INT32_Test;
2151 // Base class for Unary tests that need to be tested
2152 typedef OpConverter_UnaryTest<float> OpConverter_FP32_UnaryTest;
2153 typedef OpConverter_UnaryTest<int> OpConverter_BOOL_Test;
2154 
2155 // Instantiate parameter combinations to OpConverter_<DT_X...>_Test
2156 INSTANTIATE_TEST_CASE_P(
2157     OpConvTestInstantiation, OpConverter_FP32_Test,
2158     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2159                        ::testing::Values(DT_FLOAT),
2160                        ::testing::Values(TrtPrecisionMode::FP32)));
2161 
2162 INSTANTIATE_TEST_CASE_P(
2163     OpConvTestInstantiation, OpConverter_FP32_FP16_Test,
2164     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2165                        ::testing::Values(DT_FLOAT, DT_HALF),
2166                        ::testing::Values(TrtPrecisionMode::FP32)));
2167 
2168 INSTANTIATE_TEST_CASE_P(
2169     OpConvTestInstantiation, OpConverter_FP32_FP16_INT32_Test,
2170     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2171                        ::testing::Values(DT_FLOAT, DT_HALF, DT_INT32),
2172                        ::testing::Values(TrtPrecisionMode::FP32)));
2173 
2174 INSTANTIATE_TEST_CASE_P(
2175     OpConvTestInstantiation, OpConverter_INT32_Test,
2176     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2177                        ::testing::Values(DT_INT32),
2178                        ::testing::Values(TrtPrecisionMode::FP32)));
2179 
2180 INSTANTIATE_TEST_CASE_P(
2181     OpConvTestInstantiation, OpConverter_FP32_UnaryTest,
2182     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2183                        ::testing::Values(DT_FLOAT),
2184                        ::testing::Values(TrtPrecisionMode::FP32)));
2185 
2186 INSTANTIATE_TEST_CASE_P(
2187     OpConvTestInstantiation, OpConverter_BOOL_Test,
2188     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2189                        ::testing::Values(DT_BOOL),
2190                        ::testing::Values(TrtPrecisionMode::FP32)));
2191 
2192 INSTANTIATE_TEST_CASE_P(
2193     OpConvTestInstantiation, OpConverter_FP32_FP16_BinaryTest,
2194     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2195                        ::testing::Values(DT_FLOAT, DT_HALF),
2196                        ::testing::Values(TrtPrecisionMode::FP32)));
2197 
2198 INSTANTIATE_TEST_CASE_P(
2199     OpConvTestInstantiation, OpConverter_BOOL_BinaryTest,
2200     ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
2201                        ::testing::Values(DT_BOOL),
2202                        ::testing::Values(TrtPrecisionMode::FP32)));
2203 
2204 template <typename T>
CopyTensorElements(const Tensor & tensor,protobuf::RepeatedField<T> * out)2205 void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
2206   out->Clear();
2207   if (tensor.NumElements() == 0) return;
2208 
2209   // TensorProto does not need to have all the elements present and can truncate
2210   // trailing elements with the same value for compressed representation. Such
2211   // elements are derived based on the tensor shape.
2212   const auto flat = tensor.flat<T>();
2213   int64 last_index = 0;
2214   for (int64 i = 0; i < tensor.NumElements(); ++i) {
2215     if (flat(i) != flat(last_index)) {
2216       last_index = i;
2217     }
2218   }
2219 
2220   int num_out_elements = last_index + 1;
2221   out->Reserve(num_out_elements);
2222   out->AddNAlreadyReserved(num_out_elements);
2223   const T* src = flat.data();
2224   T* dst = out->mutable_data();
2225   std::copy(src, src + num_out_elements, dst);
2226 }
2227 
2228 template <DataType dtype, typename CType>
TestConvertVariableV2(VariableOpConverterTest * test)2229 void TestConvertVariableV2(VariableOpConverterTest* test) {
2230   struct TestParam {
2231     string container;
2232     string shared_name;
2233     std::vector<int> dims;
2234     float epsilon;
2235     Status conversion_status;
2236   };
2237 
2238   std::vector<TestParam> test_param = {
2239       {"", "var0", {}, 0.001, Status::OK()},
2240       {"", "var0", {64}, 0.001, Status::OK()},
2241       {"", "var0", {8, 16}, 0.001, Status::OK()},
2242       {"box", "var", {8, 16}, 0.001, Status::OK()}};
2243   for (auto p : test_param) {
2244     // Create node definition.
2245     NodeDef node_def;
2246     std::vector<int64_t> dims_64(p.dims.begin(), p.dims.end());
2247     TensorShape shape = TensorShape(absl::Span<int64_t>(dims_64));
2248     TF_CHECK_OK(NodeDefBuilder("my_var", "VariableV2")
2249                     .Attr("dtype", dtype)
2250                     .Attr("shape", shape)
2251                     .Attr("container", p.container)
2252                     .Attr("shared_name", p.shared_name)
2253                     .Finalize(&node_def));
2254 
2255     OpKernel* kernel;
2256     OpKernelContext* context;
2257     test->CreateContext(node_def, &kernel, &context);
2258 
2259     test->Reset(TrtPrecisionMode::FP32, TrtTestMode::kDynamicShape);
2260 
2261     // Set the value of the variable according to p.dims.
2262     int var_size = std::accumulate(p.dims.begin(), p.dims.end(), 1,
2263                                    std::multiplies<int>());
2264     std::vector<CType> expected_value;
2265     expected_value.reserve(var_size);
2266     for (int i = 0; i < var_size; i++) {
2267       expected_value.push_back((CType)i);
2268     }
2269 
2270     // To set the variable, we get the tensor by executing the VariableV2 op
2271     // rather than creating the resource directly in the manager, because:
2272     // 1) LegacyVar defined in `variable_ops.cc` is not accessible.
2273     // 2) Tensor::set_shape is private, VariableOp is a friend class.
2274     kernel->Compute(context);
2275     Tensor* tensor_ptr = context->mutable_output(0);
2276     CHECK_NOTNULL(tensor_ptr);
2277     // We allocate the tensor in the temporary memory. Note that creating a
2278     // tensor in this scope and sharing the underlying storage by copy would
2279     // lead to double destruction.
2280     AllocatorAttributes attr;
2281     attr.set_gpu_compatible(true);
2282     attr.set_nic_compatible(true);
2283     OP_REQUIRES_OK(context,
2284                    context->allocate_temp(dtype, shape, tensor_ptr, attr));
2285     // The tensor is allocated on GPU. We copy the values from the CPU.
2286     auto tensor_flat = tensor_ptr->flat<CType>();
2287     CHECK_NOTNULL(tensor_flat.data());
2288     auto ret = cudaMemcpy(tensor_flat.data(), expected_value.data(),
2289                           expected_value.size() * sizeof(CType),
2290                           cudaMemcpyHostToDevice);
2291     CHECK_EQ(ret, 0);
2292 
2293     test->RunValidationAndConversion(node_def);
2294     TRT_TensorOrWeights output;
2295     TF_EXPECT_OK(test->GetTensorOrWeights("my_var", &output));
2296     EXPECT_THAT(output.weights(),
2297                 ShapedWeightsHasDimsAndValues<CType>(p.dims, expected_value));
2298   }
2299 }
2300 
TEST_F(VariableOpConverterTest,ConvertVariableV2)2301 TEST_F(VariableOpConverterTest, ConvertVariableV2) {
2302   TestConvertVariableV2<DT_FLOAT, float>(this);
2303   TestConvertVariableV2<DT_HALF, Eigen::half>(this);
2304 }
2305 
2306 template <DataType dtype, typename CType>
TestConvertReadVariableOp(VariableOpConverterTest * test)2307 void TestConvertReadVariableOp(VariableOpConverterTest* test) {
2308   struct TestParam {
2309     string container;
2310     string name;
2311     std::vector<int> dims;
2312     float epsilon;
2313     Status conversion_status;
2314   };
2315 
2316   std::vector<TestParam> test_param = {
2317       {"", "var0", {}, 0.001, Status::OK()},
2318       {"", "var0", {64}, 0.001, Status::OK()},
2319       {"", "var0", {8, 16}, 0.001, Status::OK()},
2320       {"box", "var", {8, 16}, 0.001, Status::OK()}};
2321   for (auto p : test_param) {
2322     // Create node definition.
2323     NodeDefBuilder::NodeOut rvo_input =
2324         NodeDefBuilder::NodeOut("my_handle", 0, DT_RESOURCE);
2325     NodeDef node_def;
2326     std::vector<int64_t> dims_64(p.dims.begin(), p.dims.end());
2327     TensorShape shape =
2328         TensorShape(gtl::ArraySlice<int64_t>(dims_64));  // non-absl ok
2329     TF_CHECK_OK(NodeDefBuilder("my_var", "ReadVariableOp")
2330                     .Attr("dtype", dtype)
2331                     .Attr("_shape", shape)
2332                     .Input(rvo_input)
2333                     .Finalize(&node_def));
2334 
2335     OpKernel* kernel;
2336     OpKernelContext* context;
2337     test->CreateContext(node_def, &kernel, &context);
2338 
2339     test->Reset(TrtPrecisionMode::FP32, TrtTestMode::kDynamicShape);
2340 
2341     // Set the value of the variable according to p.dims.
2342     int var_size = std::accumulate(p.dims.begin(), p.dims.end(), 1,
2343                                    std::multiplies<int>());
2344     std::vector<CType> expected_value;
2345     expected_value.reserve(var_size);
2346     for (int i = 0; i < var_size; i++) {
2347       // Set expected_value[i] = (cast)i.
2348       expected_value.push_back((CType)i);
2349     }
2350 
2351     // Create a resource handle.
2352     DtypeAndPartialTensorShape dtype_and_shape;
2353     dtype_and_shape.dtype = dtype;
2354     TF_CHECK_OK(PartialTensorShape::BuildPartialTensorShape(
2355         gtl::ArraySlice<int64_t>(dims_64),  // non-absl ok
2356         &dtype_and_shape.shape));
2357     ResourceHandle handle = MakeResourceHandle<Var>(
2358         context, p.container, p.name,
2359         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape});
2360 
2361     // Create input resource with the handle.
2362     test->AddTestResource("my_handle", handle);
2363 
2364     // Create a resource with this handle.
2365     Var* resource = new Var(dtype);
2366     TF_EXPECT_OK(CreateResource(context, handle, resource));
2367 
2368     // Setup the tensor of the variable.
2369     // We allocate the tensor in the temporary memory. Note that creating a
2370     // tensor in this scope and sharing the underlying storage by copy would
2371     // lead to double destruction.
2372     AllocatorAttributes attr_value;
2373     attr_value.set_gpu_compatible(true);
2374     attr_value.set_nic_compatible(true);
2375     TF_EXPECT_OK(
2376         context->allocate_temp(dtype, shape, resource->tensor(), attr_value));
2377     // The tensor is allocated on GPU. We copy the values from the CPU.
2378     auto tensor_flat = resource->tensor()->flat<CType>();
2379     CHECK(tensor_flat.data());
2380     auto ret = cudaMemcpy(tensor_flat.data(), expected_value.data(),
2381                           expected_value.size() * sizeof(CType),
2382                           cudaMemcpyHostToDevice);
2383     CHECK_EQ(ret, 0);
2384 
2385     test->RunValidationAndConversion(node_def);
2386     TRT_TensorOrWeights output;
2387     TF_EXPECT_OK(test->GetTensorOrWeights("my_var", &output));
2388     EXPECT_THAT(output.weights(),
2389                 ShapedWeightsHasDimsAndValues<CType>(p.dims, expected_value));
2390   }
2391 }
2392 
TEST_F(VariableOpConverterTest,ConvertReadVariableOp)2393 TEST_F(VariableOpConverterTest, ConvertReadVariableOp) {
2394   TestConvertReadVariableOp<DT_FLOAT, float>(this);
2395   TestConvertReadVariableOp<DT_HALF, Eigen::half>(this);
2396 }
2397 
2398 template <DataType dtype, typename InputCType, typename OutputCType>
TestConvertConst(OpConverterTest * test)2399 void TestConvertConst(OpConverterTest* test) {
2400   NodeDef node_def;
2401   node_def.set_name("my_const");
2402   node_def.set_op("Const");
2403 
2404   auto reset_and_test = [&node_def, test](
2405                             const Tensor& tensor, const bool as_tensor_content,
2406                             const std::vector<int>& expected_dims,
2407                             const std::vector<OutputCType>& expected_value) {
2408     test->Reset();
2409 
2410     TensorProto* tensor_attr =
2411         (*node_def.mutable_attr())["value"].mutable_tensor();
2412     tensor_attr->Clear();
2413 
2414     if (as_tensor_content) {
2415       tensor.AsProtoTensorContent(tensor_attr);
2416     } else {
2417       tensor.shape().AsProto(tensor_attr->mutable_tensor_shape());
2418       tensor_attr->set_dtype(tensor.dtype());
2419 
2420       if (tensor.dtype() == DT_FLOAT) {
2421         CopyTensorElements<float>(tensor, tensor_attr->mutable_float_val());
2422       } else if (tensor.dtype() == DT_INT32) {
2423         CopyTensorElements<int32>(tensor, tensor_attr->mutable_int_val());
2424       } else {
2425         tensor.AsProtoField(tensor_attr);
2426       }
2427     }
2428     test->RunValidationAndConversion(node_def);
2429     TRT_TensorOrWeights output;
2430     TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output));
2431     EXPECT_THAT(output.weights(), ShapedWeightsHasDimsAndValues<OutputCType>(
2432                                       expected_dims, expected_value));
2433   };
2434 
2435   auto& attr = *node_def.mutable_attr();
2436   attr["dtype"].set_type(dtype);
2437   {
2438     // By default empty tensor will pick DT_FLOAT as data type and we fix it
2439     // here.
2440     Tensor t(dtype);  // Empty tensor.
2441     reset_and_test(t, false, {}, {});
2442   }
2443   {
2444     Tensor t = test::AsScalar<InputCType>(12);
2445     std::vector<int> expected_dims{1};
2446     // Scalars are represented as rank 0 tensors.
2447     expected_dims.clear();
2448     reset_and_test(t, false, expected_dims, {12});
2449     reset_and_test(t, true, expected_dims, {12});
2450   }
2451   {
2452     Tensor t = test->AsTensor<InputCType>({1, 2});
2453     reset_and_test(t, false, {2}, {1, 2});
2454     reset_and_test(t, true, {2}, {1, 2});
2455   }
2456   {
2457     Tensor t =
2458         test->AsTensor<InputCType>({1, 2, 3, 4, 5, 6}, TensorShape({2, 3}));
2459     reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6});
2460     reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6});
2461   }
2462   {
2463     // Set all tensor elements to the same value. Such tensors are encoded
2464     // using a single element list in tensor proto.
2465     Tensor t =
2466         test->AsTensor<InputCType>({1, 1, 1, 1, 1, 1}, TensorShape({2, 3}));
2467     reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1});
2468     reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1});
2469   }
2470   {
2471     // Set trailing tensor elements to the same value. Such tensors are
2472     // encoded by truncating all equal elements except the first one.
2473     Tensor t =
2474         test->AsTensor<InputCType>({2, 2, 1, 1, 1, 1}, TensorShape({2, 3}));
2475     reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1});
2476     reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1});
2477   }
2478 }
2479 
TEST_F(OpConverterTest,ConvertConst)2480 TEST_F(OpConverterTest, ConvertConst) {
2481   {
2482     Reset();
2483     NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
2484     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2485                                "Unsupported tensorflow data type double");
2486   }
2487   {
2488     Reset();
2489     Tensor tensor =
2490         AsTensor<int64_t>({1, std::numeric_limits<int64_t>::max(), 1, 1, 1,
2491                            std::numeric_limits<int64_t>::lowest()},
2492                           TensorShape({2, 3}));
2493     NodeDef node_def;
2494     node_def.set_name("my_const");
2495     node_def.set_op("Const");
2496     (*node_def.mutable_attr())["dtype"].set_type(DT_INT64);
2497     TensorProto* tensor_attr =
2498         (*node_def.mutable_attr())["value"].mutable_tensor();
2499     tensor_attr->Clear();
2500     tensor.AsProtoTensorContent(tensor_attr);
2501     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
2502                                "outside the range of int32");
2503   }
2504 
2505   TestConvertConst<DT_FLOAT, float, float>(this);
2506   TestConvertConst<DT_INT8, int8, int32>(this);
2507   TestConvertConst<DT_UINT8, uint8, int32>(this);
2508   TestConvertConst<DT_INT16, int16, int32>(this);
2509   TestConvertConst<DT_UINT16, uint16, int32>(this);
2510   TestConvertConst<DT_INT32, int32, int32>(this);
2511   TestConvertConst<DT_UINT32, uint32, int32>(this);
2512   TestConvertConst<DT_INT64, int64, int32>(this);
2513   TestConvertConst<DT_UINT64, uint64, int32>(this);
2514 }
2515 
2516 template <typename T>
CreateFusedBatchNormOp(DataType tf_type,std::string data_format,bool is_training,float epsilon)2517 NodeDef CreateFusedBatchNormOp(DataType tf_type, std::string data_format,
2518                                bool is_training, float epsilon) {
2519   Scope s = Scope::NewRootScope();
2520   auto x = ops::Placeholder(s.WithOpName("x"), tf_type);
2521   auto scale = ops::Placeholder(s.WithOpName("scale"), tf_type);
2522   auto offset = ops::Placeholder(s.WithOpName("offset"), tf_type);
2523   auto mean = ops::Placeholder(s.WithOpName("mean"), tf_type);
2524   auto variance = ops::Placeholder(s.WithOpName("variance"), tf_type);
2525   typename T::Attrs attrs;
2526   attrs.data_format_ = data_format;
2527   attrs.is_training_ = is_training;
2528   if (epsilon > 0) {
2529     attrs.epsilon_ = epsilon;
2530   } else {
2531     EXPECT_GE(epsilon, 0);
2532   }
2533   return T(s.WithOpName("my_batchnorm"), x, scale, offset, mean, variance,
2534            attrs)
2535       .operation.node()
2536       ->def();
2537 }
2538 
TEST_P(OpConverter_FP32_Test,ConvertFusedBatchNorm)2539 TEST_P(OpConverter_FP32_Test, ConvertFusedBatchNorm) {
2540   using OpFunc = std::function<NodeDef(DataType, std::string, bool, float)>;
2541   std::vector<OpFunc> get_node_def_vec{
2542       CreateFusedBatchNormOp<ops::FusedBatchNorm>,
2543       CreateFusedBatchNormOp<ops::FusedBatchNormV2>,
2544       CreateFusedBatchNormOp<ops::FusedBatchNormV3>};
2545 
2546   struct TestParam {
2547     std::string data_format;
2548     int tensor_input_idx;  // Index of an input that will be provided as tensor.
2549     bool is_training;
2550     float epsilon;
2551     Status conversion_status;
2552     bool keep_channel_unknown;
2553   };
2554 
2555   struct NodeInput {
2556     std::string name;
2557     std::vector<int> dims;
2558     std::vector<float> val;
2559   };
2560   std::vector<NodeInput> node_input_nchw{
2561       {"x", {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}},
2562       {"scale", {3}, {7, 8, 9}},
2563       {"offset", {3}, {10, 20, 30}},
2564       {"mean", {3}, {1, 2, 3}},
2565       {"variance", {3}, {4, 5, 6}}};
2566 
2567   std::vector<NodeInput> node_input_nhwc{
2568       {"x", {2, 2, 1, 3}, {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}},
2569       {"scale", {3}, {7, 8, 9}},
2570       {"offset", {3}, {10, 20, 30}},
2571       {"mean", {3}, {1, 2, 3}},
2572       {"variance", {3}, {4, 5, 6}}};
2573 
2574   std::vector<float> expected_output_nchw{
2575       10.0,    13.495633, 23.574135, 27.148273, 37.342354, 41.013527,
2576       30.9738, 34.469433, 45.018955, 48.59309,  59.369415, 63.04059};
2577 
2578   std::vector<float> expected_output_nhwc{
2579       10.0,    23.574135, 37.342354, 13.495633, 27.148273, 41.013527,
2580       30.9738, 45.018955, 59.369415, 34.469433, 48.59309,  63.04059};
2581 
2582   for (auto get_node_def : get_node_def_vec) {
2583     NodeDef tmp_node_def = get_node_def(tf_type_, "NCHW", true, 0);
2584     std::string op_name = tmp_node_def.op();
2585     std::vector<TestParam> test_param{
2586         {"NCHW", 0, true, 0,
2587          errors::Unimplemented(
2588              StrCat(op_name, " only supports is_training=false"))},
2589         {"NCHW", 1, false, 0,
2590          errors::Unimplemented(StrCat("The input \"scale\" for ", op_name,
2591                                       " must be a constant"))},
2592         {"NCHW", 2, false, 0,
2593          errors::Unimplemented(StrCat("The input \"offset\" for ", op_name,
2594                                       " must be a constant"))},
2595         {"NCHW", 3, false, 0,
2596          errors::Unimplemented(StrCat("The input \"mean\" for ", op_name,
2597                                       " must be a constant"))},
2598         {"NCHW", 4, false, 0,
2599          errors::Unimplemented(StrCat("The input \"variance\" for ", op_name,
2600                                       " must be a constant"))},
2601         {"NCHW", 0, false, 0.01},
2602         {"NHWC", 0, false, 0.01}};
2603     if (trt_mode_ == TrtTestMode::kDynamicShape) {
2604       test_param.push_back(
2605           {"NCHW", 0, false, 0.01,
2606            errors::InvalidArgument("Channel dimension must be static"), true});
2607       test_param.push_back(
2608           {"NHWC", 0, false, 0.01,
2609            errors::InvalidArgument("Channel dimension must be static"), true});
2610     }
2611     for (auto p : test_param) {
2612       Reset();
2613       NodeDef node_def =
2614           get_node_def(tf_type_, p.data_format, p.is_training, p.epsilon);
2615       std::vector<NodeInput> node_input =
2616           p.data_format == "NCHW" ? node_input_nchw : node_input_nhwc;
2617       std::vector<float> expected_output =
2618           p.data_format == "NCHW" ? expected_output_nchw : expected_output_nhwc;
2619       for (int i = 0; i < node_input.size(); i++) {
2620         if (i == 0 || i == p.tensor_input_idx) {
2621           // The first input (x) is always added as a tensor, and it has shape
2622           // NCHW/NHWC. The other inputs are per channel values (1D, size C).
2623           //
2624           // In implicit batch mode, it is not possible to add any of the 1D
2625           // inputs as a tensor: the first dim is always treated as batch dim in
2626           // implicit batch mode, and that has to agree for all tensors. We have
2627           // two input tensors with shapes NCHW and C and in general N != C.
2628           // The converter already picked up N from the fist input, and reports
2629           // an error when we try to add any other tensors with not matching
2630           // first dim.
2631           //
2632           // This restriction does not apply in explicit batch mode: the tensors
2633           // can have different first dim. The converter still expects that only
2634           // the first arg is a tensor. TODO(tfeher) Check if one can relax this
2635           // restriction.
2636           Status expected_status =
2637               (i != 0 && trt_mode_ == TrtTestMode::kImplicitBatch)
2638                   ? errors::InvalidArgument(
2639                         StrCat("Batch size doesn't match for tensor ",
2640                                node_input[i].name,
2641                                ": Provided batch size does not match "
2642                                "converter batch size: 3 vs 2"))
2643                   : Status::OK();
2644           std::vector<int> partial_input_shape;
2645           if (i == 0 && trt_mode_ == TrtTestMode::kDynamicShape &&
2646               !p.keep_channel_unknown) {
2647             // keep channel dim static (known)
2648             partial_input_shape.resize(4, -1);
2649             int channel_dim = (p.data_format == "NCHW" ? 1 : 3);
2650             partial_input_shape[channel_dim] = node_input[i].dims[channel_dim];
2651           }
2652           AddTestTensor(node_input[i].name, node_input[i].dims, tf_type_,
2653                         node_input[i].val, partial_input_shape,
2654                         expected_status);
2655 
2656         } else {
2657           AddTestWeights(node_input[i].name, node_input[i].dims,
2658                          node_input[i].val, tf_type_);
2659         }
2660       }
2661       TestOpConverter("my_batchnorm", node_def, node_input[0].dims,
2662                       p.conversion_status, Status::OK(),
2663                       ArrayFloatNear(expected_output));
2664     }
2665   }
2666 }
2667 
TEST_P(OpConverter_FP32_Test,ConvertTranspose)2668 TEST_P(OpConverter_FP32_Test, ConvertTranspose) {
2669   // Get the NodeDef for Transpose.
2670   Scope s = Scope::NewRootScope();
2671   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
2672   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
2673   auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
2674   const NodeDef& node_def = transpose.operation.node()->def();
2675 
2676   std::vector<TestParamBase> test_params = {
2677       // For the first test we leave param empty. This signals to use a
2678       // input as weight which will be invalid
2679       TestParamBase{{3, 1, 2, 1},
2680                     {},
2681                     {},
2682                     {},
2683                     Status(error::UNIMPLEMENTED,
2684                            "The input \"perm\" for Transpose must be a "
2685                            "constant")},
2686       TestParamBase{{1, 1, 2, 3},
2687                     {},
2688                     {},
2689                     {0, 1, 2},
2690                     Status(error::INVALID_ARGUMENT,
2691                            "Rank of perm for transpose does not match with "
2692                            "that of the input.")},
2693       // Transpose batch dim
2694       TestParamBase{
2695           {1, 1, 2, 3},
2696           {},
2697           {3, 2, 1, 1},
2698           {3, 2, 1, 0},
2699           (trt_mode_ == TrtTestMode::kImplicitBatch)
2700               ? Status(error::UNIMPLEMENTED,
2701                        "Transpose at batch dimension is not supported")
2702               : Status::OK()},
2703       TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}},
2704   };
2705   if (trt_mode_ == TrtTestMode::kDynamicShape) {
2706     // Dynamic shape tests where some shapes are known
2707     test_params.push_back(TestParamBase{
2708         {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}});
2709   }
2710   std::vector<float> expected_values{1, 4, 2, 5, 3, 6};
2711   for (auto p : test_params) {
2712     SCOPED_TRACE(p);
2713     Reset();
2714     AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6},
2715                   p.partial_input_dims);
2716     if (p.param.empty()) {
2717       AddTestTensor("weights", {3});
2718     } else {
2719       AddTestWeights<int32>("weights", {static_cast<int>(p.param.size())},
2720                             p.param);
2721     }
2722     TestOpConverter("my_transpose", node_def, p.expected_output_dims, p.status,
2723                     p.runtime_status, ElementsAreArray(expected_values));
2724   }
2725 }
2726 
TEST_P(OpConverter_FP32_Test,ConvertTile)2727 TEST_P(OpConverter_FP32_Test, ConvertTile) {
2728   Scope s = Scope::NewRootScope();
2729   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
2730   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
2731   auto tile = ops::Tile(s.WithOpName("my_tile"), input, weights);
2732   const NodeDef& node_def = tile.operation.node()->def();
2733 
2734   struct TileParam {
2735     std::vector<int> input_dims;
2736     std::vector<int> multiplier;
2737     std::vector<float> tensor;
2738     // Concrete (static) output dimensions, including batch size as first dim.
2739     std::vector<int> expected_output_dims;
2740     std::vector<int> expected_results;
2741     int test_ID;
2742     // Expected status of conversion (with concrete error message).
2743     Status status;
2744   };
2745 
2746   std::vector<TileParam> test_params = {
2747       // Tests to be rejected by ConvertTile::Validate() for any trt_mode_.
2748       TileParam{{1, 2, 3},   // input_dims
2749                 {1, -2, 1},  // multiplier
2750                 {},          // tensor
2751                 {},          // expected_output_dims
2752                 {},          // expected_results
2753                 1,           // test_ID
2754                 Status(error::INVALID_ARGUMENT,
2755                        "All replications of the Tile operation in "
2756                        "'my_tile' should be positive, got (1, -2, 1).")},
2757       TileParam{{1, 2, 3},           // input_dims
2758                 {1, 2, 1, 3},        // multiplier
2759                 {0, 1, 2, 3, 4, 5},  // tensor
2760                 {},                  // expected_output_dims
2761                 {},                  // expected_results
2762                 2,                   // test_ID
2763                 Status(error::INVALID_ARGUMENT,
2764                        "The length of the replication vector (4) of the "
2765                        "Tile operation in 'my_tile' is expected to be equal "
2766                        "to the rank of the input vector (3).")},
2767       // Tests passed ConvertTile::Validate() for at least some trt_mode_.
2768       TileParam{{1, 2},                                 // input_dims
2769                 {1, 3},                                 // multiplier
2770                 {2, 3},                                 // tensor
2771                 {1, 6},                                 // expected_output_dims
2772                 {2, 3, 2, 3, 2, 3}},                    // out values
2773       TileParam{{1, 2, 3},                              // input_dims
2774                 {1, 2, 1},                              // multiplier
2775                 {0, 1, 2, 3, 4, 5},                     // tensor
2776                 {1, 4, 3},                              // output dims
2777                 {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}},  // expected_results
2778       TileParam{{1, 2, 3},                              // input_dims
2779                 {1, 1, 2},                              // multiplier
2780                 {0, 1, 2, 3, 4, 5},                     // tensor
2781                 {1, 2, 6},                              // expected_output_dims
2782                 {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5}},  // expected_results
2783       TileParam{{1, 2, 3},                              // input_dims
2784                 {1, 2, 2},                              // multiplier
2785                 {0, 1, 2, 3, 4, 5},                     // tensor
2786                 {1, 4, 6},                              // expected_output_dims
2787                 {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
2788                  0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5}},  // expected_results
2789       // Tests with non trivial batch size multiplier.
2790       TileParam{{1, 2},                                 // input_dims
2791                 {2, 3},                                 // multiplier
2792                 {2, 3},                                 // tensor
2793                 {2, 6},                                 // expected_output_dims
2794                 {2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3}},  // out values
2795       TileParam{{1, 2, 3},                              // input_dims
2796                 {2, 2, 1},                              // multiplier
2797                 {0, 1, 2, 3, 4, 5},                     // tensor
2798                 {2, 4, 3},                              // output dims
2799                 {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
2800                  0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}},  // expected_results
2801   };
2802 
2803   for (bool multiplier_is_tensor : {true, false}) {
2804     for (bool input_is_tensor : {true, false}) {
2805       for (auto p : test_params) {
2806         std::vector<int> num_mults = {static_cast<int>(p.multiplier.size())};
2807         std::vector<int> partial_input_dims = {};
2808         if (multiplier_is_tensor) {
2809           if (trt_mode_ == TrtTestMode::kImplicitBatch) {
2810             p.status =
2811                 Status(error::INVALID_ARGUMENT,
2812                        "Conversion for Tile is not implemented for multipliers "
2813                        "passed as a tensor in implicit batch mode");
2814             num_mults = {1, static_cast<int>(p.multiplier.size())};
2815           } else {
2816             if (p.test_ID == 1) {
2817               // Skip this test because in that situation it is impossible
2818               // to do a valid check for negative multipliers.
2819               continue;
2820             }
2821 
2822             if (trt_mode_ == TrtTestMode::kDynamicShape) {
2823               partial_input_dims = num_mults;
2824               p.status = Status::OK();
2825             }
2826 
2827             if (p.test_ID == 2) {
2828               p.status = Status(error::INVALID_ARGUMENT,
2829                                 "When replications are defined as a tensor, "
2830                                 "the number of its elements (4) must be equal "
2831                                 "to the rank of the input tensor (3).");
2832             }
2833           }
2834         } else {
2835           if (trt_mode_ == TrtTestMode::kImplicitBatch && p.multiplier[0] > 1) {
2836             p.status =
2837                 Status(error::UNIMPLEMENTED,
2838                        "The Tile operation along "
2839                        "the batch dimension in 'my_tile' is not implemented.");
2840           }
2841         }
2842 
2843         Reset();
2844         if (input_is_tensor) {
2845           AddTestTensor("input", p.input_dims, p.tensor);
2846         } else {
2847           AddTestWeights("input", p.input_dims, p.tensor, tf_type_);
2848         }
2849 
2850         if (multiplier_is_tensor) {
2851           AddTestTensor<int>("weights", num_mults, DT_INT32, p.multiplier,
2852                              partial_input_dims);
2853         } else {
2854           AddTestWeights<int32>("weights", num_mults, p.multiplier);
2855         }
2856 
2857         TestOpConverter("my_tile", node_def, p.expected_output_dims, p.status,
2858                         Status::OK(), ElementsAreArray(p.expected_results));
2859       }
2860     }
2861   }
2862 }
2863 
TEST_P(OpConverter_FP32_Test,ConvertReshape)2864 TEST_P(OpConverter_FP32_Test, ConvertReshape) {
2865   // Get the NodeDef for Reshape.
2866   Scope s = Scope::NewRootScope();
2867   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
2868   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
2869   auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights);
2870   const NodeDef& node_def = reshape.operation.node()->def();
2871 
2872   if (trt_mode_ == TrtTestMode::kImplicitBatch) {
2873     // Shape is a tensor, should fail in implicit batch mode.
2874     Reset();
2875     AddTestTensor("input", {3, 2, 1});
2876     AddTestTensor("weights", {3});
2877     RunValidationAndConversion(
2878         node_def, error::INVALID_ARGUMENT,
2879         "The input \"shape\" for Reshape must be a constant in implicit batch "
2880         "mode");
2881   } else if (!IS_TRT_VERSION_GE(7, 1, 3, 0)) {
2882     // Shape is a tensor, should fail before TRT 7.1.3 even in explicit batch /
2883     // dynamic shape mode.
2884     Reset();
2885     AddTestTensor("input", {3, 2, 1});
2886     AddTestTensor("weights", {3});
2887     RunValidationAndConversion(
2888         node_def, error::INVALID_ARGUMENT,
2889         "Non constant shape input tensor for Reshape requires minimum TRT "
2890         "7.1.3");
2891   }
2892 
2893   Status reshape_from_scalar_status =
2894       trt_mode_ == TrtTestMode::kImplicitBatch
2895           ? errors::Internal(
2896                 "Failed to convert at least one input to a TRT_TensorOrWeights:"
2897                 " Scalar input tensor is not supported since the first "
2898                 "dimension is treated as batch dimension by TRT")
2899           : Status::OK();
2900   Status add_scalar_tensor_status =
2901       trt_mode_ == TrtTestMode::kImplicitBatch
2902           ? errors::InvalidArgument(
2903                 "removing first dim requires explicit batch dimension")
2904           : Status::OK();
2905   Status reshape_to_scalar_status =
2906       trt_mode_ == TrtTestMode::kImplicitBatch
2907           ? errors::Unimplemented("Reshape to shape=[] is not supported")
2908           : Status::OK();
2909   Status reshape_batch_status =
2910       trt_mode_ == TrtTestMode::kImplicitBatch
2911           ? errors::Unimplemented("Reshape on batch dimension is not supported")
2912           : Status::OK();
2913 
2914   struct TestParams {
2915     std::vector<int> tensor_dims;
2916     std::vector<int> shape;
2917     std::vector<int> expected_shape;
2918     Status conversion_status;
2919     Status runtime_status;
2920     std::vector<int> shape_prof;  // needed concrete values if shape == -1.
2921     Status add_test_tensor_status;
2922   };
2923 
2924   std::vector<TestParams> params = {
2925       // Reshape scalar to tensor, should fail in implicit batch mode.
2926       TestParams{{},
2927                  {1, 1},
2928                  {},
2929                  reshape_from_scalar_status,
2930                  {},
2931                  {},
2932                  add_scalar_tensor_status},
2933       // Reshape tensor to scalar, should fail in implicit batch mode.
2934       // - In explicit batch mode if shape is set as weight it works.
2935       // - In explicit batch mode && using shape as tensor input it should
2936       //   fail. In that case we set the expected conversion status in the
2937       //   test loop.
2938       TestParams{{1, 1}, {}, {}, reshape_to_scalar_status},
2939       // Reshape at batch dimension, should fail in implicit batch mode.
2940       TestParams{{1, 1, 2, 3}, {3, 1, 1, 2}, {}, reshape_batch_status},
2941       TestParams{{2, 1, 2, 3}, {-1, 1, 4}, {3, 1, 4}, reshape_batch_status},
2942       // Tests that should succeed in every trt_mode.
2943       TestParams{{1, 1, 2, 3}, {-1, 1, 3, 2}, {1, 1, 3, 2}},
2944       TestParams{{1, 1, 2, 3}, {1, 1, -1}, {1, 1, 6}},
2945       TestParams{{1, 1, 2, 3}, {1, 1, 3, 2}},
2946       TestParams{{2, 1, 2, 3}, {2, 1, 3, 2}},
2947       TestParams{{1, 1, 1}, {1}},
2948       TestParams{{1}, {1, 1}},
2949       TestParams{{2, 1, 1}, {2}},
2950       TestParams{{2}, {2, 1}},
2951   };
2952   if (trt_mode_ == TrtTestMode::kImplicitBatch) {
2953     // Reshape tensor with zero rank using an empty shape tensor, should fail in
2954     // implicit batch mode. In explicit batch mode this is an identity operation
2955     // and does not add a reshape layer therefore we do not test it.
2956     params.push_back(TestParams{{},
2957                                 {},
2958                                 {},
2959                                 reshape_from_scalar_status,
2960                                 {},
2961                                 {},
2962                                 add_scalar_tensor_status});
2963   }
2964   // Testing the methods for representing the reshape shape for IShuffleLayer:
2965   // as a weight (true) or as a tensor (false).
2966   std::vector<bool> shape_input_options(1, true);
2967 
2968   if (trt_mode_ != TrtTestMode::kImplicitBatch &&
2969       IS_TRT_VERSION_GE(7, 1, 3, 0)) {
2970     shape_input_options.push_back(false);
2971   }
2972 
2973   for (auto p : params) {
2974     for (auto shape_as_weight : shape_input_options) {
2975       std::ostringstream oss;
2976       oss << "shape " << PrintToString(p.shape);
2977       SCOPED_TRACE(StrCat(oss.str(), shape_as_weight ? " weight" : " tensor"));
2978       if (!shape_as_weight && p.shape.empty()) {
2979         p.conversion_status = errors::Unimplemented(
2980             "Reshape with dynamic input requires 1D input tensor");
2981       }
2982       Reset();
2983       const int n_elements =
2984           std::accumulate(p.tensor_dims.begin(), p.tensor_dims.end(), 1,
2985                           std::multiplies<int>());
2986       std::vector<float> input_vec(n_elements);
2987       std::iota(input_vec.begin(), input_vec.end(), 1);
2988       AddTestTensor("input", p.tensor_dims, tf_type_, input_vec, {},
2989                     p.add_test_tensor_status);
2990       if (shape_as_weight) {
2991         AddTestWeights<int32>("weights", {static_cast<int>(p.shape.size())},
2992                               p.shape);
2993       } else {
2994         std::vector<int32> dims;
2995         std::vector<int32> values{p.shape};
2996         if (!p.shape.empty()) {
2997           dims.push_back(p.shape.size());
2998         } else {
2999           // If the shape is empty we use a dummy value to ensure that
3000           // AddTestTensor creates the corresponding entry in InputOutputData.
3001           values.push_back(1);
3002         }
3003         AddTestTensor("weights", dims, DT_INT32, values, dims);
3004       }
3005       std::vector<int> expected_shape =
3006           p.expected_shape.empty() ? p.shape : p.expected_shape;
3007       VLOG(2) << "Calling TestOpConverter";
3008       TestOpConverter("my_reshape", node_def, expected_shape,
3009                       p.conversion_status, p.runtime_status,
3010                       ElementsAreArray(input_vec));
3011     }
3012   }
3013 }
3014 
TEST_P(OpConverter_FP32_Test,ConvertShape)3015 TEST_P(OpConverter_FP32_Test, ConvertShape) {
3016   // Get the NodeDef for Shape op.
3017   Scope s = Scope::NewRootScope();
3018   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
3019   auto shape = ops::Shape(s.WithOpName("my_shape"), input);
3020   const NodeDef& node_def = shape.operation.node()->def();
3021 
3022   Status conversion_status =
3023       (trt_mode_ == TrtTestMode::kImplicitBatch)
3024           ? errors::Unimplemented(
3025                 "Shape is only supported for explicit batch mode.")
3026           : Status::OK();
3027   std::vector<TestParamBase> test_params = {
3028 // TODO(b/166274212): Enable the test parameter for TensorRT 7.1.3.
3029 #if !IS_TRT_VERSION_GE(7, 1, 3, 0)
3030     TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status},
3031 #endif
3032     // Add input as weight (we use non empty param ({1}) to trigger this).
3033     TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status},
3034   };
3035 
3036   auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); };
3037   for (auto p : test_params) {
3038     SCOPED_TRACE(p);
3039     Reset();
3040     // The number of elements of the input tensor. We leave it 0 in case we do
3041     // not need to add an input tensor. This happens in explicit batch mode: the
3042     // shape is known at conversion time and therefore the shape is added to the
3043     // network as a constant layer. In this case the single node network that
3044     // we use for the unit test have no actual input tensor when it is converted
3045     // to a TensorRT network.
3046     int n_elements = 0;
3047     if (input_is_weight(p) || trt_mode_ != TrtTestMode::kExplicitBatch) {
3048       // Calculate the number of elements for adding input data.
3049       n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1,
3050                                    std::multiplies<int>());
3051     }
3052     std::vector<float> input_val(n_elements, 1);
3053     if (!input_is_weight(p)) {
3054       AddTestTensor("input", p.input_dims, input_val);
3055     } else {
3056       AddTestWeights("input", p.input_dims, input_val, tf_type_);
3057     }
3058     TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status,
3059                     p.runtime_status, ElementsAreArray(p.input_dims),
3060                     {DT_INT32});
3061   }
3062 }
3063 
3064 struct MatMulTestParams {
3065   std::vector<int> shape_a;
3066   std::vector<int> values_a;
3067   bool transpose_a;
3068   std::vector<int> shape_b;
3069   std::vector<int> values_b;
3070   bool transpose_b;
3071   std::vector<int> expected_shape;
3072   std::vector<int> expected_output;
3073 };
3074 
3075 // Helper function for testing MatMul and BatchMatMul. get_matmul is a function
3076 // used to generate the node. It accepts (DataType, transpose_a, transpose_b) as
3077 // parameters.
TestMatMulHelper(ParameterizedOpConverterTestBase * test,const std::function<NodeDef (DataType,bool,bool)> & get_matmul,const std::vector<MatMulTestParams> & params)3078 void TestMatMulHelper(
3079     ParameterizedOpConverterTestBase* test,
3080     const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
3081     const std::vector<MatMulTestParams>& params) {
3082   {
3083     // Unsupported data type.
3084     test->Reset();
3085     NodeDef node_def = get_matmul(DT_INT32, false, false);
3086     test->AddTestTensor("input", {1, 2}, DT_INT32, {});
3087     test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
3088     test->RunValidationAndConversion(
3089         node_def, error::UNIMPLEMENTED,
3090         StrCat("Data type int32 is not supported for ", node_def.op(),
3091                ", must be one of [float, half]"));
3092   }
3093 
3094   // FC conversion depends on whether the last dim of A is known or not. In
3095   // Dynamic shape mode, we will check whether A is handled correctly if it has
3096   // a partially known input shape (last dim known).
3097   std::vector<bool> a_test_partial_shape_values{false};
3098   if (test->get_trt_mode() == TrtTestMode::kDynamicShape) {
3099     a_test_partial_shape_values.push_back(true);
3100   }
3101 
3102   for (auto p : params) {
3103     for (bool a_is_tensor : {true, false}) {
3104       for (bool b_is_tensor : {true, false}) {
3105         for (bool a_partial_shape : a_test_partial_shape_values) {
3106           if (a_partial_shape && !a_is_tensor) {
3107             // Only tensors can have partial shape.
3108             continue;
3109           }
3110           if (!a_is_tensor && !b_is_tensor) {
3111             // Skip test when both args are weights. We do not convert this
3112             // since const folding eliminates this case.
3113             continue;
3114           }
3115           SCOPED_TRACE(StrCat("A", p.transpose_a ? ".T" : "", " is ",
3116                               a_is_tensor ? "tensor" : "weight", ", B",
3117                               p.transpose_b ? ".T" : "", " is ",
3118                               b_is_tensor ? "tensor " : "weight, rank A ",
3119                               p.shape_a.size(), ", rank B ", p.shape_b.size()));
3120           test->Reset();
3121 
3122           NodeDef node_def =
3123               get_matmul(test->get_tf_type(), p.transpose_a, p.transpose_b);
3124           const bool is_batch_matmul = node_def.op() == "BatchMatMul";
3125 
3126           if (a_is_tensor) {
3127             if (a_partial_shape) {
3128               // Prepare a partial shape for A where only the last dim is known.
3129               std::vector<int> partial_shape(p.shape_a.size(), -1);
3130               int k = p.shape_a.size() - 1;
3131               partial_shape.at(k) = p.shape_a.at(k);
3132               test->AddTestTensor("input", p.shape_a, test->get_tf_type(),
3133                                   p.values_a, partial_shape);
3134             } else {
3135               test->AddTestTensor("input", p.shape_a, p.values_a);
3136             }
3137           } else {
3138             test->AddTestWeights("input", p.shape_a, p.values_a,
3139                                  test->get_tf_type());
3140           }
3141           if (b_is_tensor) {
3142             if (a_is_tensor && p.shape_a[0] != p.shape_b[0] &&
3143                 test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
3144               VLOG(2) << "Skipping test with inpcompatible batch dimensions";
3145               continue;
3146             }
3147             test->AddTestTensor("weights", p.shape_b, p.values_b);
3148           } else {
3149             test->AddTestWeights("weights", p.shape_b, p.values_b,
3150                                  test->get_tf_type());
3151           }
3152 
3153           Status conversion_status = Status::OK();
3154           if (test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
3155             // Implicit batch mode has several restriction. We change conversion
3156             // status accordingly.
3157             if (is_batch_matmul) {
3158               if (a_is_tensor && p.shape_a.size() < p.shape_b.size()) {
3159                 conversion_status = errors::InvalidArgument(
3160                     "Broadcasting beyond batch dimension is not supported "
3161                     "(tensor #dims ",
3162                     p.shape_a.size(), " vs broadcast #dims ", p.shape_b.size(),
3163                     ")");
3164               }
3165               if (b_is_tensor && p.shape_b.size() < p.shape_a.size()) {
3166                 conversion_status = errors::InvalidArgument(
3167                     "Broadcasting beyond batch dimension is not supported "
3168                     "(tensor #dims ",
3169                     p.shape_b.size(), " vs broadcast #dims ", p.shape_a.size(),
3170                     ")");
3171               }
3172               if ((!a_is_tensor || !b_is_tensor) && p.shape_a[0] != 1) {
3173                 conversion_status = errors::Unimplemented(
3174                     "TensorRT does not support batched constants in implicit "
3175                     "batch mode.");
3176               }
3177             } else if ((a_is_tensor && p.shape_a.size() <= 2 &&
3178                         (p.transpose_a || b_is_tensor)) ||
3179                        (b_is_tensor && p.shape_b.size() <= 2)) {
3180               conversion_status = errors::InvalidArgument(
3181                   "MatMul with 2D tensors requires explicit batch mode, or that"
3182                   " tensor A is not transposed and B is a constant tensor.");
3183             }
3184           }
3185 
3186           test->TestOpConverter("my_matmul", node_def, p.expected_shape,
3187                                 conversion_status, Status::OK(),
3188                                 ElementsAreArray(p.expected_output));
3189           if (!conversion_status.ok()) {
3190             VLOG(2) << "Converted with status " << conversion_status;
3191           }
3192           VLOG(2) << "== Finished test iteration ==";
3193         }
3194       }
3195     }
3196   }
3197 }
3198 
3199 template <typename LayerType>
CheckAddedLayers(OpConverterTest * test,bool expect_found)3200 void CheckAddedLayers(OpConverterTest* test, bool expect_found) {
3201   bool layer_found = false;
3202   for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) {
3203     nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i);
3204     if (dynamic_cast<LayerType*>(layer)) {
3205       layer_found = true;
3206     }
3207   }
3208   EXPECT_EQ(expect_found, layer_found);
3209 }
3210 
GetMatMulTestParams()3211 std::vector<MatMulTestParams> GetMatMulTestParams() {
3212   std::vector<MatMulTestParams> params{
3213       // clang-format off
3214       MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false,  // A (shape, val, T?)
3215                        {2, 2}, {0, 1, 2, 3}, false,  // B (shape, val, T?)
3216                        {2, 2}, {2, 3, 6, 11}},       // result (shape, val)
3217       MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false,
3218                        {2, 2}, {0, 1, 2, 3},  true,
3219                        {2, 2}, {1, 3, 3, 13}},
3220       MatMulTestParams{{2, 2}, {0, 1, 2, 3},  true,
3221                        {2, 2}, {0, 1, 2, 3}, false,
3222                        {2, 2}, {4, 6, 6, 10}},
3223       MatMulTestParams{{2, 2}, {0, 1, 2, 3}, true,
3224                        {2, 2}, {0, 1, 2, 3}, true,
3225                        {2, 2}, {2, 6, 3, 11}},
3226       MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, false,
3227                        {2, 3}, {1, 2, 3, 4, 5, 6}, true,
3228                        {2, 2}, {8, 17, 26, 62}},
3229       MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, true,
3230                        {2, 3}, {1, 2, 3, 4, 5, 6}, false,
3231                        {3, 3}, {12, 15, 18, 17, 22, 27, 22, 29, 36}},
3232       MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, false,
3233                        {2, 3}, {1, 2, 3, 4, 5, 6}, false,
3234                        {3, 3}, {4, 5, 6, 14, 19, 24, 24, 33, 42}},
3235       MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, true,
3236                        {2, 3}, {1, 2, 3, 4, 5, 6}, true,
3237                        {2, 2}, {16, 34, 22, 49}},
3238       // clang-format on
3239   };
3240   return params;
3241 }
3242 
TEST_P(OpConverter_FP32_Test,ConvertMatMul)3243 TEST_P(OpConverter_FP32_Test, ConvertMatMul) {
3244   // Get the NodeDef for MatMul.
3245   auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
3246                                bool transpose_b) -> NodeDef {
3247     Scope s = Scope::NewRootScope();
3248     auto input = ops::Placeholder(s.WithOpName("input"), dtype);
3249     auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
3250     const auto matmul_attrs =
3251         ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b);
3252     auto matmul =
3253         ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs);
3254     return matmul.operation.node()->def();
3255   };
3256 
3257   TestMatMulHelper(this, get_matmul_nodedef, GetMatMulTestParams());
3258 }
3259 
TEST_P(OpConverter_FP32_Test,ConvertBatchMatMul)3260 TEST_P(OpConverter_FP32_Test, ConvertBatchMatMul) {
3261   // Get the NodeDef for BatchMatMul.
3262   auto get_batch_matmul_nodedef = [](DataType dtype, bool transpose_a,
3263                                      bool transpose_b) -> NodeDef {
3264     Scope s = Scope::NewRootScope();
3265     auto input = ops::Placeholder(s.WithOpName("input"), dtype);
3266     auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
3267     const auto matmul_attrs =
3268         ops::BatchMatMul::AdjX(transpose_a).AdjY(transpose_b);
3269     auto matmul = ops::BatchMatMul(s.WithOpName("my_matmul"), input, weights,
3270                                    matmul_attrs);
3271     return matmul.operation.node()->def();
3272   };
3273 
3274   // We derive test data from the MatMul test params by adding extra leading
3275   // dimensions.
3276   std::vector<MatMulTestParams> params_2d = GetMatMulTestParams();
3277   std::vector<MatMulTestParams> params;
3278   params.reserve(params_2d.size() * 3 + 1);
3279 
3280   auto insert_ones = [](std::vector<int> v, int n) {
3281     std::vector<int> ones(n, 1);
3282     ones.insert(ones.end(), v.begin(), v.end());
3283     return ones;
3284   };
3285 
3286   // Add a leading 1 dimension to A, B and result.
3287   std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
3288                  [](MatMulTestParams p) {
3289                    p.shape_a.insert(p.shape_a.begin(), 1);
3290                    p.shape_b.insert(p.shape_b.begin(), 1);
3291                    p.expected_shape.insert(p.expected_shape.begin(), 1);
3292                    return p;
3293                  });
3294 
3295   // Test with N > 1: weights cannot be batched in implicit batch mode.
3296   // clang-format off
3297   params.push_back(
3298       MatMulTestParams{{2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false,  // A
3299                        {2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false,  // B
3300                        {2, 2, 2}, {2, 3, 6, 11, 2, 3, 6, 11}}       // result
3301   );
3302 
3303   params.push_back(
3304       MatMulTestParams{{2, 2, 3}, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5},
3305       false,
3306                        {2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, true,
3307                        {2, 2, 2}, {8, 17, 26, 62, 8, 17, 26, 62}});
3308   // clang-format on
3309 
3310   // Add two leading 1 dimensions to A, B and result.
3311   std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
3312                  [insert_ones](MatMulTestParams p) {
3313                    p.shape_a = insert_ones(p.shape_a, 2);
3314                    p.shape_b = insert_ones(p.shape_b, 2);
3315                    p.expected_shape = insert_ones(p.expected_shape, 2);
3316                    return p;
3317                  });
3318 
3319   // Test broadcast: add two leading 1 dimensions to A, but not to B.
3320   std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
3321                  [insert_ones](MatMulTestParams p) {
3322                    p.shape_a = insert_ones(p.shape_a, 2);
3323                    p.expected_shape = insert_ones(p.expected_shape, 2);
3324                    return p;
3325                  });
3326 
3327   // Test broadcast: add a leading 1 dimension to A and two leading 1s to B.
3328   // Broadcasting A need a dynamic brodacast which will be incompatible with
3329   // FC layer.
3330   std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
3331                  [insert_ones](MatMulTestParams p) {
3332                    p.shape_a = insert_ones(p.shape_a, 1);
3333                    p.shape_b = insert_ones(p.shape_b, 2);
3334                    p.expected_shape = insert_ones(p.expected_shape, 2);
3335                    return p;
3336                  });
3337 
3338   // Test with N > 1: since weights cannot be batched in implicit batch mode.
3339   // We tests with batch size 2.
3340   std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
3341                  [insert_ones](MatMulTestParams p) {
3342                    p.shape_a.insert(p.shape_a.begin(), 2);
3343                    p.values_a.reserve(p.values_a.size() * 2);
3344                    p.values_a.insert(p.values_a.end(), p.values_a.begin(),
3345                                      p.values_a.end());
3346 
3347                    p.shape_b.insert(p.shape_b.begin(), 2);
3348                    p.values_b.reserve(p.values_b.size() * 2);
3349                    p.values_b.insert(p.values_b.end(), p.values_b.begin(),
3350                                      p.values_b.end());
3351 
3352                    p.expected_shape.insert(p.expected_shape.begin(), 2);
3353                    p.expected_output.reserve(p.expected_output.size() * 2);
3354                    p.expected_output.insert(p.expected_output.end(),
3355                                             p.expected_output.begin(),
3356                                             p.expected_output.end());
3357                    return p;
3358                  });
3359 
3360   // 4D tensor where the second "batch dim" is not 1
3361   params.push_back(MatMulTestParams{
3362       {1, 2, 4, 5},
3363       {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13,
3364        14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
3365        28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
3366       false,  // A
3367       {1, 2, 3, 5},
3368       {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
3369        16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30},
3370       true,  // B
3371       {1, 2, 4, 3},
3372       {40,   90,   140,  115,  290,  465,  190,  490,
3373        790,  265,  690,  1115, 1990, 2540, 3090, 2440,
3374        3115, 3790, 2890, 3690, 4490, 3340, 4265, 5190}});  // result
3375 
3376   TestMatMulHelper(this, get_batch_matmul_nodedef, params);
3377 }
3378 
3379 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
TEST_P(OpConverter_FP32_Test,ConvertEinsum)3380 TEST_P(OpConverter_FP32_Test, ConvertEinsum) {
3381   // Get the NodeDef for Einsum.
3382   auto get_einsum_nodedef = [](DataType dtype, std::string eq,
3383                                int n_inputs = 2) -> NodeDef {
3384     Scope s = Scope::NewRootScope();
3385     auto a = ops::Placeholder(s.WithOpName("input_a"), dtype);
3386     std::vector<Input> input_vec{a};
3387     if (n_inputs > 1) {
3388       auto b = ops::Placeholder(s.WithOpName("input_b"), dtype);
3389       input_vec.push_back(b);
3390     }
3391     InputList inputs(input_vec);
3392     auto einsum = ops::Einsum(s.WithOpName("my_einsum"), inputs, eq);
3393     return einsum.operation.node()->def();
3394   };
3395 
3396   if (trt_mode_ == TrtTestMode::kImplicitBatch) {
3397     Reset();
3398     NodeDef node_def = get_einsum_nodedef(tf_type_, "ab,cb->ac");
3399     AddTestTensor("input_a", {2, 3});
3400     AddTestTensor("input_b", {2, 3});
3401     TestOpConverter(
3402         "my_einsum", node_def, {2, 2},
3403         errors::Unimplemented("Einsum converter requires dynamic shape mode"),
3404         Status::OK(), ElementsAreArray({13, 16, 40, 52}));
3405     // No further tests.
3406     return;
3407   }
3408 
3409   struct TestParams {
3410     std::string equation;
3411     std::vector<int> shape_a;
3412     std::vector<int> values_a;
3413     std::vector<int> shape_b;
3414     std::vector<int> values_b;
3415     std::vector<int> expected_shape;
3416     std::vector<int> expected_output;
3417     Status conv_status;
3418   };
3419 
3420   Status unimplemented_eq = errors::Unimplemented("");
3421   Status internal_err = errors::Internal("");
3422   Status internal_err_before_TRT82 =
3423       IS_TRT_VERSION_GE(8, 2, 0, 0) ? Status::OK() : internal_err;
3424   Status unimplemented_before_TRT82 =
3425       IS_TRT_VERSION_GE(8, 2, 0, 0) ? Status::OK() : unimplemented_eq;
3426 
3427   Status diagonal_error = unimplemented_eq;
3428   // The old converter only accepts 2 inputs, and the validator returns
3429   // internal_err if only 1 input is used.
3430   Status diagonal_error_1_input =
3431       IS_TRT_VERSION_GE(8, 2, 0, 0) ? unimplemented_eq : internal_err;
3432 
3433   std::vector<TestParams> params{
3434       // Dot product.
3435       TestParams{"i,i->", {2}, {2, 3}, {2}, {1, 2}, {}, {8}, unimplemented_eq},
3436       TestParams{"ik,ik->",
3437                  {2, 2},
3438                  {2, 3, 4, 1},
3439                  {2, 2},
3440                  {1, 2, 1, 3},
3441                  {},
3442                  {15},
3443                  unimplemented_eq},
3444       // Outer product.
3445       TestParams{"i,k->ik",
3446                  {2},
3447                  {1, 2},
3448                  {3},
3449                  {1, 2, 3},
3450                  {2, 3},
3451                  {1, 2, 3, 2, 4, 6},
3452                  unimplemented_eq},
3453       TestParams{"ij,kl->ijkl",
3454                  {2, 1},
3455                  {1, 2},
3456                  {3, 1},
3457                  {1, 2, 3},
3458                  {2, 1, 3, 1},
3459                  {1, 2, 3, 2, 4, 6},
3460                  unimplemented_before_TRT82},
3461       // Transpose.
3462       TestParams{"ik->ki",
3463                  {2, 3},
3464                  {0, 1, 2, 3, 4, 5},
3465                  {},
3466                  {},
3467                  {3, 2},
3468                  {0, 3, 1, 4, 2, 5},
3469                  internal_err_before_TRT82},
3470       // Diag.
3471       TestParams{"ii->i",
3472                  {3, 3},
3473                  {0, 1, 2, 3, 4, 5, 6, 7, 8},
3474                  {},
3475                  {},
3476                  {3},
3477                  {0, 4, 8},
3478                  diagonal_error_1_input},
3479       // Trace.
3480       TestParams{"ii->",  // Note TF einsum op always has '->'.
3481                  {3, 3},
3482                  {0, 1, 2, 3, 4, 5, 6, 7, 8},
3483                  {},
3484                  {},
3485                  {},
3486                  {12},
3487                  diagonal_error_1_input},
3488       // MatMul with reduction.
3489       TestParams{"abbc,dc->ad",
3490                  {1, 2, 2, 3},
3491                  {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3492                  {2, 3},
3493                  {1, 2, 3, 4, 5, 6},
3494                  {2, 3},
3495                  {1, 2, 3, 2, 4, 6},
3496                  diagonal_error},
3497       // Ellipsis with broadcast.
3498       TestParams{"...ik,...jk->...ij",
3499                  {1, 3, 1, 4},
3500                  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
3501                  {2, 1, 1, 4},
3502                  {1, 2, 3, 4, 5, 6, 7, 8},
3503                  {2, 3, 1, 1},
3504                  {20, 60, 100, 44, 148, 252},
3505                  unimplemented_eq},
3506       // MatMul.
3507       TestParams{"ab,bc->ac",
3508                  {2, 3},
3509                  {0, 1, 2, 3, 4, 5},
3510                  {3, 2},
3511                  {1, 2, 3, 4, 5, 6},
3512                  {2, 2},
3513                  {13, 16, 40, 52}},
3514       // Batched MatMul.
3515       TestParams{"abc,cde->abde",
3516                  /*shape_a=*/{1, 2, 3},
3517                  /*values_a=*/{0, 1, 2, 3, 4, 5},
3518                  /*shape_b=*/{3, 2, 2},
3519                  /*values_v=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3520                  /*expected_shape=*/{1, 2, 2, 2},
3521                  /*expected_output=*/{23, 26, 29, 32, 68, 80, 92, 104}},
3522       TestParams{"abcd,cde->abe",
3523                  {1, 2, 2, 3},
3524                  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
3525                  {2, 3, 2},
3526                  {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3527                  {1, 2, 2},
3528                  {125, 140, 341, 392}},
3529       // TF assumes case sensitive labels.
3530       TestParams{"aBAE,AEe->aBe",
3531                  {1, 2, 2, 3},
3532                  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
3533                  {2, 3, 2},
3534                  {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3535                  {1, 2, 2},
3536                  {125, 140, 341, 392}},
3537       TestParams{"abc,cd->abd",
3538                  {1, 2, 3},
3539                  {0, 1, 2, 3, 4, 5},
3540                  {3, 2},
3541                  {1, 2, 3, 4, 5, 6},
3542                  {1, 2, 2},
3543                  {13, 16, 40, 52}},
3544       TestParams{"acbe,aecd->abcd",
3545                  {1, 2, 3, 4},
3546                  {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
3547                   12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
3548                  {1, 4, 2, 3},
3549                  {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
3550                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
3551                  {1, 3, 2, 3},
3552                  {90, 96, 102, 732, 786, 840, 250, 272, 294, 940, 1010, 1080,
3553                   410, 448, 486, 1148, 1234, 1320}},
3554       TestParams{"aecd,abcd->acbe",
3555                  {1, 2, 3, 4},
3556                  {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
3557                   12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
3558                  {1, 2, 3, 4},
3559                  {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
3560                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
3561                  {1, 3, 2, 2},
3562                  {20, 140, 92, 788, 148, 460, 412, 1300, 404, 908, 860, 1940}},
3563       TestParams{"acd,dce->ae",
3564                  {1, 2, 3},
3565                  {0, 1, 2, 3, 4, 5},
3566                  {3, 2, 2},
3567                  {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3568                  {1, 2},
3569                  {115, 130}},
3570       TestParams{"abcd,bace->bade",
3571                  {2, 3, 2, 1},
3572                  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
3573                  {3, 2, 2, 1},
3574                  {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
3575                  {3, 2, 1, 1},
3576                  {2, 46, 28, 128, 86, 242}},
3577       TestParams{
3578           "cebfad,fageb->abcdg",
3579           {1, 1, 3, 3, 2, 2},
3580           {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
3581            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
3582            24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
3583           {3, 2, 2, 1, 3},
3584           {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
3585            13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
3586            25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36},
3587           {2, 3, 1, 2, 2},
3588           {252, 288, 291, 336, 768,  912,  810,  963,  1356, 1608, 1401, 1662,
3589            438, 492, 495, 558, 1176, 1338, 1236, 1407, 1986, 2256, 2049, 2328}},
3590   };
3591 
3592   for (auto p : params) {
3593     for (bool a_is_tensor : {true, false}) {
3594       for (bool b_is_tensor : {true, false}) {
3595         if (!a_is_tensor && !b_is_tensor) {
3596           // Skip test when both args are weights. We do not convert this
3597           // since const folding eliminates this case.
3598           continue;
3599         }
3600         Reset();
3601         int n_inputs = p.shape_b.empty() ? 1 : 2;
3602         NodeDef node_def = get_einsum_nodedef(tf_type_, p.equation, n_inputs);
3603         if (a_is_tensor) {
3604           AddTestTensor("input_a", p.shape_a, p.values_a);
3605         } else {
3606           AddTestWeights("input_a", p.shape_a, p.values_a, tf_type_);
3607         }
3608         if (!p.shape_b.empty()) {
3609           if (b_is_tensor) {
3610             AddTestTensor("input_b", p.shape_b, p.values_b);
3611           } else {
3612             AddTestWeights("input_b", p.shape_b, p.values_b, tf_type_);
3613           }
3614         }
3615         TestOpConverter("my_einsum", node_def, p.expected_shape, p.conv_status,
3616                         Status::OK(), ElementsAreArray(p.expected_output));
3617       }
3618     }
3619   }
3620 }
3621 #endif  // IS_TRT_VERSION_GE(7, 1, 3, 0)
3622 
TEST_P(OpConverter_FP32_FP16_Test,ConvertBiasAdd)3623 TEST_P(OpConverter_FP32_FP16_Test, ConvertBiasAdd) {
3624   // Note that kINT32 is not supported by IScaleLayer, so we don't test
3625   // DT_INT32 type here. DT_FLOAT and DT_HALF are tested.
3626   // Get the NodeDef for BiasAdd.
3627   auto get_biasadd_nodedef = [](const string& data_format,
3628                                 DataType tf_type) -> NodeDef {
3629     Scope s = Scope::NewRootScope();
3630     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
3631     auto weights = ops::Placeholder(s.WithOpName("weights"), tf_type);
3632     const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format);
3633     auto biasadd =
3634         ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs);
3635     return biasadd.operation.node()->def();
3636   };
3637 
3638   for (const string& data_format : {"NHWC", "NCHW"}) {
3639     for (const int trt_input_rank : {1, 2, 3, 4}) {
3640       Reset();
3641       NodeDef node_def = get_biasadd_nodedef(data_format, tf_type_);
3642 
3643       // Add input, dims_array will be like {2, 1, ..., 1, 3}
3644       std::vector<int32> dims_array(trt_input_rank + 1, 1);
3645       if (trt_input_rank == 1) {
3646         dims_array[1] = (data_format == "NHWC" ? 3 : 2);
3647       } else {
3648         dims_array[1] = 2;
3649         dims_array[trt_input_rank] = 3;
3650       }
3651       const int64_t num_input = DimsAdapter(dims_array).Volume();
3652       ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2),
3653                 num_input);
3654       std::vector<float> input_data(num_input, 0);
3655 
3656       AddTestTensor("input", dims_array, input_data);
3657 
3658       const int channel_size = (data_format == "NHWC" ? 3 : 2);
3659       std::vector<float> bias(channel_size);
3660       for (int i = 0; i < channel_size; ++i) {
3661         bias[i] = i + 1;  // bias will be {1, 2, 3, ...}
3662       }
3663       AddTestWeights("weights", {channel_size}, bias, tf_type_);
3664 
3665       // Build and run the engine.
3666       std::vector<float> output_data;
3667 
3668       if (trt_input_rank == 1) {
3669         if (data_format == "NHWC") {
3670           output_data = {1, 2, 3};
3671         } else {
3672           output_data = {1, 2};
3673         }
3674       } else {
3675         if (data_format == "NHWC") {
3676           output_data = {1, 2, 3, 1, 2, 3};
3677         } else {
3678           output_data = {1, 1, 1, 2, 2, 2};
3679         }
3680       }
3681       TestOpConverter("my_biasadd", node_def, dims_array, Status::OK(),
3682                       Status::OK(), ElementsAreArray(output_data));
3683     }
3684   }
3685 }
3686 
3687 template <typename OpType>
GetBinaryOpNodeDef(DataType dtype)3688 NodeDef GetBinaryOpNodeDef(DataType dtype) {
3689   Scope s = Scope::NewRootScope();
3690   auto input_l = ops::Placeholder(s.WithOpName("input1"), dtype);
3691   auto input_r = ops::Placeholder(s.WithOpName("input2"), dtype);
3692   auto op = OpType(s.WithOpName("my_binary"), input_l, input_r);
3693   return op.operation.node()->def();
3694 }
3695 
TEST_P(OpConverter_FP32_FP16_BinaryTest,ConvertBinary)3696 TEST_P(OpConverter_FP32_FP16_BinaryTest, ConvertBinary) {
3697   using OpFunc = std::function<NodeDef(DataType)>;
3698   std::map<std::string, std::pair<OpFunc, std::vector<float>>> op_test_info;
3699 #define ADD_OP(name, op, v1, v2, v3, v4, v5, v6, v7, v8) \
3700   op_test_info[name] =                                   \
3701       std::make_pair(GetBinaryOpNodeDef<op>,             \
3702                      std::vector<float>(v1, v2, v3, v4, v5, v6, v7, v8))
3703   ADD_OP("Add", ops::Add, {5, 8, 6, 9, 5, 8, 6, 9});
3704   ADD_OP("AddV2", ops::AddV2, {5, 8, 6, 9, 5, 8, 6, 9});
3705   ADD_OP("Sub", ops::Sub, {1, 4, 0, 3, 1, 4, 0, 3});
3706   ADD_OP("Mul", ops::Mul, {6, 12, 9, 18, 6, 12, 9, 18});
3707   ADD_OP("Div", ops::Div, {1.5, 3, 1, 2, 1.5, 3, 1, 2});
3708   ADD_OP("RealDiv", ops::RealDiv, {1.5, 3, 1, 2, 1.5, 3, 1, 2});
3709   ADD_OP("FloorDiv", ops::FloorDiv, {1, 3, 1, 2, 1, 3, 1, 2});
3710   ADD_OP("Minimum", ops::Minimum, {2, 2, 3, 3, 2, 2, 3, 3});
3711   ADD_OP("Maximum", ops::Maximum, {3, 6, 3, 6, 3, 6, 3, 6});
3712   ADD_OP("Pow", ops::Pow, {9, 36, 27, 216, 9, 36, 27, 216});
3713 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
3714   ADD_OP("Greater", ops::Greater, {1, 1, 0, 1, 1, 1, 0, 1});
3715   ADD_OP("Less", ops::Less, {0, 0, 0, 0, 0, 0, 0, 0});
3716   ADD_OP("Equal", ops::Equal, {0, 0, 1, 0, 0, 0, 1, 0});
3717   ADD_OP("GreaterEqual", ops::Less, {1, 1, 1, 1, 1, 1, 1, 1});
3718   ADD_OP("LessEqual", ops::Greater, {0, 0, 1, 0, 0, 0, 1, 0});
3719 #endif
3720 #undef ADD_OP
3721   std::vector<std::vector<float>> data = {
3722       {3, 6, 3, 6}, {3, 6}, {2, 3, 2, 3}, {2, 3}};
3723   RunTests(*BinaryOperationMap(), op_test_info, data);
3724 }
3725 
TEST_P(OpConverter_BOOL_BinaryTest,ConvertBooleanBinary)3726 TEST_P(OpConverter_BOOL_BinaryTest, ConvertBooleanBinary) {
3727   using OpFunc = std::function<NodeDef(DataType)>;
3728   std::map<std::string, std::pair<OpFunc, std::vector<int>>> op_test_info;
3729 #define ADD_OP(name, op, v1, v2, v3, v4, v5, v6, v7, v8) \
3730   op_test_info[name] =                                   \
3731       std::make_pair(GetBinaryOpNodeDef<op>,             \
3732                      std::vector<int>(v1, v2, v3, v4, v5, v6, v7, v8))
3733   ADD_OP("LogicalOr", ops::LogicalOr, {1, 1, 0, 1, 1, 1, 0, 1});
3734   ADD_OP("LogicalAnd", ops::LogicalAnd, {0, 1, 0, 0, 0, 1, 0, 0});
3735 #undef ADD_OP
3736 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
3737   std::vector<std::vector<int>> data = {
3738       {0, 1, 0, 1}, {0, 1}, {1, 0, 1, 0}, {1, 0}};
3739   RunTests(*BinaryBooleanOperationMap(), op_test_info, data);
3740 #endif
3741 }
3742 
GetAddNNodeDef(const std::vector<string> & input_names,DataType dtype)3743 NodeDef GetAddNNodeDef(const std::vector<string>& input_names, DataType dtype) {
3744   Scope s = Scope::NewRootScope();
3745   OutputList inputs;
3746   for (const string& name : input_names) {
3747     inputs.push_back(ops::Placeholder(s.WithOpName(name), dtype));
3748   }
3749   auto op = ops::AddN(s.WithOpName("my_addn"), inputs);
3750   return op.operation.node()->def();
3751 }
3752 
3753 struct AddNTestParams {
3754   std::vector<float> input_values;
3755   std::vector<string> input_names;
3756   std::vector<int> dimensions;
3757   std::vector<float> expected_output;
3758   Status status;
3759 };
3760 
TestAddN(ParameterizedOpConverterTestBase * test,AddNTestParams & p)3761 void TestAddN(ParameterizedOpConverterTestBase* test, AddNTestParams& p) {
3762   // All inputs are tensors.
3763   test->Reset();
3764   const NodeDef node_def = GetAddNNodeDef(p.input_names, test->get_tf_type());
3765 
3766   if (p.input_values.size() % p.input_names.size() != 0) {
3767     LOG(ERROR) << "The number of input values: `" << p.input_values.size()
3768                << "` is not a multiple of the number of inputs: `"
3769                << p.input_names.size() << "`";
3770     ASSERT_TRUE(false);
3771   }
3772 
3773   DataVec input_data;
3774   int input_offset = 0;
3775   const int window_size = p.input_values.size() / p.input_names.size();
3776   for (const string& name : p.input_names) {
3777     std::vector<float>::const_iterator start_pos =
3778         p.input_values.begin() + input_offset;
3779     std::vector<float>::const_iterator end_pos = start_pos + window_size;
3780     std::vector<float> sub_input_val(start_pos, end_pos);
3781     input_offset += window_size;
3782 
3783     test->AddTestTensor(name, p.dimensions, test->get_tf_type(), sub_input_val);
3784   }
3785 
3786   test->TestOpConverter("my_addn", node_def, p.dimensions,
3787                         /*expected_conversion_status=*/p.status,
3788                         /*expected_runtime_status=*/p.status,
3789                         /*matcher=*/ElementsAreArray(p.expected_output),
3790                         /*out_tf_types=*/{test->get_tf_type()});
3791 }
3792 
TEST_P(OpConverter_FP32_FP16_Test,ConvertAddN)3793 TEST_P(OpConverter_FP32_FP16_Test, ConvertAddN) {
3794   {
3795     // Weights with batch dim that is not 1.
3796     Reset();
3797     const NodeDef node_def = GetAddNNodeDef({"tensor", "weights"}, tf_type_);
3798     AddTestTensor("tensor", /*dims=*/{1, 2});
3799     AddTestWeights<float>("weights", {2, 1, 2}, {0, 1, 2, 3});
3800     RunValidationAndConversion(
3801         node_def, error::INVALID_ARGUMENT,
3802         "Weights input to AddN is required to have batch dimension 1.");
3803   }
3804 
3805   const std::vector<float> common_input = CreateVectorIota<float>(6);
3806 
3807   std::vector<AddNTestParams> params = {
3808       {/*input_values=*/common_input,
3809        /*input_names=*/{"inp1", "inp2", "inp3"},
3810        /*dimensions=*/{1, 1, 2, 1, 1},
3811        /*expected_output=*/{6, 9},
3812        /*status=*/Status::OK()},
3813       {/*input_values=*/common_input,
3814        /*input_names=*/{"inp1", "inp2"},
3815        /*dimensions=*/{1, 1, 3, 1, 1},
3816        /*expected_output=*/{3, 5, 7},
3817        /*status=*/Status::OK()},
3818       {/*input_values=*/common_input,
3819        /*input_names=*/{"inp1", "inp2", "inp3"},
3820        /*dimensions=*/{1, 2, 1, 1},
3821        /*expected_output=*/{6, 9},
3822        /*status=*/Status::OK()},
3823       {/*input_values=*/common_input,
3824        /*input_names=*/{"inp1", "inp2"},
3825        /*dimensions=*/{1, 1, 3, 1},
3826        /*expected_output=*/{3, 5, 7},
3827        /*status=*/Status::OK()},
3828       {/*input_values=*/common_input,
3829        /*input_names=*/{"inp1", "inp2", "inp3"},
3830        /*dimensions=*/{1, 2, 1},
3831        /*expected_output=*/{6, 9},
3832        /*status=*/Status::OK()},
3833       {/*input_values=*/common_input,
3834        /*input_names=*/{"inp1", "inp2"},
3835        /*dimensions=*/{1, 1, 3},
3836        /*expected_output=*/{3, 5, 7},
3837        /*status=*/Status::OK()},
3838       {/*input_value=*/common_input,
3839        /*input_names=*/{"inp1", "inp2", "inp3"},
3840        /*dimensions=*/{2, 1},
3841        /*expected_output=*/{6, 9},
3842        /*status=*/Status::OK()},
3843       {/*input_values=*/common_input,
3844        /*input_names=*/{"inp1", "inp2"},
3845        /*dimensions=*/{1, 3},
3846        /*expected_output=*/{3, 5, 7},
3847        /*status=*/Status::OK()},
3848       {/*input_values=*/common_input,
3849        /*input_names=*/{"inp1", "inp2", "inp3"},
3850        /*dimensions=*/{2},
3851        /*expected_output=*/{6, 9},
3852        /*status=*/Status::OK()},
3853       {/*input_values=*/common_input,
3854        /*input_names=*/{"inp1", "inp2"},
3855        /*dimensions=*/{3},
3856        /*expected_output=*/{3, 5, 7},
3857        /*status=*/Status::OK()},
3858       {/*input_values=*/common_input,
3859        /*input_names=*/{"inp1", "inp2", "inp3", "inp4", "inp5", "inp6"},
3860        /*dimensions=*/{1},
3861        /*expected_output=*/{15},
3862        /*status=*/Status::OK()},
3863   };
3864 
3865   for (auto p : params) {
3866     TestAddN(this, p);
3867   }
3868 }
3869 
TEST_P(OpConverter_FP32_Test,ConvertQDQDynamicRangeMode)3870 TEST_P(OpConverter_FP32_Test, ConvertQDQDynamicRangeMode) {
3871   {
3872     // FakeQuantWithMinMaxArgs attributes are empty, should fail.
3873     Reset(TrtPrecisionMode::INT8);
3874     NodeDef node_def =
3875         MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"});
3876     AddTestTensor("input", {1, 2, 3});
3877     RunValidationAndConversion(node_def, error::NOT_FOUND,
3878                                "No attr named 'min'");
3879   }
3880   {
3881     // FakeQuantWithMinMaxArgs ranges set via attributes, ok.
3882     Reset(TrtPrecisionMode::INT8);
3883     Scope s = Scope::NewRootScope();
3884     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3885     auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
3886     auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"),
3887                                                  input, quantize_attrs);
3888     const NodeDef& node_def = quantize.operation.node()->def();
3889     AddTestTensor("input", {1, 2, 3});
3890     RunValidationAndConversion(node_def);
3891     TRT_TensorOrWeights output;
3892     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
3893     ASSERT_TRUE(output.is_tensor());
3894     auto ranges = quantization_ranges();
3895     EXPECT_EQ(1, ranges.count(output.tensor()->trt_tensor()));
3896     EXPECT_EQ(6.0f, ranges[output.tensor()->trt_tensor()]);
3897   }
3898   {
3899     // FakeQuantWithMinMaxVars ranges set via inputs, ok.
3900     Reset(TrtPrecisionMode::INT8);
3901     Scope s = Scope::NewRootScope();
3902     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3903     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
3904     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
3905     auto quantize = ops::FakeQuantWithMinMaxVars(
3906         s.WithOpName("my_quantize"), input, weights_min, weights_max);
3907     const NodeDef& node_def = quantize.operation.node()->def();
3908     AddTestTensor("input", {1, 2, 3});
3909     AddTestWeights<float>("weights_min", {1}, {-6.0f});
3910     AddTestWeights<float>("weights_max", {1}, {6.0f});
3911     RunValidationAndConversion(node_def);
3912     TRT_TensorOrWeights output;
3913     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
3914     ASSERT_TRUE(output.is_tensor());
3915     auto ranges = quantization_ranges();
3916     EXPECT_EQ(1, ranges.count(output.tensor()->trt_tensor()));
3917     EXPECT_EQ(6.0f, ranges[output.tensor()->trt_tensor()]);
3918   }
3919   {
3920     // QuantizeAndDequantizeV2 ranges set via inputs, ok.
3921     Reset(TrtPrecisionMode::INT8);
3922     Scope s = Scope::NewRootScope();
3923     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3924     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
3925     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
3926     auto quantize = ops::QuantizeAndDequantizeV2(
3927         s.WithOpName("my_quantize"), input, weights_min, weights_max);
3928     const NodeDef& node_def = quantize.operation.node()->def();
3929     AddTestTensor("input", {1, 2, 3});
3930     AddTestWeights<float>("weights_min", {1}, {-6.0f});
3931     AddTestWeights<float>("weights_max", {1}, {6.0f});
3932     RunValidationAndConversion(node_def);
3933     TRT_TensorOrWeights output;
3934     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
3935     ASSERT_TRUE(output.is_tensor());
3936     auto ranges = quantization_ranges();
3937     EXPECT_EQ(1, ranges.count(output.tensor()->trt_tensor()));
3938     EXPECT_EQ(6.0f, ranges[output.tensor()->trt_tensor()]);
3939   }
3940   {
3941     // QuantizeAndDequantizeV2 Range inputs are tensors, should fail.
3942     Reset(TrtPrecisionMode::INT8);
3943     Scope s = Scope::NewRootScope();
3944     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3945     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
3946     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
3947     auto quantize = ops::QuantizeAndDequantizeV2(
3948         s.WithOpName("my_quantize"), input, weights_min, weights_max);
3949     const NodeDef& node_def = quantize.operation.node()->def();
3950     AddTestTensor("input", {1, 2, 3});
3951     AddTestTensor("weights_min", {1});
3952     AddTestTensor("weights_max", {1});
3953     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3954                                "The input \"input_min\" for "
3955                                "QuantizeAndDequantizeV2 must be a constant");
3956   }
3957   {
3958     // QuantizeAndDequantizeV3 ranges set via inputs, ok.
3959     Reset(TrtPrecisionMode::INT8);
3960     Scope s = Scope::NewRootScope();
3961     auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
3962     auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
3963     auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
3964     auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32);
3965     auto quantize = ops::QuantizeAndDequantizeV3(
3966         s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits);
3967     const NodeDef& node_def = quantize.operation.node()->def();
3968     AddTestTensor("input", {1, 2, 3});
3969     AddTestWeights<float>("weights_min", {1}, {-6.0f});
3970     AddTestWeights<float>("weights_max", {1}, {6.0f});
3971     AddTestWeights<int>("num_bits", {1}, {8});
3972     RunValidationAndConversion(node_def);
3973     TRT_TensorOrWeights output;
3974     TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
3975     ASSERT_TRUE(output.is_tensor());
3976     auto ranges = quantization_ranges();
3977     EXPECT_EQ(1, ranges.count(output.tensor()->trt_tensor()));
3978     EXPECT_EQ(6.0f, ranges[output.tensor()->trt_tensor()]);
3979   }
3980 }
3981 
TEST_P(OpConverter_FP32_FP16_Test,ConvertSquare)3982 TEST_P(OpConverter_FP32_FP16_Test, ConvertSquare) {
3983   {
3984     // Input is weights, should fail.
3985     Reset();
3986     Scope s = Scope::NewRootScope();
3987     auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
3988     auto square = ops::Square(s.WithOpName("my_square"), input);
3989     NodeDef node_def = square.operation.node()->def();
3990     AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type_);
3991     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
3992                                "The input \"x\" for Square must be a tensor");
3993   }
3994 
3995   Reset();
3996 
3997   Scope s = Scope::NewRootScope();
3998   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
3999   auto square = ops::Square(s.WithOpName("my_square"), input);
4000   NodeDef node_def = square.operation.node()->def();
4001 
4002   const int num_inputs = 20;
4003   std::vector<float> inputs(num_inputs);
4004   std::vector<float> expected_outputs(num_inputs);
4005 
4006   for (int i = 0; i < num_inputs; ++i) {
4007     const float value = (i - 9);
4008     inputs[i] = value;
4009     expected_outputs[i] = value * value;
4010   }
4011   AddTestTensor("input", {1, 1, 20}, tf_type_, inputs);
4012 
4013   TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(),
4014                   ArrayFloatNear(expected_outputs, 0));
4015 }
4016 
4017 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
4018 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertFill)4019 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertFill) {
4020   Scope s = Scope::NewRootScope();
4021   auto dims = ops::Placeholder(s.WithOpName("dims"), DT_INT32);
4022   auto value = ops::Placeholder(s.WithOpName("value"), tf_type_);
4023   auto fill = ops::Fill(s.WithOpName("my_fill"), dims, value);
4024   const NodeDef& node_def = fill.operation.node()->def();
4025 
4026   if (trt_mode_ == TrtTestMode::kImplicitBatch) {
4027     Reset();
4028     // random data
4029     AddTestWeights("dims", {2}, {2, 2}, DT_INT32);
4030     AddTestWeights("value", {1}, {42.0}, tf_type_);
4031     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
4032                                "Conversion for Fill is not implemented in "
4033                                "implicit batch mode");
4034     return;
4035   }
4036 
4037   std::vector<std::vector<int>> output_dims_params = {
4038       {8}, {8, 2, 4}, {32, 32, 3200}};
4039   std::vector<std::vector<int>> value_dims_params = {{}, {1}};
4040 
4041   float val = 42.0;
4042   Status status = Status::OK();
4043   for (bool dims_is_tensor : {true, false}) {
4044     for (bool value_is_tensor : {true, false}) {
4045       for (auto output_dims : output_dims_params) {
4046         for (auto value_dims : value_dims_params) {
4047           Reset();
4048           std::vector<int32> dims_dims = {output_dims.size()};
4049           if (dims_is_tensor) {
4050             AddTestTensor("dims", dims_dims, DT_INT32, output_dims, dims_dims);
4051           } else {
4052             AddTestWeights("dims", dims_dims, output_dims, DT_INT32);
4053           }
4054           if (value_is_tensor) {
4055             AddTestTensor("value", value_dims, tf_type_, {val});
4056           } else {
4057             AddTestWeights("value", value_dims, {val}, tf_type_);
4058           }
4059           size_t nb_el = 1;
4060           for (auto d : output_dims) {
4061             nb_el *= d;
4062           }
4063           std::vector<float> expected_output(nb_el, val);
4064           TestOpConverter("my_fill", node_def, output_dims, status, status,
4065                           ElementsAreArray(expected_output));
4066         }
4067       }
4068     }
4069   }
4070 }
4071 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertRange)4072 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) {
4073   auto get_casted_value = [this](const float value, const DataType dtype) {
4074     return dtype == DT_INT32 ? static_cast<int32>(value) : value;
4075   };
4076 
4077   // A function that builds the next lexicographically greater configuration
4078   // for the current one. The configuration is described as a (0,1)-vector
4079   // config, where config[i] is 0 or 1 when the i-th parameter is passed as
4080   // a weight or tensor, respectively. The function returns TRUE if such
4081   // a configuration is built, or FALSE otherwise.
4082   auto nextTensorWeigtConfiguration = [this](std::vector<int>& config) {
4083     for (int i = config.size(); i-- > 0;) {
4084       if (config[i] = 1 - config[i]) return true;
4085     }
4086     return false;
4087   };
4088 
4089   auto set_parameters = [this](const std::array<const char*, 3>& name,
4090                                const std::array<std::vector<float>, 3>& value,
4091                                const std::array<DataType, 3>& type,
4092                                const std::vector<int>& config,
4093                                int shape_idx = -1) {
4094     Reset();
4095     for (int i = 0; i < 3; i++) {
4096       if (config[i]) {
4097         std::vector<int32> partial_shape_dims = {};
4098         // The correct partial shape will be provided
4099         // (a) for all parameters, when shape_idx > 3
4100         // (b) for all parameters, except shape_idx, when shape_idx >= 0
4101         // (c) for none of the shape_idx < 0
4102         if (shape_idx > 3 || shape_idx >= 0 && shape_idx != i) {
4103           partial_shape_dims = {1};
4104         }
4105         AddTestTensor(name[i], {1}, type[i], value[i], partial_shape_dims);
4106       } else {
4107         AddTestWeights(name[i], {1}, value[i], type[i]);
4108       }
4109     }
4110   };
4111 
4112   const float start = 1.0;
4113   const float limit = 43.0;
4114   const float delta = 2.0;
4115 
4116   const std::array<const char*, 3> param_name = {"start", "limit", "delta"};
4117   std::array<std::vector<float>, 3> param_value;
4118   param_value[0] = {start};
4119   param_value[1] = {limit};
4120   param_value[2] = {delta};
4121   const auto start_type = tf_type_;
4122   std::array<DataType, 3> param_type = {tf_type_, tf_type_, tf_type_};
4123 
4124   Scope s = Scope::NewRootScope();
4125   const auto range =
4126       ops::Range(s.WithOpName("my_range"),
4127                  ops::Placeholder(s.WithOpName(param_name[0]), param_type[0]),
4128                  ops::Placeholder(s.WithOpName(param_name[1]), param_type[1]),
4129                  ops::Placeholder(s.WithOpName(param_name[2]), param_type[2]));
4130 
4131   const NodeDef& ndef = range.operation.node()->def();
4132   const std::vector<DataType> param_types{DT_FLOAT, DT_HALF, DT_INT32};
4133 
4134   // ConverterRange is not implemented for Implicite batch mode.
4135   std::vector<int> config(3, 0);
4136   if (trt_mode_ == TrtTestMode::kImplicitBatch) {
4137     do {
4138       set_parameters(param_name, param_value, param_type, config);
4139       RunValidationAndConversion(ndef, error::UNIMPLEMENTED,
4140                                  "Conversion for Range is not implemented in "
4141                                  "implicit batch mode");
4142     } while (nextTensorWeigtConfiguration(config));
4143 
4144     return;
4145   }
4146 
4147   const std::string expect_msg = convert_range_expected_msg(ndef);
4148   bool all_weights = true;
4149   do {
4150     for (auto limit_type : param_types) {
4151       param_type[1] = limit_type;
4152       for (auto delta_type : param_types) {
4153         param_type[2] = delta_type;
4154 
4155         const auto all_integers = start_type == DT_INT32 &&
4156                                   limit_type == DT_INT32 &&
4157                                   delta_type == DT_INT32;
4158 
4159         if (all_weights || all_integers && !config[2]) {
4160           // Reject invalid parameters if delta = 0 and it's passed as a weight.
4161           param_value[2] = {0};
4162           set_parameters(param_name, param_value, param_type, config);
4163           RunValidationAndConversion(
4164               ndef, error::INVALID_ARGUMENT,
4165               "The delta parameter of Range operation cannot be equal to 0");
4166 
4167           if (!all_weights && !config[2]) {
4168             param_value[2] = {-1};
4169             set_parameters(param_name, param_value, param_type, config);
4170             const string err = StrCat(
4171                 "The delta parameter of Range operation "
4172                 "cannot be negative, when one of (start, limit) is passed as "
4173                 "a tensor, but got ",
4174                 param_value[2][0]);
4175             RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, err);
4176           }
4177         }
4178 
4179         if (all_weights) {
4180           // Reject invalid parameters preventing the limit from
4181           // being reached for fixed values of start and delta.
4182           for (int j = 0; j <= 1; j++) {
4183             param_value[j] = {get_casted_value(start, tf_type_)};
4184             param_value[1 - j] = {get_casted_value(limit, limit_type)};
4185             param_value[2] = {(2 * j - 1) *
4186                               get_casted_value(delta, delta_type)};
4187             set_parameters(param_name, param_value, param_type, config);
4188             const auto error = convert_range_error_msg(
4189                 param_value[0][0], param_value[1][0], param_value[2][0]);
4190             RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, error);
4191           }
4192         }
4193 
4194         param_value[0] = {start};
4195         param_value[2] = {delta};
4196         if (all_integers) {
4197           if (trt_mode_ == TrtTestMode::kDynamicShape) {
4198             // Wrong dimension for the parameter passed as a tensor.
4199             for (int j = 0; j < 3; j++) {
4200               if (!config[j]) continue;
4201 
4202               const string err =
4203                   StrCat("Dimension for '", param_name[j],
4204                          "' of Range operator should be equal to 1");
4205               set_parameters(param_name, param_value, param_type, config, j);
4206               RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, err);
4207             }
4208           }
4209         } else {
4210           if (!all_weights) {
4211             // The following test should fail, when
4212             //    (a) at least one parameter is passed as a tensor;
4213             //    (b) at least one parameter is not of type DT_INT32.
4214             set_parameters(param_name, param_value, param_type, config);
4215             RunValidationAndConversion(ndef, error::UNIMPLEMENTED, expect_msg);
4216           }
4217         }
4218       }
4219     }
4220     // All other configs will be set so that at least one parameter
4221     // will be passed as a tensor
4222     all_weights = false;
4223   } while (nextTensorWeigtConfiguration(config));
4224 
4225   nvinfer1::DataType trt_type;
4226   TF_ASSERT_OK(TfTypeToTrtType(DT_BOOL, &trt_type));
4227   const std::string error_msg =
4228       "Unsupported data type " + DebugString(trt_type) + " used for '";
4229   do {
4230     for (auto limit_type : param_types) {
4231       param_type[1] = limit_type;
4232       for (auto delta_type : param_types) {
4233         param_type[2] = delta_type;
4234 
4235         for (int i = 0; i < 3; i++) {
4236           if (!config[i]) {
4237             const auto saved_type = param_type[i];
4238             param_type[i] = DT_BOOL;
4239             set_parameters(param_name, param_value, param_type, config);
4240             param_type[i] = saved_type;
4241             RunValidationAndConversion(ndef, error::INVALID_ARGUMENT,
4242                                        error_msg + param_name[i] + "'");
4243           }
4244         }
4245       }
4246     }
4247   } while (nextTensorWeigtConfiguration(config));
4248 
4249   // The tests that pass all checks in ConvertRange::Validate().
4250   const Status status = Status::OK();
4251   const std::vector<DataType> int_type{DT_INT32};
4252   int partial_shape_idx = -1;
4253   all_weights = true;
4254   do {
4255     // For now when at least one of (start, limit, delta) is passed as a tensor
4256     //    (a) all these parameters should be of DT_INT32 type;
4257     //    (b) only positive delta could be used.
4258     const auto& types = all_weights ? param_types : int_type;
4259     const auto jEnd = all_weights ? 1 : 0;
4260     for (auto limit_type : types) {
4261       param_type[1] = limit_type;
4262       for (auto delta_type : types) {
4263         param_type[2] = delta_type;
4264         // Loop for positive and negative deltas.
4265         for (int j = 0; j <= jEnd; j++) {
4266           // Define the expected result which should match the usage
4267           // of DT_INT32 for one of (start, limit, delta).
4268           const int mult = (1 - 2 * j);
4269           param_value[j] = {get_casted_value(start, tf_type_)};
4270           param_value[1 - j] = {get_casted_value(limit, limit_type)};
4271           param_value[2] = {mult * get_casted_value(delta, delta_type)};
4272 
4273           // Create expected output.
4274           std::vector<float> expected_output;
4275           const float limit_curr = param_value[1][0];
4276           const float delta_curr = param_value[2][0];
4277           float value = param_value[0][0];
4278           int num_values = 0;
4279           while (mult * (limit_curr - value) > 0) {
4280             num_values++;
4281             expected_output.push_back(value);
4282             value += delta_curr;
4283           }
4284 
4285           set_parameters(param_name, param_value, param_type, config,
4286                          partial_shape_idx);
4287           const std::vector<int> output_dims = {num_values};
4288           TestOpConverter("my_range", ndef, output_dims, status, status,
4289                           ElementsAreArray(expected_output));
4290         }
4291       }
4292     }
4293 
4294     if (all_weights) {
4295       if (start_type != DT_INT32) break;
4296       if (trt_mode_ == TrtTestMode::kDynamicShape) partial_shape_idx = 3;
4297 
4298       // All other configs will be set so that at least one parameter
4299       // will be passed as a tensor
4300       all_weights = false;
4301     }
4302   } while (nextTensorWeigtConfiguration(config));
4303 }
4304 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertLikeOps)4305 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertLikeOps) {
4306   auto get_node = [&](int value) -> NodeDef {
4307     Scope s = Scope::NewRootScope();
4308     auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
4309     if (value == 0) {
4310       auto zeros_like = ops::ZerosLike(s.WithOpName("Zeros"), input);
4311       return zeros_like.operation.node()->def();
4312     }
4313     auto ones_like = ops::OnesLike(s.WithOpName("Ones"), input);
4314     return ones_like.operation.node()->def();
4315   };
4316 
4317   for (int value : {0, 1}) {
4318     Reset();
4319     const NodeDef& node_def = get_node(value);
4320     const std::string name = value ? "Ones" : "Zeros";
4321 
4322     if (trt_mode_ == TrtTestMode::kImplicitBatch) {
4323       std::vector<float> input_data(8, 42.0f);
4324       AddTestTensor("input", {8}, tf_type_, input_data);
4325       RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
4326                                  "Conversion for " + name + "Like is not " +
4327                                      "implemented in implicit batch mode");
4328       continue;
4329     }
4330 
4331     std::vector<std::vector<int>> output_dims_params = {
4332         {8}, {8, 2, 4}, {32, 32, 3200}};
4333 
4334     float val = 42.0;
4335     Status status = Status::OK();
4336     for (bool input_is_tensor : {true, false}) {
4337       for (auto output_dims : output_dims_params) {
4338         Reset();
4339         size_t nb_el = 1;
4340         for (auto d : output_dims) {
4341           nb_el *= d;
4342         }
4343         std::vector<float> input_data(nb_el, val);
4344         if (input_is_tensor) {
4345           AddTestTensor("input", output_dims, tf_type_, input_data);
4346         } else {
4347           AddTestWeights("input", output_dims, input_data, tf_type_);
4348         }
4349         std::vector<float> expected_output(nb_el, value);
4350         TestOpConverter(name, node_def, output_dims, status, status,
4351                         ElementsAreArray(expected_output));
4352       }
4353     }
4354   }
4355 }
4356 
4357 #endif  // IS_TRT_VERSION_GE(8, 2, 0, 0)
4358 
4359 #if IS_TRT_VERSION_GE(8, 2, 1, 6) || defined(TF_TRT_USE_EFFICIENT_NMS_PLUGIN)
4360 
TEST_P(OpConverter_FP32_Test,ConvertCombinedNMS)4361 TEST_P(OpConverter_FP32_Test, ConvertCombinedNMS) {
4362   // Get the NodeDef for CombinedNMS.
4363   auto get_nms_nodedef = [](DataType tf_type, bool clip_boxes = true,
4364                             bool pad_per_class = false) -> NodeDef {
4365     Scope s = Scope::NewRootScope();
4366     auto boxes_tensor = ops::Placeholder(s.WithOpName("boxes"), tf_type);
4367     auto scores_tensor = ops::Placeholder(s.WithOpName("scores"), tf_type);
4368     auto max_output_size_per_class =
4369         ops::Placeholder(s.WithOpName("max_output_size_per_class"), DT_INT32);
4370     auto max_total_size =
4371         ops::Placeholder(s.WithOpName("max_total_size"), DT_INT32);
4372     auto iou_threshold =
4373         ops::Placeholder(s.WithOpName("iou_threshold"), tf_type);
4374     auto score_threshold =
4375         ops::Placeholder(s.WithOpName("score_threshold"), tf_type);
4376     auto nms_attrs = ops::CombinedNonMaxSuppression::Attrs()
4377                          .PadPerClass(pad_per_class)
4378                          .ClipBoxes(clip_boxes);
4379 
4380     auto nms_op = ops::CombinedNonMaxSuppression(
4381         s.WithOpName("my_nms"), boxes_tensor, scores_tensor,
4382         max_output_size_per_class, max_total_size, iou_threshold,
4383         score_threshold, nms_attrs);
4384     return nms_op.operation.node()->def();
4385   };
4386 
4387   struct TestParams {
4388     const std::string description;
4389     const std::vector<int32> boxes_tensor_dims;
4390     const std::vector<int32> scores_tensor_dims;
4391     const std::vector<float> boxes_values;
4392     const std::vector<float> scores_values;
4393     const int32 max_output_size_per_class;
4394     const int32 max_total_size;
4395     const float iou_threshold;
4396     const float score_threshold;
4397     bool pad_per_class;
4398     bool clip_boxes;
4399     const std::vector<std::vector<int32>> expected_output_dims;
4400     const std::vector<float> exp_boxes;
4401     const std::vector<float> exp_scores;
4402     const std::vector<float> exp_classes;
4403     const std::vector<float> exp_num_detections;
4404     Status conversion_status;
4405     Status runtime_status;
4406   };
4407 
4408   Status conv_status =
4409       trt_mode_ == TrtTestMode::kImplicitBatch
4410           ? errors::Unimplemented(
4411                 "Implict batch mode not supported with CombinedNMS")
4412           : Status::OK();
4413 
4414   std::vector<TestParams> params = {
4415       TestParams{"Test 1: clip boxes",
4416                  {1, 1, 3, 4},  // boxes dims
4417                  {1, 1, 3},     // scores dims
4418                                 // boxes values:
4419                  {0, 0, 0.3, 1.4, 0, 0, 0.3, 1.4, 0, 0, 0.3, 1.4},
4420                  {0.4, 0.7, 0.3},  // scores values
4421                  3,                // max_output_size_per_class
4422                  2,                // max_total_size
4423                  0.1,              // IOU threshold
4424                  0,                // score_threshold
4425                  false,            // pad_per_class
4426                  true,             // clip_boxes
4427                  {{1, 2, 4},       // expected_nmsed_boxes_dims
4428                   {1, 2},          // expected_nmsed_scores_dims
4429                   {1, 2},          // expected_nmsed_classes_dims
4430                   {1}},            // expected_valid_detections_dims
4431                                    // exp_boxes_values:
4432                  {0, 0, 0.3, 1.0, 0, 0, 0.3, 1.0},
4433                  {0.7, 0.4},  // exp_scores
4434                  {1, 0},      // exp_classes
4435                  {2},         // exp_num_detections
4436                  conv_status},
4437       TestParams{
4438           "Test 2: iou threshold",
4439           {1, 5, 1, 4},  // boxes dims
4440           {1, 5, 1},     // scores dims
4441                          // boxes values:
4442           {0, 0, 5, 10, 0, 1, 5, 11, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4443           {5, 4, 3, 2, 1},  // scores values
4444           4,                // max_output_size_per_class
4445           4,                // max_total_size
4446           0.7,              // IOU threshold
4447           0,                // score threshold
4448           false,            // pad_per_class
4449           false,            // clip_boxes
4450           {{1, 4, 4},       // expected nmsed_boxes_dims
4451            {1, 4},          // expected nmsed_scores_dims
4452            {1, 4},          // expected_nmsed_classes_dims
4453            {1}},            // expected_valid_detections_dims
4454                             // exp_boxes_values:
4455           {0, 0, 5, 10, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4456           {5, 3, 2, 1},  // exp_scores
4457           {0, 0, 0, 0},  // exp_classes
4458           {4},           // exp_num_detections
4459           conv_status},
4460       TestParams{
4461           "Test 3: score threshold",
4462           {1, 5, 1, 4},  // boxes dims
4463           {1, 5, 1},     // scores dims
4464                          // boxes values:
4465           {0, 0, 5, 10, 0, 1, 5, 11, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4466           {5, 4, 3, 2, 1},  // scores values
4467           4,                // max_output_size_per_class
4468           4,                // max_total_size
4469           0.1,              // IOU threshold
4470           2,                // score threshold
4471           false,            // pad_per_class
4472           false,            // clip_boxes
4473           {{1, 4, 4},       // expected nmsed_boxes_dims
4474            {1, 4},          // expected nmsed_scores_dims
4475            {1, 4},          // expected_nmsed_classes_dims
4476            {1}},            // expected_valid_detections_dims
4477                             // exp_boxes_values:
4478           {0, 0, 5, 10, 8, 0, 12, 4, 0, 0, 0, 0, 0, 0, 0, 0},
4479           {5, 3, 0, 0},  // exp_scores
4480           {0, 0, 0, 0},  // exp_classes
4481           {2},           // exp_num_detections
4482           conv_status},
4483       TestParams{
4484           "Test 4: per class size and pad",
4485           {1, 5, 1, 4},  // boxes dims
4486           {1, 5, 2},     // scores dims
4487                          // boxes values:
4488           {0, 0, 5, 10, 0, 1, 5, 11, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4489           // scores values:
4490           {5, 0, 0, 4, 3, 0, 2, 0, 1, 0},
4491           1,           // max_output_size_per_class
4492           4,           // max_total_size
4493           0.1,         // IOU threshold
4494           0,           // score threshold
4495           true,        // pad_per_class
4496           false,       // clip_boxes
4497           {{1, 2, 4},  // expected nmsed_boxes_dims
4498            {1, 2},     // expected nmsed_scores_dims
4499            {1, 2},     // expected_nmsed_classes_dims
4500            {1}},       // expected_valid_detections_dims
4501                        // exp_boxes_values:
4502           {0, 0, 5, 10, 0, 1, 5, 11},
4503           {5, 4},  // exp_scores
4504           {0, 1},  // exp_classes
4505           {2},     // exp_num_detections
4506           conv_status},
4507       TestParams{
4508           "Test 5: different box coordinate order",
4509           {1, 5, 1, 4},  // boxes dims
4510           {1, 5, 2},     // scores dims
4511                          // boxes values:
4512           {5, 10, 0, 0, 5, 11, 0, 1, 12, 4, 8, 0, 10, 6, 6, 2, 11, 12, 8, 9},
4513           // scores values:
4514           {5, 0, 0, 4, 3, 0, 2, 0, 1, 0},
4515           1,           // max_output_size_per_class
4516           4,           // max_total_size
4517           0.1,         // IOU threshold
4518           0,           // score threshold
4519           true,        // pad_per_class
4520           false,       // clip_boxes
4521           {{1, 2, 4},  // expected nmsed_boxes_dims
4522            {1, 2},     // expected nmsed_scores_dims
4523            {1, 2},     // expected_nmsed_classes_dims
4524            {1}},       // expected_valid_detections_dims
4525                        // exp_boxes_values:
4526           {5, 10, 0, 0, 5, 11, 0, 1},
4527           {5, 4},  // exp_scores
4528           {0, 1},  // exp_classes
4529           {2},     // exp_num_detections
4530           conv_status},
4531   };
4532 
4533   for (auto p : params) {
4534     Reset();
4535     SCOPED_TRACE(p.description);
4536     AddTestTensor("boxes", p.boxes_tensor_dims, p.boxes_values);
4537     AddTestTensor("scores", p.scores_tensor_dims, p.scores_values);
4538     AddTestWeights<int32>("max_output_size_per_class", {1},
4539                           {p.max_output_size_per_class});
4540     AddTestWeights<int32>("max_total_size", {1}, {p.max_total_size});
4541     AddTestWeights<float>("iou_threshold", {1}, {p.iou_threshold}, tf_type_);
4542     AddTestWeights<float>("score_threshold", {1}, {p.score_threshold},
4543                           tf_type_);
4544 
4545     auto node_def = get_nms_nodedef(tf_type_, p.clip_boxes, p.pad_per_class);
4546 
4547     TestOpConverterMultiOut("my_nms", node_def, p.expected_output_dims,
4548                             p.conversion_status, p.runtime_status,
4549                             {
4550                                 ElementsAreArray(p.exp_boxes),
4551                                 ElementsAreArray(p.exp_scores),
4552                                 ElementsAreArray(p.exp_classes),
4553                                 ElementsAreArray(p.exp_num_detections),
4554                             },
4555                             {tf_type_, tf_type_, tf_type_, DT_INT32});
4556   }
4557 }
4558 
4559 #elif IS_TRT_VERSION_GE(7, 1, 3, 0)
4560 
TEST_P(OpConverter_FP32_Test,ConvertCombinedNMS)4561 TEST_P(OpConverter_FP32_Test, ConvertCombinedNMS) {
4562   // Get the NodeDef for CombinedNMS.
4563   auto get_nms_nodedef = [](DataType tf_type, bool clip_boxes = true,
4564                             bool pad_per_class = false) -> NodeDef {
4565     Scope s = Scope::NewRootScope();
4566     auto boxes_tensor = ops::Placeholder(s.WithOpName("boxes"), tf_type);
4567     auto scores_tensor = ops::Placeholder(s.WithOpName("scores"), tf_type);
4568     auto max_output_size_per_class =
4569         ops::Placeholder(s.WithOpName("max_output_size_per_class"), DT_INT32);
4570     auto max_total_size =
4571         ops::Placeholder(s.WithOpName("max_total_size"), DT_INT32);
4572     auto iou_threshold =
4573         ops::Placeholder(s.WithOpName("iou_threshold"), tf_type);
4574     auto score_threshold =
4575         ops::Placeholder(s.WithOpName("score_threshold"), tf_type);
4576     auto nms_attrs = ops::CombinedNonMaxSuppression::Attrs()
4577                          .PadPerClass(pad_per_class)
4578                          .ClipBoxes(clip_boxes);
4579 
4580     auto nms_op = ops::CombinedNonMaxSuppression(
4581         s.WithOpName("my_nms"), boxes_tensor, scores_tensor,
4582         max_output_size_per_class, max_total_size, iou_threshold,
4583         score_threshold, nms_attrs);
4584     return nms_op.operation.node()->def();
4585   };
4586 
4587   struct TestParams {
4588     const std::string description;
4589     const std::vector<int32> boxes_tensor_dims;
4590     const std::vector<int32> scores_tensor_dims;
4591     const std::vector<float> boxes_values;
4592     const std::vector<float> scores_values;
4593     const int32 max_output_size_per_class;
4594     const int32 max_total_size;
4595     const float iou_threshold;
4596     const float score_threshold;
4597     bool pad_per_class;
4598     bool clip_boxes;
4599     const std::vector<std::vector<int32>> expected_output_dims;
4600     const std::vector<float> exp_boxes;
4601     const std::vector<float> exp_scores;
4602     const std::vector<float> exp_classes;
4603     const std::vector<float> exp_num_detections;
4604     Status conversion_status;
4605     Status runtime_status;
4606   };
4607 
4608   Status conv_status =
4609       trt_mode_ == TrtTestMode::kDynamicShape
4610           ? errors::Unimplemented(
4611                 "TensorRT BatchedNMS Plugin requires input with static shape")
4612           : Status::OK();
4613 
4614   std::vector<TestParams> params = {
4615       // TODO(aaroey): there is a bug in TRT's CombinedNonMaxSuppression
4616       // implementation that, the extra output classes that are outside of the
4617       // range specified by valid_detections[i] are not zeros but -1s.
4618       TestParams{
4619           "Test 1: Original test",
4620           {1, 1, 3, 4},                                      // boxes dims
4621           {1, 1, 3},                                         // scores dims
4622           {0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4},  // boxes values
4623           {0.4, 0.7, 0.3},                                   // scores values
4624           3,                                 // max_output_size_per_class
4625           2,                                 // max_total_size
4626           .5f,                               // IOU threshold
4627           0,                                 // score_threshold
4628           false,                             // pad_per_class
4629           true,                              // clip_boxes
4630           {{1, 2, 4},                        // expected_nmsed_boxes_dims
4631            {1, 2},                           // expected_nmsed_scores_dims
4632            {1, 2},                           // expected_nmsed_classes_dims
4633            {1}},                             // expected_valid_detections_dims
4634           {0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4},  // exp_boxes_values
4635           {0.7, 0.4},                        // exp_scores
4636           {1, 0},                            // exp_classes
4637           {2},                               // exp_num_detections
4638           conv_status},
4639       // Test with clip_boxes = False
4640       TestParams{
4641           "Test 2: clip_boxes",
4642           {1, 5, 1, 4},  // boxes dims
4643           {1, 5, 1},     // scores dims
4644           // boxes values:
4645           {0, 0, 5, 10, 0, 4, 5, 14, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4646           {5, 4, 3, 2, 1},  // scores values
4647           4,                // max_output_size_per_class
4648           4,                // max_total_size
4649           0.1,              // IOU threshold
4650           0,                // score threshold
4651           false,            // pad_per_class
4652           false,            // clip_boxes
4653           {{1, 4, 4},       // expected nmsed_boxes_dims
4654            {1, 4},          // expected nmsed_scores_dims
4655            {1, 4},          // expected_nmsed_classes_dims
4656            {1}},            // expected_valid_detections_dims
4657                             // exp_boxes_values:
4658           {0, 0, 5, 10, 8, 0, 12, 4, 8, 9, 11, 12, 0, 0, 0, 0},
4659           {5, 3, 1, 0},   // exp_scores
4660           {0, 0, 0, -1},  // exp_classes
4661           {3},            // exp_num_detections
4662           conv_status},
4663       // Test with clip_boxes = False, and nonzero score threshold
4664       TestParams{
4665           "Test 3: score threshold",
4666           {1, 5, 1, 4},  // boxes dims
4667           {1, 5, 1},     // scores dims
4668           // boxes values:
4669           {0, 0, 5, 10, 0, 4, 5, 14, 8, 0, 12, 4, 6, 2, 10, 6, 8, 9, 11, 12},
4670           {5, 4, 3, 2, 1},  // scores values
4671           4,                // max_output_size_per_class
4672           4,                // max_total_size
4673           0.1,              // IOU threshold
4674           2,                // score threshold
4675           false,            // pad_per_class
4676           false,            // clip_boxes
4677           {{1, 4, 4},       // expected nmsed_boxes_dims
4678            {1, 4},          // expected nmsed_scores_dims
4679            {1, 4},          // expected_nmsed_classes_dims
4680            {1}},            // expected_valid_detections_dims
4681                             // exp_boxes_values:
4682           {0, 0, 5, 10, 8, 0, 12, 4, 0, 0, 0, 0, 0, 0, 0, 0},
4683           {5, 3, 0, 0},    // exp_scores
4684           {0, 0, -1, -1},  // exp_classes
4685           {2},             // exp_num_detections
4686           conv_status},
4687       // Test where the boxes are defined as with max value first for the box
4688       // coordinates. This test fails before TRT 7.1.3.
4689       TestParams{
4690           "Test 4: max coord first",
4691           {1, 5, 1, 4},  // boxes dims
4692           {1, 5, 1},     // scores dims
4693                          // boxes values:
4694           {5, 10, 0, 0, 5, 14, 0, 4, 12, 4, 8, 0, 10, 6, 6, 2, 11, 12, 8, 9},
4695           {5, 4, 3, 2, 1},  // scores values
4696           4,                // max_output_size_per_class
4697           4,                // max_total_size
4698           0.1,              // IOU threshold
4699           0,                // score threshold
4700           false,            // pad_per_class
4701           false,            // clip_boxes
4702           {{1, 4, 4},       // expected nmsed_boxes_dims
4703            {1, 4},          // expected nmsed_scores_dims
4704            {1, 4},          // expected_nmsed_classes_dims
4705            {1}},            // expected_valid_detections_dims
4706                             // exp_boxes_values:
4707           {5, 10, 0, 0, 12, 4, 8, 0, 11, 12, 8, 9, 0, 0, 0, 0},
4708           {5, 3, 1, 0},   // exp_scores
4709           {0, 0, 0, -1},  // exp_classes
4710           {3},            // exp_num_detections
4711           conv_status},
4712       TestParams{"Test 5: TopK error",
4713                  {1, 5000, 1, 4},  // boxes dims
4714                  {1, 5000, 1},     // scores dims
4715                  {},               // boxes values:
4716                  {},               // scores values
4717                  4,                // max_output_size_per_class
4718                  4,                // max_total_size
4719                  0.1,              // IOU threshold
4720                  0,                // score threshold
4721                  false,            // pad_per_class
4722                  false,            // clip_boxes
4723                  {},               // expected_valid_detections_dims
4724                  {},               // exp_boxes_values
4725                  {},               // exp_scores
4726                  {},               // exp_classes
4727                  {},               // exp_num_detections
4728                  conv_status.ok()
4729                      ? errors::InvalidArgument(
4730                            "TRT NMS plugin allow top_k<=4096, where top_k = "
4731                            "max(num_boxes, max_total_size). You can override "
4732                            "this by setting TF_TRT_ALLOW_NMS_TOPK_OVERRIDE=1 "
4733                            "environment variable, but this can result in a "
4734                            "loss of accuracy.")
4735                      : conv_status},
4736   };
4737 
4738   for (auto p : params) {
4739     Reset();
4740     SCOPED_TRACE(p.description);
4741     AddTestTensor("boxes", p.boxes_tensor_dims, p.boxes_values);
4742     AddTestTensor("scores", p.scores_tensor_dims, p.scores_values);
4743     AddTestWeights<int32>("max_output_size_per_class", {1},
4744                           {p.max_output_size_per_class});
4745     AddTestWeights<int32>("max_total_size", {1}, {p.max_total_size});
4746     AddTestWeights<float>("iou_threshold", {1}, {p.iou_threshold}, tf_type_);
4747     AddTestWeights<float>("score_threshold", {1}, {p.score_threshold},
4748                           tf_type_);
4749 
4750     auto node_def = get_nms_nodedef(tf_type_, p.clip_boxes, p.pad_per_class);
4751 
4752     TestOpConverterMultiOut("my_nms", node_def, p.expected_output_dims,
4753                             p.conversion_status, p.runtime_status,
4754                             {
4755                                 ElementsAreArray(p.exp_boxes),
4756                                 ElementsAreArray(p.exp_scores),
4757                                 ElementsAreArray(p.exp_classes),
4758                                 ElementsAreArray(p.exp_num_detections),
4759                             },
4760                             {tf_type_, tf_type_, tf_type_, DT_INT32});
4761   }
4762 }
4763 
4764 #endif  // IS_TRT_VERSION_GE(7, 1, 3, 0)
4765 
4766 template <typename T>
CreateUnaryOp(DataType tf_type)4767 NodeDef CreateUnaryOp(DataType tf_type) {
4768   Scope s = Scope::NewRootScope();
4769   auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
4770   return T(s.WithOpName("my_unary"), input).operation.node()->def();
4771 }
4772 
4773 constexpr float kLeakyReluAlpha = 0.2f;
4774 template <>
CreateUnaryOp(DataType tf_type)4775 NodeDef CreateUnaryOp<ops::internal::LeakyRelu>(DataType tf_type) {
4776   Scope s = Scope::NewRootScope();
4777   auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
4778   return ops::internal::LeakyRelu(
4779              s.WithOpName("my_unary"), input,
4780              ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha))
4781       .operation.node()
4782       ->def();
4783 }
4784 
TEST_P(OpConverter_FP32_UnaryTest,ConvertActivation)4785 TEST_P(OpConverter_FP32_UnaryTest, ConvertActivation) {
4786   constexpr float kSeluAlpha = 1.7580993408473768599402175208123f;
4787   constexpr float kSeluScale = 1.0507009873554804934193349852946f;
4788   using OpFunc = std::function<NodeDef(DataType)>;
4789   using ValFunc = float (*)(float);
4790   std::map<std::string, std::pair<OpFunc, ValFunc>> op_map;
4791 
4792 #define ADD_OP(name, op, compute) \
4793   op_map[name] = std::make_pair(CreateUnaryOp<op>, compute)
4794   ADD_OP("LeakyRelu", ops::internal::LeakyRelu,
4795          [](float x) { return (x > 0.0f) ? x : x * kLeakyReluAlpha; });
4796   ADD_OP("Relu", ops::Relu, [](float x) { return (x > 0.0f) ? x : 0.0f; });
4797   ADD_OP("Relu6", ops::Relu6,
4798          [](float x) { return std::min(std::max(x, 0.0f), 6.0f); });
4799   ADD_OP("Sigmoid", ops::Sigmoid,
4800          [](float x) { return 1.0f / (1.0f + std::exp(-x)); });
4801   ADD_OP("Tanh", ops::Tanh, static_cast<ValFunc>(std::tanh));
4802   ADD_OP("Elu", ops::Elu,
4803          [](float x) { return (x > 0.0f) ? x : std::exp(x) - 1; });
4804   ADD_OP("Selu", ops::Selu, [](float x) {
4805     return (x > 0.0f) ? kSeluScale * x
4806                       : kSeluScale * kSeluAlpha * (std::exp(x) - 1);
4807   });
4808   ADD_OP("Softsign", ops::Softsign,
4809          [](float x) { return x / (std::abs(x) + 1); });
4810   ADD_OP("Softplus", ops::Softplus,
4811          [](float x) { return std::log(std::exp(x) + 1); });
4812 #undef ADD_OP
4813 
4814   // std::exp in Softplus will overflow for input > 88
4815   const std::vector<float> input = {-100, -2, -1, 0, 1, 88};
4816   const bool nan_sensitive = false;
4817 
4818 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
4819   // NVBug # 3322482 - Known bug with TRT 8.0 on specific GPU architectures
4820   const float max_abs_error = 1e-4;
4821 #else
4822   const float max_abs_error = 0.;
4823 #endif
4824   RunTests("Activation", *ActivationTypeMap(), op_map, input, "input",
4825            max_abs_error, nan_sensitive);
4826 }
4827 
TEST_P(OpConverter_FP32_Test,ConvertExpandDims)4828 TEST_P(OpConverter_FP32_Test, ConvertExpandDims) {
4829   // Get the NodeDef for ExpandDims.
4830   Scope s = Scope::NewRootScope();
4831   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
4832   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
4833   auto expanddims =
4834       ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
4835   const NodeDef& node_def = expanddims.operation.node()->def();
4836   {
4837     // Input is weights, should fail.
4838     Reset();
4839     AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
4840     AddTestWeights<int32>("weights", {1}, {1});
4841     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
4842                                "The input \"input\" for ExpandDims must be a "
4843                                "tensor");
4844   }
4845   {
4846     // Axis is a tensor, should fail.
4847     Reset();
4848     AddTestTensor("input", {3, 2, 1});
4849     AddTestTensor("weights", {3});
4850     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
4851                                "The input \"axis\" for ExpandDims must be a "
4852                                "constant");
4853   }
4854   std::vector<TestParamBase> test_params = {
4855       TestParamBase{{1, 1, 2, 3},
4856                     {},
4857                     {1, 1, 1, 2, 3},
4858                     {0},
4859                     trt_mode_ == TrtTestMode::kImplicitBatch
4860                         ? Status(error::UNIMPLEMENTED,
4861                                  "TensorRT does not allow manipulation of the "
4862                                  "batch dimension")
4863                         : Status::OK()},
4864       TestParamBase{{1, 1, 2, 3},
4865                     {},
4866                     {1, 1, 1, 2, 3},
4867                     {-5},
4868                     trt_mode_ == TrtTestMode::kImplicitBatch
4869                         ? Status(error::UNIMPLEMENTED,
4870                                  "TensorRT does not allow manipulation of the "
4871                                  "batch dimension")
4872                         : Status::OK()},
4873       TestParamBase{{1, 1, 2, 3},
4874                     {},
4875                     {},
4876                     {5},
4877                     Status(error::INVALID_ARGUMENT,
4878                            "Axis value of 5 is out of bounds, must be in range"
4879                            " [-5, 5)")},
4880       TestParamBase{{1, 1, 2, 3},
4881                     {},
4882                     {},
4883                     {-6},
4884                     Status(error::INVALID_ARGUMENT,
4885                            "Axis value of -6 is out of bounds, must be in range"
4886                            " [-5, 5)")},
4887       TestParamBase{{1, 2, 3}, {}, {1, 1, 2, 3}, {1}},
4888       TestParamBase{{1, 2, 3}, {}, {1, 1, 2, 3}, {-3}},
4889       TestParamBase{{1, 2, 3}, {}, {1, 2, 3, 1}, {3}},
4890       TestParamBase{{1, 2, 3}, {}, {1, 2, 3, 1}, {-1}},
4891       TestParamBase{{1, 2, 3}, {}, {1, 2, 1, 3}, {2}},
4892       TestParamBase{{1, 2, 3}, {}, {1, 2, 1, 3}, {-2}},
4893       TestParamBase{{1, 6}, {}, {1, 1, 6}, {1}},
4894       TestParamBase{{1, 6}, {}, {1, 6, 1}, {-1}},
4895   };
4896   for (auto p : test_params) {
4897     Reset();
4898     AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6});
4899     AddTestWeights<int32>("weights", {1}, {p.param[0]});
4900     TestOpConverter("my_expanddims", node_def, p.expected_output_dims, p.status,
4901                     p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6}));
4902   }
4903 }
4904 
TEST_P(OpConverter_FP32_FP16_Test,ConvertSoftmax)4905 TEST_P(OpConverter_FP32_FP16_Test, ConvertSoftmax) {
4906   // Get the NodeDef for SoftMax.
4907   Scope s = Scope::NewRootScope();
4908   auto input = ops::Placeholder(s.WithOpName("logits"), tf_type_);
4909   auto softmax = ops::Softmax(s.WithOpName("my_softmax"), input);
4910   const NodeDef& node_def = softmax.operation.node()->def();
4911 
4912   struct TestParams {
4913     std::vector<int> input_dims;
4914     std::vector<float> expected_values;
4915   };
4916   std::vector<TestParams> test_params = {
4917       TestParams{/*input_dims=*/{2, 3},
4918                  /*expected_values=*/{0.09003057, 0.24472848, 0.66524094,
4919                                       0.09003057, 0.24472848, 0.66524094}},
4920       TestParams{/*input_dims=*/{6, 1},
4921                  /*expected_values=*/{1, 1, 1, 1, 1, 1}},  // works w/ std input
4922       TestParams{/*input_dims=*/{1, 6},  // this works w/ arange(1,7) input
4923                  /*expected_values=*/{0.00426978, 0.01160646, 0.03154963,
4924                                       0.08576079, 0.23312202, 0.6336913}}};
4925   std::vector<float> input_values{1, 2, 3, 4, 5, 6};
4926   for (auto p : test_params) {
4927     Reset();
4928     AddTestTensor("logits", p.input_dims, input_values);
4929     TestOpConverter("my_softmax", node_def, p.input_dims, Status::OK(),
4930                     Status::OK(), ArrayFloatNear(p.expected_values, 1e-3));
4931   }
4932 }
4933 
TEST_P(OpConverter_FP32_FP16_Test,ConvertLogSoftmax)4934 TEST_P(OpConverter_FP32_FP16_Test, ConvertLogSoftmax) {
4935   // Get the NodeDef for LogSoftMax.
4936   Scope s = Scope::NewRootScope();
4937   auto input = ops::Placeholder(s.WithOpName("logits"), tf_type_);
4938   auto logsoftmax = ops::LogSoftmax(s.WithOpName("my_logsoftmax"), input);
4939   const NodeDef& node_def = logsoftmax.operation.node()->def();
4940 
4941   struct TestParams {
4942     std::vector<int> input_dims;
4943     std::vector<float> expected_values;
4944   };
4945 
4946   std::vector<TestParams> test_params = {
4947       TestParams{/*input_dims=*/{2, 3},
4948                  /*expected_values=*/{-2.4076061, -1.407606, -0.40760604,
4949                                       -2.4076061, -1.407606, -0.40760604}},
4950       TestParams{/*input_dims=*/{1, 6},
4951                  /*expected_values=*/{-5.4561934, -4.4561934, -3.4561934,
4952                                       -2.4561934, -1.4561933, -0.45619333}},
4953       TestParams{/*input_dims=*/{6, 1},
4954                  /*expected_values=*/{0, 0, 0, 0, 0, 0}}};
4955   std::vector<float> input_values{1, 2, 3, 4, 5, 6};
4956   for (auto p : test_params) {
4957     Reset();
4958     AddTestTensor("logits", p.input_dims, input_values);
4959     TestOpConverter("my_logsoftmax", node_def, p.input_dims, Status::OK(),
4960                     Status::OK(), ArrayFloatNear(p.expected_values, 1e-3));
4961   }
4962 }
4963 
TEST_P(OpConverter_FP32_Test,ConvertSqueeze)4964 TEST_P(OpConverter_FP32_Test, ConvertSqueeze) {
4965   const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch);
4966   // Get the NodeDef for Squeeze.
4967   auto get_squeeze_nodedef = [](std::vector<int> axes,
4968                                 DataType tf_type) -> NodeDef {
4969     Scope s = Scope::NewRootScope();
4970     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
4971     if (!axes.empty()) {
4972       ops::Squeeze::Attrs squeeze_attrs;
4973       squeeze_attrs.axis_ = gtl::ArraySlice<int>(axes);  // non-absl ok
4974       auto squeeze =
4975           ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs);
4976       return squeeze.operation.node()->def();
4977     } else {
4978       auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input);
4979       return squeeze.operation.node()->def();
4980     }
4981   };
4982   std::vector<TestParamBase> test_params = {
4983       TestParamBase{
4984           {1, 2, 1, 3},  // input dims
4985           {},            // input partial dims
4986           {2, 3},        // expected output dims
4987           {},            // axis
4988           trt_mode_ == TrtTestMode::kExplicitBatch
4989               ? Status::OK()
4990               : Status{error::UNIMPLEMENTED,
4991                        "Squeeze is not implemented for empty squeeze_dims"}},
4992       TestParamBase{{1, 2, 1, 3},
4993                     {},
4994                     {2, 1, 3},
4995                     {0},
4996                     use_implicit_batch
4997                         ? Status{error::UNIMPLEMENTED,
4998                                  "TensorRT does not allow manipulation of the "
4999                                  "batch dimension"}
5000                         : Status::OK()},
5001       TestParamBase{{1, 2, 1, 3},
5002                     {},
5003                     {2, 1, 3},
5004                     {-4},
5005                     use_implicit_batch
5006                         ? Status{error::UNIMPLEMENTED,
5007                                  "TensorRT does not allow manipulation of the "
5008                                  "batch dimension"}
5009                         : Status::OK()},
5010       TestParamBase{
5011           {1, 1, 2, 3},
5012           {},
5013           {},
5014           {4},
5015           Status{error::INVALID_ARGUMENT,
5016                  "Axis value of 4 is out of bounds, must be in range [-4, 4)"}},
5017       TestParamBase{
5018           {1, 1, 2, 3},
5019           {},
5020           {},
5021           {-5},
5022           Status{
5023               error::INVALID_ARGUMENT,
5024               "Axis value of -5 is out of bounds, must be in range [-4, 4)"}},
5025       TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {1}},
5026       TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {-3}},
5027       TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {3}},
5028       TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {-1}},
5029       TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, 3, 5}},
5030       TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {3, 1, 5}},
5031       TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {-1, -3, -5}},
5032       TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, -3, 5}},
5033       TestParamBase{{1, 1, 6}, {}, {1, 6}, {1}},
5034       TestParamBase{{1, 6, 1}, {}, {1, 6}, {2}},
5035   };
5036   auto squeeze_non_singleton = TestParamBase{
5037       {1, 1, 2, 3},
5038       {},
5039       {},
5040       {2},
5041       Status{error::INVALID_ARGUMENT,
5042              "Dimension 2 with size 2 cannot be squeezed because it must be "
5043              "size 1"}};
5044 
5045   if (trt_mode_ == TrtTestMode::kDynamicShape) {
5046     // In this test we try to squeeze axis=2 which has size > 1. In dynamic
5047     // shape mode the converter sees only -1, so it cannot catch this error.
5048     squeeze_non_singleton.status = Status::OK();  // conversion status
5049     squeeze_non_singleton.runtime_status =
5050         errors::InvalidArgument("Negative number of dimensions -1");
5051     // Dynamic shape tests with partially known input shape
5052     test_params.push_back(TestParamBase{{2, 1, 3}, {2, -1, 3}, {2, 3}, {1}});
5053     test_params.push_back(TestParamBase{{2, 1, 3}, {2, 1, -1}, {2, 3}, {1}});
5054   }
5055   test_params.push_back(squeeze_non_singleton);
5056 
5057   for (TestParamBase p : test_params) {
5058     SCOPED_TRACE(p);
5059     Reset();
5060     NodeDef node_def = get_squeeze_nodedef(p.param, tf_type_);
5061     AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6},
5062                   p.partial_input_dims);
5063     TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status,
5064                     p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6}));
5065   }
5066 }
5067 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertStridedSlice)5068 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertStridedSlice) {
5069   // Get nodedef for StridedSlice layer.
5070   auto get_strided_slice_nodedef =
5071       [](DataType tf_type, int64 begin_mask = 0, int64 end_mask = 0,
5072          int64 ellipsis_mask = 0, int64 new_axis_mask = 0,
5073          int64 shrink_axis_mask = 0) -> NodeDef {
5074     Scope s = Scope::NewRootScope();
5075     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
5076     auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
5077     auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32);
5078     auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32);
5079     ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs()
5080                                          .BeginMask(begin_mask)
5081                                          .EndMask(end_mask)
5082                                          .EllipsisMask(ellipsis_mask)
5083                                          .NewAxisMask(new_axis_mask)
5084                                          .ShrinkAxisMask(shrink_axis_mask);
5085     auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"),
5086                                            input, begin, end, strides, attrs);
5087     return strided_slice.operation.node()->def();
5088   };
5089 
5090   {
5091     // Input is weights, should fail.
5092     Reset();
5093     NodeDef node_def = get_strided_slice_nodedef(tf_type_);
5094     AddTestWeights<int32>("input", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
5095     AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
5096     AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
5097     AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
5098     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
5099                                "The input \"input\" for StridedSlice must "
5100                                "be a tensor");
5101   }
5102   {
5103     // Begin, end, strides are tensors, should fail.
5104     Reset();
5105     NodeDef node_def = get_strided_slice_nodedef(tf_type_);
5106     AddTestTensor("input", {4, 1, 1, 1});
5107     AddTestTensor("begin", {4});
5108     AddTestTensor("end", {4});
5109     AddTestTensor("strides", {4});
5110     RunValidationAndConversion(
5111         node_def, error::UNIMPLEMENTED,
5112         "The input \"begin\" for StridedSlice must be a constant");
5113   }
5114 
5115   struct TestParams {
5116     std::vector<int> input_dims;
5117     std::vector<int> begin;
5118     std::vector<int> end;
5119     std::vector<int> strides;
5120     int begin_mask;
5121     int end_mask;
5122     int ellipsis_mask;
5123     int new_axis_mask;
5124     int shrink_axis_mask;
5125     std::vector<int> expected_output_dims;
5126     std::vector<float> expected_output;
5127     Status conversion_status;
5128     Status runtime_status;
5129     std::vector<int> partial_input_dims;
5130   };
5131 
5132   auto get_mask = [](const std::vector<int>& mask) {
5133     int result = 0;
5134     for (int i = 0; i < mask.size(); i++) {
5135       if (mask[i]) result += (1 << i);
5136     }
5137     return result;
5138   };
5139 
5140   // Same input is used for all tests.
5141   const std::vector<float> ok_input = {1, 2, 3, 4, 5, 6};
5142 
5143   Status modified_batch_dim_status =
5144       (trt_mode_ == TrtTestMode::kImplicitBatch)
5145           ? errors::Unimplemented(
5146                 "TensorRT does not allow modifications to "
5147                 "the batch dimension")
5148           : Status::OK();
5149   std::vector<TestParams> params = {
5150       // Modify batch dim, should fail in implicit batch mode.
5151       TestParams{/*input_dims=*/{2, 1, 1, 3},
5152                  /*begin=*/{0, 0, 0, 0},
5153                  /*end=*/{1, 1, 1, 2},
5154                  /*strides=*/{1, 1, 1, 1},
5155                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5156                  /*end_mask=*/get_mask({0, 0, 0, 0}),
5157                  /*ellipsis_mask=*/0,
5158                  /*new_axis_mask=*/0,
5159                  /*shrink_axis_mask=*/0,
5160                  /*expected_output_dims=*/{1, 1, 1, 2},
5161                  /*expected_output=*/{1, 2},
5162                  /*conversion_status=*/modified_batch_dim_status,
5163                  /*runtime_status=*/Status::OK(),
5164                  /*partial_input_dims=*/{}},
5165       // Unknown batch size without end_mask.
5166       TestParams{
5167           /*input_dims=*/{2, 1, 1, 3},
5168           /*begin=*/{0, 0, 0, 0},
5169           /*end=*/{1, 1, 1, 2},
5170           /*strides=*/{1, 1, 1, 1},
5171           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5172           /*end_mask=*/get_mask({0, 0, 0, 0}),
5173           /*ellipsis_mask=*/0,
5174           /*new_axis_mask=*/0,
5175           /*shrink_axis_mask=*/0,
5176           /*expected_output_dims=*/{1, 1, 1, 2},
5177           /*expected_output=*/{1, 2},
5178           modified_batch_dim_status,
5179           Status::OK(),
5180           /*partial_input_dims=*/{-1, 1, 1, 3},
5181       },
5182       // Test Case 2: Unknown batch size with end_mask.
5183       TestParams{
5184           /*input_dims=*/{2, 1, 1, 3},
5185           /*begin=*/{0, 0, 0, 0},
5186           /*end=*/{0, 1, 1, 2},
5187           /*strides=*/{1, 1, 1, 1},
5188           /*begin_mask=*/get_mask({1, 0, 0, 0}),
5189           /*end_mask=*/get_mask({1, 0, 0, 0}),
5190           /*ellipsis_mask=*/0,
5191           /*new_axis_mask=*/0,
5192           /*shrink_axis_mask=*/0,
5193           /*expected_output_dims=*/{2, 1, 1, 2},
5194           /*expected_output=*/{1, 2, 4, 5},
5195           Status::OK(),
5196           Status::OK(),
5197           /*partial_input_dims=*/{-1, 1, 1, 3},
5198       },
5199       // Invalid parameters: end[2] < begin[2]
5200       TestParams{/*input_dims=*/{1, 1, 2, 3},
5201                  /*begin=*/{0, 0, 2, 0},
5202                  /*end=*/{1, 1, 0, 3},
5203                  /*strides=*/{1, 1, 1, 1},
5204                  /*begin_mask=*/0,
5205                  /*end_mask=*/0,
5206                  /*ellipsis_mask=*/0,
5207                  /*new_axis_mask=*/0,
5208                  /*shrink_axis_mask=*/0,
5209                  /*expected_output_dims=*/{},
5210                  /*expected_output=*/{},
5211                  errors::InvalidArgument("\"size\" cannot be negative for "
5212                                          "StridedSlice"),
5213                  Status::OK(),
5214                  /*partial_input_dims=*/{}},
5215       // Slice on the last two dimensions. All dimensions are static.
5216       TestParams{
5217           /*input_dims=*/{1, 1, 2, 3},
5218           /*begin=*/{0, 0, 0, 0},
5219           /*end=*/{0, 0, 1, 2},
5220           /*strides=*/{1, 1, 1, 1},
5221           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5222           /*end_mask=*/get_mask({1, 1, 0, 0}),
5223           /*ellipsis_mask=*/0,
5224           /*new_axis_mask=*/0,
5225           /*shrink_axis_mask=*/0,
5226           /*expected_output_dims=*/{1, 1, 1, 2},
5227           /*expected_output=*/{1, 2},
5228       },
5229       // Slice on the last two dimensions. The slice is fully
5230       // specified for the dynamic dimensions.
5231       TestParams{
5232           /*input_dims=*/{1, 1, 2, 3},
5233           /*begin=*/{0, 0, 0, 0},
5234           /*end=*/{0, 0, 1, 2},
5235           /*strides=*/{1, 1, 1, 1},
5236           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5237           /*end_mask=*/get_mask({1, 1, 0, 0}),
5238           /*ellipsis_mask=*/0,
5239           /*new_axis_mask=*/0,
5240           /*shrink_axis_mask=*/0,
5241           /*expected_output_dims=*/{1, 1, 1, 2},
5242           /*expected_output=*/{1, 2},
5243           Status::OK(),
5244           Status::OK(),
5245           /*partial_input_dims=*/{1, 1, -1, -1},
5246       },
5247       // End mask is provided on all dimensions. This should override the fact
5248       // that the end value is 0. For dynamic shape, it tests
5249       // that we can infer tensor size when "end mask" is provided.
5250       TestParams{
5251           /*input_dims=*/{1, 1, 2, 3},
5252           /*begin=*/{0, 0, 1, 1},
5253           /*end=*/{0, 0, 0, 0},
5254           /*strides=*/{1, 1, 1, 1},
5255           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5256           /*end_mask=*/get_mask({1, 1, 1, 1}),
5257           /*ellipsis_mask=*/0,
5258           /*new_axis_mask=*/0,
5259           /*shrink_axis_mask=*/0,
5260           /*expected_output_dims=*/{1, 1, 1, 2},
5261           /*expected_output=*/{5, 6},
5262           Status::OK(),
5263           Status::OK(),
5264           /*partial_input_dims=*/{1, 1, -1, -1},
5265       },
5266       // End mask is provided for the batch dimension to overwrite the end value
5267       // 0 for that dimension.
5268       TestParams{
5269           /*input_dims=*/{1, 1, 2, 3},
5270           /*begin=*/{0, 0, 1, 1},
5271           /*end=*/{0, 1, 2, 3},
5272           /*strides=*/{1, 1, 1, 1},
5273           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5274           /*end_mask=*/get_mask({1, 1, 0, 0}),
5275           /*ellipsis_mask=*/0,
5276           /*new_axis_mask=*/0,
5277           /*shrink_axis_mask=*/0,
5278           /*expected_output_dims=*/{1, 1, 1, 2},
5279           /*expected_output=*/{5, 6},
5280       },
5281       // Test slice on two dimensions with negative stride, without end_mask set
5282       // on crop dimensions.
5283       TestParams{/*input_dims=*/{1, 1, 2, 3},
5284                  /*begin=*/{0, 0, 1, 2},
5285                  /*end=*/{0, 0, 0, 0},
5286                  /*strides=*/{1, 1, -1, -1},
5287                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5288                  /*end_mask=*/get_mask({1, 1, 0, 0}),
5289                  /*ellipsis_mask=*/0,
5290                  /*new_axis_mask=*/0,
5291                  /*shrink_axis_mask=*/0,
5292                  /*expected_output_dims=*/{1, 1, 1, 2},
5293                  /*expected_output=*/{6, 5},
5294                  /*conversion_status=*/Status::OK(),
5295                  /*runtime_status=*/Status::OK(),
5296                  /*partial_input_dims=*/{1, 1, -1, -1}},
5297       // Test slice on two dimensions with negative stride, with end_mask set on
5298       // crop dimensions. In dynamic shape mode, this tests the runtime size
5299       // computation.
5300       TestParams{/*input_dims=*/{1, 1, 2, 3},
5301                  /*begin=*/{0, 0, 1, 1},
5302                  /*end=*/{0, 0, 0, 0},
5303                  /*strides=*/{1, 1, -1, -1},
5304                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5305                  /*end_mask=*/get_mask({1, 1, 1, 1}),
5306                  /*ellipsis_mask=*/0,
5307                  /*new_axis_mask=*/0,
5308                  /*shrink_axis_mask=*/0,
5309                  /*expected_output_dims=*/{1, 1, 2, 2},
5310                  /*expected_output=*/{5, 4, 2, 1},
5311                  /*conversion_status=*/Status::OK(),
5312                  /*runtime_status=*/Status::OK(),
5313                  /*partial_input_dims=*/{1, 1, -1, -1}},
5314       // Test slice on two dimensions with negative stride, with begin_mask set
5315       // on the crop dimensions. In dynamic shape mode, this tests the runtime
5316       // size computation.
5317       TestParams{/*input_dims=*/{1, 1, 2, 3},
5318                  /*begin=*/{0, 0, 0, 0},
5319                  /*end=*/{0, 0, 0, 0},
5320                  /*strides=*/{1, 1, -1, -1},
5321                  /*begin_mask=*/get_mask({0, 0, 1, 1}),
5322                  /*end_mask=*/get_mask({1, 1, 0, 0}),
5323                  /*ellipsis_mask=*/0,
5324                  /*new_axis_mask=*/0,
5325                  /*shrink_axis_mask=*/0,
5326                  /*expected_output_dims=*/{1, 1, 1, 2},
5327                  /*expected_output=*/{6, 5},
5328                  /*conversion_status=*/Status::OK(),
5329                  /*runtime_status=*/Status::OK(),
5330                  /*partial_input_dims=*/{1, 1, -1, -1}},
5331       // Test the reversal of all non-batch dimensions by providing the begin
5332       // masks, end masks, and -1 as strides.
5333       TestParams{/*input_dims=*/{1, 1, 2, 3},
5334                  /*begin=*/{0, 0, 0, 0},
5335                  /*end=*/{0, 0, 0, 0},
5336                  /*strides=*/{1, -1, -1, -1},
5337                  /*begin_mask=*/get_mask({1, 1, 1, 1}),
5338                  /*end_mask=*/get_mask({1, 1, 1, 1}),
5339                  /*ellipsis_mask=*/0,
5340                  /*new_axis_mask=*/0,
5341                  /*shrink_axis_mask=*/0,
5342                  /*expected_output_dims=*/{1, 1, 2, 3},
5343                  /*expected_output=*/{6, 5, 4, 3, 2, 1},
5344                  /*conversion_status=*/Status::OK(),
5345                  /*runtime_status=*/Status::OK(),
5346                  /*partial_input_dims=*/{1, -1, -1, -1}},
5347       // Slice on dimensions 1 and 2.
5348       TestParams{
5349           /*input_dims=*/{1, 2, 3, 1},
5350           /*begin=*/{0, 0, 0, 0},
5351           /*end=*/{0, 1, 2, 1},
5352           /*strides=*/{1, 1, 1, 1},
5353           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5354           /*end_mask=*/get_mask({1, 0, 0, 0}),
5355           /*ellipsis_mask=*/0,
5356           /*new_axis_mask=*/0,
5357           /*shrink_axis_mask=*/0,
5358           /*expected_output_dims=*/{1, 1, 2, 1},
5359           /*expected_output=*/{1, 2},
5360       },
5361       // Slice on dimensions 1 and 2.
5362       TestParams{
5363           /*input_dims=*/{1, 2, 3, 1},
5364           /*begin=*/{0, 1, 1, 0},
5365           /*end=*/{0, 2, 3, 1},
5366           /*strides=*/{1, 1, 1, 1},
5367           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5368           /*end_mask=*/get_mask({1, 0, 0, 0}),
5369           /*ellipsis_mask=*/0,
5370           /*new_axis_mask=*/0,
5371           /*shrink_axis_mask=*/0,
5372           /*expected_output_dims=*/{1, 1, 2, 1},
5373           /*expected_output=*/{5, 6},
5374       },
5375       // Slice on dimensions 1 and 3.
5376       TestParams{
5377           /*input_dims=*/{1, 2, 1, 3},
5378           /*begin=*/{0, 0, 0, 0},
5379           /*end=*/{0, 1, 1, 2},
5380           /*strides=*/{1, 1, 1, 1},
5381           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5382           /*end_mask=*/get_mask({1, 0, 0, 0}),
5383           /*ellipsis_mask=*/0,
5384           /*new_axis_mask=*/0,
5385           /*shrink_axis_mask=*/0,
5386           /*expected_output_dims=*/{1, 1, 1, 2},
5387           /*expected_output=*/{1, 2},
5388       },
5389       // Slice on dimensions 1 and 3 with non-zero slice start.
5390       TestParams{
5391           /*input_dims=*/{1, 2, 1, 3},
5392           /*begin=*/{0, 1, 0, 1},
5393           /*end=*/{0, 2, 1, 3},
5394           /*strides=*/{1, 1, 1, 1},
5395           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5396           /*end_mask=*/get_mask({1, 0, 0, 0}),
5397           /*ellipsis_mask=*/0,
5398           /*new_axis_mask=*/0,
5399           /*shrink_axis_mask=*/0,
5400           /*expected_output_dims=*/{1, 1, 1, 2},
5401           /*expected_output=*/{5, 6},
5402       },
5403       // Slice on 3D tensor.
5404       TestParams{
5405           /*input_dims=*/{1, 2, 3},
5406           /*begin=*/{0, 0, 0},
5407           /*end=*/{0, 1, 2},
5408           /*strides=*/{1, 1, 1},
5409           /*begin_mask=*/get_mask({0, 0, 0}),
5410           /*end_mask=*/get_mask({1, 0, 0}),
5411           /*ellipsis_mask=*/0,
5412           /*new_axis_mask=*/0,
5413           /*shrink_axis_mask=*/0,
5414           /*expected_output_dims=*/{1, 1, 2},
5415           /*expected_output=*/{1, 2},
5416       },
5417       // Slice on 3D tensor using end_mask. For dynamic shape, all
5418       // dimensions are dynamic.
5419       TestParams{/*input_dims=*/{1, 2, 3},
5420                  /*begin=*/{0, 1, 1},
5421                  /*end=*/{0, 0, 0},
5422                  /*strides=*/{1, 1, 1},
5423                  /*begin_mask=*/get_mask({0, 0, 0}),
5424                  /*end_mask=*/get_mask({1, 1, 1}),
5425                  /*ellipsis_mask=*/0,
5426                  /*new_axis_mask=*/0,
5427                  /*shrink_axis_mask=*/0,
5428                  /*expected_output_dims=*/{1, 1, 2},
5429                  /*expected_output=*/{5, 6},
5430                  /*conversion_status=*/Status::OK(),
5431                  /*runtime_status=*/Status::OK(),
5432                  /*partial_input_dims=*/{-1, -1, -1}},
5433       // Slice on 3D tensor using end_mask. For dynamic shape, all
5434       // dimensions are dynamic.
5435       TestParams{/*input_dims=*/{1, 1, 2, 3},
5436                  /*begin=*/{0, 0, 0, 0},
5437                  /*end=*/{0, 0, 0, 2},
5438                  /*strides=*/{1, 1, 1, 1},
5439                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5440                  /*end_mask=*/get_mask({1, 1, 1, 0}),
5441                  /*ellipsis_mask=*/0,
5442                  /*new_axis_mask=*/0,
5443                  /*shrink_axis_mask=*/0,
5444                  /*expected_output_dims=*/{1, 1, 2, 2},
5445                  /*expected_output=*/{1, 2, 4, 5},
5446                  /*conversion_status=*/Status::OK(),
5447                  /*runtime_status=*/Status::OK(),
5448                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5449       TestParams{
5450           /*input_dims=*/{1, 1, 2, 3},
5451           /*begin=*/{0, 0, 1, 0},
5452           /*end=*/{0, 0, 0, 0},
5453           /*strides=*/{1, 1, 1, 1},
5454           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5455           /*end_mask=*/get_mask({1, 1, 1, 1}),
5456           /*ellipsis_mask=*/0,
5457           /*new_axis_mask=*/0,
5458           /*shrink_axis_mask=*/0,
5459           /*expected_output_dims=*/{1, 1, 1, 3},
5460           /*expected_output=*/{4, 5, 6},
5461       },
5462       // 1D simple slice.
5463       TestParams{/*input_dims=*/{1, 2, 3, 1},
5464                  /*begin=*/{0, 0, 0, 0},
5465                  /*end=*/{0, 1, 0, 0},
5466                  /*strides=*/{1, 1, 1, 1},
5467                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5468                  /*end_mask=*/get_mask({1, 0, 1, 1}),
5469                  /*ellipsis_mask=*/0,
5470                  /*new_axis_mask=*/0,
5471                  /*shrink_axis_mask=*/0,
5472                  /*expected_output_dims=*/{1, 1, 3, 1},
5473                  /*expected_output=*/{1, 2, 3},
5474                  /*conversion_status=*/Status::OK(),
5475                  /*runtime_status=*/Status::OK(),
5476                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5477       TestParams{
5478           /*input_dims=*/{1, 2, 3, 1},
5479           /*begin=*/{0, 1, 0, 0},
5480           /*end=*/{0, 0, 0, 0},
5481           /*strides=*/{1, 1, 1, 1},
5482           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5483           /*end_mask=*/get_mask({1, 1, 1, 1}),
5484           /*ellipsis_mask=*/0,
5485           /*new_axis_mask=*/0,
5486           /*shrink_axis_mask=*/0,
5487           /*expected_output_dims=*/{1, 1, 3, 1},
5488           /*expected_output=*/{4, 5, 6},
5489       },
5490       // Simple 1D slice on 2D input.
5491       TestParams{/*input_dims=*/{1, 6},
5492                  /*begin=*/{0, 0},
5493                  /*end=*/{0, 3},
5494                  /*strides=*/{1, 1},
5495                  /*begin_mask=*/get_mask({0, 0}),
5496                  /*end_mask=*/get_mask({1, 0}),
5497                  /*ellipsis_mask=*/0,
5498                  /*new_axis_mask=*/0,
5499                  /*shrink_axis_mask=*/0,
5500                  /*expected_output_dims=*/{1, 3},
5501                  /*expected_output=*/{1, 2, 3},
5502                  /*conversion_status=*/Status::OK(),
5503                  /*runtime_status=*/Status::OK(),
5504                  /*partial_input_dims=*/{-1, -1}},
5505       TestParams{
5506           /*input_dims=*/{1, 1, 6},
5507           /*begin=*/{0, 0, 2},
5508           /*end=*/{0, 0, 5},
5509           /*strides=*/{1, 1, 1},
5510           /*begin_mask=*/get_mask({0, 0, 0}),
5511           /*end_mask=*/get_mask({1, 1, 0}),
5512           /*ellipsis_mask=*/0,
5513           /*new_axis_mask=*/0,
5514           /*shrink_axis_mask=*/0,
5515           /*expected_output_dims=*/{1, 1, 3},
5516           /*expected_output=*/{3, 4, 5},
5517       },
5518       TestParams{
5519           /*input_dims=*/{1, 6, 1},
5520           /*begin=*/{0, 2, 0},
5521           /*end=*/{0, 5, 0},
5522           /*strides=*/{1, 1, 1},
5523           /*begin_mask=*/get_mask({0, 0, 0}),
5524           /*end_mask=*/get_mask({1, 0, 1}),
5525           /*ellipsis_mask=*/0,
5526           /*new_axis_mask=*/0,
5527           /*shrink_axis_mask=*/0,
5528           /*expected_output_dims=*/{1, 3, 1},
5529           /*expected_output=*/{3, 4, 5},
5530       },
5531       // Negative axis.
5532       TestParams{
5533           /*input_dims=*/{1, 6, 1},
5534           /*begin=*/{0, -6, 0},
5535           /*end=*/{0, -3, 0},
5536           /*strides=*/{1, 1, 1},
5537           /*begin_mask=*/get_mask({0, 0, 0}),
5538           /*end_mask=*/get_mask({1, 0, 1}),
5539           /*ellipsis_mask=*/0,
5540           /*new_axis_mask=*/0,
5541           /*shrink_axis_mask=*/0,
5542           /*expected_output_dims=*/{1, 3, 1},
5543           /*expected_output=*/{1, 2, 3},
5544       },
5545       TestParams{
5546           /*input_dims=*/{1, 6, 1},
5547           /*begin=*/{0, 0, 0},
5548           /*end=*/{0, -1, 0},
5549           /*strides=*/{1, 1, 1},
5550           /*begin_mask=*/get_mask({0, 0, 0}),
5551           /*end_mask=*/get_mask({1, 0, 1}),
5552           /*ellipsis_mask=*/0,
5553           /*new_axis_mask=*/0,
5554           /*shrink_axis_mask=*/0,
5555           /*expected_output_dims=*/{1, 5, 1},
5556           /*expected_output=*/{1, 2, 3, 4, 5},
5557       },
5558       // Clamp out of bounds begin and end.
5559       TestParams{
5560           /*input_dims=*/{1, 1, 2, 3},
5561           /*begin=*/{0, 0, -9999, -9},
5562           /*end=*/{0, 1, 1000, 4},
5563           /*strides=*/{1, 1, 1, 1},
5564           /*begin_mask=*/get_mask({0, 0, 0, 0}),
5565           /*end_mask=*/get_mask({1, 0, 0, 0}),
5566           /*ellipsis_mask=*/0,
5567           /*new_axis_mask=*/0,
5568           /*shrink_axis_mask=*/0,
5569           /*expected_output_dims=*/{1, 1, 2, 3},
5570           /*expected_output=*/{1, 2, 3, 4, 5, 6},
5571       },
5572       // Stride values >= 2.
5573       TestParams{/*input_dims=*/{1, 6},
5574                  /*begin=*/{0, 0},
5575                  /*end=*/{0, 5},
5576                  /*strides=*/{1, 2},
5577                  /*begin_mask=*/get_mask({0, 0}),
5578                  /*end_mask=*/get_mask({1, 0}),
5579                  /*ellipsis_mask=*/0,
5580                  /*new_axis_mask=*/0,
5581                  /*shrink_axis_mask=*/0,
5582                  /*expected_output_dims=*/{1, 3},
5583                  /*expected_output=*/{1, 3, 5},
5584                  /*conversion_status=*/Status::OK(),
5585                  /*runtime_status=*/Status::OK(),
5586                  /*partial_input_dims=*/{-1, -1}},
5587       TestParams{/*input_dims=*/{1, 6},
5588                  /*begin=*/{0, 0},
5589                  /*end=*/{0, 6},
5590                  /*strides=*/{1, 2},
5591                  /*begin_mask=*/get_mask({0, 0}),
5592                  /*end_mask=*/get_mask({1, 0}),
5593                  /*ellipsis_mask=*/0,
5594                  /*new_axis_mask=*/0,
5595                  /*shrink_axis_mask=*/0,
5596                  /*expected_output_dims=*/{1, 3},
5597                  /*expected_output=*/{1, 3, 5},
5598                  /*conversion_status=*/Status::OK(),
5599                  /*runtime_status=*/Status::OK(),
5600                  /*partial_input_dims=*/{-1, -1}},
5601       TestParams{/*input_dims=*/{1, 6},
5602                  /*begin=*/{0, 1},
5603                  /*end=*/{0, 6},
5604                  /*strides=*/{1, 2},
5605                  /*begin_mask=*/get_mask({0, 0}),
5606                  /*end_mask=*/get_mask({1, 0}),
5607                  /*ellipsis_mask=*/0,
5608                  /*new_axis_mask=*/0,
5609                  /*shrink_axis_mask=*/0,
5610                  /*expected_output_dims=*/{1, 3},
5611                  /*expected_output=*/{2, 4, 6},
5612                  /*conversion_status=*/Status::OK(),
5613                  /*runtime_status=*/Status::OK(),
5614                  /*partial_input_dims=*/{-1, -1}},
5615       TestParams{/*input_dims=*/{1, 6},
5616                  /*begin=*/{0, 2},
5617                  /*end=*/{0, 6},
5618                  /*strides=*/{1, 3},
5619                  /*begin_mask=*/get_mask({0, 0}),
5620                  /*end_mask=*/get_mask({1, 0}),
5621                  /*ellipsis_mask=*/0,
5622                  /*new_axis_mask=*/0,
5623                  /*shrink_axis_mask=*/0,
5624                  /*expected_output_dims=*/{1, 2},
5625                  /*expected_output=*/{3, 6},
5626                  /*conversion_status=*/Status::OK(),
5627                  /*runtime_status=*/Status::OK(),
5628                  /*partial_input_dims=*/{-1, -1}},
5629       // Stride values <= -2.
5630       TestParams{/*input_dims=*/{1, 6},
5631                  /*begin=*/{0, 5},
5632                  /*end=*/{0, 0},
5633                  /*strides=*/{1, -2},
5634                  /*begin_mask=*/get_mask({0, 0}),
5635                  /*end_mask=*/get_mask({1, 1}),
5636                  /*ellipsis_mask=*/0,
5637                  /*new_axis_mask=*/0,
5638                  /*shrink_axis_mask=*/0,
5639                  /*expected_output_dims=*/{1, 3},
5640                  /*expected_output=*/{6, 4, 2},
5641                  /*conversion_status=*/Status::OK(),
5642                  /*runtime_status=*/Status::OK(),
5643                  /*partial_input_dims=*/{-1, -1}},
5644       TestParams{/*input_dims=*/{1, 6},
5645                  /*begin=*/{0, 5},
5646                  /*end=*/{0, 0},
5647                  /*strides=*/{1, -2},
5648                  /*begin_mask=*/get_mask({0, 0}),
5649                  /*end_mask=*/get_mask({1, 0}),
5650                  /*ellipsis_mask=*/0,
5651                  /*new_axis_mask=*/0,
5652                  /*shrink_axis_mask=*/0,
5653                  /*expected_output_dims=*/{1, 3},
5654                  /*expected_output=*/{6, 4, 2},
5655                  /*conversion_status=*/Status::OK(),
5656                  /*runtime_status=*/Status::OK(),
5657                  /*partial_input_dims=*/{-1, -1}},
5658       TestParams{/*input_dims=*/{1, 6},
5659                  /*begin=*/{0, 5},
5660                  /*end=*/{0, 1},
5661                  /*strides=*/{1, -3},
5662                  /*begin_mask=*/get_mask({0, 0}),
5663                  /*end_mask=*/get_mask({1, 0}),
5664                  /*ellipsis_mask=*/0,
5665                  /*new_axis_mask=*/0,
5666                  /*shrink_axis_mask=*/0,
5667                  /*expected_output_dims=*/{1, 2},
5668                  /*expected_output=*/{6, 3},
5669                  /*conversion_status=*/Status::OK(),
5670                  /*runtime_status=*/Status::OK(),
5671                  /*partial_input_dims=*/{-1, -1}},
5672       // Ellipsis_mask causes leading dimensions to be ignored. Begin, end,
5673       // stride, and mask values of size 2 should be interpreted as applying to
5674       // the last 2 dimensions, while the ellipsis applies to the first 2 (for a
5675       // 4D input tensor).
5676       TestParams{/*input_dims=*/{1, 1, 2, 3},
5677                  /*begin=*/{0, 1},
5678                  /*end=*/{0, 2},
5679                  /*strides=*/{1, 1},
5680                  /*begin_mask=*/get_mask({0, 0}),
5681                  /*end_mask=*/get_mask({0, 0}),
5682                  /*ellipsis_mask=*/get_mask({1, 0, 0}),
5683                  /*new_axis_mask=*/0,
5684                  /*shrink_axis_mask=*/0,
5685                  /*expected_output_dims=*/{1, 1, 2, 1},
5686                  /*expected_output=*/{2, 5},
5687                  /*conversion_status=*/Status::OK(),
5688                  /*runtime_status=*/Status::OK(),
5689                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5690       // Ellipsis_mask on single inner dimension.
5691       TestParams{
5692           /*input_dims=*/{1, 1, 2, 3},
5693           /*begin=*/{0, 0, 1},
5694           /*end=*/{0, 0, 2},
5695           /*strides=*/{1, 1, 1},
5696           /*begin_mask=*/get_mask({1, 0, 0, 0}),
5697           /*end_mask=*/get_mask({1, 0, 0, 0}),
5698           /*ellipsis_mask=*/get_mask({0, 1, 0, 0}),
5699           /*new_axis_mask=*/0,
5700           /*shrink_axis_mask=*/0,
5701           /*expected_output_dims=*/{1, 1, 2, 1},
5702           /*expected_output=*/{2, 5},
5703       },
5704       // Ellipsis_mask on single leading dimension.
5705       TestParams{/*input_dims=*/{1, 1, 2, 3},
5706                  /*begin=*/{0, 0, 0, 1},
5707                  /*end=*/{0, 1, 2, 2},
5708                  /*strides=*/{1, 1, 1, 1},
5709                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5710                  /*end_mask=*/get_mask({0, 0, 0, 0}),
5711                  /*ellipsis_mask=*/get_mask({1, 0, 0, 0}),
5712                  /*new_axis_mask=*/0,
5713                  /*shrink_axis_mask=*/0,
5714                  /*expected_output_dims=*/{1, 1, 2, 1},
5715                  /*expected_output=*/{2, 5},
5716                  /*conversion_status=*/Status::OK(),
5717                  /*runtime_status=*/Status::OK(),
5718                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5719       // Ellipsis_mask on single inner dimension overrides that dimensions'
5720       // begin/end values.
5721       TestParams{/*input_dims=*/{1, 1, 2, 3},
5722                  /*begin=*/{0, 1, 0, 1},
5723                  /*end=*/{1, 1, 2, 2},
5724                  /*strides=*/{1, 1, 1, 1},
5725                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5726                  /*end_mask=*/get_mask({0, 0, 0, 0}),
5727                  /*ellipsis_mask=*/get_mask({0, 1, 0, 0}),
5728                  /*new_axis_mask=*/0,
5729                  /*shrink_axis_mask=*/0,
5730                  /*expected_output_dims=*/{1, 1, 2, 1},
5731                  /*expected_output=*/{2, 5},
5732                  /*conversion_status=*/Status::OK(),
5733                  /*runtime_status=*/Status::OK(),
5734                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5735       // Ellipsis mask on single leading dimension should throw out extra
5736       // leading values of begin/end vectors so that only the last N-1 values of
5737       // each remain.
5738       TestParams{/*input_dims=*/{1, 1, 2, 3},
5739                  /*begin=*/{0, 0, 0, 0, 1},
5740                  /*end=*/{0, 1, 1, 2, 2},
5741                  /*strides=*/{1, 1, 1, 1, 1},
5742                  /*begin_mask=*/get_mask({0, 0, 0, 0}),
5743                  /*end_mask=*/get_mask({0, 0, 0, 0}),
5744                  /*ellipsis_mask=*/get_mask({1, 0, 0, 0}),
5745                  /*new_axis_mask=*/0,
5746                  /*shrink_axis_mask=*/0,
5747                  /*expected_output_dims=*/{1, 1, 2, 1},
5748                  /*expected_output=*/{2, 5},
5749                  /*conversion_status=*/Status::OK(),
5750                  /*runtime_status=*/Status::OK(),
5751                  /*partial_input_dims=*/{-1, -1, -1, -1}},
5752       // Shrink-axis mask set for the final dimension of final size 1 should
5753       // remove that dimension from the final shape.
5754       TestParams{/*input_dims=*/{1, 1, 2, 3},
5755                  /*begin=*/{0, 0, 0, 1},
5756                  /*end=*/{0, 0, 0, 2},
5757                  /*strides=*/{1, 1, 1, 1},
5758                  /*begin_mask=*/get_mask({1, 1, 1, 0}),
5759                  /*end_mask=*/get_mask({1, 1, 1, 0}),
5760                  /*ellipsis_mask=*/0,
5761                  /*new_axis_mask=*/0,
5762                  /*shrink_axis_mask=*/get_mask({0, 0, 0, 1}),
5763                  /*expected_output_dims=*/{1, 1, 2},
5764                  /*expected_output=*/{2, 5},
5765                  /*conversion_status=*/Status::OK(),
5766                  /*runtime_status=*/Status::OK(),
5767                  /*partial_input_dims=*/{1, 1, 2, -1}},
5768       // Shrink-axis mask set for multiple dimensions that have a final size of
5769       // 1 should remove those dimensions from the final shape.
5770       TestParams{/*input_dims=*/{1, 1, 2, 3},
5771                  /*begin=*/{0, 0, 0, 1},
5772                  /*end=*/{0, 1, 2, 2},
5773                  /*strides=*/{1, 1, 1, 1},
5774                  /*begin_mask=*/get_mask({1, 0, 0, 0}),
5775                  /*end_mask=*/get_mask({1, 0, 0, 0}),
5776                  /*ellipsis_mask=*/0,
5777                  /*new_axis_mask=*/0,
5778                  /*shrink_axis_mask=*/get_mask({0, 1, 0, 1}),
5779                  /*expected_output_dims=*/{1, 2},
5780                  /*expected_output=*/{2, 5},
5781                  /*conversion_status=*/Status::OK(),
5782                  /*runtime_status=*/Status::OK(),
5783                  /*partial_input_dims=*/{1, 1, 2, -1}},
5784       // Shrink-axis mask set for multiple sequential dimensions of final size 1
5785       // should
5786       // remove those dimensions from the final shape.
5787       TestParams{/*input_dims=*/{6, 1, 1},
5788                  /*begin=*/{0, 0, 0},
5789                  /*end=*/{0, 0, 0},
5790                  /*strides=*/{1, 1, 1},
5791                  /*begin_mask=*/get_mask({1, 1, 1}),
5792                  /*end_mask=*/get_mask({1, 1, 1}),
5793                  /*ellipsis_mask=*/0,
5794                  /*new_axis_mask=*/0,
5795                  /*shrink_axis_mask=*/get_mask({0, 1, 1}),
5796                  /*expected_output_dims=*/{6},
5797                  /*expected_output=*/{1, 2, 3, 4, 5, 6},
5798                  /*conversion_status=*/Status::OK(),
5799                  /*runtime_status=*/Status::OK(),
5800                  /*partial_input_dims=*/{-1, -1, -1}},
5801       // The new_axis_mask parameter is not supported.
5802       TestParams{/*input_dims=*/{1, 6},
5803                  /*begin=*/{0, 0, 0},
5804                  /*end=*/{0, 0, 0},
5805                  /*strides=*/{1, 1, 1},
5806                  /*begin_mask=*/
5807                  get_mask({0, 1, 1}),
5808                  /*end_mask=*/get_mask({0, 1, 1}),
5809                  /*ellipsis_mask=*/0,
5810                  /*new_axis_mask=*/get_mask({1, 0, 0}),
5811                  /*shrink_axis_mask=*/get_mask({0, 0, 0}),
5812                  /*expected_output_dims=*/{1, 1, 6},
5813                  /*expected_output=*/{1, 1, 6},
5814                  /*conversion_status=*/
5815                  errors::Unimplemented(
5816                      "new_axis_mask is not supported for StridedSlice"),
5817                  /*runtime_status=*/Status::OK(),
5818                  /*partial_input_dims=*/{1, 6}},
5819   };
5820 
5821   int i = 0;
5822   for (auto p : params) {
5823     Reset();
5824     NodeDef node_def = get_strided_slice_nodedef(
5825         tf_type_, p.begin_mask, p.end_mask, p.ellipsis_mask, p.new_axis_mask,
5826         p.shrink_axis_mask);
5827 
5828     VLOG(2) << "Preparing test case " << i++ << " with dims "
5829             << DebugString(p.input_dims);
5830 
5831     switch (trt_mode_) {
5832       case TrtTestMode::kImplicitBatch: {
5833         AddTestTensor("input", p.input_dims, ok_input);
5834         break;
5835       }
5836       case TrtTestMode::kExplicitBatch: {
5837         AddTestTensor("input", p.input_dims, ok_input);
5838         break;
5839       }
5840       case TrtTestMode::kDynamicShape: {
5841         if (p.partial_input_dims.size() > 0) {
5842           AddTestTensor("input", p.input_dims, tf_type_, ok_input,
5843                         p.partial_input_dims);
5844 
5845         } else {
5846           AddTestTensor("input", p.input_dims, tf_type_, ok_input,
5847                         p.input_dims);
5848         }
5849         break;
5850       }
5851     }
5852 
5853     VLOG(2) << "Adding weights begin: " << DebugString(p.begin)
5854             << ", end: " << DebugString(p.end)
5855             << ", strides: " << DebugString(p.strides);
5856     AddTestWeights<int32>("begin", {static_cast<int>(p.begin.size())}, p.begin);
5857     AddTestWeights<int32>("end", {static_cast<int>(p.end.size())}, p.end);
5858     AddTestWeights<int32>("strides", {static_cast<int>(p.strides.size())},
5859                           p.strides);
5860 
5861     TestOpConverter("my_strided_slice", node_def, p.expected_output_dims,
5862                     p.conversion_status, p.runtime_status,
5863                     ElementsAreArray(p.expected_output));
5864   }
5865 }
5866 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertSlice)5867 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertSlice) {
5868   // Get nodedef for Slice layer.
5869   auto get_slice_nodedef = [](DataType tf_type) -> NodeDef {
5870     Scope s = Scope::NewRootScope();
5871     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
5872     auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
5873     auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32);
5874     auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size);
5875     return slice.operation.node()->def();
5876   };
5877 
5878   struct TestParams {
5879     std::vector<int> input_dims;
5880     std::vector<int>
5881         partial_input_dims;  // Symbolic shape in dynamic shape mode.
5882     std::vector<int> begin;
5883     std::vector<int> size;
5884     std::vector<int> expected_output_dims;
5885     std::vector<int> expected_output;
5886     Status conversion_status;
5887     Status runtime_status;
5888   };
5889 
5890   std::vector<TestParams> params = {
5891       // Slice start points must always be >= 0.
5892       TestParams{/*input_dims=*/{1, 1, 2, 3},
5893                  /*partial_input_dims=*/{-1, -1, -1, -1},
5894                  /*begin=*/{0, 0, -1, 0},
5895                  /*size=*/{1, 1, 2, 3},
5896                  /*expected_output_dims=*/{},
5897                  /*expected_output=*/{},
5898                  /*conversion_status=*/
5899                  errors::InvalidArgument("\"begin\" in Slice "
5900                                          "is out of range")},
5901       // In implicit batch mode, slicing the batch dimension is not allowed.
5902       TestParams{/*input_dims=*/{2, 1, 1, 3},
5903                  /*partial_input_dims=*/{-1, -1, -1, -1},
5904                  /*begin=*/{0, 0, 0, 0},
5905                  /*size=*/{1, 1, 1, 3},
5906                  /*expected_output_dims=*/{1, 1, 1, 3},
5907                  /*expected_output=*/{1, 2, 3},
5908                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
5909                      ? errors::Unimplemented(
5910                            "TensorRT does not allow modifications to the batch "
5911                            "dimension in implicit batch mode")
5912                      : Status::OK()},
5913       // Dynamic batch size but using size[0] of -1, ok.
5914       TestParams{{1, 1, 2, 3},
5915                  /*partial_input_dims=*/{-1, -1, -1, -1},
5916                  {0, 0, 0, 0},
5917                  {-1, 1, 2, 2},
5918                  {1, 1, 2, 2},
5919                  {1, 2, 4, 5},
5920                  Status::OK()},
5921       TestParams{{1, 1, 2, 3},
5922                  /*partial_input_dims=*/{-1, -1, -1, -1},
5923                  {0, 0, 0, 0},
5924                  {-1, -1, -1, -1},
5925                  {1, 1, 2, 3},
5926                  {1, 2, 3, 4, 5, 6},
5927                  Status::OK()},
5928       TestParams{{1, 1, 2, 3},
5929                  /*partial_input_dims=*/{-1, -1, -1, -1},
5930                  {0, 0, 0, 0},
5931                  {1, 1, 2, 3},
5932                  {1, 1, 2, 3},
5933                  {1, 2, 3, 4, 5, 6}},
5934       TestParams{{1, 1, 2, 3},
5935                  /*partial_input_dims=*/{-1, -1, -1, -1},
5936                  /*begin=*/{0, 0, 0, 0},
5937                  /*size=*/{1, -1, 2, 2},
5938                  /*expected_output_dims=*/{1, 1, 2, 2},
5939                  /*expected_output=*/{1, 2, 4, 5},
5940                  Status::OK()},
5941       TestParams{/*input_dims=*/{1, 6},
5942                  /*partial_input_dims=*/{-1, -1},
5943                  /*being=*/{0, 1},
5944                  /*size=*/{1, 5},
5945                  /*expected_output_dims=*/{1, 5},
5946                  /*expected_output=*/{2, 3, 4, 5, 6}},
5947       TestParams{/*input_dims=*/{1, 6},
5948                  /*partial_input_dims=*/{-1, -1},
5949                  /*begin=*/{0, 1},
5950                  /*size=*/{-1, 3},
5951                  /*expected_output_dims=*/{1, 3},
5952                  /*expected_output=*/{2, 3, 4}, Status::OK()},
5953       // In dynamic shape mode we do not know the input shape during
5954       // conversion, therfore we cannot check out of bound access.
5955       TestParams{
5956           {1, 1, 2, 3},
5957           /*partial_input_dims=*/{-1, -1, -1, -1},
5958           /*begin=*/{0, 0, 3, 0},
5959           /*end=*/{1, 1, 2, 3},
5960           {},
5961           {},
5962           trt_mode_ == TrtTestMode::kDynamicShape
5963               ? Status::OK()
5964               : errors::InvalidArgument("\"begin\" + \"size\" for dimension "
5965                                         "2 in Slice is out of range"),
5966           errors::Internal("Internal: Failed to build TensorRT engine")},
5967       // The slice operation should expect that the "size[i]" values are not
5968       // less than -1.
5969       TestParams{/*input_dims=*/{1, 1, 2, 3},
5970                  /*partial_input_dims=*/{-1, -1, -1, -1},
5971                  /*begin=*/{0, 0, 0, 0},
5972                  /*size=*/{1, 1, 2, -2},
5973                  {},
5974                  {},
5975                  errors::InvalidArgument("\"size\" in Slice is out of range")},
5976       TestParams{
5977           /*input_dims=*/{1, 1, 2, 3},
5978           /*partial_input_dims=*/{-1, -1, -1, -1},
5979           /*begin=*/{0, 0, 0, 0},
5980           /*size=*/{1, 1, 3, 2},
5981           /*expected_output_dims=*/{},
5982           /*expected_output=*/{},
5983           /*conversion_status=*/trt_mode_ == TrtTestMode::kDynamicShape
5984               ? Status::OK()
5985               : errors::InvalidArgument("\"begin\" + \"size\" for dimension "
5986                                         "2 in Slice is out of range"),
5987           errors::Internal("Internal: Failed to build TensorRT engine")},
5988   };
5989 
5990   logger_.unsuppressAllLoggerMsgs();
5991   int i = 0;
5992   for (auto p : params) {
5993     Reset();
5994     NodeDef node_def = get_slice_nodedef(tf_type_);
5995 
5996     VLOG(2) << "Preparing test case " << i++ << " with dims "
5997             << DebugString(p.input_dims);
5998 
5999     // The input tensor always has size 6.
6000     std::vector<int> input_vals = {1, 2, 3, 4, 5, 6};
6001 
6002     switch (trt_mode_) {
6003       case TrtTestMode::kImplicitBatch: {
6004         AddTestTensor("input", p.input_dims, input_vals);
6005         break;
6006       }
6007       case TrtTestMode::kExplicitBatch: {
6008         AddTestTensor("input", p.input_dims, input_vals);
6009         break;
6010       }
6011       case TrtTestMode::kDynamicShape: {
6012         if (p.partial_input_dims.size() > 0) {
6013           AddTestTensor("input", p.input_dims, tf_type_, input_vals,
6014                         p.partial_input_dims);
6015 
6016         } else {
6017           AddTestTensor("input", p.input_dims, tf_type_, input_vals,
6018                         p.input_dims);
6019         }
6020         break;
6021       }
6022     }
6023 
6024     AddTestWeights<int32>("begin", {static_cast<int>(p.begin.size())}, p.begin);
6025     AddTestWeights<int32>("size", {static_cast<int>(p.size.size())}, p.size);
6026 
6027     const bool flag =
6028         trt_mode_ == TrtTestMode::kDynamicShape && (i == 9 || i == 11);
6029     if (flag) logger_.suppressLoggerMsgs(nvinfer1::ILogger::Severity::kERROR);
6030 
6031     TestOpConverter("my_slice", node_def, p.expected_output_dims,
6032                     p.conversion_status, p.runtime_status,
6033                     ElementsAreArray(p.expected_output));
6034     if (flag) logger_.unsuppressLoggerMsgs(nvinfer1::ILogger::Severity::kERROR);
6035   }
6036 }
6037 
TEST_P(OpConverter_FP32_Test,ConvertConv2D)6038 TEST_P(OpConverter_FP32_Test, ConvertConv2D) {
6039   // Get nodedef for Conv2D layer.
6040   DataType tf_type = tf_type_;
6041   auto get_conv2d_nodedef =
6042       [tf_type](std::vector<int> strides = {1, 1, 1, 1},
6043                 string padding = "SAME", string data_format = "NCHW",
6044                 std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
6045     Scope s = Scope::NewRootScope();
6046     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
6047     auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
6048     ops::Conv2D::Attrs attrs =
6049         ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
6050     auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
6051                               padding, attrs);
6052     return conv2d.operation.node()->def();
6053   };
6054 
6055   {
6056     // Input is weights, should fail.
6057     Reset();
6058     NodeDef node_def = get_conv2d_nodedef();
6059     AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
6060     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6061     RunValidationAndConversion(
6062         node_def, error::UNIMPLEMENTED,
6063         "The input \"input\" for Conv2D must be a tensor");
6064   }
6065   {
6066     // Filter is tensor, should fail.
6067     Reset();
6068     NodeDef node_def = get_conv2d_nodedef();
6069     AddTestTensor("input", {3, 1, 2, 1});
6070     AddTestTensor("weights", {3, 3, 1, 1});
6071     RunValidationAndConversion(
6072         node_def, error::UNIMPLEMENTED,
6073         "The input \"filter\" for Conv2D must be a constant");
6074   }
6075   {
6076     // Filter is not 4D, should fail.
6077     Reset();
6078     NodeDef node_def = get_conv2d_nodedef();
6079     AddTestTensor("input", {1, 1, 2, 3});
6080     AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6081     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
6082                                "Conv2D expects kernel of dimension 4");
6083   }
6084   {
6085     // Dilations is not 4D, should fail.
6086     Reset();
6087     NodeDef node_def =
6088         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
6089     AddTestTensor("input", {1, 1, 2, 3});
6090     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6091     RunValidationAndConversion(
6092         node_def, error::INVALID_ARGUMENT,
6093         "Convolution dilations field must specify 4 dimensions");
6094   }
6095   {
6096     // Dilation value is not 1 for channel, should fail.
6097     Reset();
6098     NodeDef node_def =
6099         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
6100     AddTestTensor("input", {1, 1, 2, 3});
6101     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6102     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6103                                "Dilation rate must be 1 for batch and channel "
6104                                "dimensions");
6105   }
6106   {
6107     // Dilation value is not 1 for channel (NHWC), should fail.
6108     Reset();
6109     NodeDef node_def =
6110         get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
6111     AddTestTensor("input", {1, 2, 3, 1});
6112     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6113     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6114                                "Dilation rate must be 1 for batch and channel "
6115                                "dimensions");
6116   }
6117   {
6118     // Strides is not 4D, should fail.
6119     Reset();
6120     NodeDef node_def =
6121         get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
6122     AddTestTensor("input", {1, 1, 2, 3});
6123     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6124     RunValidationAndConversion(
6125         node_def, error::INVALID_ARGUMENT,
6126         "Convolution strides field must specify 4 dimensions");
6127   }
6128   {
6129     // Stride value is not 1 for channel, should fail.
6130     Reset();
6131     NodeDef node_def =
6132         get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
6133     AddTestTensor("input", {1, 1, 2, 3});
6134     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6135     RunValidationAndConversion(
6136         node_def, error::UNIMPLEMENTED,
6137         "Stride must be 1 for batch and channel dimensions");
6138   }
6139   if (trt_mode_ == TrtTestMode::kDynamicShape) {
6140     Reset();
6141     NodeDef node_def = get_conv2d_nodedef();
6142     // Channel dim unknown, should fail.
6143     nvinfer1::DataType trt_type;
6144     TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type));
6145     AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, trt_type);
6146     AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
6147     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
6148                                "Channel dimension must be static");
6149   }
6150 
6151   struct TestParams {
6152     std::vector<int> input_dims;
6153     std::vector<float> input;
6154     std::vector<int> filter_dims;
6155     std::vector<float> filter;
6156     std::vector<int> strides;
6157     string padding;
6158     string data_format;
6159     std::vector<int> dilations;
6160     std::vector<int> expected_output_dims;
6161     std::vector<float> expected_output;
6162   };
6163 
6164   // Ok.
6165   std::vector<TestParams> ok_params = {
6166       // Basic
6167       TestParams{/*input_dims=*/{1, 1, 2, 3},
6168                  /*input=*/{0, 1, 2, 3, 3, 4},
6169                  /*filter_dims=*/{1, 2, 1, 1},
6170                  /*filter=*/{-1, 1},
6171                  /*strides=*/{1, 1, 1, 1},
6172                  /*padding=*/"VALID",
6173                  /*data_format=*/"NCHW",
6174                  /*dilations=*/{1, 1, 1, 1},
6175                  /*expected_output_dims=*/{1, 1, 2, 2},
6176                  /*expected_output=*/{1, 1, 0, 1}},
6177       // SAME padding (Asymmetric)
6178       TestParams{/*input_dims=*/{1, 1, 2, 3},
6179                  /*input=*/{0, 1, 2, 3, 3, 4},
6180                  /*filter_dims=*/{1, 2, 1, 1},
6181                  /*filter=*/{-1, 1},
6182                  /*strides=*/{1, 1, 1, 1},
6183                  /*padding=*/"SAME",
6184                  /*data_format=*/"NCHW",
6185                  /*dilations=*/{1, 1, 1, 1},
6186                  /*expected_output_dims=*/{1, 1, 2, 3},
6187                  /*expected_output=*/{1, 1, -2, 0, 1, -4}},
6188       // SAME padding (Symmetric)
6189       TestParams{/*input_dims=*/{1, 1, 2, 3},
6190                  /*input=*/{0, 1, 2, 3, 3, 4},
6191                  /*filter_dims=*/{1, 3, 1, 1},
6192                  /*filter=*/{-1, 0, 1},
6193                  /*strides=*/{1, 1, 1, 1},
6194                  /*padding=*/"SAME",
6195                  /*data_format=*/"NCHW",
6196                  /*dilations=*/{1, 1, 1, 1},
6197                  /*expected_output_dims=*/{1, 1, 2, 3},
6198                  /*expected_output=*/{1, 2, -1, 3, 1, -3}},
6199       // NHWC
6200       TestParams{/*input_dims=*/{1, 2, 3, 1},
6201                  /*input=*/{0, 1, 2, 3, 3, 4},
6202                  /*filter_dims=*/{1, 2, 1, 1},
6203                  /*filter=*/{-1, 1},
6204                  /*strides=*/{1, 1, 1, 1},
6205                  /*padding=*/"VALID",
6206                  /*data_format=*/"NHWC",
6207                  /*dilations=*/{1, 1, 1, 1},
6208                  /*expected_output_dims=*/{1, 2, 2, 1},
6209                  /*expected_output=*/{1, 1, 0, 1}},
6210       // Dilated
6211       TestParams{/*input_dims=*/{1, 1, 2, 3},
6212                  /*input=*/{0, 1, 2, 3, 3, 4},
6213                  /*filter_dims=*/{1, 2, 1, 1},
6214                  /*filter=*/{-1, 1},
6215                  /*strides=*/{1, 1, 1, 1},
6216                  /*padding=*/"VALID",
6217                  /*data_format=*/"NCHW",
6218                  /*dilations=*/{1, 1, 1, 2},
6219                  /*expected_output_dims=*/{1, 1, 2, 1},
6220                  /*expected_output=*/{2, 1}},
6221       // Strided
6222       TestParams{/*input_dims=*/{1, 1, 2, 4},
6223                  /*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
6224                  /*filter_dims=*/{1, 2, 1, 1},
6225                  /*filter=*/{-1, 1},
6226                  /*strides=*/{1, 1, 1, 2},
6227                  /*padding=*/"VALID",
6228                  /*data_format=*/"NCHW",
6229                  /*dilations=*/{1, 1, 1, 1},
6230                  /*expected_output_dims=*/{1, 1, 2, 2},
6231                  /*expected_output=*/{1, 0, 1, 3}},
6232   };
6233 
6234   for (int i = 0; i < ok_params.size(); i++) {
6235     Reset();
6236     NodeDef node_def =
6237         get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
6238                            ok_params[i].data_format, ok_params[i].dilations);
6239     std::vector<int> partial_input_shape;
6240     if (trt_mode_ == TrtTestMode::kDynamicShape) {
6241       // The channel dim cannot have unknown size, fix that.
6242       partial_input_shape.resize(ok_params[i].input_dims.size(), -1);
6243       int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3;
6244       partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
6245     }
6246 
6247     AddTestTensor("input", ok_params[i].input_dims, tf_type_,
6248                   ok_params[i].input, partial_input_shape);
6249     AddTestWeights<float>("weights", ok_params[i].filter_dims,
6250                           ok_params[i].filter);
6251 
6252     TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims,
6253                     Status::OK(), Status::OK(),
6254                     ElementsAreArray(ok_params[i].expected_output));
6255   }
6256 }
6257 
TEST_P(OpConverter_FP32_Test,ConvertConv2DBackpropInput)6258 TEST_P(OpConverter_FP32_Test, ConvertConv2DBackpropInput) {
6259   // Get nodedef for Conv2D layer.
6260   auto get_conv2d_backprop_input_nodedef =
6261       [](DataType tf_type, std::vector<int> strides = {1, 1, 1, 1},
6262          string padding = "SAME", string data_format = "NCHW",
6263          std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
6264     Scope s = Scope::NewRootScope();
6265     auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
6266     auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
6267     auto input_sizes = ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32);
6268     ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs()
6269                                                 .DataFormat(data_format)
6270                                                 .Dilations(dilations);
6271     auto conv2d = ops::Conv2DBackpropInput(
6272         s.WithOpName("my_conv2d_backprop_input"), input_sizes, filter, input,
6273         strides, padding, attrs);
6274     return conv2d.operation.node()->def();
6275   };
6276 
6277   struct TestParams {
6278     std::vector<int> input_dims;
6279     std::vector<float> input;
6280     std::vector<int> filter_dims;
6281     std::vector<float> filter;
6282     std::vector<int> strides;
6283     string padding;
6284     string data_format;
6285     std::vector<int> dilations;
6286     std::vector<int> expected_output_dims;
6287     std::vector<float> expected_output;
6288     Status conversion_status;
6289     // For dynamic shape mode, we must use the partial_input_dims for
6290     // creating the test tensor if any of the input_dims are -1.
6291     std::vector<int> partial_input_dims;
6292   };
6293 
6294   // Ok.
6295   std::vector<TestParams> params = {
6296       // Transpose Strided
6297       TestParams{/*input_dims=*/{1, 1, 2, 2},
6298                  /*input=*/{0, 1, 2, 3},
6299                  /*filter_dims=*/{1, 2, 1, 1},
6300                  /*filter=*/{-1, 1},
6301                  /*strides=*/{1, 1, 1, 2},
6302                  /*padding=*/"SAME",
6303                  /*data_format=*/"NCHW",
6304                  /*dilations=*/{1, 1, 1, 1},
6305                  /*expected_output_dims=*/{1, 1, 2, 4},
6306                  /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}},
6307       // Transpose Strided NHWC
6308       TestParams{/*input_dims=*/{1, 2, 2, 1},
6309                  /*input=*/{0, 1, 2, 3},
6310                  /*filter_dims=*/{1, 2, 1, 1},
6311                  /*filter=*/{-1, 1},
6312                  /*strides=*/{1, 1, 2, 1},
6313                  /*padding=*/"SAME",
6314                  /*data_format=*/"NHWC",
6315                  /*dilations=*/{1, 1, 1, 1},
6316                  /*expected_output_dims=*/{1, 2, 4, 1},
6317                  /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}},
6318       // Transpose Strided NHWC with VALID padding
6319       TestParams{/*input_dims=*/{1, 3, 1, 1},
6320                  /*input=*/{0, 1, 2},
6321                  /*filter_dims=*/{2, 1, 1, 1},
6322                  /*filter=*/{-1, 1},
6323                  /*strides=*/{1, 2, 1, 1},
6324                  /*padding=*/"VALID",
6325                  /*data_format=*/"NHWC",
6326                  /*dilations=*/{1, 1, 1, 1},
6327                  /*expected_output_dims=*/{1, 7, 1, 1},
6328                  /*expected_output=*/{0, 0, -1, 1, -2, 2, 0}},
6329       TestParams{/*input_dims=*/{1, 1, 2, 2},
6330                  /*input=*/{0, 1, 2, 3},
6331                  /*filter_dims=*/{1, 2, 1, 1},
6332                  /*filter=*/{-1, 1},
6333                  /*strides=*/{1, 1, 1, 2},
6334                  /*padding=*/"EXPLICIT",
6335                  /*data_format=*/"NCHW",
6336                  /*dilations=*/{1, 1, 1, 1},
6337                  /*expected_output_dims=*/{1, 1, 2, 4},
6338                  /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3},
6339                  errors::Unimplemented("EXPLICIT padding type not "
6340                                        "implemented, only VALID and SAME are"
6341                                        " supported")},
6342       // Dilation + Conv2DBackpropInput, should fail.
6343       TestParams{/*input_dims=*/{1, 1, 2, 2},
6344                  /*input=*/{0, 1, 2, 3},
6345                  /*filter_dims=*/{1, 2, 1, 1},
6346                  /*filter=*/{-1, 1},
6347                  /*strides=*/{1, 1, 1, 1},
6348                  /*padding=*/"SAME",
6349                  /*data_format=*/"NCHW",
6350                  /*dilations=*/{1, 1, 1, 2},
6351                  {1, 1, 2, 2},
6352                  {},
6353                  errors::Unimplemented("Dilation with Conv2DBackpropInput "
6354                                        "(conv2d_transpose) is not supported")},
6355   };
6356   if (trt_mode_ == TrtTestMode::kDynamicShape) {
6357     params.push_back(
6358         TestParams{/*input_dims=*/{1, 1, 2, 2},
6359                    /*input=*/{0, 1, 2, 3},
6360                    /*filter_dims=*/{1, 2, 1, 1},
6361                    /*filter=*/{-1, 1},
6362                    /*strides=*/{1, 1, 1, 2},
6363                    /*padding=*/"SAME",
6364                    /*data_format=*/"NCHW",
6365                    /*dilations=*/{1, 1, 1, 1},
6366                    /*expected_output_dims=*/{1, 1, 2, 4},
6367                    /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3},
6368                    errors::InvalidArgument("Channel dimension must be static"),
6369                    /*partial input dims=*/{1, -1, 2, 2}});
6370     // Test dynamic  batch dimension.
6371     params.push_back(
6372         TestParams{/*input_dims=*/{2, 1, 2, 2},
6373                    /*input=*/
6374                    // clang-format off
6375                       {0, 1, 2, 3,
6376                        3, 2, 1, 0},
6377                    // clang-format on
6378                    /*filter_dims=*/{1, 2, 1, 1},
6379                    /*filter=*/{-1, 1},
6380                    /*strides=*/{1, 1, 1, 2},
6381                    /*padding=*/"SAME",
6382                    /*data_format=*/"NCHW",
6383                    /*dilations=*/{1, 1, 1, 1},
6384                    /*expected_output_dims=*/{2, 1, 2, 4},
6385                    /*expected_output=*/
6386                    // clang-format off
6387                    { 0, 0, -1, 1, -2, 2, -3, 3,
6388                     -3, 3, -2, 2, -1, 1, 0, 0},
6389                    // clang-format on
6390                    /*conversion_status=*/Status::OK(),
6391                    /*partial input dims=*/{-1, 1, 2, 2}});
6392 
6393     // Test dynamic height and width.
6394     params.push_back(TestParams{
6395         /*input_dims=*/{1, 1, 2, 2},
6396         /*input=*/{0, 1, 2, 3},
6397         /*filter_dims=*/{1, 2, 1, 1},
6398         /*filter=*/{-1, 1},
6399         /*strides=*/{1, 1, 1, 2},
6400         /*padding=*/"SAME",
6401         /*data_format=*/"NCHW",
6402         /*dilations=*/{1, 1, 1, 1},
6403         /*expected_output_dims=*/{1, 1, 2, 4},
6404         /*expected_output=*/
6405         {0, 0, -1, 1, -2, 2, -3, 3},
6406         /*conversion_status=*/
6407         errors::Unimplemented(
6408             "Conv2dBackpropInput does not support input with unknown spatial "
6409             "shape"),
6410         /*partial input dims=*/{1, 1, -1, -1}});
6411   }
6412   for (auto p : params) {
6413     for (int input_sizes_length : {2, 4}) {
6414       Reset();
6415       NodeDef node_def = get_conv2d_backprop_input_nodedef(
6416           tf_type_, p.strides, p.padding, p.data_format, p.dilations);
6417 
6418       switch (trt_mode_) {
6419         case TrtTestMode::kImplicitBatch: {
6420           AddTestTensor("input", p.input_dims, p.input);
6421           break;
6422         }
6423         case TrtTestMode::kExplicitBatch: {
6424           AddTestTensor("input", p.input_dims, p.input);
6425           break;
6426         }
6427         case TrtTestMode::kDynamicShape: {
6428           AddTestTensor("input", p.input_dims, tf_type_, p.input,
6429                         p.partial_input_dims.size() > 0 ? p.partial_input_dims
6430                                                         : p.input_dims);
6431           break;
6432         }
6433         default: {
6434           ASSERT_TRUE(false) << "unknown test mode";
6435         }
6436       }
6437 
6438       AddTestWeights<float>("weights", p.filter_dims, p.filter, tf_type_);
6439 
6440       if (input_sizes_length == 4) {
6441         AddTestWeights<int>("input_sizes", {4}, p.expected_output_dims);
6442       } else {
6443         std::vector<int> tf_input_sizes(2);
6444         // Remove the channel and batch dimensions.
6445         if (p.data_format == "NHWC") {
6446           std::copy(p.expected_output_dims.begin() + 1,
6447                     p.expected_output_dims.end() - 1, tf_input_sizes.begin());
6448         } else {
6449           std::copy(p.expected_output_dims.begin() + 2,
6450                     p.expected_output_dims.end(), tf_input_sizes.begin());
6451         }
6452         QCHECK_EQ(2, tf_input_sizes.size());
6453         AddTestWeights<int>("input_sizes", {2}, tf_input_sizes);
6454       }
6455 
6456       TestOpConverter("my_conv2d_backprop_input", node_def,
6457                       p.expected_output_dims, p.conversion_status, Status::OK(),
6458                       ElementsAreArray(p.expected_output));
6459     }
6460   }
6461 }
6462 
6463 // Get the NodeDef for Pack.
GetConv3DNodeDef(std::vector<int> strides={1, 1, 1, 1, 1},string padding="SAME",string data_format="NCDHW",std::vector<int> dilations={1, 1, 1, 1, 1},bool is_conv3d_backprop_input=false)6464 NodeDef GetConv3DNodeDef(std::vector<int> strides = {1, 1, 1, 1, 1},
6465                          string padding = "SAME", string data_format = "NCDHW",
6466                          std::vector<int> dilations = {1, 1, 1, 1, 1},
6467                          bool is_conv3d_backprop_input = false) {
6468   Scope s = Scope::NewRootScope();
6469   auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
6470   auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
6471 
6472   if (is_conv3d_backprop_input) {
6473     auto input_sizes = ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32);
6474     ops::Conv3DBackpropInputV2::Attrs attrs =
6475         ops::Conv3DBackpropInputV2::Attrs()
6476             .DataFormat(data_format)
6477             .Dilations(dilations);
6478     auto conv3d =
6479         ops::Conv3DBackpropInputV2(s.WithOpName("my_conv3d"), input_sizes,
6480                                    filter, input, strides, padding, attrs);
6481     return conv3d.operation.node()->def();
6482   } else {
6483     ops::Conv3D::Attrs attrs =
6484         ops::Conv3D::Attrs().DataFormat(data_format).Dilations(dilations);
6485     auto conv3d = ops::Conv3D(s.WithOpName("my_conv3d"), input, filter, strides,
6486                               padding, attrs);
6487     return conv3d.operation.node()->def();
6488   }
6489 }
6490 
6491 struct Conv3DTestParams {
6492   std::vector<int> input_dims;
6493   std::vector<float> input;
6494   std::vector<int> filter_dims;
6495   std::vector<float> filter;
6496   std::vector<int> strides;
6497   string padding;
6498   string data_format;
6499   std::vector<int> dilations;
6500   bool is_conv3d_backprop;
6501   std::vector<int> expected_output_dims;
6502   std::vector<float> expected_output;
6503   bool allow_dynamic_channel_dim;
6504   Status validation_status;
6505 };
6506 
TestConv3D(ParameterizedOpConverterTestBase * test,Conv3DTestParams & p)6507 void TestConv3D(ParameterizedOpConverterTestBase* test, Conv3DTestParams& p) {
6508   test->Reset();
6509   NodeDef node_def = GetConv3DNodeDef(p.strides, p.padding, p.data_format,
6510                                       p.dilations, p.is_conv3d_backprop);
6511 
6512   std::vector<int> partial_input_shape;
6513   if (!p.allow_dynamic_channel_dim &&
6514       test->get_trt_mode() == TrtTestMode::kDynamicShape) {
6515     // The channel dim cannot have unknown size, fix that.
6516     partial_input_shape.resize(p.input_dims.size(), -1);
6517     int channel_id = (p.data_format == "NCDHW") ? 1 : 4;
6518     partial_input_shape[channel_id] = p.input_dims[channel_id];
6519   }
6520 
6521   test->AddTestTensor("input", p.input_dims, test->get_tf_type(), p.input,
6522                       partial_input_shape);
6523   test->AddTestWeights<float>("weights", p.filter_dims, p.filter);
6524 
6525   if (p.is_conv3d_backprop) {
6526     test->AddTestWeights<float>("input_sizes",
6527                                 {static_cast<int>(p.expected_output.size())},
6528                                 p.expected_output);
6529   }
6530 
6531   test->TestOpConverter("my_conv3d", node_def, p.expected_output_dims,
6532                         /*expected_conversion_status=*/p.validation_status,
6533                         /*expected_runtime_status=*/Status::OK(),
6534                         /*matcher=*/ElementsAreArray(p.expected_output),
6535                         /*out_tf_types=*/{test->get_tf_type()});
6536 }
6537 
TEST_P(OpConverter_FP32_FP16_Test,ConvertConv3D)6538 TEST_P(OpConverter_FP32_FP16_Test, ConvertConv3D) {
6539   {
6540     // Input is weights, should fail.
6541     Reset();
6542     NodeDef node_def = GetConv3DNodeDef();
6543 
6544     AddTestWeights<float>("input", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
6545     AddTestWeights<float>("weights", {1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6546     RunValidationAndConversion(
6547         node_def, error::UNIMPLEMENTED,
6548         "The input \"input\" for Conv3D must be a tensor");
6549   }
6550   {
6551     // Filter is tensor, should fail.
6552     Reset();
6553     NodeDef node_def = GetConv3DNodeDef();
6554     AddTestTensor("input", {1, 1, 2, 3}, tf_type_, CreateVectorIota<float>(6));
6555     AddTestTensor("weights", {1, 3, 3, 1}, tf_type_,
6556                   CreateVectorIota<float>(9));
6557     RunValidationAndConversion(
6558         node_def, error::UNIMPLEMENTED,
6559         "The input \"filter\" for Conv3D must be a constant");
6560   }
6561   {
6562     // Filter is not 5D, should fail.
6563     Reset();
6564     NodeDef node_def = GetConv3DNodeDef();
6565     AddTestTensor("input", {1, 1, 2, 3}, tf_type_, CreateVectorIota<float>(6));
6566     AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
6567     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
6568                                "Conv3D expects kernel of dimension 5");
6569   }
6570   {
6571     // Dilations is not 5D, should fail.
6572     Reset();
6573     NodeDef node_def =
6574         GetConv3DNodeDef({1, 1, 1, 1, 1}, "SAME", "NCDHW", {1, 1, 1, 1});
6575     AddTestTensor("input", {1, 1, 2, 3}, tf_type_, CreateVectorIota<float>(6));
6576     AddTestWeights<float>(
6577         "weights", {3, 3, 1, 1, 1},
6578         {1, 2, 3, 4, 5, 6, 7, 8, 9});  // Dimensions, then values
6579     RunValidationAndConversion(
6580         node_def, error::INVALID_ARGUMENT,
6581         "Convolution dilations field must specify 5 dimensions");
6582   }
6583   {
6584     // Dilation value is not 1 for channel, should fail.
6585     Reset();
6586     NodeDef node_def =
6587         GetConv3DNodeDef({1, 1, 1, 1, 1}, "SAME", "NCDHW", {1, 2, 1, 1, 1});
6588     AddTestTensor("input", {1, 1, 2, 3}, tf_type_, CreateVectorIota<float>(6));
6589     AddTestWeights<float>("weights", {3, 3, 1, 1, 1},
6590                           {1, 2, 3, 4, 5, 6, 7, 8, 9});
6591     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6592                                "Dilation rate must be 1 for batch and channel "
6593                                "dimensions");
6594   }
6595   {
6596     // Dilation value is not 1 for channel (NDHWC), should fail.
6597     Reset();
6598     NodeDef node_def =
6599         GetConv3DNodeDef({1, 1, 1, 1, 1}, "SAME", "NDHWC", {1, 1, 1, 1, 2});
6600     AddTestTensor("input", {1, 2, 3, 1}, tf_type_, CreateVectorIota<float>(6));
6601     AddTestWeights<float>("weights", {3, 3, 1, 1, 1},
6602                           {1, 2, 3, 4, 5, 6, 7, 8, 9});
6603     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6604                                "Dilation rate must be 1 for batch and channel "
6605                                "dimensions");
6606   }
6607   {
6608     // Dilation + Conv3DBackpropInputV2, should fail.
6609     Reset();
6610     NodeDef node_def = GetConv3DNodeDef({1, 1, 1, 1, 1}, "SAME", "NDHWC",
6611                                         {1, 1, 2, 1, 1}, true);
6612     AddTestTensor("input", {1, 2, 3, 1}, tf_type_, CreateVectorIota<float>(6));
6613     AddTestWeights<float>("weights", {3, 3, 1, 1, 1},
6614                           {1, 2, 3, 4, 5, 6, 7, 8, 9});
6615     AddTestWeights<int>("input_sizes", {4}, {1, 2, 3, 1});
6616     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6617                                "Dilation with Conv3DBackpropInputV2 "
6618                                "(conv3d_transpose) is not supported");
6619   }
6620   {
6621     // Asymmetric+ Conv3DBackpropInputV2, should fail.
6622     Reset();
6623     NodeDef node_def = GetConv3DNodeDef({1, 1, 1, 1, 1}, "SAME", "NDHWC",
6624                                         {1, 1, 1, 1, 1}, true);
6625     AddTestTensor("input", {1, 2, 2, 2}, tf_type_, CreateVectorIota<float>(8));
6626     AddTestWeights<float>("weights", {1, 1, 2, 1, 1}, {1, 1});
6627     AddTestWeights<int>("input_sizes", {8}, {1, 2, 3, 4, 5, 6, 7, 8});
6628     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
6629                                "Asymmetric padding with Conv3DBackpropInputV2 "
6630                                "(conv3d_transpose) is not supported");
6631   }
6632   {
6633     // Strides is not 5D, should fail.
6634     Reset();
6635     NodeDef node_def =
6636         GetConv3DNodeDef({1, 1, 1, 1, 1, 1}, "SAME", "NCDHW", {1, 1, 1, 1, 1});
6637     AddTestTensor("input", {1, 2, 2, 2}, tf_type_, CreateVectorIota<float>(8));
6638     AddTestWeights<float>("weights", {1, 1, 2, 1, 1}, {1, 1});
6639     RunValidationAndConversion(
6640         node_def, error::INVALID_ARGUMENT,
6641         "Convolution strides field must specify 5 dimensions");
6642   }
6643   {
6644     // Stride value is not 1 for channel, should fail.
6645     Reset();
6646     NodeDef node_def =
6647         GetConv3DNodeDef({1, 2, 1, 1, 1}, "SAME", "NCDHW", {1, 1, 1, 1, 1});
6648     AddTestTensor("input", {1, 1, 2, 3}, tf_type_, CreateVectorIota<float>(6));
6649     AddTestWeights<float>("weights", {3, 3, 1, 1, 1},
6650                           {1, 2, 3, 4, 5, 6, 7, 8, 9});
6651     RunValidationAndConversion(
6652         node_def, error::UNIMPLEMENTED,
6653         "Stride must be 1 for batch and channel dimensions");
6654   }
6655 
6656   // Start here
6657   std::vector<Conv3DTestParams> ok_params = {
6658       // Basic - just 1x1 conv - input = output
6659       {/*input_dims=*/{1, 1, 3, 3, 3},  // CDHW
6660        /*input=*/{1, 2,  15,  3, 6,  -3, 22, 1, 88, 56, 36, 1,  1, 105,
6661                   1, 16, -28, 1, 42, 9,  3,  1, 7,  1,  11, 61, 5},
6662        /*filter_dims=*/{1, 1, 1, 1, 1},  // DRSCK
6663        /*filter=*/{1},
6664        /*strides=*/{1, 1, 1, 1, 1},
6665        /*padding=*/"VALID",
6666        /*data_format=*/"NCDHW",
6667        /*dilations=*/{1, 1, 1, 1, 1},
6668        /*is_conv3d_backprop=*/false,
6669        /*expected_output_dims=*/{1, 1, 3, 3, 3},
6670        /*expected_output=*/{1,  2,  15, 3, 6,   -3, 22, 1,   88,
6671                             56, 36, 1,  1, 105, 1,  16, -28, 1,
6672                             42, 9,  3,  1, 7,   1,  11, 61,  5},
6673        /*allow_dynamic_channel_dim=*/false,
6674        /*validation_status=*/Status::OK()},
6675       // Basic - 2x1 filter
6676       {/*input_dims=*/{1, 1, 3, 3, 3},  // CDHW
6677        /*input=*/{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
6678                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6},
6679        /*filter_dims=*/{2, 1, 1, 1, 1},  // DRSCK
6680        /*filter=*/{1, 1},
6681        /*strides=*/{1, 1, 1, 1, 1},
6682        /*padding=*/"VALID",
6683        /*data_format=*/"NCDHW",
6684        /*dilations=*/{1, 1, 1, 1, 1},
6685        /*is_conv3d_backprop=*/false,
6686        /*expected_output_dims=*/{1, 1, 2, 3, 3},
6687        /*expected_output=*/
6688        {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7},
6689        /*allow_dynamic_channel_dim=*/false,
6690        /*validation_status=*/Status::OK()},
6691       // SAME padding (Asymmetric)
6692       {/*input_dims=*/{1, 1, 2, 3, 2},  // CDHW
6693        /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
6694        /*filter_dims=*/{2, 1, 1, 1, 1},  // DRSCK
6695        /*filter=*/{-1, 1},
6696        /*strides=*/{1, 1, 1, 1, 1},
6697        /*padding=*/"SAME",
6698        /*data_format=*/"NCDHW",
6699        /*dilations=*/{1, 1, 1, 1, 1},
6700        /*is_conv3d_backprop=*/false,
6701        /*expected_output_dims=*/{1, 1, 2, 3, 2},
6702        // Diff in first 2 depths is const 6.
6703        /*expected_output=*/{6, 6, 6, 6, 6, 6, -6, -7, -8, -9, -10, -11},
6704        /*allow_dynamic_channel_dim=*/false,
6705        /*validation_status=*/Status::OK()},
6706       // SAME padding (Symmetric)
6707       {/*input_dims=*/{1, 1, 2, 3, 2},  // CDHW
6708        /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
6709        /*filter_dims=*/{3, 1, 1, 1, 1},  // DRSCK
6710        /*filter=*/{-1, 0, 1},
6711        /*strides=*/{1, 1, 1, 1, 1},
6712        /*padding=*/"SAME",
6713        /*data_format=*/"NCDHW",
6714        /*dilations=*/{1, 1, 1, 1, 1},
6715        /*is_conv3d_backprop=*/false,
6716        /*expected_output_dims=*/{1, 1, 2, 3, 2},
6717        // Swaps front two depths, negates
6718        /*expected_output=*/{6, 7, 8, 9, 10, 11, 0, -1, -2, -3, -4, -5},
6719        /*allow_dynamic_channel_dim=*/false,
6720        /*validation_status=*/Status::OK()
6721 
6722       },
6723       // NDHWC (multi-channel)
6724       {/*input_dims=*/{1, 2, 3, 2, 2},  // DHWC
6725        /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
6726                   0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
6727        /*filter_dims=*/{2, 1, 1, 2, 1},  // DRSCK
6728        /*filter=*/{-1, 1, 1, -1},
6729        /*strides=*/{1, 1, 1, 1, 1},
6730        /*padding=*/"VALID",
6731        /*data_format=*/"NDHWC",
6732        /*dilations=*/{1, 1, 1, 1, 1},
6733        /*is_conv3d_backprop=*/false,
6734        /*expected_output_dims=*/{1, 1, 3, 2, 1},
6735        /*expected_output=*/{0, 0, 0, 0, 0, 0},  // Filters oppose each-other
6736        /*allow_dynamic_channel_dim=*/false,
6737        /*validation_status=*/Status::OK()},
6738       // Dilated
6739       {/*input_dims=*/{1, 1, 3, 3, 3},  // CDHW
6740        /*input=*/{1,   1,   1,   1,   1, 1, 1, 1, 1, -10, -10, -10, -10, -10,
6741                   -10, -10, -10, -10, 7, 7, 7, 7, 7, 7,   7,   7,   7},
6742        /*filter_dims=*/{2, 1, 1, 1, 1},  // DRSCK
6743        /*filter=*/{1, 1},
6744        /*strides=*/{1, 1, 1, 1, 1},
6745        /*padding=*/"VALID",
6746        /*data_format=*/"NCDHW",
6747        /*dilations=*/{1, 1, 2, 1, 1},
6748        /*is_conv3d_backprop=*/false,
6749        /*expected_output_dims=*/{1, 1, 1, 3, 3},
6750        // Only front depth is valid, skips neg values
6751        /*expected_output=*/{8, 8, 8, 8, 8, 8, 8, 8, 8},
6752        /*allow_dynamic_channel_dim=*/false,
6753        /*validation_status=*/Status::OK()},
6754       // Strided
6755       {/*input_dims=*/{1, 1, 3, 3, 3},
6756        /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0,
6757                   0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8},
6758        /*filter_dims=*/{1, 1, 1, 1, 1},
6759        /*filter=*/{1},
6760        /*strides=*/{1, 1, 2, 2, 2},
6761        /*padding=*/"VALID",
6762        /*data_format=*/"NCDHW",
6763        /*dilations=*/{1, 1, 1, 1, 1},
6764        /*is_conv3d_backprop=*/false,
6765        /*expected_output_dims=*/{1, 1, 2, 2, 2},
6766        // Should only pick up the corners
6767        /*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8},
6768        /*allow_dynamic_channel_dim=*/false,
6769        /*validation_status=*/Status::OK()},
6770       // Transpose Strided
6771       {/*input_dims=*/{1, 1, 2, 2, 2},  // CDHW
6772        /*input=*/{1, 2, 3, 4, 5, 6, 7, 8},
6773        /*filter_dims=*/{1, 1, 1, 1, 1},
6774        /*filter=*/{1},
6775        /*strides=*/{1, 1, 2, 2, 2},
6776        /*padding=*/"VALID",
6777        /*data_format=*/"NCDHW",
6778        /*dilations=*/{1, 1, 1, 1, 1},
6779        /*is_conv3d_backprop=*/true,
6780        /*expected_output_dims=*/{1, 1, 3, 3, 3},
6781        /*expected_output=*/{1, 0, 2, 0, 0, 0, 3, 0, 4,   // Cube expands and
6782                             0, 0, 0, 0, 0, 0, 0, 0, 0,   // fills center
6783                             5, 0, 6, 0, 0, 0, 7, 0, 8},  // with zeroes
6784        /*allow_dynamic_channel_dim=*/false,
6785        /*validation_status=*/Status::OK()},
6786   };
6787 
6788   if (trt_mode_ == TrtTestMode::kDynamicShape) {
6789     ok_params.reserve(ok_params.size() + 2);
6790     const std::vector<float> common_input = CreateVectorIota<float>(3 * 3 * 3);
6791     // NCDHW - Dynamic Channel - Should fail in kDynamicShape
6792     ok_params.push_back(Conv3DTestParams{
6793         /*input_dims=*/{1, 1, 3, 3, 3},
6794         /*input=*/common_input,
6795         /*filter_dims=*/{1, 1, 1, 1, 1},
6796         /*filter=*/{1},
6797         /*strides=*/{1, 1, 2, 2, 2},
6798         /*padding=*/"VALID",
6799         /*data_format=*/"NCDHW",
6800         /*dilations=*/{1, 1, 1, 1, 1},
6801         /*is_conv3d_backprop=*/false,
6802         /*expected_output_dims=*/{},  // ignore, will fail anyway
6803         /*expected_output=*/{},       // ignore, will fail anyway
6804         /*allow_dynamic_channel_dim=*/true,
6805         /*validation_status=*/
6806         Status{error::INVALID_ARGUMENT, "Channel dimension must be static"}});
6807     // NDHWC - Dynamic Channel - Should fail in kDynamicShape
6808     ok_params.push_back(Conv3DTestParams{
6809         /*input_dims=*/{1, 3, 3, 3, 1},
6810         /*input=*/common_input,
6811         /*filter_dims=*/{1, 1, 1, 1, 1},
6812         /*filter=*/{1},
6813         /*strides=*/{1, 2, 2, 2, 1},
6814         /*padding=*/"VALID",
6815         /*data_format=*/"NDHWC",
6816         /*dilations=*/{1, 1, 1, 1, 1},
6817         /*is_conv3d_backprop=*/false,
6818         /*expected_output_dims=*/{},  // ignore, will fail anyway
6819         /*expected_output=*/{},       // ignore, will fail anyway
6820         /*allow_dynamic_channel_dim=*/true,
6821         /*validation_status=*/
6822         Status{error::INVALID_ARGUMENT, "Channel dimension must be static"}});
6823   }
6824 
6825   for (auto p : ok_params) {
6826     TestConv3D(this, p);
6827   }
6828 }
6829 
6830 template <typename T>
CreatePoolOp(DataType tf_type,std::vector<int> ksize,std::vector<int> strides,string padding,string data_format)6831 NodeDef CreatePoolOp(DataType tf_type, std::vector<int> ksize,
6832                      std::vector<int> strides, string padding,
6833                      string data_format) {
6834   Scope s = Scope::NewRootScope();
6835   auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
6836   typename T::Attrs attrs;
6837   attrs.data_format_ = data_format;
6838   return T(s.WithOpName("my_pool"), input, ksize, strides, padding, attrs)
6839       .operation.node()
6840       ->def();
6841 }
TEST_P(OpConverter_FP32_Test,ConvertPool)6842 TEST_P(OpConverter_FP32_Test, ConvertPool) {
6843   // Get nodedef for MaxPool and AvgPool layers (2D or 3D).
6844   auto get_pool_nodedef =
6845       [](DataType tf_type, int nDim, std::vector<int> ksize = {},
6846          std::vector<int> strides = {}, string padding = "SAME",
6847          string data_format = "", const bool is_max_pooling = true) -> NodeDef {
6848     if (ksize.empty()) {
6849       ksize = nDim == 2 ? std::vector<int>{1, 1, 1, 1}
6850                         : std::vector<int>{1, 1, 1, 1, 1};
6851     }
6852     if (strides.empty()) {
6853       strides = nDim == 2 ? std::vector<int>{1, 1, 1, 1}
6854                           : std::vector<int>{1, 1, 1, 1, 1};
6855     }
6856     if (data_format == "") {
6857       data_format = nDim == 2 ? "NCHW" : "NCDHW";
6858     }
6859     if (is_max_pooling) {
6860       if (nDim == 3) {
6861         return CreatePoolOp<ops::MaxPool3D>(tf_type, ksize, strides, padding,
6862                                             data_format);
6863       } else {
6864         return CreatePoolOp<ops::MaxPool>(tf_type, ksize, strides, padding,
6865                                           data_format);
6866       }
6867     } else {
6868       if (nDim == 3) {
6869         return CreatePoolOp<ops::AvgPool3D>(tf_type, ksize, strides, padding,
6870                                             data_format);
6871       } else {
6872         return CreatePoolOp<ops::AvgPool>(tf_type, ksize, strides, padding,
6873                                           data_format);
6874       }
6875     }
6876   };
6877 
6878   std::vector<int> test_nDims{2, 3};
6879 
6880   for (int nDim : test_nDims) {
6881     // Input is weights, should fail.
6882     Reset();
6883     NodeDef node_def = get_pool_nodedef(tf_type_, nDim);
6884 
6885     AddTestWeights<float>("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
6886     RunValidationAndConversion(
6887         node_def, error::UNIMPLEMENTED,
6888         StrCat("The input \"input\" for ", node_def.op(), " must be a tensor"));
6889   }
6890 
6891   struct TestParams {
6892     std::vector<int> input_dims;
6893     std::vector<float> input;
6894     std::vector<int> ksize;
6895     std::vector<int> strides;
6896     string padding;
6897     string data_format;
6898     std::vector<int> expected_output_dims;
6899     // The expected outputs for the following operations: MaxPool2D, AvgPool2D,
6900     // MaxPool3D, AvgPool3D
6901     std::vector<std::vector<float>> expected_outputs;
6902   };
6903 
6904   // We use common_input as the input to test both 2D and 3D pooling operations,
6905   // to simplify TestParams. For 2D operations, only the first 1/3 of the values
6906   // are used.
6907   const std::vector<float> common_input{-4, 2,  15, 3, 6,   -3, 22, 1,   88,
6908                                         56, 36, 1,  1, 105, 1,  16, -28, 1,
6909                                         42, 9,  3,  1, 7,   1,  11, 61,  5};
6910   // The output of 2D ops for the case where the op is equivalent to the
6911   // identity op.
6912   const std::vector<float> common_2d_output{-4, 2, 15, 3, 6, -3, 22, 1, 88};
6913   std::vector<TestParams> ok_params = {
6914       // Basic - just 1x1 max pooling - input = output
6915       TestParams{
6916           /*input_dims=*/{1, 1, 3, 3, 3},
6917           /*input=*/common_input,
6918           /*ksize=*/{1, 1, 1, 1, 1},
6919           /*strides=*/{1, 1, 1, 1, 1},
6920           /*padding=*/"VALID",
6921           /*data_format=*/"NCDHW",
6922           /*expected_output_dims=*/{1, 1, 3, 3, 3},
6923           /*expected_outputs=*/
6924           {common_2d_output, common_2d_output, common_input, common_input}},
6925       // Basic - just 1x1 max pooling - input = output, SAME padding
6926       TestParams{
6927           /*input_dims=*/{1, 1, 3, 3, 3},
6928           /*input=*/common_input,
6929           /*ksize=*/{1, 1, 1, 1, 1},
6930           /*strides=*/{1, 1, 1, 1, 1},
6931           /*padding=*/"SAME",
6932           /*data_format=*/"NCDHW",
6933           /*expected_output_dims=*/{1, 1, 3, 3, 3},
6934           /*expected_outputs=*/
6935           {common_2d_output, common_2d_output, common_input, common_input}},
6936       // 3x3 pooling NCDHW
6937       TestParams{/*input_dims=*/{1, 1, 3, 3, 3},
6938                  /*input=*/common_input,
6939                  /*ksize=*/{1, 1, 3, 3, 3},
6940                  /*strides=*/{1, 1, 1, 1, 1},
6941                  /*padding=*/"VALID",
6942                  /*data_format=*/"NCDHW",
6943                  /*expected_output_dims=*/{1, 1, 1, 1, 1},
6944                  /*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}},
6945       // 3x3 pooling, NDHWC
6946       TestParams{/*input_dims=*/{1, 3, 3, 3, 1},
6947                  /*input=*/common_input,
6948                  /*ksize=*/{1, 3, 3, 3, 1},
6949                  /*strides=*/{1, 1, 1, 1, 1},
6950                  /*padding=*/"VALID",
6951                  /*data_format=*/"NDHWC",
6952                  /*expected_output_dims=*/{1, 1, 1, 1, 1},
6953                  /*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}},
6954       // Strided
6955       TestParams{/*input_dims=*/{1, 1, 3, 3, 3},
6956                  /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0,
6957                             0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8},
6958                  /*ksize=*/{1, 1, 1, 1, 1},
6959                  /*strides=*/{1, 1, 2, 2, 2},
6960                  /*padding=*/"VALID",
6961                  /*data_format=*/"NCDHW",
6962                  /*expected_output_dims=*/{1, 1, 2, 2, 2},
6963                  /*expected_outputs=*/
6964                  {{1, 2, 3, 4},  // Should only pick up the corners
6965                   {1, 2, 3, 4},
6966                   {1, 2, 3, 4, 5, 6, 7, 8},
6967                   {1, 2, 3, 4, 5, 6, 7, 8}}},
6968   };
6969 
6970   for (auto p : ok_params) {
6971     int test_counter = 0;
6972     for (int nDim : test_nDims) {
6973       auto input = p.input;
6974       auto input_dims = p.input_dims;
6975       auto ksize = p.ksize;
6976       auto strides = p.strides;
6977       auto expected_output_dims = p.expected_output_dims;
6978       std::string data_format = p.data_format;
6979       if (nDim == 2) {
6980         input.resize(9);
6981         data_format = p.data_format == "NDHWC" ? "NHWC" : "NCHW";
6982         // Remove one of the spatial dimensions
6983         input_dims.erase(input_dims.begin() + 2);
6984         ksize.erase(ksize.begin() + 2);
6985         strides.erase(strides.begin() + 2);
6986         expected_output_dims.erase(expected_output_dims.begin() + 2);
6987       }
6988       for (bool is_max_pooling : {true, false}) {
6989         Reset();
6990         NodeDef node_def =
6991             get_pool_nodedef(tf_type_, nDim, ksize, strides, p.padding,
6992                              data_format, is_max_pooling);
6993         AddTestTensor("input", input_dims, input);
6994         TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(),
6995                         Status::OK(),
6996                         ElementsAreArray(p.expected_outputs.at(test_counter)));
6997         test_counter++;
6998       }
6999     }
7000   }
7001 }
7002 
TEST_P(OpConverter_FP32_FP16_Test,ConvertTopK)7003 TEST_P(OpConverter_FP32_FP16_Test, ConvertTopK) {
7004   // Get the NodeDef for TopKV2.
7005   Scope s = Scope::NewRootScope();
7006   auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
7007   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
7008   auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights);
7009   const NodeDef& node_def = topk.operation.node()->def();
7010   {
7011     // K is a tensor, should fail.
7012     Reset();
7013     AddTestTensor("input", {1, 1, 2, 3});
7014     AddTestTensor("weights", {1}, DT_INT32, {});
7015     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
7016                                "The input \"k\" for TopKV2 must be a constant");
7017   }
7018   {
7019     // Ok.
7020     Reset();
7021     AddTestTensor("input", {1, 1, 2, 5}, {-9, 3, 5, 1, 6, -5, 7, 1, 0, -1});
7022     AddTestWeights<int32>("weights", {1}, {2});
7023     std::vector<std::vector<int>> expected_output_dims{{1, 1, 2, 2},
7024                                                        {1, 1, 2, 2}};
7025     TestOpConverterMultiOut("my_topk", node_def, expected_output_dims,
7026                             Status::OK(), Status::OK(),
7027                             {ElementsAre(6, 5, 7, 1), ElementsAre(4, 2, 1, 2)},
7028                             {tf_type_, DT_INT32});
7029   }
7030 }
7031 
7032 struct DataFormatVecPermuteTestParams {
7033   string dst_format;
7034   string src_format;
7035   std::vector<int> x_shape;
7036   std::vector<int> x;
7037   bool x_is_tensor;
7038   std::vector<int> expected_output;
7039   Status conversion_status;
7040 };
7041 
GetDataFormatVecPermuteNodeDef(string dst_format,string src_format,std::vector<int> & x_shape)7042 NodeDef GetDataFormatVecPermuteNodeDef(string dst_format, string src_format,
7043                                        std::vector<int>& x_shape) {
7044   Scope s = Scope::NewRootScope();
7045   PartialTensorShape tensor_shape;
7046   auto x = ops::Placeholder(s.WithOpName("x"), DT_INT32);
7047   const auto attrs = ops::DataFormatVecPermute::Attrs()
7048                          .DstFormat(dst_format)
7049                          .SrcFormat(src_format);
7050   auto dfvp = ops::DataFormatVecPermute(s.WithOpName("my_dfvp"), x, attrs);
7051   return dfvp.operation.node()->def();
7052 }
7053 
TEST_P(OpConverter_INT32_Test,ConvertDataFormatVecPermute)7054 TEST_P(OpConverter_INT32_Test, ConvertDataFormatVecPermute) {
7055   Status implicit_error = Status{
7056       error::UNIMPLEMENTED, "Implicit batch mode not supported, at my_dfvp"};
7057 
7058   std::vector<DataFormatVecPermuteTestParams> test_params = {
7059       // 1D case with tensor.
7060       DataFormatVecPermuteTestParams{
7061           /*dst_format=*/"NCHW",
7062           /*src_format=*/"NHWC",
7063           /*x_shape=*/{4},
7064           /*x=*/{1, 2, 3, 4},
7065           /*x_is_tensor=*/true,
7066           /*expected_output=*/{1, 4, 2, 3},
7067           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7068               ? implicit_error
7069               : Status::OK()},
7070       // 1D case with weights.
7071       DataFormatVecPermuteTestParams{
7072           /*dst_format=*/"NCHW",
7073           /*src_format=*/"NHWC",
7074           /*x_shape=*/{4},
7075           /*x=*/{1, 2, 3, 4},
7076           /*x_is_tensor=*/false,
7077           /*expected_output=*/{1, 4, 2, 3},
7078           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7079               ? implicit_error
7080               : Status::OK()},
7081       // 2D case with tensor.
7082       DataFormatVecPermuteTestParams{
7083           /*dst_format=*/"NCHW",
7084           /*src_format=*/"NHWC",
7085           /*x_shape=*/{4, 2},
7086           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8},
7087           /*x_is_tensor=*/true,
7088           /*expected_output=*/{1, 2, 7, 8, 3, 4, 5, 6},
7089           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7090               ? implicit_error
7091               : Status::OK()},
7092       // 2D case with weights.
7093       DataFormatVecPermuteTestParams{
7094           /*dst_format=*/"NCHW",
7095           /*src_format=*/"NHWC",
7096           /*x_shape=*/{4, 2},
7097           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8},
7098           /*x_is_tensor=*/false,
7099           /*expected_output=*/{1, 2, 7, 8, 3, 4, 5, 6},
7100           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7101               ? implicit_error
7102               : Status::OK()},
7103       // Format of size 5.
7104       DataFormatVecPermuteTestParams{
7105           /*dst_format=*/"NCDHW",
7106           /*src_format=*/"NDHWC",
7107           /*x_shape=*/{5, 2},
7108           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
7109           /*x_is_tensor=*/true,
7110           /*expected_output=*/{1, 2, 9, 10, 3, 4, 5, 6, 7, 8},
7111           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7112               ? implicit_error
7113               : Status::OK()},
7114       // Input of size 2: treat the elements as spatial dimensions.
7115       DataFormatVecPermuteTestParams{
7116           /*dst_format=*/"NCWH",
7117           /*src_format=*/"NHWC",
7118           /*x_shape=*/{2, 2},
7119           /*x=*/{1, 2, 3, 4},
7120           /*x_is_tensor=*/true,
7121           /*expected_output=*/{3, 4, 1, 2},
7122           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7123               ? implicit_error
7124               : Status::OK()},
7125       // Input of size 3: treat the elements as spatial dimensions.
7126       DataFormatVecPermuteTestParams{
7127           /*dst_format=*/"NCHWD",
7128           /*src_format=*/"NDHWC",
7129           /*x_shape=*/{3},
7130           /*x=*/{1, 2, 3},
7131           /*x_is_tensor=*/true,
7132           /*expected_output=*/{2, 3, 1},
7133           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7134               ? implicit_error
7135               : Status::OK()},
7136       // Invalid rank, should fail.
7137       DataFormatVecPermuteTestParams{
7138           /*dst_format=*/"NCHW",
7139           /*src_format=*/"NHWC",
7140           /*x_shape=*/{2, 2, 2},
7141           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8},
7142           /*x_is_tensor=*/true,
7143           /*expected_output=*/{},
7144           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7145               ? implicit_error
7146               : Status{error::INVALID_ARGUMENT,
7147                        "Input must be a vector or matrix, but got rank 3, at "
7148                        "my_dfvp"}},
7149       // Invalid size for 1D input, should fail.
7150       DataFormatVecPermuteTestParams{
7151           /*dst_format=*/"NCHW",
7152           /*src_format=*/"NHWC",
7153           /*x_shape=*/{3},
7154           /*x=*/{1, 2, 3},
7155           /*x_is_tensor=*/true,
7156           /*expected_output=*/{},
7157           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7158               ? implicit_error
7159               : Status{error::INVALID_ARGUMENT,
7160                        "1D input must be of size 2 or 4, but got size 3, at "
7161                        "my_dfvp"}},
7162       // Invalid first dim for 2D input, should fail.
7163       DataFormatVecPermuteTestParams{
7164           /*dst_format=*/"NCDHW",
7165           /*src_format=*/"NDHWC",
7166           /*x_shape=*/{4, 2},
7167           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8},
7168           /*x_is_tensor=*/true,
7169           /*expected_output=*/{},
7170           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7171               ? implicit_error
7172               : Status{error::INVALID_ARGUMENT,
7173                        "First dimension of 2D input must be of size 3 or 5, "
7174                        "but got shape (4, 2), at my_dfvp"}},
7175       // Invalid second dim for 2D input, should fail.
7176       DataFormatVecPermuteTestParams{
7177           /*dst_format=*/"NCHW",
7178           /*src_format=*/"NHWC",
7179           /*x_shape=*/{4, 3},
7180           /*x=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
7181           /*x_is_tensor=*/true,
7182           /*expected_output=*/{},
7183           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7184               ? implicit_error
7185               : Status{error::INVALID_ARGUMENT,
7186                        "Second dimension of 2D input must be of size 2, but "
7187                        "got shape (4, 3), at my_dfvp"}},
7188   };
7189 
7190   for (auto p : test_params) {
7191     Reset();
7192     const NodeDef node_def =
7193         GetDataFormatVecPermuteNodeDef(p.dst_format, p.src_format, p.x_shape);
7194 
7195     if (p.x_is_tensor) {
7196       AddTestTensor("x", p.x_shape, DT_INT32, p.x, p.x_shape);
7197     } else {
7198       AddTestWeights("x", p.x_shape, p.x, DT_INT32);
7199     }
7200 
7201     TestOpConverter("my_dfvp", node_def, p.x_shape, p.conversion_status,
7202                     Status::OK(), ElementsAreArray(p.expected_output));
7203   }
7204 }
7205 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertGather)7206 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertGather) {
7207   // Get the NodeDef for GatherV2.
7208   Scope s = Scope::NewRootScope();
7209   auto params = ops::Placeholder(s.WithOpName("params"), tf_type_);
7210   auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
7211   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
7212   auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
7213   const NodeDef& node_def = gather.operation.node()->def();
7214   {
7215     // Axis is a tensor, should fail.
7216     Reset();
7217     AddTestTensor("params", {1, 1, 2, 3}, tf_type_, {});
7218     AddTestTensor("indices", {1, 2}, DT_INT32, {});
7219     AddTestTensor("axis", {1}, DT_INT32, {});
7220     RunValidationAndConversion(
7221         node_def, error::UNIMPLEMENTED,
7222         "The input \"axis\" for GatherV2 must be a constant");
7223   }
7224   {
7225     // Axis is out of bounds, should fail.
7226     Reset();
7227     AddTestTensor("params", {1, 1, 2, 3});
7228     AddTestTensor("indices", {1, 2}, DT_INT32, {});
7229     AddTestWeights<int32>("axis", {1}, {4});
7230     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
7231                                "Axis value of 4 is out of bounds, must be in "
7232                                "range [-4, 4)");
7233   }
7234 
7235   struct TestParams {
7236     // TF shape of the input 'params' (including batch dimension).
7237     std::vector<int> params_shape;
7238     // TF shape of the input 'indices' (including batch dimension).
7239     std::vector<int> indices_shape;
7240     std::vector<int> indices;
7241     int axis;
7242     // Expected TF shape of the output (including batch dimension).
7243     std::vector<int> expected_output_shape;
7244     std::vector<int> expected_output;
7245     bool params_is_tensor;
7246     bool indices_is_tensor;
7247     Status conversion_status;
7248     Status runtime_status;
7249     Status add_index_status;
7250   };
7251 
7252   // Input is the same {1, 2, 3, 4, 5, 6} for all cases.
7253   const std::vector<int> params_input = {1, 2, 3, 4, 5, 6};
7254 
7255   std::vector<TestParams> test_params = {
7256       // Axis is batch dimension, should fail in implicit batch mode.
7257       TestParams{/*params_shape=*/{2, 1, 1, 3},
7258                  /*indices_shape=*/{2},
7259                  /*indices=*/{1, 0},
7260                  /*axis=*/0,
7261                  /*expected_output_shape=*/{2, 1, 1, 3},
7262                  /*expected_output=*/{4, 5, 6, 1, 2, 3},
7263                  /*params_is_tensor=*/true,
7264                  /*indices_is_tensor=*/true,
7265                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7266                      ? Status{error::UNIMPLEMENTED,
7267                               "TensorRT does not allow "
7268                               "manipulation of the batch dimension"}
7269                      : Status::OK()},
7270       // Batch size of indices is not 1 when params and indices are tensors.
7271       TestParams{/*params_shape=*/{2, 1, 3},
7272                  /*indices_shape=*/{2, 1},
7273                  /*indices=*/{2, 0},
7274                  /*axis=*/2,
7275                  /*expected_output_shape=*/{2, 1, 2, 1},
7276                  /*expected_output=*/{3, 1, 6, 4},
7277                  /*params_is_tensor=*/true,
7278                  /*indices_is_tensor=*/true,
7279                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7280                      ? Status{error::UNIMPLEMENTED,
7281                               "Params and indices must have a"
7282                               " batch size of 1 when params and indices are "
7283                               "both tensors or both"
7284                               " constants."}
7285                      : Status::OK()},
7286       // Batch size of indices is not 1 when params is tensor and indices are
7287       // constant.
7288       TestParams{/*params_shape=*/{2, 1, 3},
7289                  /*indices_shape=*/{2, 1},
7290                  /*indices=*/{2, 0},
7291                  /*axis=*/2,
7292                  /*expected_output_shape=*/{2, 1, 2, 1},
7293                  /*expected_output=*/{3, 1, 6, 4},
7294                  /*params_is_tensor=*/true,
7295                  /*indices_is_tensor=*/false,
7296                  /*conversion_status=*/Status::OK()},
7297       // Axis is not zero when params is a weight, should fail in implicit batch
7298       // mode.
7299       TestParams{/*params_shape=*/{2, 1, 3},
7300                  /*indices_shape=*/{2},
7301                  /*indices=*/{1, 2},
7302                  /*axis=*/2,
7303                  /*expected_output_shape=*/{2, 1, 2},
7304                  /*expected_output=*/{2, 3, 5, 6},
7305                  /*params_is_tensor=*/false,
7306                  /*indices_is_tensor=*/true,
7307                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7308                      ? Status{error::UNIMPLEMENTED,
7309                               "The input axis must be zero when "
7310                               "params is a weight."}
7311                      : Status::OK()},
7312       // Params with only batch dimension.
7313       TestParams{
7314           /*params_shape=*/{6},
7315           /*indices_shape=*/{2},
7316           /*indices=*/{1, 3},
7317           /*axis=*/0,
7318           /*expected_output_shape=*/{2},
7319           /*expected_output=*/{2, 4},
7320           /*params_is_tensor=*/true,
7321           /*indices_is_tensor=*/true,
7322           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7323               ? Status{error::UNIMPLEMENTED,
7324                        "TensorRT does not allow "
7325                        "manipulation of the batch dimension"}
7326               : Status::OK(),
7327           /*runtime_status=*/Status::OK(),
7328           /*add_index_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7329               ? Status{error::INVALID_ARGUMENT,
7330                        "Batch size doesn't match for "
7331                        "tensor indices: Provided batch size does not match "
7332                        "converter batch size: 2 vs 6"}
7333               : Status::OK()},
7334       // Vector indices, and output rank is rank(params).
7335       TestParams{
7336           /*params_shape=*/{1, 1, 2, 3},
7337           /*indices_shape=*/{1},
7338           /*indices=*/{0},
7339           /*axis=*/3,
7340           /*expected_output_shape=*/{1, 1, 2, 1},
7341           /*expected_output=*/{1, 4},
7342           /*params_is_tensor=*/true,
7343           /*indices_is_tensor=*/true,
7344       },
7345       TestParams{
7346           /*params_shape=*/{1, 1, 2, 3},
7347           /*indices_shape=*/{1},
7348           /*indices=*/{1},
7349           /*axis=*/2,
7350           /*expected_output_shape=*/{1, 1, 1, 3},
7351           /*expected_output=*/{4, 5, 6},
7352           /*params_is_tensor=*/true,
7353           /*indices_is_tensor=*/true,
7354       },
7355       // Indices with rank>1, and output rank is rank(params) + rank(indices) -
7356       // 1
7357       TestParams{
7358           /*params_shape=*/{1, 1, 2, 3},
7359           /*indices_shape=*/{1, 1},
7360           /*indices=*/{0},
7361           /*axis=*/3,
7362           /*expected_output_shape=*/{1, 1, 2, 1, 1},
7363           /*expected_output=*/{1, 4},
7364           /*params_is_tensor=*/true,
7365           /*indices_is_tensor=*/true,
7366       },
7367       TestParams{
7368           /*params_shape=*/{1, 1, 2, 3},
7369           /*indices_shape=*/{1, 1},
7370           /*indices=*/{1},
7371           /*axis=*/3,
7372           /*expected_output_shape=*/{1, 1, 2, 1, 1},
7373           /*expected_output=*/{2, 5},
7374           /*params_is_tensor=*/true,
7375           /*indices_is_tensor=*/true,
7376       },
7377       TestParams{
7378           /*params_shape=*/{1, 1, 2, 3},
7379           /*indices_shape=*/{1, 1},
7380           /*indices=*/{2},
7381           /*axis=*/-1,
7382           /*expected_output_shape=*/{1, 1, 2, 1, 1},
7383           /*expected_output=*/{3, 6},
7384           /*params_is_tensor=*/true,
7385           /*indices_is_tensor=*/true,
7386       },
7387       TestParams{
7388           /*params_shape=*/{1, 1, 2, 3},
7389           /*indices_shape=*/{1, 3},
7390           /*indices=*/{2, 0, 1},
7391           /*axis=*/3,
7392           /*expected_output_shape=*/{1, 1, 2, 1, 3},
7393           /*expected_output=*/{3, 1, 2, 6, 4, 5},
7394           /*params_is_tensor=*/true,
7395           /*indices_is_tensor=*/true,
7396       },
7397       TestParams{
7398           /*params_shape=*/{1, 3, 2},
7399           /*indices_shape=*/{1, 2, 2},
7400           /*indices=*/{0, 0, 1, 0},
7401           /*axis=*/2,
7402           /*expected_output_shape=*/{1, 3, 1, 2, 2},
7403           /*expected_output=*/{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5},
7404           /*params_is_tensor=*/true,
7405           /*indices_is_tensor=*/true,
7406       },
7407       TestParams{
7408           /*params_shape=*/{1, 2, 3},
7409           /*indices_shape=*/{1},
7410           /*indices=*/{0},
7411           /*axis=*/0,
7412           /*expected_output_shape=*/{1, 2, 3},
7413           /*expected_output=*/{1, 2, 3, 4, 5, 6},
7414           /*params_is_tensor=*/false,
7415           /*indices_is_tensor=*/true,
7416       },
7417       TestParams{
7418           /*params_shape=*/{3, 2},
7419           /*indices_shape=*/{1, 2},
7420           /*indices=*/{0, 1},
7421           /*axis=*/0,
7422           /*expected_output_shape=*/{1, 2, 2},
7423           /*expected_output=*/{1, 2, 3, 4},
7424           /*params_is_tensor=*/false,
7425           /*indices_is_tensor=*/true,
7426       },
7427       TestParams{
7428           /*params_shape=*/{2, 3},
7429           /*indices_shape=*/{1, 1, 2},
7430           /*indices=*/{0, 1},
7431           /*axis=*/0,
7432           /*expected_output_shape=*/{1, 1, 2, 3},
7433           /*expected_output=*/{1, 2, 3, 4, 5, 6},
7434           /*params_is_tensor=*/false,
7435           /*indices_is_tensor=*/true,
7436       },
7437       TestParams{
7438           /*params_shape=*/{3, 2},
7439           /*indices_shape=*/{2, 2},
7440           /*indices=*/{0, 2, 1, 0},
7441           /*axis=*/0,
7442           /*expected_output_shape=*/{2, 2, 2},
7443           /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
7444           /*params_is_tensor=*/false,
7445           /*indices_is_tensor=*/true,
7446       },
7447       // Test cases in which indices constant
7448       TestParams{
7449           /*params_shape=*/{1, 1, 2, 3},
7450           /*indices_shape=*/{1, 1},
7451           /*indices=*/{0},
7452           /*axis=*/3,
7453           /*expected_output_shape=*/{1, 1, 2, 1, 1},
7454           /*expected_output=*/{1, 4},
7455           /*params_is_tensor=*/true,
7456           /*indices_is_tensor=*/false,
7457       },
7458       // Test cases in which both input and indices constant
7459       TestParams{/*params_shape=*/{1, 2, 3},
7460                  /*indices_shape=*/{1},
7461                  /*indices=*/{0},
7462                  /*axis=*/0,
7463                  /*expected_output_shape=*/{1, 2, 3},
7464                  /*expected_output=*/{1, 2, 3, 4, 5, 6},
7465                  /*params_is_tensor=*/false,
7466                  /*indices_is_tensor=*/false,
7467                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7468                      ? Status{error::UNIMPLEMENTED,
7469                               "Params and indices must have a"
7470                               " batch size of 1 when params and indices are "
7471                               "both tensors or both"
7472                               " constants."}
7473                      : Status::OK()},
7474       TestParams{/*params_shape=*/{3, 2},
7475                  /*indices_shape=*/{2, 2},
7476                  /*indices=*/{0, 2, 1, 0},
7477                  /*axis=*/0,
7478                  /*expected_output_shape=*/{2, 2, 2},
7479                  /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
7480                  /*params_is_tensor=*/false,
7481                  /*indices_is_tensor=*/false,
7482                  /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7483                      ? Status{error::UNIMPLEMENTED,
7484                               "Params and indices must have a"
7485                               " batch size of 1 when params and indices are "
7486                               "both tensors or both"
7487                               " constants."}
7488                      : Status::OK()},
7489   };
7490 
7491   for (auto p : test_params) {
7492     Reset();
7493 
7494     if (p.params_is_tensor) {
7495       AddTestTensor("params", p.params_shape, params_input);
7496     } else {
7497       AddTestWeights("params", p.params_shape, params_input, tf_type_);
7498     }
7499 
7500     if (p.indices_is_tensor) {
7501       AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
7502                     p.add_index_status);
7503     } else {
7504       std::vector<int> indices_shape(p.indices_shape);
7505       AddTestWeights("indices", indices_shape, p.indices, DT_INT32);
7506     }
7507 
7508     AddTestWeights<int32>("axis", {1}, {p.axis});
7509     TestOpConverter("my_gather", node_def, p.expected_output_shape,
7510                     p.conversion_status, p.runtime_status,
7511                     ElementsAreArray(p.expected_output));
7512   }
7513 }
7514 
7515 template <typename OpType>
CreateReduceOp(DataType tf_type,bool keep_dims)7516 NodeDef CreateReduceOp(DataType tf_type, bool keep_dims) {
7517   Scope s = Scope::NewRootScope();
7518   auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
7519   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
7520   typename OpType::Attrs op_attrs;
7521   op_attrs.keep_dims_ = keep_dims;
7522   auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs);
7523   return op.operation.node()->def();
7524 }
7525 
7526 // Applies reduction op on sub-sequences of input
7527 // output[i] = reduce(input[m * i : m * (i +1)])
CalcReduce(string op_name,std::vector<float> input,int m,float (* op)(float,float),float init)7528 std::vector<float> CalcReduce(string op_name, std::vector<float> input, int m,
7529                               float (*op)(float, float), float init) {
7530   std::vector<float> output(input.size() / m);
7531   for (int i = 0; i < output.size(); i++) {
7532     auto begin = input.begin() + i * m;
7533     auto end = input.begin() + (i + 1) * m;
7534     output[i] = std::accumulate(begin, end, init, op);
7535     if (op_name == "Mean") {
7536       output[i] /= m;
7537     }
7538   }
7539   return output;
7540 }
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertReduce)7541 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertReduce) {
7542   {
7543     // Input is weights, should fail.
7544     Reset();
7545     const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type_, false);
7546     AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
7547     AddTestWeights<int32>("axis", {1}, {1});
7548     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
7549                                "The input \"input\" for Sum must be a tensor");
7550   }
7551   {
7552     // Axis is weights, should fail.
7553     Reset();
7554     const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type_, false);
7555     AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
7556     AddTestTensor("axis", {1}, DT_INT32, {1});
7557     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
7558                                "The input \"axis\" for Sum must be a constant");
7559   }
7560   using OpFunc = std::function<NodeDef(DataType, bool)>;
7561   using ValFunc = float (*)(float, float);
7562   struct ReduceTestDescriptor {
7563     string name;
7564     OpFunc get_node;
7565     ValFunc val_func;
7566     float init_val;
7567   };
7568   std::vector<ReduceTestDescriptor> op_test_info{
7569       {"Sum", CreateReduceOp<ops::Sum>, [](float x, float y) { return x + y; },
7570        0},
7571       {"Prod", CreateReduceOp<ops::Prod>,
7572        [](float x, float y) { return x * y; }, 1},
7573       {"Mean", CreateReduceOp<ops::Mean>,
7574        [](float x, float y) { return x + y; }, 0},
7575       {"Min", CreateReduceOp<ops::Min>,
7576        [](float x, float y) { return y < x ? y : x; }, 1000},
7577       {"Max", CreateReduceOp<ops::Max>,
7578        [](float x, float y) { return x < y ? y : x; }, -1000}};
7579 
7580   std::vector<float> input_values{1, 2, 3, 4, 5, 6};
7581   struct TestParams {
7582     std::vector<int> input_dims;
7583     std::vector<float> input_values;
7584     // Helper array contains the same elements as input but permuted in a way
7585     // that the reduction can be calculated over contiguous elements using
7586     // CalcReduce
7587     std::vector<float> helper_array;
7588     std::vector<int> axis;
7589     int stride;  // product of input_dims along axis
7590     Status conversion_status;
7591   };
7592   std::vector<TestParams> params{
7593       // Out of range tests
7594       TestParams{{2, 3, 1}, input_values, input_values, {3}, 3},
7595       TestParams{{2, 3, 1}, input_values, input_values, {-4}, 3},
7596       // Ok tests
7597       TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {0}, 2},
7598       TestParams{{2, 3, 1}, input_values, input_values, {1}, 3},
7599       TestParams{{2, 3, 1}, input_values, input_values, {2}, 1},
7600       TestParams{{2, 3, 1}, input_values, input_values, {0, 1}, 6},
7601       // Ok tests with negative axis values
7602       TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {-3}, 2},
7603       TestParams{{2, 3, 1}, input_values, input_values, {-2}, 3},
7604       TestParams{{2, 3, 1}, input_values, input_values, {-1}, 1},
7605       TestParams{{2, 3, 1}, input_values, input_values, {-3, 1}, 6},
7606   };
7607 
7608   for (bool keep_dims : {false, true}) {
7609     for (auto& op : op_test_info) {
7610       VLOG(2) << "Processing " << op.name << " with keep_dims=" << keep_dims;
7611       for (auto p : params) {
7612         SCOPED_TRACE(StrCat(op.name, keep_dims ? " & keep_dims" : ""));
7613         Reset();
7614         NodeDef node_def = op.get_node(tf_type_, keep_dims);
7615 
7616         AddTestTensor("input", p.input_dims, p.input_values);
7617         AddTestWeights<int32>("axis", {static_cast<int>(p.axis.size())},
7618                               p.axis);
7619         std::vector<int> expected_output_dims(p.input_dims);
7620 
7621         // Set expected output dim and conversion error messages
7622         for (int ax : p.axis) {
7623           int rank = p.input_dims.size();
7624           if (ax >= rank || ax < -rank) {
7625             p.conversion_status =
7626                 errors::InvalidArgument("Axis value of ", ax,
7627                                         " is out of bounds, must be in "
7628                                         "range [",
7629                                         -rank, ", ", rank, ")");
7630           } else {
7631             int ax_positive = ax >= 0 ? ax : ax + rank;
7632             // Zero marks elements that we will remove later.
7633             expected_output_dims[ax_positive] = keep_dims ? 1 : 0;
7634             if (trt_mode_ == TrtTestMode::kImplicitBatch &&
7635                 (ax == 0 || ax == -rank)) {
7636               p.conversion_status = errors::Unimplemented(
7637                   "TensorRT does not allow manipulation of the batch "
7638                   "dimension");
7639             }
7640           }
7641         }
7642         expected_output_dims.erase(std::remove(expected_output_dims.begin(),
7643                                                expected_output_dims.end(), 0),
7644                                    expected_output_dims.end());
7645         VLOG(2) << "out dims "
7646                 << absl::StrCat("[", absl::StrJoin(expected_output_dims, ","),
7647                                 "]");
7648         std::vector<float> expected_values = CalcReduce(
7649             op.name, p.helper_array, p.stride, op.val_func, op.init_val);
7650 
7651         if (tf_type_ == DT_INT32) {
7652           // We need to floor the float values in the `expected_values` vector.
7653           std::for_each(expected_values.begin(), expected_values.end(),
7654                         [](float& _n) { _n = std::floor(_n); });
7655         }
7656 
7657         TestOpConverter("my_reduce", node_def, expected_output_dims,
7658                         p.conversion_status, Status::OK(),
7659                         ArrayFloatNear(expected_values));
7660       }
7661     }
7662   }
7663 }
7664 
CreateCastOp(DataType tf_type)7665 NodeDef CreateCastOp(DataType tf_type) {
7666   Scope s = Scope::NewRootScope();
7667   auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
7668   return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
7669       .operation.node()
7670       ->def();
7671 }
7672 
TEST_P(OpConverter_FP32_UnaryTest,ConvertUnary)7673 TEST_P(OpConverter_FP32_UnaryTest, ConvertUnary) {
7674   using OpFunc = std::function<NodeDef(DataType)>;
7675   using ValFunc = float (*)(float);
7676   std::map<std::string, std::pair<OpFunc, ValFunc>> op_map;
7677 #define ADD_OP(name, op, compute) \
7678   op_map[name] =                  \
7679       std::make_pair(CreateUnaryOp<op>, static_cast<ValFunc>(compute))
7680   ADD_OP("Abs", ops::Abs, std::abs);
7681   ADD_OP("Acos", ops::Acos, std::acos);
7682   ADD_OP("Acosh", ops::Acosh, std::acosh);
7683   ADD_OP("Asin", ops::Asin, std::asin);
7684   ADD_OP("Asinh", ops::Asinh, std::asinh);
7685   ADD_OP("Atan", ops::Atan, std::atan);
7686   ADD_OP("Atanh", ops::Atanh, std::atanh);
7687   op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; });
7688   ADD_OP("Ceil", ops::Ceil, std::ceil);
7689   ADD_OP("Cos", ops::Cos, std::cos);
7690   ADD_OP("Cosh", ops::Cosh, std::cosh);
7691   ADD_OP("Exp", ops::Exp, std::exp);
7692   ADD_OP("Erf", ops::Erf, std::erf);
7693   ADD_OP("Floor", ops::Floor, std::floor);
7694   ADD_OP("Log", ops::Log, std::log);
7695   ADD_OP("Neg", ops::Neg, [](float x) { return -x; });
7696   ADD_OP("Reciprocal", ops::Reciprocal, [](float x) { return 1.0f / x; });
7697 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
7698   ADD_OP("Round", ops::Round, [](float x) { return (float)std::round(x); });
7699   ADD_OP("Sign", ops::Sign,
7700          [](float x) { return x > 0 ? 1.0f : (x < 0 ? -1.0f : 0.0f); });
7701 #endif
7702   ADD_OP("Rsqrt", ops::Rsqrt, [](float x) { return 1.0f / std::sqrt(x); });
7703   ADD_OP("Sin", ops::Sin, std::sin);
7704   ADD_OP("Sinh", ops::Sinh, std::sinh);
7705   ADD_OP("Sqrt", ops::Sqrt, std::sqrt);
7706   ADD_OP("Tan", ops::Tan, std::tan);
7707 #undef ADD_OP
7708 
7709   std::vector<float> input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
7710   RunTests("Unary", *UnaryOperationMap(), op_map, input_values, "x");
7711 }
7712 
TEST_P(OpConverter_BOOL_Test,ConvertBoolean)7713 TEST_P(OpConverter_BOOL_Test, ConvertBoolean) {
7714   std::vector<int> input_values{1, 0, 1, 0, 0, 1};
7715   using OpFunc = std::function<NodeDef(DataType)>;
7716 
7717   using ValFunc = int (*)(int);
7718   std::map<std::string, std::pair<OpFunc, ValFunc>> op_map;
7719 #define ADD_OP(name, op, compute) \
7720   op_map[name] =                  \
7721       std::make_pair(CreateUnaryOp<op>, static_cast<ValFunc>(compute))
7722   ADD_OP("LogicalNot", ops::LogicalNot, [](int x) { return 1 - x; });
7723 #undef ADD_OP
7724 
7725 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
7726   // The test does not actually run for TPT versions less than 8.2
7727   RunTests("LogicalUnary", *UnaryBooleanOperationMap(), op_map, input_values,
7728            "x");
7729 #endif
7730 }
7731 
7732 // Get the NodeDef for ConcatV2.
7733 // TODO(hinsu): Consider switching this to static function.
__anond617510d3a02(DataType dtype, int num_inputs) 7734 auto get_concat_nodedef = [](DataType dtype, int num_inputs) -> NodeDef {
7735   Scope s = Scope::NewRootScope();
7736   std::vector<Input> values;
7737   values.reserve(num_inputs);
7738   for (int i = 0; i < num_inputs; ++i) {
7739     const string input_name = StrCat("values_", i);
7740     values.push_back(ops::Placeholder(s.WithOpName(input_name), dtype));
7741   }
7742   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
7743   auto concat = ops::Concat(s.WithOpName("my_concat"),
7744                             absl::Span<const Input>(values), axis);
7745   return concat.operation.node()->def();
7746 };
7747 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertConcat)7748 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertConcat) {
7749   {
7750     // Axis is a tensor, should fail.
7751     Reset();
7752     NodeDef node_def = get_concat_nodedef(tf_type_, 2);
7753     AddTestTensor("values_0", {1, 1, 2, 3});
7754     AddTestTensor("values_1", {1, 1, 2, 3});
7755     AddTestTensor("axis", {1});
7756     RunValidationAndConversion(
7757         node_def, error::UNIMPLEMENTED,
7758         "The input \"axis\" for ConcatV2 must be a constant");
7759   }
7760   {
7761     // Axis is out of bounds, should fail.
7762     Reset();
7763     NodeDef node_def = get_concat_nodedef(tf_type_, 2);
7764     AddTestTensor("values_0", {1, 1, 2, 3});
7765     AddTestTensor("values_1", {1, 1, 2, 3});
7766     AddTestWeights<int32>("axis", {1}, {4});
7767     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
7768                                "Axis value of 4 is out of bounds, must be in "
7769                                "range [-4, 4)");
7770   }
7771   {
7772     // Inputs have inconsistent ranks, should fail.
7773     Reset();
7774     NodeDef node_def = get_concat_nodedef(tf_type_, 2);
7775     AddTestTensor("values_0", {1, 1, 2, 3});
7776     AddTestTensor("values_1", {1, 1, 6});
7777     AddTestWeights<int32>("axis", {1}, {1});
7778     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
7779                                "Received inputs with inconsistent rank");
7780   }
7781 
7782   struct TestParams {
7783     std::vector<std::vector<int>> input_shapes;
7784     std::vector<std::vector<int>> input_values;
7785     std::vector<bool> inputs_are_tensors;
7786     int axis;
7787     std::vector<int> expected_output_dims;
7788     std::vector<int> expected_output;
7789     Status conversion_status;
7790     Status run_status;
7791   };
7792 
7793   const std::vector<std::vector<int>> common_input{CreateVectorIota<int>(6),
7794                                                    CreateVectorIota<int>(6, 6)};
7795 
7796   std::vector<TestParams> params = {
7797       {
7798           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7799           /*input_values=*/common_input,
7800           /*inputs_are_tensors=*/{true, true},
7801           /*axis=*/1,
7802           /*expected_output_dims=*/{1, 2, 2, 3},
7803           /*expected_output=*/CreateVectorIota<int>(12),
7804       },
7805       {
7806           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7807           /*input_values=*/common_input,
7808           /*inputs_are_tensors=*/{true, true},
7809           /*axis=*/2,
7810           /*expected_output_dims=*/{1, 1, 4, 3},
7811           /*expected_output=*/CreateVectorIota<int>(12),
7812       },
7813       {
7814           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7815           /*input_values=*/common_input,
7816           /*inputs_are_tensors=*/{true, true},
7817           /*axis=*/3,
7818           /*expected_output_dims=*/{1, 1, 2, 6},
7819           /*expected_output=*/
7820           {0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11},
7821       },
7822       {
7823           /*input_shapes=*/{{1, 1}, {1, 2}, {1, 3}, {1, 1}, {1, 1}, {1, 2}},
7824           /*input_values=*/
7825           {{1}, {2, 3}, {4, 5, 6}, {7}, {8}, {9, 10}},
7826           /*inputs_are_tensors=*/{true, true, true, true, true, true},
7827           /*axis=*/1,
7828           /*expected_output_dims=*/{1, 10},
7829           /*expected_output=*/
7830           CreateVectorIota<int>(10, /*start_value=*/1),
7831       },
7832       {
7833           // An input is a weight
7834           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7835           /*input_values=*/common_input,
7836           /*inputs_are_tensors=*/{true, false},
7837           /*axis=*/1,
7838           /*expected_output_dims=*/{1, 2, 2, 3},
7839           /*expected_output=*/CreateVectorIota<int>(12),
7840           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7841               ? errors::Unimplemented(
7842                     "The input \"values_1\" for ConcatV2 must be a tensor")
7843               : Status::OK(),
7844           /*run_status=*/Status::OK(),
7845       },
7846       {
7847           // An input is a weight
7848           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7849           /*input_values=*/common_input,
7850           /*inputs_are_tensors=*/{false, false},
7851           /*axis=*/1,
7852           /*expected_output_dims=*/{1, 2, 2, 3},
7853           /*expected_output=*/CreateVectorIota<int>(12),
7854           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7855               ? errors::Unimplemented(
7856                     "The input \"values_0\" for ConcatV2 must be a tensor")
7857               : Status::OK(),
7858           /*run_status=*/Status::OK(),
7859       },
7860       {
7861           // Axis is batch dimension, should fail in implicit batch mode.
7862           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 2, 3}},
7863           /*input_values=*/common_input,
7864           /*inputs_are_tensors=*/{true, true},
7865           /*axis=*/0,
7866           /*expected_output_dims=*/{2, 1, 2, 3},
7867           /*expected_output=*/CreateVectorIota<int>(12),
7868           /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
7869               ? errors::Unimplemented(
7870                     "TensorRT does not allow manipulation of the "
7871                     "batch dimension")
7872               : Status::OK(),
7873       },
7874       {
7875           // Inconsistent input shape, runtime error in dynamic shape mode.
7876           /*input_shapes=*/{{1, 1, 2, 3}, {1, 1, 3, 2}},
7877           /*input_values=*/common_input,
7878           /*inputs_are_tensors=*/{true, true},
7879           /*axis=*/1,
7880           /*expected_output_dims=*/{2, 1, 2, 3},
7881           /*expected_output=*/CreateVectorIota<int>(12),
7882           trt_mode_ != TrtTestMode::kDynamicShape
7883               ? errors::InvalidArgument(
7884                     "Received inputs with inconsistent shape")
7885               : Status::OK(),
7886           errors::InvalidArgument(""),
7887       }};
7888 
7889   for (auto p : params) {
7890     Reset();
7891     const int num_inputs = p.input_shapes.size();
7892     EXPECT_EQ(num_inputs, p.input_values.size());
7893 
7894     NodeDef node_def = get_concat_nodedef(tf_type_, num_inputs);
7895 
7896     // Create inputs.
7897     for (int j = 0; j < num_inputs; ++j) {
7898       string name = StrCat("values_", j);
7899 
7900       if (!p.inputs_are_tensors[j]) {
7901         AddTestWeights(name, p.input_shapes[j], p.input_values[j], tf_type_);
7902       } else {
7903         AddTestTensor(name, p.input_shapes[j], p.input_values[j]);
7904       }
7905     }
7906     AddTestWeights<int32>("axis", {1}, {p.axis});
7907 
7908     TestOpConverter("my_concat", node_def, p.expected_output_dims,
7909                     p.conversion_status, p.run_status,
7910                     ElementsAreArray(p.expected_output));
7911   }
7912 }
7913 
7914 // Get the NodeDef for Split.
__anond617510d3b02(DataType dtype, int num_split) 7915 auto get_split_nodedef = [](DataType dtype, int num_split) -> NodeDef {
7916   Scope s = Scope::NewRootScope();
7917   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
7918   auto value = ops::Placeholder(s.WithOpName("value"), dtype);
7919   auto split = ops::Split(s.WithOpName("my_split"), axis, value, num_split);
7920   return split.operation.node()->def();
7921 };
7922 
7923 template <DataType dtype>
TestConvertSplit(OpConverterTest * test)7924 void TestConvertSplit(OpConverterTest* test) {
7925   typedef typename EnumToDataType<dtype>::Type CType;
7926 
7927   struct TestParams {
7928     std::vector<int> input_shape;
7929     std::vector<CType> value;
7930     int axis;
7931     int num_split;
7932     std::vector<int> expected_output_dims;
7933     std::vector<std::vector<CType>> expected_outputs;
7934   };
7935 
7936   const std::vector<CType> common_input = CreateVectorIota<CType>(6);
7937   std::vector<TestParams> ok_params = {
7938       // Identity (num_split = 1)
7939       {/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1,
7940        /*num_split=*/1, /*expected_output_dims=*/{1, 2, 3},
7941        /*expected_outputs=*/{CreateVectorIota<CType>(6)}},
7942       {/*input_shape=*/{1, 2, 3},
7943        /*value=*/common_input,
7944        /*axis=*/3,
7945        /*num_split=*/3,
7946        /*expected_output_dims=*/{1, 2, 1},
7947        /*expected_outputs=*/
7948        {{CType(0), CType(3)}, {CType(1), CType(4)}, {CType(2), CType(5)}}},
7949       {/*input_shape=*/{1, 6},
7950        /*value=*/common_input,
7951        /*axis=*/2,
7952        /*num_split=*/6,
7953        /*expected_output_dims=*/{1, 1},
7954        /*expected_outputs=*/
7955        {{CType(0)},
7956         {CType(1)},
7957         {CType(2)},
7958         {CType(3)},
7959         {CType(4)},
7960         {CType(5)}}},
7961       {/*input_shape=*/{1, 6},
7962        /*value=*/common_input,
7963        /*axis=*/-1,
7964        /*num_split=*/2,
7965        /*expected_output_dims=*/{1, 3},
7966        /*expected_outputs=*/
7967        {CreateVectorIota<CType>(3), CreateVectorIota<CType>(3, CType(3))}},
7968   };
7969 
7970   for (int i = 0; i < ok_params.size(); ++i) {
7971     test->Reset();
7972     NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
7973     // Create inputs.
7974     test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
7975     nvinfer1::DataType trt_type;
7976     TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
7977     test->AddTestTensor("value", ok_params[i].input_shape, 1, trt_type);
7978     // Convert.
7979     test->RunValidationAndConversion(node_def);
7980 
7981     // Get output tensors and verify output dims.
7982     EXPECT_EQ(ok_params[i].expected_outputs.size(), ok_params[i].num_split);
7983     std::vector<TRT_TensorOrWeights> outputs(ok_params[i].num_split);
7984     DataVec output_data;
7985     for (int j = 0; j < outputs.size(); ++j) {
7986       const string name = j == 0 ? StrCat("my_split") : StrCat("my_split:", j);
7987       TF_EXPECT_OK(test->GetTensorOrWeights(name, &outputs[j]));
7988       EXPECT_TRUE(outputs[j].is_tensor());
7989       EXPECT_THAT(outputs[j].tensor()->getDimensions(),
7990                   DimsAreArray(ok_params[i].expected_output_dims));
7991       // Create buffer to store output.
7992       output_data.push_back(
7993           {name, test->ConstructTensor<CType>(
7994                      ok_params[i].expected_outputs[j].size())});
7995     }
7996 
7997     // Verify output values are correct.
7998     const DataVec input_data{
7999         {"value", test->AsTensor<CType>(ok_params[i].value)}};
8000     TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
8001     for (int j = 0; j < outputs.size(); ++j) {
8002       EXPECT_THAT(GetSpanForData<CType>(output_data[j]),
8003                   ElementsAreArray(ok_params[i].expected_outputs[j]));
8004     }
8005   }
8006 }
8007 
TEST_F(OpConverterTest,ConvertSplit)8008 TEST_F(OpConverterTest, ConvertSplit) {
8009   {
8010     // Axis is a tensor, should fail.
8011     Reset();
8012     NodeDef node_def = get_split_nodedef(DT_FLOAT, 1);
8013     AddTestTensor("axis", {1});
8014     AddTestTensor("value", {1, 2, 3});
8015     RunValidationAndConversion(
8016         node_def, error::UNIMPLEMENTED,
8017         "The input \"axis\" for Split must be a constant");
8018   }
8019   {
8020     // Axis is out of bounds, should fail.
8021     Reset();
8022     NodeDef node_def = get_split_nodedef(DT_FLOAT, 1);
8023     AddTestWeights<int32>("axis", {1}, {4});
8024     AddTestTensor("value", {1, 2, 3});
8025     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8026                                "Axis value of 4 is out of bounds, must be in "
8027                                "range [-4, 4)");
8028   }
8029   {
8030     // Axis is out of bounds (negative), should fail.
8031     Reset();
8032     NodeDef node_def = get_split_nodedef(DT_FLOAT, 1);
8033     AddTestWeights<int32>("axis", {1}, {-5});
8034     AddTestTensor("value", {1, 2, 3});
8035     RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8036                                "Axis value of -5 is out of bounds, must be in "
8037                                "range [-4, 4)");
8038   }
8039   {
8040     // Axis is batch dimension, should fail.
8041     Reset();
8042     NodeDef node_def = get_split_nodedef(DT_FLOAT, 1);
8043     AddTestWeights<int32>("axis", {1}, {0});
8044     AddTestTensor("value", {1, 2, 3});
8045     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8046                                "TensorRT does not allow manipulation of the "
8047                                "batch dimension");
8048   }
8049   {
8050     // Value is a weight, should fail.
8051     Reset();
8052     NodeDef node_def = get_split_nodedef(DT_FLOAT, 1);
8053     AddTestWeights<int32>("axis", {1}, {1});
8054     AddTestWeights<float>("value", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
8055     RunValidationAndConversion(
8056         node_def, error::UNIMPLEMENTED,
8057         "The input \"value\" for Split must be a tensor");
8058   }
8059   {
8060     // Dim is not evenly divisibly by num_split, should fail.
8061     Reset();
8062     NodeDef node_def = get_split_nodedef(DT_FLOAT, 2);
8063     AddTestWeights<int32>("axis", {1}, {3});
8064     AddTestTensor("value", {1, 2, 3});
8065     RunValidationAndConversion(
8066         node_def, error::INVALID_ARGUMENT,
8067         "Dimension 3 of size 3 is not evenly divisible by 2");
8068   }
8069   {
8070     // num_split > dim size, should fail.
8071     Reset();
8072     NodeDef node_def = get_split_nodedef(DT_FLOAT, 4);
8073     AddTestWeights<int32>("axis", {1}, {3});
8074     AddTestTensor("value", {1, 2, 3});
8075     RunValidationAndConversion(
8076         node_def, error::INVALID_ARGUMENT,
8077         "Dimension 3 of size 3 is not evenly divisible by 4");
8078   }
8079 
8080   TestConvertSplit<DT_FLOAT>(this);
8081   TestConvertSplit<DT_HALF>(this);
8082   TestConvertSplit<DT_INT32>(this);
8083 }
8084 
8085 // Get the NodeDef for Unpack (Unstack in TF API).
__anond617510d3c02(DataType dtype, int num, int axis) 8086 auto get_unpack_nodedef = [](DataType dtype, int num, int axis) -> NodeDef {
8087   Scope s = Scope::NewRootScope();
8088   auto value = ops::Placeholder(s.WithOpName("value"), dtype);
8089   auto unstack_attrs = ops::Unstack::Axis(axis);
8090   auto unstack =
8091       ops::Unstack(s.WithOpName("my_unpack"), value, num, unstack_attrs);
8092   return unstack.operation.node()->def();
8093 };
8094 
8095 struct UnpackTestParams {
8096   std::vector<int> input_shape;
8097   std::vector<float> input_value;
8098   int axis;
8099   int num;
8100   std::vector<int> expected_output_dims;
8101   std::vector<std::vector<float>> expected_outputs;
8102   Status run_status;
8103 };
8104 
TestConvertUnpack(ParameterizedOpConverterTestBase * test,UnpackTestParams & p)8105 void TestConvertUnpack(ParameterizedOpConverterTestBase* test,
8106                        UnpackTestParams& p) {
8107   test->Reset();
8108   NodeDef node_def = get_unpack_nodedef(test->get_tf_type(), p.num, p.axis);
8109   // Create inputs.
8110   test->AddTestTensor("value", p.input_shape, test->get_tf_type(),
8111                       p.input_value);
8112 
8113   std::vector<Matcher<std::vector<float>>> matcher_vec;
8114   std::vector<DataType> datatype_vec;
8115   std::vector<std::vector<int>> expected_output_dims;
8116 
8117   for (int j = 0; j < p.expected_outputs.size(); ++j) {
8118     matcher_vec.push_back(ElementsAreArray(p.expected_outputs[j]));
8119     datatype_vec.push_back(test->get_tf_type());
8120     expected_output_dims.push_back(p.expected_output_dims);
8121   }
8122 
8123   test->TestOpConverterMultiOut(/*name=*/"my_unpack",
8124                                 /*node_def=*/node_def,
8125                                 /*expected_output_dims=*/expected_output_dims,
8126                                 /*expected_conversion_status=*/p.run_status,
8127                                 /*expected_runtime_status=*/p.run_status,
8128                                 /*matcher=*/matcher_vec,
8129                                 /*out_tf_type=*/datatype_vec);
8130 }
8131 
8132 // TODO: Reactivate when INT32 Segfault fixed
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertUnpack)8133 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertUnpack) {
8134   // We need to skip error testing for Dynamic Shape mode, as it is impossible
8135   // to convert Unpack in Dynamic Shape Mode.
8136   if (trt_mode_ != TrtTestMode::kDynamicShape) {
8137     {
8138       // Value is weights, should fail.
8139       Reset();
8140       NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/3, /*axis=*/3);
8141       AddTestWeights<float>("value", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
8142       RunValidationAndConversion(
8143           node_def, error::UNIMPLEMENTED,
8144           "The input \"value\" for Unpack must be a tensor");
8145     }
8146     {
8147       // Axis is out of bounds, should fail.
8148       Reset();
8149       NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/1, /*axis=*/4);
8150       AddTestTensor("value", {1, 1, 2, 3});
8151       RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8152                                  "Axis value of 4 is out of bounds, must be in "
8153                                  "range [-4, 4)");
8154     }
8155     {
8156       // Axis is out of bounds (negative), should fail.
8157       Reset();
8158       NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/1, /*axis=*/-5);
8159       AddTestTensor("value", {1, 1, 2, 3});
8160       RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8161                                  "Axis value of -5 is out of bounds, must be "
8162                                  "in range [-4, 4)");
8163     }
8164     {
8165       if (trt_mode_ != TrtTestMode::kExplicitBatch) {
8166         // Axis is batch dimension, should fail.
8167         Reset();
8168         NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/1, /*axis=*/0);
8169         AddTestTensor("value", {1, 2, 3});
8170         RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8171                                    "TensorRT does not allow manipulation of "
8172                                    "the batch dimension");
8173       }
8174     }
8175     {
8176       // Dim size does not match num, should fail.
8177       Reset();
8178       NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/5, /*axis=*/2);
8179       AddTestTensor("value", {1, 1, 6});
8180       RunValidationAndConversion(
8181           node_def, error::INVALID_ARGUMENT,
8182           "Dimension 2 has size 6 which is not equal to num of 5");
8183     }
8184     {
8185       // Output would be TF scalar, should fail.
8186       Reset();
8187       NodeDef node_def = get_unpack_nodedef(tf_type_, /*num=*/1, /*axis=*/0);
8188       AddTestTensor(
8189           "value", {}, tf_type_, {}, {},
8190           trt_mode_ == TrtTestMode::kImplicitBatch
8191               ? errors::InvalidArgument(
8192                     "removing first dim requires explicit batch dimension")
8193               : Status::OK());
8194       if (trt_mode_ == TrtTestMode::kImplicitBatch) {
8195         RunValidationAndConversion(
8196             node_def, error::INTERNAL,
8197             "Failed to convert at least one input to a TRT_TensorOrWeights: "
8198             "Scalar input tensor is not supported since the first dimension is "
8199             "treated as batch dimension by TRT");
8200       } else {
8201         RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8202                                    "Input \"value\" for Unpack must be rank 2 "
8203                                    "or greater");
8204       }
8205     }
8206   }
8207 
8208   const std::vector<float> common_input = CreateVectorIota<float>(6);
8209 
8210   Status run_status =
8211       trt_mode_ == TrtTestMode::kDynamicShape
8212           ? errors::InvalidArgument(
8213                 "The argument `strided_slice_spec` is "
8214                 "`std::nullopt` with `dynamic_input_size_indices` non empty.")
8215           : Status::OK();
8216 
8217   std::vector<UnpackTestParams> params = {
8218       {/*input_shape=*/{1, 1, 2, 1, 3, 1},
8219        /*input_value=*/common_input,
8220        /*axis=*/4,
8221        /*num=*/3,
8222        /*expected_output_dims=*/{1, 1, 2, 1, 1},
8223        /*expected_outputs=*/{{0, 3}, {1, 4}, {2, 5}},
8224        /*run_status=*/run_status},
8225       {/*input_shape=*/{1, 1, 2, 1, 3},
8226        /*input_value=*/common_input,
8227        /*axis=*/4,
8228        /*num=*/3,
8229        /*expected_output_dims=*/{1, 1, 2, 1},
8230        /*expected_outputs=*/{{0, 3}, {1, 4}, {2, 5}},
8231        /*run_status=*/run_status},
8232       {/*input_shape=*/{1, 1, 2, 3},
8233        /*input_value=*/common_input,
8234        /*axis=*/1,
8235        /*num=*/1,
8236        /*expected_output_dims=*/{1, 2, 3},
8237        /*expected_outputs=*/{CreateVectorIota<float>(6)},
8238        /*run_status=*/run_status},
8239       {/*input_shape=*/{1, 6, 1},
8240        /*input_value=*/common_input,
8241        /*axis=*/-2,
8242        /*num=*/6,
8243        /*expected_output_dims=*/{1, 1},
8244        /*expected_outputs=*/{{0}, {1}, {2}, {3}, {4}, {5}},
8245        /*run_status=*/run_status},
8246       {/*input_shape=*/{1, 6},
8247        /*input_value=*/common_input,
8248        /*axis=*/1,
8249        /*num=*/6,
8250        /*expected_output_dims=*/{1},
8251        /*expected_outputs=*/{{0}, {1}, {2}, {3}, {4}, {5}},
8252        /*run_status=*/run_status},
8253   };
8254   for (auto p : params) {
8255     TestConvertUnpack(this, p);
8256   }
8257 }
8258 
8259 // Get the NodeDef for Pack.
GetPackNodeDef(DataType dtype,int num_inputs,int axis)8260 NodeDef GetPackNodeDef(DataType dtype, int num_inputs, int axis) {
8261   Scope s = Scope::NewRootScope();
8262   std::vector<Input> values;
8263   values.reserve(num_inputs);
8264   for (int i = 0; i < num_inputs; ++i) {
8265     const string input_name = StrCat("values_", i);
8266     values.push_back(ops::Placeholder(s.WithOpName(input_name), dtype));
8267   }
8268   // Pack op is renamed to Stack in APIs.
8269   auto pack =
8270       ops::Stack(s.WithOpName("my_pack"), absl::Span<const Input>(values),
8271                  ops::Stack::Axis(axis));
8272   return pack.operation.node()->def();
8273 }
8274 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertPack)8275 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertPack) {
8276   struct TestParams {
8277     std::vector<std::vector<int>> input_shapes;
8278     std::vector<std::vector<int>> partial_input_shapes;
8279     std::vector<std::vector<float>> input_values;
8280     int axis;
8281     std::vector<int> expected_output_dims;
8282     std::vector<float> expected_output;
8283     Status conversion_status;
8284     Status runtime_status;
8285     bool input_1_is_weight;
8286   };
8287 
8288   const std::vector<std::vector<float>> common_input{
8289       CreateVectorIota<float>(6),
8290       CreateVectorIota<float>(6, /*start_value=*/6)};
8291   std::vector<TestParams> params = {
8292       // Second input is weight, should fail in implicit batch mode
8293       {/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8294        /*partial_input_shapes=*/{{}, {}},
8295        /*input_values=*/common_input,
8296        /*axis=*/1,
8297        /*expected_output_dims=*/{1, 2, 2, 3},
8298        /*expected_output=*/CreateVectorIota<float>(12),
8299        trt_mode_ == TrtTestMode::kImplicitBatch
8300            ? Status{error::UNIMPLEMENTED,
8301                     "The input \"values_1\" for Pack must be a tensor"}
8302            : Status::OK(),
8303        /*runtime_status*/ Status::OK(),
8304        /*weight_input*/ true},
8305       // Axis is out of bounds, should fail.
8306       {
8307           /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8308           /*partial_input_shapes=*/{{}, {}},
8309           /*input_values=*/common_input,
8310           /*axis=*/-5,
8311           /*expected_output_dims=*/{},
8312           /*expected_output=*/{},
8313           Status{error::INVALID_ARGUMENT,
8314                  "Axis value of -5 is out of bounds, must be in"
8315                  " range [-4, 4)"},
8316       },
8317       // Axis is batch dimension, should fail in implicit batch mode.
8318       {/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8319        /*partial_input_shapes=*/{{}, {}},
8320        /*input_values=*/common_input,
8321        /*axis=*/-4,
8322        /*expected_output_dims=*/{2, 1, 2, 3},
8323        /*expected_output=*/CreateVectorIota<float>(12),
8324        trt_mode_ == TrtTestMode::kImplicitBatch
8325            ? Status{error::UNIMPLEMENTED,
8326                     "TensorRT does not allow manipulation of the batch "
8327                     "dimension"}
8328            : Status::OK()},
8329       // Inconsistent rank, should fail.
8330       {
8331           /*input_shapes=*/{{1, 2, 3}, {1, 6}},
8332           /*partial_input_shapes=*/{{}, {}},
8333           /*input_values=*/common_input,
8334           /*axis=*/1,
8335           /*expected_output_dims=*/{},
8336           /*expected_output=*/{},
8337           Status{error::INVALID_ARGUMENT,
8338                  "Received inputs with inconsistent rank"},
8339       },
8340       {
8341           /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8342           /*partial_input_shapes=*/{{}, {}},
8343           /*input_values=*/common_input,
8344           /*axis=*/1,
8345           /*expected_output_dims=*/{1, 2, 2, 3},
8346           /*expected_output=*/CreateVectorIota<float>(12),
8347       },
8348       {
8349           /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8350           /*partial_input_shapes=*/{{}, {}},
8351           /*input_values=*/common_input,
8352           /*axis=*/2,
8353           /*expected_output_dims=*/{1, 2, 2, 3},
8354           /*expected_output=*/
8355           {0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11},
8356       },
8357       {
8358           /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8359           /*partial_input_shapes=*/{{}, {}},
8360           /*input_values=*/common_input,
8361           /*axis=*/3,
8362           /*expected_output_dims=*/{1, 2, 3, 2},
8363           /*expected_output=*/
8364           {0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11},
8365       },
8366       {
8367           /*input_shapes=*/{{1, 2, 3}},
8368           /*partial_input_shapes=*/{{}},
8369           /*input_values=*/{CreateVectorIota<float>(6)},
8370           /*axis=*/1,
8371           /*expected_output_dims=*/{1, 1, 2, 3},
8372           /*expected_output=*/CreateVectorIota<float>(6),
8373       },
8374       {
8375           /*input_shapes=*/{{1, 2, 3}},
8376           /*partial_input_shapes=*/{{}},
8377           /*input_values=*/{CreateVectorIota<float>(6)},
8378           /*axis=*/2,
8379           /*expected_output_dims=*/{1, 2, 1, 3},
8380           /*expected_output=*/CreateVectorIota<float>(6),
8381       },
8382   };
8383   // Inputs have inconsistent shapes, should fail.
8384   if (trt_mode_ != TrtTestMode::kDynamicShape) {
8385     params.push_back(
8386         TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 3, 2}},
8387                    /*partial_input_shapes=*/{{}, {}},
8388                    /*input_values=*/common_input,
8389                    /*axis=*/1,
8390                    /*expected_output_dims=*/{},
8391                    /*expected_output=*/CreateVectorIota<float>(12),
8392                    Status{error::INVALID_ARGUMENT,
8393                           "Received inputs with inconsistent shape"}});
8394   } else {
8395     // In dynamic shape mode we cannot catch inconsistent shapes at conversion
8396     // time, only during runtime. But TensorRT does not raise a proper runtime
8397     // error, instead it aborts the program with the following message:
8398     //  Assertion failed: t->start.d[i] + t->extent.d[i] <= r.dims.d[i]
8399     // ../builder/cudnnBuilderGraph.cpp:862
8400     // Aborting...
8401     // TODO(tfeher) Add dynamic shapes test once TRT handles shape error
8402     // decently
8403   }
8404   if (trt_mode_ == TrtTestMode::kDynamicShape) {
8405     // Test with mixed dynamic / static shape input tensors
8406     params.push_back(
8407         TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
8408                    /*partial_input_shapes=*/{{-1, -1, -1}, {1, 2, 3}},
8409                    /*input_values=*/common_input,
8410                    /*axis=*/2,
8411                    /*expected_output_dims=*/{1, 2, 2, 3},
8412                    /*expected_output=*/
8413                    {0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11}});
8414   }
8415   for (auto p : params) {
8416     Reset();
8417     const int num_inputs = p.input_shapes.size();
8418     EXPECT_EQ(num_inputs, p.input_values.size());
8419 
8420     NodeDef node_def = GetPackNodeDef(tf_type_, num_inputs, p.axis);
8421     // Create inputs.
8422     for (int j = 0; j < num_inputs; ++j) {
8423       if (j == 1 && p.input_1_is_weight) {
8424         AddTestWeights(StrCat("values_", j), p.input_shapes[j],
8425                        p.input_values[j], tf_type_);
8426       } else {
8427         AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type_,
8428                       p.input_values[j], p.partial_input_shapes[j]);
8429       }
8430     }
8431     TestOpConverter("my_pack", node_def, p.expected_output_dims,
8432                     p.conversion_status, p.runtime_status,
8433                     ElementsAreArray(p.expected_output));
8434   }
8435 }
8436 
8437 // Get the NodeDef for ArgMin or ArgMax.
8438 template <typename OpType>
GetArgMinMaxNodeDef(DataType input_dtype,DataType output_dtype)8439 NodeDef GetArgMinMaxNodeDef(DataType input_dtype, DataType output_dtype) {
8440   Scope s = Scope::NewRootScope();
8441   auto input = ops::Placeholder(s.WithOpName("input"), input_dtype);
8442   auto dimension = ops::Placeholder(s.WithOpName("dimension"), DT_INT32);
8443   auto attrs = OpType::OutputType(output_dtype);
8444   auto arg = OpType(s.WithOpName("my_arg"), input, dimension, attrs);
8445   return arg.operation.node()->def();
8446 }
8447 
8448 struct ArgMinMaxTestParams {
8449   std::vector<int> input_shape;
8450   std::vector<float> input_value;
8451   int axis;
8452   std::vector<int> expected_output_dims;
8453   std::vector<int> expected_argmax_output;
8454   std::vector<int> expected_argmin_output;
8455   Status status;
8456 };
8457 
8458 template <typename OpType>
TestConvertArgMinMax(ParameterizedOpConverterTestBase * test,DataType _tf_type,ArgMinMaxTestParams & p)8459 void TestConvertArgMinMax(ParameterizedOpConverterTestBase* test,
8460                           DataType _tf_type, ArgMinMaxTestParams& p) {
8461   test->Reset();
8462 
8463   NodeDef node_def = GetArgMinMaxNodeDef<OpType>(_tf_type,
8464                                                  /*output_dtype=*/DT_INT32);
8465 
8466   std::vector<int> expected_out;
8467   if (node_def.op() == "ArgMax") {
8468     expected_out = p.expected_argmax_output;
8469   } else if (node_def.op() == "ArgMin") {
8470     expected_out = p.expected_argmin_output;
8471   } else {
8472     ASSERT_TRUE(false);
8473   }
8474 
8475   test->AddTestTensor("input", p.input_shape, _tf_type, p.input_value);
8476   test->AddTestWeights("dimension", {1}, {p.axis}, DT_INT32);
8477 
8478   test->TestOpConverter("my_arg", node_def, p.expected_output_dims,
8479                         /*expected_conversion_status=*/p.status,
8480                         /*expected_runtime_status=*/Status::OK(),
8481                         /*matcher=*/ElementsAreArray(expected_out), {DT_INT32});
8482 }
8483 
TEST_P(OpConverter_FP32_FP16_Test,ConvertArgMinMax)8484 TEST_P(OpConverter_FP32_FP16_Test, ConvertArgMinMax) {
8485   {
8486     // Dimension is a tensor, should fail.
8487     Reset();
8488     NodeDef node_def =
8489         GetArgMinMaxNodeDef<ops::ArgMax>(tf_type_,
8490                                          /*output_dtype=*/DT_INT32);
8491     AddTestTensor("input", {1, 2, 3});
8492     AddTestTensor("dimension", {1});
8493     RunValidationAndConversion(
8494         node_def, error::UNIMPLEMENTED,
8495         "The input \"dimension\" for ArgMax must be a constant");
8496   }
8497   {
8498     // Output type is INT64, should fail.
8499     Reset();
8500     NodeDef node_def =
8501         GetArgMinMaxNodeDef<ops::ArgMax>(tf_type_,
8502                                          /*output_dtype=*/DT_INT64);
8503     AddTestTensor("input", {1, 2, 3});
8504     AddTestWeights("dimension", {1}, {3}, DT_INT32);
8505     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8506                                "Output type int64 is not supported");
8507   }
8508 
8509   const std::vector<float> common_input = CreateVectorIota<float>(6);
8510   std::vector<ArgMinMaxTestParams> params = {
8511       {/*input_shape=*/{2, 3},
8512        /*input_value=*/common_input,
8513        /*axis=*/0,
8514        /*expected_output_dims=*/{3},
8515        /*expected_argmax_output=*/{1, 1, 1},
8516        /*expected_argmin_output=*/{0, 0, 0},
8517        trt_mode_ == TrtTestMode::kImplicitBatch
8518            ? errors::Unimplemented("TensorRT does not allow manipulation of "
8519                                    "the batch dimension")
8520            : Status::OK()},
8521       {
8522           /*input_shape=*/{1, 6},
8523           /*input_value=*/common_input,
8524           /*axis=*/1,
8525           /*expected_output_dims=*/{1},
8526           /*expected_argmax_output=*/{5},
8527           /*expected_argmin_output=*/{0},
8528       },
8529       {
8530           /*input_shape=*/{1, 10},
8531           /*input_value=*/
8532           {-5.0f, 3.0f, 5.0f, 1.0f, 6.0f, -9.0f, 7.0f, 1.0f, 0.0f, -1.0f},
8533           /*axis=*/-1,
8534           /*expected_output_dims=*/{1},
8535           /*expected_argmax_output=*/{6},
8536           /*expected_argmin_output=*/{5},
8537       },
8538       {
8539           /*input_shape=*/{1, 2, 3},
8540           /*input_value=*/common_input,
8541           /*axis=*/2,
8542           /*expected_output_dims=*/{1, 2},
8543           /*expected_argmax_output=*/{2, 2},
8544           /*expected_argmin_output=*/{0, 0},
8545       },
8546       {
8547           /*input_shape=*/{1, 2, 3},
8548           /*input_value=*/common_input,
8549           /*axis=*/-2,
8550           /*expected_output_dims=*/{1, 3},
8551           /*expected_argmax_output=*/{1, 1, 1},
8552           /*expected_argmin_output=*/{0, 0, 0},
8553       },
8554       {
8555           /*input_shape=*/{1, 2, 1, 3},
8556           /*input_value=*/common_input,
8557           /*axis=*/3,
8558           /*expected_output_dims=*/{1, 2, 1},
8559           /*expected_argmax_output=*/{2, 2},
8560           /*expected_argmin_output=*/{0, 0},
8561       },
8562       {
8563           /*input_shape=*/{1, 2, 1, 3},
8564           /*input_value=*/common_input,
8565           /*axis=*/-3,
8566           /*expected_output_dims=*/{1, 1, 3},
8567           /*expected_argmax_output=*/{1, 1, 1},
8568           /*expected_argmin_output=*/{0, 0, 0},
8569       },
8570       {/*input_shape=*/{1, 2, 1, 1, 3},
8571        /*input_value=*/common_input,
8572        /*axis=*/4,
8573        /*expected_output_dims=*/{1, 2, 1, 1},
8574        /*expected_argmax_output=*/{2, 2},
8575        /*expected_argmin_output=*/{0, 0},
8576 #if !IS_TRT_VERSION_GE(7, 0, 0, 11)
8577        errors::Unimplemented("op is not able to support tensors with 4+"
8578                              " dimensions (excluding batch size)")
8579 #else
8580        Status::OK()
8581 #endif
8582       },
8583       {/*input_shape=*/{1, 2, 1, 1, 3},
8584        /*input_value=*/common_input,
8585        /*axis=*/-4,
8586        /*expected_output_dims=*/{1, 1, 1, 3},
8587        /*expected_argmax_output=*/{1, 1, 1},
8588        /*expected_argmin_output=*/{0, 0, 0},
8589 #if !IS_TRT_VERSION_GE(7, 0, 0, 11)
8590        errors::Unimplemented("op is not able to support tensors with 4+"
8591                              " dimensions (excluding batch size)")
8592 #else
8593        Status::OK()
8594 #endif
8595       },
8596   };
8597 
8598   for (auto p : params) {
8599     TestConvertArgMinMax<ops::ArgMin>(this, tf_type_, p);
8600     TestConvertArgMinMax<ops::ArgMax>(this, tf_type_, p);
8601   }
8602 }
8603 
8604 // Get the NodeDef for DepthToSpace or SpaceToSpace.
8605 template <typename OpType>
GetDepthSpaceShuffleNodeDef(DataType dtype,int block_size,string data_format)8606 NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
8607                                     string data_format) {
8608   Scope s = Scope::NewRootScope();
8609   auto input = ops::Placeholder(s.WithOpName("input"), dtype);
8610   auto attrs = OpType::DataFormat(data_format);
8611   auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
8612   return shuffle.operation.node()->def();
8613 }
8614 
8615 struct DepthSpaceShuffleTestParams {
8616   std::vector<int> input_dims;
8617   std::vector<int> input_value;
8618   int block_size;
8619   string data_format;
8620   std::vector<int> expected_output_dims;
8621   std::vector<int> expected_output;
8622 };
8623 
8624 template <typename OpType>
TestConvertDepthSpaceShuffle(ParameterizedOpConverterTestBase * test,const std::vector<DepthSpaceShuffleTestParams> & params)8625 void TestConvertDepthSpaceShuffle(
8626     ParameterizedOpConverterTestBase* test,
8627     const std::vector<DepthSpaceShuffleTestParams>& params) {
8628   Status status = Status::OK();
8629 
8630   {
8631     // Input is a weight, should fail.
8632     test->Reset();
8633     NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
8634         test->get_tf_type(), 2, "NCHW");
8635     test->AddTestWeights<float>("input", {1, 4, 1, 1}, {1, 2, 3, 4});
8636     test->RunValidationAndConversion(
8637         node_def, error::UNIMPLEMENTED,
8638         StrCat("The input \"input\" for ", node_def.op(), " must be a tensor"));
8639   }
8640   {
8641     // Input rank != 4
8642     test->Reset();
8643     NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
8644         test->get_tf_type(), 2, "NCHW");
8645     test->AddTestTensor("input", {1, 16, 32});
8646     test->RunValidationAndConversion(
8647         node_def, error::INVALID_ARGUMENT,
8648         StrCat("The input to ", node_def.op(), " must be rank 4"));
8649   }
8650   {
8651     // Unsupported format, should fail.
8652     test->Reset();
8653     NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
8654         test->get_tf_type(), 2, "NCHW_VECT_C");
8655     test->AddTestTensor("input", {1, 16, 32, 32});
8656     test->RunValidationAndConversion(
8657         node_def, error::UNIMPLEMENTED,
8658         "Data format NCHW_VECT_C is not supported");
8659   }
8660   if (test->get_trt_mode() != TrtTestMode::kDynamicShape) {
8661     // In dynamic shape mode, we cannot check input dimension values at
8662     // conversion time therefore we cannot confirm block_size vs input dim
8663     // consistency. We rely on the user to provide a valid TF graph. Otherwise
8664     // TRT will fail with a runtime error.
8665     if (std::is_same<OpType, ops::DepthToSpace>::value) {
8666       // Channels not divisible by block_size, should fail.
8667       test->Reset();
8668       NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
8669           test->get_tf_type(), 3, "NCHW");
8670       test->AddTestTensor("input", {1, 16, 32, 32});
8671       test->RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8672                                        "Number of channels must be divisible by"
8673                                        " block_size*block_size");
8674     } else {
8675       {  // Width not divisible by block_size, should fail.
8676         test->Reset();
8677         NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
8678             test->get_tf_type(), 3, "NCHW");
8679         test->AddTestTensor("input", {1, 16, 9, 32});
8680         test->RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8681                                          "Width and height must be divisible by"
8682                                          " block_size");
8683       }
8684       {
8685         // Height not divisible by block_size, should fail.
8686         test->Reset();
8687         NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
8688             test->get_tf_type(), 3, "NCHW");
8689         test->AddTestTensor("input", {1, 16, 32, 9});
8690         test->RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
8691                                          "Width and height must be divisible by"
8692                                          " block_size");
8693       }
8694     }
8695   }
8696 
8697   for (auto p : params) {
8698     test->Reset();
8699     NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
8700         test->get_tf_type(), p.block_size, p.data_format);
8701     test->AddTestTensor("input", p.input_dims, p.input_value);
8702     test->TestOpConverter("my_shuffle", node_def, p.expected_output_dims,
8703                           status, Status::OK(),
8704                           ElementsAreArray(p.expected_output));
8705   }
8706 }
8707 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertDepthToSpace)8708 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertDepthToSpace) {
8709   const std::vector<int> common_input = CreateVectorIota<int>(16);
8710   std::vector<DepthSpaceShuffleTestParams> params = {
8711       {
8712           /*input_shape=*/{1, 4, 2, 2},
8713           /*input_value=*/common_input,
8714           /*block_size=*/2,
8715           /*data_format=*/"NCHW",
8716           /*expected_output_dims=*/{1, 1, 4, 4},
8717           /*expected_output=*/
8718           {0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15},
8719       },
8720       {
8721           /*input_shape=*/{1, 2, 2, 4},
8722           /*input_value=*/common_input,
8723           /*block_size=*/2,
8724           /*data_format=*/"NHWC",
8725           /*expected_output_dims=*/{1, 4, 4, 1},
8726           /*expected_output=*/
8727           {0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15},
8728       },
8729       {
8730           /*input_shape=*/{1, 16, 1, 1},
8731           /*input_value=*/common_input,
8732           /*block_size=*/4,
8733           /*data_format=*/"NCHW",
8734           /*expected_output_dims=*/{1, 1, 4, 4},
8735           /*expected_output=*/CreateVectorIota<int>(16),
8736       },
8737       {
8738           /*input_shape=*/{1, 2, 2, 8},
8739           /*input_value=*/CreateVectorIota<int>(32),
8740           /*block_size=*/2,
8741           /*data_format=*/"NHWC",
8742           /*expected_output_dims=*/{1, 4, 4, 2},
8743           /*expected_output=*/{0,  1,  2,  3,  8,  9,  10, 11, 4,  5,  6,
8744                                7,  12, 13, 14, 15, 16, 17, 18, 19, 24, 25,
8745                                26, 27, 20, 21, 22, 23, 28, 29, 30, 31},
8746       }};
8747 
8748   TestConvertDepthSpaceShuffle<ops::DepthToSpace>(this, params);
8749 }
8750 
TEST_P(OpConverter_FP32_FP16_INT32_Test,ConvertSpaceToDepth)8751 TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertSpaceToDepth) {
8752   const std::vector<int> common_input = CreateVectorIota<int>(16);
8753   std::vector<DepthSpaceShuffleTestParams> params = {
8754       {
8755           /*input_shape=*/{1, 1, 4, 4},
8756           /*input_value=*/common_input,
8757           /*block_size=*/2,
8758           /*data_format=*/"NCHW",
8759           /*expected_output_dims=*/{1, 4, 2, 2},
8760           /*expected_output=*/
8761           {0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15},
8762       },
8763       {
8764           /*input_shape=*/{1, 4, 4, 1},
8765           /*input_value=*/common_input,
8766           /*block_size=*/2,
8767           /*data_format=*/"NHWC",
8768           /*expected_output_dims=*/{1, 2, 2, 4},
8769           /*expected_output=*/
8770           {0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15},
8771       },
8772       {
8773           /*input_shape=*/{1, 1, 4, 4},
8774           /*input_value=*/common_input,
8775           /*block_size=*/4,
8776           /*data_format=*/"NCHW",
8777           /*expected_output_dims=*/{1, 16, 1, 1},
8778           /*expected_output=*/CreateVectorIota<int>(16),
8779       },
8780       {
8781           /*input_shape=*/{1, 4, 4, 2},
8782           /*input_value=*/CreateVectorIota<int>(32),
8783           /*block_size=*/2,
8784           /*data_format=*/"NHWC",
8785           /*expected_output_dims=*/{1, 2, 2, 8},
8786           /*expected_output=*/{0,  1,  2,  3,  8,  9,  10, 11, 4,  5,  6,
8787                                7,  12, 13, 14, 15, 16, 17, 18, 19, 24, 25,
8788                                26, 27, 20, 21, 22, 23, 28, 29, 30, 31},
8789       },
8790   };
8791   TestConvertDepthSpaceShuffle<ops::SpaceToDepth>(this, params);
8792 }
8793 
TEST_P(OpConverter_FP32_FP16_Test,ConvertClipByValue)8794 TEST_P(OpConverter_FP32_FP16_Test, ConvertClipByValue) {
8795   Scope s = Scope::NewRootScope();
8796   auto t = ops::Placeholder(s.WithOpName("t"), tf_type_);
8797   auto clip_value_min =
8798       ops::Placeholder(s.WithOpName("clip_value_min"), tf_type_);
8799   auto clip_value_max =
8800       ops::Placeholder(s.WithOpName("clip_value_max"), tf_type_);
8801   auto clip = ops::ClipByValue(s.WithOpName("my_clip"), t, clip_value_min,
8802                                clip_value_max);
8803   const NodeDef& node_def = clip.operation.node()->def();
8804 
8805   nvinfer1::DataType trt_type_;
8806   TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type_));
8807 
8808   {
8809     // Input is a weight, should fail.
8810     Reset();
8811     AddTestWeights("t", {1, 2, 3}, {1, 2, 3, 4, 5, 6}, tf_type_);
8812     AddTestWeights("clip_value_min", {1}, {1}, tf_type_);
8813     AddTestWeights("clip_value_max", {1}, {5}, tf_type_);
8814     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8815                                "The input \"t\" for ClipByValue must be a "
8816                                "tensor");
8817   }
8818   {
8819     // Clip min is a tensor, should fail.
8820     Reset();
8821     AddTestTensor("t", {1, 2, 3});
8822     AddTestTensor("clip_value_min", {1});
8823     AddTestWeights("clip_value_max", {1}, {1}, tf_type_);
8824     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8825                                "The input \"clip_value_min\" for ClipByValue "
8826                                "must be a constant");
8827   }
8828   {
8829     // Clip max is a tensor, should fail.
8830     Reset();
8831     AddTestTensor("t", {1, 2, 3});
8832     AddTestWeights("clip_value_min", {1}, {1}, tf_type_);
8833     AddTestTensor("clip_value_max", {1});
8834     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8835                                "The input \"clip_value_max\" for ClipByValue "
8836                                "must be a constant");
8837   }
8838 
8839   struct TestParams {
8840     std::vector<int> dims;
8841     int clip_value_min;
8842     int clip_value_max;
8843     std::vector<float> expected_output;
8844   };
8845 
8846   const std::vector<float> common_input = CreateVectorIota<float>(6);
8847 
8848   std::vector<TestParams> params = {{
8849                                         /*dims=*/{6},
8850                                         /*clip_value_min=*/2,
8851                                         /*clip_value_max=*/4,
8852                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8853                                     },
8854                                     {
8855                                         /*dims=*/{1, 6},
8856                                         /*clip_value_min=*/2,
8857                                         /*clip_value_max=*/4,
8858                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8859                                     },
8860                                     {
8861                                         /*dims=*/{1, 2, 3},
8862                                         /*clip_value_min=*/2,
8863                                         /*clip_value_max=*/4,
8864                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8865                                     },
8866                                     {
8867                                         /*dims=*/{1, 2, 3, 1},
8868                                         /*clip_value_min=*/2,
8869                                         /*clip_value_max=*/4,
8870                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8871                                     },
8872                                     {
8873                                         /*dims=*/{1, 1, 3, 1, 2},
8874                                         /*clip_value_min=*/2,
8875                                         /*clip_value_max=*/4,
8876                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8877                                     },
8878                                     {
8879                                         /*dims=*/{1, 1, 3, 1, 2, 1},
8880                                         /*clip_value_min=*/2,
8881                                         /*clip_value_max=*/4,
8882                                         /*expected_output=*/{2, 2, 2, 3, 4, 4},
8883                                     },
8884                                     {
8885                                         /*dims=*/{2, 1, 3},
8886                                         /*clip_value_min=*/-1,
8887                                         /*clip_value_max=*/8,
8888                                         /*expected_output=*/common_input,
8889                                     }};
8890 
8891   for (auto p : params) {
8892     Reset();
8893 
8894     AddTestTensor("t", p.dims, tf_type_, common_input);
8895     AddTestWeights("clip_value_min", {1}, {p.clip_value_min}, tf_type_);
8896     AddTestWeights("clip_value_max", {1}, {p.clip_value_max}, tf_type_);
8897 
8898     TestOpConverter("my_clip", node_def, p.dims,
8899                     /*expected_conversion_status=*/Status::OK(),
8900                     /*expected_runtime_status=*/Status::OK(),
8901                     /*matcher=*/ElementsAreArray(p.expected_output));
8902   }
8903 }
8904 
8905 // Get the NodeDef for SquaredDifference.
GetSquaredDifferenceNodeDef(DataType dtype)8906 NodeDef GetSquaredDifferenceNodeDef(DataType dtype) {
8907   Scope s = Scope::NewRootScope();
8908   auto x = ops::Placeholder(s.WithOpName("x"), dtype);
8909   auto y = ops::Placeholder(s.WithOpName("y"), dtype);
8910   auto squared_diff =
8911       ops::SquaredDifference(s.WithOpName("my_squared_diff"), x, y);
8912   return squared_diff.operation.node()->def();
8913 }
8914 
TEST_P(OpConverter_FP32_FP16_Test,ConvertSquaredDifference)8915 TEST_P(OpConverter_FP32_FP16_Test, ConvertSquaredDifference) {
8916   {
8917     // Input is a weight, should fail.
8918     Reset();
8919     NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_);
8920     AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
8921     AddTestTensor("y", {1, 1, 2, 3});
8922     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
8923                                "The input \"x\" for SquaredDifference must be "
8924                                "a tensor");
8925   }
8926 
8927   struct TestParams {
8928     std::vector<int> dims_x;
8929     std::vector<int> dims_y;
8930     std::vector<float> value_x;
8931     std::vector<float> value_y;
8932     std::vector<int> expected_output_dims;
8933     std::vector<float> expected_output;
8934     Status status;
8935     Status runtime_status;
8936   };
8937 
8938   const std::vector<float> common_input = CreateVectorIota<float>(6);
8939   std::vector<TestParams> params = {
8940       {/*dims_x=*/{1, 2, 3},
8941        /*dims_y=*/{1, 7, 5},
8942        /*value_x=*/common_input,
8943        /*value_y=*/std::vector<float>(7 * 5, 0),
8944        /*expected_output_dims=*/{1, 1, 2, 3},
8945        /*expected_output=*/common_input,
8946        trt_mode_ == TrtTestMode::kDynamicShape
8947            ? Status::OK()
8948            : errors::InvalidArgument("Infeasible broadcast scheme"),
8949        errors::Internal(
8950            "Binding index out of range. This can happen if profile is not set, "
8951            "or the network is invalid for the current profile.")},
8952       {
8953           /*dims_x=*/{1, 1, 2, 3},
8954           /*dims_y=*/{1, 1, 2, 3},
8955           /*value_x=*/common_input,
8956           /*value_y=*/{0, -1, 3, 0, 10, -7},
8957           /*expected_output_dims=*/{1, 1, 2, 3},
8958           /*expected_output=*/{0, 4, 1, 9, 36, 144},
8959       },
8960       {
8961           /*dims_x=*/{1, 1, 2, 3},
8962           /*dims_y=*/{1, 1, 1, 3},
8963           /*value_x=*/common_input,
8964           /*value_y=*/{0, 1, 2},
8965           /*expected_output_dims=*/{1, 1, 2, 3},
8966           /*expected_output=*/{0, 0, 0, 9, 9, 9},
8967       },
8968   };
8969 
8970   for (auto p : params) {
8971     Reset();
8972     NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_);
8973     AddTestTensor("x", p.dims_x, p.value_x);
8974     AddTestTensor("y", p.dims_y, p.value_y);
8975     TestOpConverter("my_squared_diff", node_def, p.expected_output_dims,
8976                     p.status, p.runtime_status,
8977                     ElementsAreArray(p.expected_output));
8978   }
8979 }
8980 
8981 template <typename OpType>
MakeResizeNodeDef(DataType dtype,bool align_corners)8982 NodeDef MakeResizeNodeDef(DataType dtype, bool align_corners) {
8983   Scope s = Scope::NewRootScope();
8984   auto input = ops::Placeholder(s.WithOpName("input"), dtype);
8985   auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32);
8986   auto attrs = typename OpType::Attrs().AlignCorners(align_corners);
8987   auto resize = OpType(s.WithOpName("my_resize"), input, size, attrs);
8988   return resize.operation.node()->def();
8989 }
8990 
8991 struct ResizeTestParams {
8992   std::vector<int> input_dims;
8993   std::vector<int> output_resize_dims;
8994   std::vector<float> input_value;
8995   bool size_as_tensor;
8996   bool align_corners;
8997   std::vector<int> expected_output_dims;
8998   std::vector<float> expected_nearest_output_values;
8999   std::vector<float> expected_bilinear_output_values;
9000   Status status;
9001 };
9002 
9003 template <typename OpType>
TestConvertResize(ParameterizedOpConverterTestBase * test,ResizeTestParams & p)9004 void TestConvertResize(ParameterizedOpConverterTestBase* test,
9005                        ResizeTestParams& p) {
9006   test->Reset();
9007   // Create resize node.
9008   NodeDef node_def =
9009       MakeResizeNodeDef<OpType>(test->get_tf_type(), p.align_corners);
9010 
9011   test->AddTestTensor("input", p.input_dims, test->get_tf_type(),
9012                       p.input_value);
9013   // Create output size.
9014   if (p.size_as_tensor) {
9015     std::vector<int32> size_dims{2};
9016     std::vector<int32> size_values{p.output_resize_dims};
9017     test->AddTestTensor("size", size_dims, DT_INT32, size_values, size_dims);
9018   } else {
9019     test->AddTestWeights("size", {2}, p.output_resize_dims, DT_INT32);
9020   }
9021 
9022   std::vector<float> expected_out;
9023 
9024   if (node_def.op() == "ResizeBilinear") {
9025     expected_out = p.expected_bilinear_output_values;
9026   } else if (node_def.op() == "ResizeNearestNeighbor") {
9027     expected_out = p.expected_nearest_output_values;
9028   } else {
9029     ASSERT_TRUE(false);
9030   }
9031 
9032   test->TestOpConverter("my_resize", node_def, p.expected_output_dims,
9033                         /*expected_conversion_status=*/p.status,
9034                         /*expected_runtime_status=*/p.status,
9035                         /*matcher=*/ElementsAreArray(expected_out),
9036                         /*out_tf_types=*/{DT_FLOAT});
9037 }
9038 
TEST_P(OpConverter_FP32_FP16_Test,ConvertResize)9039 TEST_P(OpConverter_FP32_FP16_Test, ConvertResize) {
9040   {
9041     // First input is weight, should fail.
9042     Reset();
9043     NodeDef node_def = MakeResizeNodeDef<ops::ResizeBilinear>(tf_type_,
9044                                                               /*align_corners=*/
9045                                                               true);
9046     AddTestWeights<float>("input", {1, 2}, {1, 2});
9047     AddTestWeights<int>("size", {1, 2}, {1, 2});
9048     RunValidationAndConversion(
9049         node_def, error::UNIMPLEMENTED,
9050         "The input \"input\" for ResizeBilinear must be a "
9051         "tensor");
9052   }
9053 
9054   std::vector<ResizeTestParams> params{
9055       {/*input_dims=*/{1, 1, 2, 1},    // N, H, W, C
9056        /*output_resize_dims=*/{2, 3},  // H_out, W_out
9057        /*input_values=*/{2.0f, -1.0f},
9058        /*size_as_tensor=*/false,
9059        /*align_corners=*/false,
9060        /*expected_output_dims=*/{1, 2, 3, 1},  // N, H, W, C
9061        /*expected_nearest_output_values=*/
9062        {2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f},
9063        /*expected_bilinear_output_values=*/
9064        {2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f},
9065        /*status=*/Status::OK()},
9066       {/*input_dims=*/{1, 1, 2, 1},    // N, H, W, C
9067        /*output_resize_dims=*/{2, 3},  // H_out, W_out
9068        /*input_values=*/{2.0f, -1.0f},
9069        /*size_as_tensor=*/false,
9070        /*align_corners=*/true,
9071        /*expected_output_dims=*/{1, 2, 3, 1},  // N, H, W, C
9072        /*expected_nearest_output_values=*/
9073        {2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f},
9074        /*expected_bilinear_output_values=*/
9075        {2.0f, 0.5f, -1.0f, 2.0f, 0.5f, -1.0f},
9076        /*status=*/Status::OK()}};
9077 
9078   if (trt_mode_ != TrtTestMode::kImplicitBatch) {
9079     // Size as a tensor is not supported in implicit batch mode.
9080     params.push_back({/*input_dims=*/{1, 1, 2, 1},    // N, H, W, C
9081                       /*output_resize_dims=*/{2, 3},  // H_out, W_out
9082                       /*input_values=*/{2.0f, -1.0f},
9083                       /*size_as_tensor=*/true,
9084                       /*align_corners=*/true,
9085                       /*expected_output_dims=*/{1, 2, 3, 1},  // N, H, W, C
9086                       /*expected_nearest_output_values=*/
9087                       {2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f},
9088                       /*expected_bilinear_output_values=*/
9089                       {2.0f, 0.5f, -1.0f, 2.0f, 0.5f, -1.0f},
9090                       /*status=*/Status::OK()});
9091   }
9092 
9093   for (auto p : params) {
9094     TestConvertResize<ops::ResizeNearestNeighbor>(this, p);
9095 
9096 // This use case is not supported as of TRT version 7.1
9097 #if IS_TRT_VERSION_GE(7, 1, 0, 0)
9098     if (!p.align_corners) {
9099       p.status = errors::InvalidArgument(
9100           "Cannot Convert Bilinear Resize when align_corners=False");
9101     }
9102 #endif
9103 
9104     TestConvertResize<ops::ResizeBilinear>(this, p);
9105   }
9106 }
9107 
MakePadNodeDef(std::string name,DataType dtype)9108 NodeDef MakePadNodeDef(std::string name, DataType dtype) {
9109   Scope s = Scope::NewRootScope();
9110   auto input = ops::Placeholder(s.WithOpName("input"), dtype);
9111   auto padding = ops::Placeholder(s.WithOpName("padding"), DT_INT32);
9112   auto pad = ops::Pad(s.WithOpName(name), input, padding);
9113   return pad.operation.node()->def();
9114 }
9115 
9116 struct PadTestParams {
9117   std::vector<int> input_dims;
9118   std::vector<int> pad_dims;
9119   std::vector<int> pad_values;
9120   std::vector<float> input_values;
9121   std::vector<int> expected_output_dims;
9122   std::vector<float> expected_output_values;
9123   Status status;
9124 };
9125 
TEST_P(OpConverter_FP32_FP16_Test,ConvertPad)9126 TEST_P(OpConverter_FP32_FP16_Test, ConvertPad) {
9127   {
9128     // First input is weight, should fail.
9129     Reset();
9130     NodeDef node_def = MakePadNodeDef("my_pad", tf_type_);
9131     AddTestWeights("input", {1, 2}, {1, 2}, tf_type_);
9132     AddTestWeights<int>("padding", {1, 2}, {1, 2});
9133     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
9134                                "The input \"tensor\" for Pad must be a "
9135                                "tensor");
9136   }
9137   {
9138     // padding is a tensor, should fail.
9139     Reset();
9140     NodeDef node_def = MakePadNodeDef("my_pad", tf_type_);
9141     AddTestTensor("input", {1, 2});
9142     AddTestTensor("padding", {1, 2});
9143     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
9144                                "The input \"paddings\" for Pad must be a "
9145                                "constant");
9146   }
9147   {
9148     // Make sure that ranges are inferred across a Pad.
9149     Reset();
9150     NodeDef node_def = MakePadNodeDef("my_pad", tf_type_);
9151     AddTestTensor("input", {1, 1, 2, 1});
9152     AddTestWeights<int>("padding", {4, 2}, {0, 0, 1, 0, 0, 1, 0, 0});
9153     TRT_TensorOrWeights input;
9154     TRT_TensorOrWeights output;
9155     RunValidationAndConversion(node_def);
9156     TF_EXPECT_OK(GetTensorOrWeights("input", &input));
9157     TF_EXPECT_OK(GetTensorOrWeights("my_pad", &output));
9158     ITensorProxyPtr input_tensor = input.tensor();
9159     converter_->ProvideQuantizationRange(&input_tensor, -5.0f, 5.0f);
9160     auto ranges = quantization_ranges();
9161     EXPECT_EQ(5.0f, ranges[input.tensor()->trt_tensor()]);
9162   }
9163 
9164   std::vector<PadTestParams> params{
9165       // 1 padding dim
9166       {
9167           /*input_dims=*/{1, 1, 3, 2},  // N, H, W, C
9168           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9169           /*pad_values*/ {0, 0, 0, 0, 0, 1, 0, 0},
9170           /*input_values=*/{1, 2, 3, 4, 5, 6},
9171           /*expected_output_dims=*/{1, 1, 4, 2},  // N, H, W, C
9172           /*expected_output_values=*/
9173           {1, 2, 3, 4, 5, 6, 0, 0},
9174       },
9175       {
9176           /*input_dims=*/{1, 1, 3, 2},  // N, H, W, C
9177           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9178           /*pad_values*/ {0, 0, 0, 0, 0, 0, 0, 1},
9179           /*input_values=*/{1, 2, 3, 4, 5, 6},
9180           /*expected_output_dims=*/{1, 1, 3, 3},  // N, H, W, C
9181           /*expected_output_values=*/
9182           {1, 2, 0, 3, 4, 0, 5, 6, 0},
9183       },
9184       {
9185           /*input_dims=*/{1, 1, 3, 2},  // N, H, W, C
9186           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9187           /*pad_values*/ {0, 0, 1, 0, 0, 0, 0, 0},
9188           /*input_values=*/{1, 2, 3, 4, 5, 6},
9189           /*expected_output_dims=*/{1, 2, 3, 2},  // N, H, W, C
9190           /*expected_output_values=*/
9191           {0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6},
9192       },
9193       // 2 padding dims
9194       {
9195           /*input_dims=*/{1, 1, 2, 1},  // N, H, W, C
9196           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9197           /*pad_values*/ {0, 0, 1, 0, 0, 1, 0, 0},
9198           /*input_values=*/{2.0f, -1.0f},
9199           /*expected_output_dims=*/{1, 2, 3, 1},  // N, H, W, C
9200           /*expected_output_values=*/
9201           {0.0, 0.0, 0.0, 2.0f, -1.0f, 0.0},
9202       },
9203       PadTestParams{
9204           /*input_dims=*/{1, 1, 2, 2},  // N, H, W, C
9205           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9206           /*pad_values*/ {0, 0, 1, 0, 0, 1, 0, 0},
9207           /*input_values=*/{2, -1, 3., 4},
9208           /*expected_output_dims=*/{1, 2, 3, 2},  // N, H, W, C
9209           /*expected_output_values=*/
9210           {0, 0, 0, 0, 0, 0, 2, -1, 3, 4, 0, 0},
9211       },
9212       PadTestParams{
9213           /*input_dims=*/{1, 1, 2, 1, 2},  // N, C, H, W, D
9214           /*pad_dims=*/{5, 2},             // #dims, {pad_before, pad_after}
9215           /*pad_values*/ {0, 0, 1, 0, 0, 1, 0, 0, 0, 0},
9216           /*input_values=*/{2, -1, 3., 4},
9217           /*expected_output_dims=*/{1, 2, 3, 1, 2},  // N, H, W, C
9218           /*expected_output_values=*/
9219           {0, 0, 0, 0, 0, 0, 2, -1, 3, 4, 0, 0},
9220       },
9221       PadTestParams{
9222           /*input_dims=*/{1, 1, 2, 1, 2},  // N, C, H, W, D
9223           /*pad_dims=*/{5, 2},             // #dims, {pad_before, pad_after}
9224           /*pad_values*/ {0, 0, 0, 1, 0, 0, 1, 1, 0, 0},
9225           /*input_values=*/{2, -1, 3., 4},
9226           /*expected_output_dims=*/{1, 2, 2, 3, 2},  // N, H, W, C
9227           /*expected_output_values=*/
9228           {0., 0., 2., -1., 0., 0., 0., 0., 3., 4., 0., 0.,
9229            0., 0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0},
9230       },
9231       PadTestParams{
9232           /*input_dims=*/{1, 1, 2, 1},  // N, H, W, C
9233           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9234           /*pad_values*/ {1, 0, 0, 0, 0, 1, 0, 0},
9235           /*input_values=*/{2.0f, -1.0f},
9236           /*expected_output_dims=*/{2, 1, 3, 1},  // N, H, W, C
9237           /*expected_output_values=*/{0.0, 0.0, 0.0, 2.0f, -1.0f, 0.0},
9238           trt_mode_ == TrtTestMode::kImplicitBatch
9239               ? errors::InvalidArgument("Padding layer does not support "
9240                                         "padding on batch dimension")
9241               : Status::OK()},
9242       PadTestParams{
9243           /*input_dims=*/{1, 1, 2, 1},  // N, H, W, C
9244           /*pad_dims=*/{4, 2},          // #dims, {pad_before, pad_after}
9245           /*pad_values*/ {0, 0, 1, 0, 0, 1, 1, 1},
9246           /*input_values=*/{2.0f, -1.0f},
9247           /*expected_output_dims=*/{},  // N, H, W, C
9248           /*expected_output_values=*/{},
9249           errors::InvalidArgument("Padding layer does not support padding on "
9250                                   "> 2")},
9251       PadTestParams{
9252           /*input_dims=*/{1, 2, 2},  // N, H, W
9253           /*pad_dims=*/{3, 2},       // #dims, {pad_before, pad_after}
9254           /*pad_values*/ {0, 0, 1, 0, 0, 1},
9255           /*input_values=*/{2, -1, 3., 4},
9256           /*expected_output_dims=*/{1, 3, 3},  // N, H, W, C
9257           /*expected_output_values=*/
9258           {0., 0., 0., 2., -1., 0., 3., 4., 0.},
9259           errors::InvalidArgument("Convertpad requires at least 4D input")}};
9260 
9261   for (auto p : params) {
9262     Reset();
9263     // Create pad node.
9264     NodeDef node_def = MakePadNodeDef("my_pad", tf_type_);
9265     // Create input tensor.
9266     AddTestTensor("input", p.input_dims, p.input_values);
9267     // Create output size.
9268     AddTestWeights<int32>("padding", p.pad_dims, p.pad_values);
9269     TestOpConverter("my_pad", node_def, p.expected_output_dims, p.status,
9270                     p.status, ElementsAreArray(p.expected_output_values));
9271   }
9272 }
9273 }  // namespace convert
9274 }  // namespace tensorrt
9275 }  // namespace tensorflow
9276 
9277 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
9278