xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/tensorrt/trt_convert.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
16
17import collections
18from functools import partial  # pylint: disable=g-importing-member
19import os
20import platform
21import sys
22import tempfile
23
24import numpy as np
25import six as _six
26
27from tensorflow.core.framework import variable_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.python.client import session
32from tensorflow.python.compiler.tensorrt import utils as trt_utils
33from tensorflow.python.eager import context
34from tensorflow.python.eager import wrap_function
35from tensorflow.python.framework import convert_to_constants
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import graph_util
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import ops
41from tensorflow.python.grappler import tf_optimizer
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import gen_resource_variable_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.saved_model import builder
46from tensorflow.python.saved_model import load
47from tensorflow.python.saved_model import loader
48from tensorflow.python.saved_model import save
49from tensorflow.python.saved_model import signature_constants
50from tensorflow.python.saved_model import tag_constants
51from tensorflow.python.trackable import asset
52from tensorflow.python.trackable import resource
53from tensorflow.python.training import saver
54from tensorflow.python.util import deprecation
55from tensorflow.python.util import nest
56from tensorflow.python.util.lazy_loader import LazyLoader
57from tensorflow.python.util.tf_export import tf_export
58
59# Lazily load the op, since it's not available in cpu-only builds. Importing
60# this at top will cause tests that imports TF-TRT fail when they're built
61# and run without CUDA/GPU.
62gen_trt_ops = LazyLoader(
63    "gen_trt_ops", globals(),
64    "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
65
66_pywrap_py_utils = LazyLoader(
67    "_pywrap_py_utils", globals(),
68    "tensorflow.compiler.tf2tensorrt._pywrap_py_utils")
69
70# Register TRT ops in python, so that when users import this module they can
71# execute a TRT-converted graph without calling any of the methods in this
72# module.
73#
74# This will call register_op_list() in
75# tensorflow/python/framework/op_def_registry.py, but it doesn't register
76# the op or the op kernel in C++ runtime.
77try:
78  gen_trt_ops.trt_engine_op  # pylint: disable=pointless-statement
79except AttributeError:
80  pass
81
82
83def _to_bytes(s):
84  """Encode s if it is a sequence of chars."""
85  if isinstance(s, _six.text_type):
86    return s.encode("utf-8", errors="surrogateescape")
87  return s
88
89
90def _to_string(s):
91  """Decode s if it is a sequence of bytes."""
92  if isinstance(s, _six.binary_type):
93    return s.decode("utf-8")
94  return s
95
96
97class TrtPrecisionMode(object):
98  FP32 = "FP32"
99  FP16 = "FP16"
100  INT8 = "INT8"
101
102  @staticmethod
103  def supported_precision_modes():
104    precisions = [
105        TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8
106    ]
107    return precisions + [p.lower() for p in precisions]
108
109
110# Use a large enough number as the default max_workspace_size for TRT engines,
111# so it can produce reasonable performance results with the default.
112# For TRT >= 8.4, the recommendation is MAX_INT.
113if (_pywrap_py_utils.is_tensorrt_enabled() and
114    trt_utils.is_loaded_tensorrt_version_greater_equal(8, 4, 0)):
115  # We must use `sys.maxsize - 512` to avoid overflow during casting.
116  DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = sys.maxsize - 512
117else:
118  DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30  # 1,073,741,824
119
120PROFILE_STRATEGY_RANGE = "Range"
121PROFILE_STRATEGY_OPTIMAL = "Optimal"
122PROFILE_STRATEGY_RANGE_OPTIMAL = "Range+Optimal"
123PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE = "ImplicitBatchModeCompatible"
124
125
126def supported_profile_strategies():
127  return [
128      PROFILE_STRATEGY_RANGE, PROFILE_STRATEGY_OPTIMAL,
129      PROFILE_STRATEGY_RANGE_OPTIMAL,
130      PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE
131  ]
132
133
134@tf_export("experimental.tensorrt.ConversionParams", v1=[])
135class TrtConversionParams(
136    collections.namedtuple("TrtConversionParams", [
137        "max_workspace_size_bytes", "precision_mode", "minimum_segment_size",
138        "maximum_cached_engines", "use_calibration", "allow_build_at_runtime"
139    ])):
140  """Parameters that are used for TF-TRT conversion.
141
142  Fields:
143    max_workspace_size_bytes: the maximum GPU temporary memory that the TRT
144      engine can use at execution time. This corresponds to the
145      'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
146    precision_mode: one of the strings in
147      TrtPrecisionMode.supported_precision_modes().
148    minimum_segment_size: the minimum number of nodes required for a subgraph
149      to be replaced by TRTEngineOp.
150    maximum_cached_engines: max number of cached TRT engines for dynamic TRT
151      ops. Created TRT engines for a dynamic dimension are cached. If the
152      number of cached engines is already at max but none of them supports the
153      input shapes, the TRTEngineOp will fall back to run the original TF
154      subgraph that corresponds to the TRTEngineOp.
155    use_calibration: this argument is ignored if precision_mode is not INT8.
156      If set to True, a calibration graph will be created to calibrate the
157      missing ranges. The calibration graph must be converted to an inference
158      graph by running calibration with calibrate(). If set to False,
159      quantization nodes will be expected for every tensor in the graph
160      (excluding those which will be fused). If a range is missing, an error
161      will occur. Please note that accuracy may be negatively affected if
162      there is a mismatch between which tensors TRT quantizes and which
163      tensors were trained with fake quantization.
164    allow_build_at_runtime: whether to allow building TensorRT engines during
165      runtime if no prebuilt TensorRT engine can be found that can handle the
166      given inputs during runtime, then a new TensorRT engine is built at
167      runtime if allow_build_at_runtime=True, and otherwise native TF is used.
168  """
169
170  def __new__(cls,
171              max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
172              precision_mode=TrtPrecisionMode.FP32,
173              minimum_segment_size=3,
174              maximum_cached_engines=1,
175              use_calibration=True,
176              allow_build_at_runtime=True):
177    return super(TrtConversionParams,
178                 cls).__new__(cls, max_workspace_size_bytes, precision_mode,
179                              minimum_segment_size, maximum_cached_engines,
180                              use_calibration, allow_build_at_runtime)
181
182
183DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
184
185_TRT_ENGINE_OP_NAME = "TRTEngineOp"
186
187
188def _check_conversion_params(conversion_params, is_v2=False):
189  """Validate the provided TrtConversionParams.
190
191  Args:
192    conversion_params: a TrtConversionParams instance.
193    is_v2: whether we're getting a RewriterConfig for TF 2.0.
194
195  Raises:
196    TypeError: if any of the parameters are of unexpected type.
197    ValueError: if any of the parameters are of unexpected value.
198  """
199  supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
200  if conversion_params.precision_mode not in supported_precision_modes:
201    raise ValueError(
202        ("precision mode '{}' is not supported."
203         "It should be one of {}").format(conversion_params.precision_mode,
204                                          supported_precision_modes))
205  if (conversion_params.minimum_segment_size <= 0 and
206      conversion_params.minimum_segment_size != -1):
207    raise ValueError("minimum segment size should be positive or -1 "
208                     "(to disable main graph conversion).")
209
210
211def _check_trt_version_compatibility():
212  """Check compatibility of TensorRT version.
213
214  Raises:
215    RuntimeError: if the TensorRT library version is incompatible.
216  """
217
218  if not _pywrap_py_utils.is_tensorrt_enabled():
219    logging.error(
220        "Tensorflow needs to be built with TensorRT support enabled to allow "
221        "TF-TRT to operate.")
222
223    raise RuntimeError("Tensorflow has not been built with TensorRT support.")
224
225  if platform.system() == "Windows":
226    logging.warn(
227        "Windows support is provided experimentally. No guarantee is made "
228        "regarding functionality or engineering support. Use at your own risk.")
229
230  linked_version = _pywrap_py_utils.get_linked_tensorrt_version()
231  loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version()
232
233  logging.info("Linked TensorRT version: %s", str(linked_version))
234  logging.info("Loaded TensorRT version: %s", str(loaded_version))
235
236  def raise_trt_version_deprecated(version_type, trt_version):
237    assert version_type in [
238        "linked", "loaded"
239    ], ("Incorrect value received for version_type: %s. Accepted: ['linked', "
240        "'loaded']") % version_type
241
242    logging.error(
243        "The {version_type} version of TensorRT: `{trt_version}` has now "
244        "been removed. Please upgrade to TensorRT 7 or more recent.".format(
245            version_type=version_type,
246            trt_version=trt_utils.version_tuple_to_string(trt_version)))
247
248    raise RuntimeError("Incompatible %s TensorRT versions" % version_type)
249
250  if not trt_utils.is_linked_tensorrt_version_greater_equal(7, 0, 0):
251    raise_trt_version_deprecated("linked", linked_version)
252
253  if not trt_utils.is_loaded_tensorrt_version_greater_equal(7, 0, 0):
254    raise_trt_version_deprecated("loaded", loaded_version)
255
256  if (loaded_version[0] != linked_version[0] or
257      not trt_utils.is_loaded_tensorrt_version_greater_equal(*linked_version)):
258    logging.error(
259        "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. A few "
260        "requirements must be met:\n"
261        "\t-It is required to use the same major version of TensorRT during "
262        "compilation and runtime.\n"
263        "\t-TensorRT does not support forward compatibility. The loaded "
264        "version has to be equal or more recent than the linked version.",
265        trt_utils.version_tuple_to_string(loaded_version),
266        trt_utils.version_tuple_to_string(linked_version))
267    raise RuntimeError("Incompatible TensorRT major version")
268
269  elif loaded_version != linked_version:
270    logging.info(
271        "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. This is "
272        "supported because TensorRT minor/patch upgrades are backward "
273        "compatible.", trt_utils.version_tuple_to_string(loaded_version),
274        trt_utils.version_tuple_to_string(linked_version))
275
276
277def _get_tensorrt_rewriter_config(conversion_params,
278                                  is_dynamic_op=None,
279                                  max_batch_size=None,
280                                  is_v2=False,
281                                  disable_non_trt_optimizers=False,
282                                  use_implicit_batch=True,
283                                  profile_strategy=PROFILE_STRATEGY_RANGE):
284  """Returns a RewriterConfig proto for TRT transformation.
285
286  Args:
287    conversion_params: a TrtConversionParams instance.
288    is_dynamic_op: whether to use dynamic engines.
289    max_batch_size: maximum batch size for static engines.
290    is_v2: whether we're getting a RewriterConfig for TF 2.0.
291    disable_non_trt_optimizers: Turn off all default Grappler optimizers.
292    use_implicit_batch: Whether to use implicit batch or explicit batch.
293    profile_strategy: dynamic shape optimization profile strategy.
294
295  Returns:
296    A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
297
298  Raises:
299    TypeError: if any of the parameters are of unexpected type.
300    ValueError: if any of the parameters are of unexpected value.
301  """
302  _check_conversion_params(conversion_params, is_v2=is_v2)
303  if is_v2 and is_dynamic_op is not None and not is_dynamic_op:
304    raise ValueError("is_dynamic_op is either None or True for TF2")
305  if not is_v2 and is_dynamic_op is None:
306    raise ValueError("is_dynamic_op can't be None for TF1")
307
308  if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None:
309    raise ValueError("max_batch_size has to be None for TF2"
310                     " or when is_dynamic_op == True in TF1")
311  if is_dynamic_op is not None and not is_dynamic_op and not isinstance(
312      max_batch_size, int):
313    raise ValueError(
314        "max_batch_size has to be an integer for is_dynamic_op==False in TF1")
315  rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
316  # Disable Grappler Remapper to avoid that fused OPs that may not be
317  # beneficial to TF-TRT and are not supported by TF-TRT.
318  rewriter_config_with_trt.remapping = False
319
320  # Prevent folding of Const->QDQ chains.
321  rewriter_config_with_trt. \
322    experimental_disable_folding_quantization_emulation = (
323      trt_utils.is_linked_tensorrt_version_greater_equal(8, 0, 0) or
324      trt_utils.is_loaded_tensorrt_version_greater_equal(8, 0, 0))
325
326  if not disable_non_trt_optimizers:
327    rewriter_config_with_trt.optimizers.extend([
328        "pruning", "debug_stripper", "layout", "dependency", "constfold",
329        "common_subgraph_elimination"
330    ])
331
332  rewriter_config_with_trt.meta_optimizer_iterations = (
333      rewriter_config_pb2.RewriterConfig.ONE)
334  optimizer = rewriter_config_with_trt.custom_optimizers.add()
335
336  if not disable_non_trt_optimizers:
337    # Add a constfold optimizer to cleanup the unused Const nodes.
338    rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
339
340  optimizer.name = "TensorRTOptimizer"
341  optimizer.parameter_map[
342      "minimum_segment_size"].i = conversion_params.minimum_segment_size
343  optimizer.parameter_map["max_workspace_size_bytes"].i = (
344      conversion_params.max_workspace_size_bytes)
345  optimizer.parameter_map["precision_mode"].s = _to_bytes(
346      conversion_params.precision_mode)
347  optimizer.parameter_map[
348      "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
349  optimizer.parameter_map[
350      "use_calibration"].b = conversion_params.use_calibration
351  optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
352  optimizer.parameter_map[
353      "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
354  if max_batch_size is not None:
355    optimizer.parameter_map["max_batch_size"].i = max_batch_size
356  optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
357  # While we accept case insensitive strings from the users, we only pass the
358  # strings in lower cases to TF-TRT converter.
359  if not use_implicit_batch:
360    optimizer.parameter_map["profile_strategy"].s = _to_bytes(
361        profile_strategy.lower())
362
363  # Disabling optimizers should happen after defining the TF-TRT grappler pass
364  # otherwise the template can overwrite the disablement.
365  if disable_non_trt_optimizers:
366    trt_utils.disable_non_trt_optimizers_in_rewriter_config(
367        rewriter_config_with_trt)
368
369  return rewriter_config_with_trt
370
371
372@deprecation.deprecated(
373    None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.")
374def get_tensorrt_rewriter_config(conversion_params,
375                                 is_dynamic_op=None,
376                                 max_batch_size=None,
377                                 is_v2=False,
378                                 disable_non_trt_optimizers=False):
379  return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op,
380                                       max_batch_size, is_v2,
381                                       disable_non_trt_optimizers)
382
383
384# Remove all scope prefixes in the node name. In TF 2.0, the same concrete
385# function can be initialized multiple times with different prefixes, and
386# this will result in the same TRTEngineOp being initialized multiple times
387# with different cache and duplicate TRT engines.
388# TODO(laigd): this may be caused by the fact that TRTEngineOp is not
389# stateful, need to investigate.
390# TODO(laigd): we rely on the fact that all functions are fully inlined
391# before TF-TRT optimizer is called, as otherwise it may generate the same
392# name when optimizing a different function graph. Fix this.
393def _get_canonical_engine_name(name):
394  return name.split("/")[-1]
395
396
397class TrtGraphConverter(object):
398  """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels.
399
400  To run the conversion without quantization calibration (e.g. for FP32/FP16
401  precision modes):
402
403  ```python
404  converter = TrtGraphConverter(
405      input_saved_model_dir="my_dir",
406      precision_mode=TrtPrecisionMode.FP16)
407  converted_graph_def = converter.convert()
408  converter.save(output_saved_model_dir)
409  ```
410
411  To run the conversion with quantization calibration:
412
413  ```python
414  converter = TrtGraphConverter(
415      input_saved_model_dir="my_dir",
416      precision_mode=TrtPrecisionMode.INT8)
417  converter.convert()
418
419  # Run calibration 10 times.
420  converted_graph_def = converter.calibrate(
421      fetch_names=['output:0'],
422      num_runs=10,
423      feed_dict_fn=lambda: {'input:0': my_next_data()})
424
425  converter.save(output_saved_model_dir)
426  ```
427  """
428
429  def __init__(self,
430               input_saved_model_dir=None,
431               input_saved_model_tags=None,
432               input_saved_model_signature_key=None,
433               input_graph_def=None,
434               nodes_denylist=None,
435               max_batch_size=1,
436               max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
437               precision_mode=TrtPrecisionMode.FP32,
438               minimum_segment_size=3,
439               is_dynamic_op=False,
440               maximum_cached_engines=1,
441               use_calibration=True):
442    """Initializes the converter.
443
444    Args:
445      input_saved_model_dir: the directory to load the SavedModel which contains
446        the input graph to transforms. Used only when input_graph_def is None.
447      input_saved_model_tags: list of tags to load the SavedModel.
448      input_saved_model_signature_key: the key of the signature to optimize the
449        graph for.
450      input_graph_def: a GraphDef object containing a model to be transformed.
451        If set to None, the graph will be read from the SavedModel loaded from
452        input_saved_model_dir.
453      nodes_denylist: list of node names to prevent the converter from touching.
454      max_batch_size: max size for the input batch.
455      max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
456        engine can use at execution time. This corresponds to the
457        'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
458      precision_mode: one of TrtPrecisionMode.supported_precision_modes().
459      minimum_segment_size: the minimum number of nodes required for a subgraph
460        to be replaced by TRTEngineOp.
461      is_dynamic_op: whether to generate dynamic TRT ops which will build the
462        TRT network and engine at run time.
463      maximum_cached_engines: max number of cached TRT engines in dynamic TRT
464        ops. If the number of cached engines is already at max but none of them
465        can serve the input, the TRTEngineOp will fall back to run the TF
466        function based on which the TRTEngineOp is created.
467      use_calibration: this argument is ignored if precision_mode is not INT8.
468        If set to True, a calibration graph will be created to calibrate the
469        missing ranges. The calibration graph must be converted to an inference
470        graph by running calibration with calibrate(). If set to False,
471        quantization nodes will be expected for every tensor in the graph
472        (excluding those which will be fused). If a range is missing, an error
473        will occur. Please note that accuracy may be negatively affected if
474        there is a mismatch between which tensors TRT quantizes and which
475        tensors were trained with fake quantization.
476
477    Raises:
478      ValueError: if the combination of the parameters is invalid.
479      RuntimeError: if this class is used in TF 2.0.
480    """
481    if context.executing_eagerly():
482      raise RuntimeError(
483          "Please use tf.experimental.tensorrt.Converter in TF 2.0.")
484
485    if input_graph_def and input_saved_model_dir:
486      raise ValueError(
487          "Can only specify one of input_graph_def and input_saved_model_dir")
488    if not input_graph_def and not input_saved_model_dir:
489      raise ValueError("Must specify one of input_graph_def and "
490                       "input_saved_model_dir")
491    _check_trt_version_compatibility()
492
493    self._input_graph_def = input_graph_def
494    self._nodes_denylist = nodes_denylist
495
496    self._input_saved_model_dir = input_saved_model_dir
497    self._converted = False
498    self._grappler_meta_graph_def = None
499
500    self._input_saved_model_tags = (
501        input_saved_model_tags or [tag_constants.SERVING])
502    self._input_saved_model_signature_key = (
503        input_saved_model_signature_key or
504        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
505
506    # For calibration usage.
507    self._calibration_graph = None
508    self._calibration_data_collected = False
509    self._need_calibration = (
510        ((precision_mode == TrtPrecisionMode.INT8) or
511         (precision_mode == TrtPrecisionMode.INT8.lower())) and use_calibration)
512    if self._need_calibration and not is_dynamic_op:
513      logging.warn(
514          "INT8 precision mode with calibration is supported with "
515          "dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
516      is_dynamic_op = True
517
518    self._is_dynamic_op = is_dynamic_op
519    if is_dynamic_op:
520      self._max_batch_size = None
521      if max_batch_size is not None:
522        logging.warn("When is_dynamic_op==True max_batch_size should be None")
523    else:
524      if not isinstance(max_batch_size, int):
525        raise ValueError("When is_dynamic_op==False max_batch_size should be "
526                         "an integer")
527      self._max_batch_size = max_batch_size
528
529    self._conversion_params = TrtConversionParams(
530        max_workspace_size_bytes=max_workspace_size_bytes,
531        precision_mode=precision_mode,
532        minimum_segment_size=minimum_segment_size,
533        maximum_cached_engines=maximum_cached_engines,
534        use_calibration=use_calibration,
535        allow_build_at_runtime=True)
536    _check_conversion_params(self._conversion_params)
537
538    self._test_only_disable_non_trt_optimizers = False
539
540  def _run_conversion(self):
541    """Run Grappler's OptimizeGraph() tool to convert the graph."""
542    # Create custom ConfigProto for Grappler.
543    grappler_session_config = config_pb2.ConfigProto()
544    custom_rewriter_config = _get_tensorrt_rewriter_config(
545        conversion_params=self._conversion_params,
546        is_dynamic_op=self._is_dynamic_op,
547        max_batch_size=self._max_batch_size,
548        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
549        use_implicit_batch=True)
550    grappler_session_config.graph_options.rewrite_options.CopyFrom(
551        custom_rewriter_config)
552
553    # Run Grappler.
554    self._converted_graph_def = tf_optimizer.OptimizeGraph(
555        grappler_session_config,
556        self._grappler_meta_graph_def,
557        graph_id=b"tf_graph")
558    self._converted = True
559
560  def _add_nodes_denylist(self):
561    if self._nodes_denylist:
562      collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
563      denylist = collection_def.node_list.value
564      for i in self._nodes_denylist:
565        if isinstance(i, ops.Tensor):
566          denylist.append(_to_bytes(i.name))
567        else:
568          denylist.append(_to_bytes(i))
569
570  def _convert_graph_def(self):
571    """Convert the input GraphDef."""
572    graph = ops.Graph()
573    with graph.as_default():
574      importer.import_graph_def(self._input_graph_def, name="")
575    self._grappler_meta_graph_def = saver.export_meta_graph(
576        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
577    self._add_nodes_denylist()
578
579    self._run_conversion()
580
581  def _collections_to_keep(self, collection_keys):
582    # TODO(laigd): currently we use the collection key to filter out
583    # collections that depend on variable ops, but this may miss some
584    # other user-defined collections. A better way would be to use
585    # CollectionDef::NodeList for the filtering.
586    collections_to_remove = (
587        ops.GraphKeys._VARIABLE_COLLECTIONS + [
588            ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT,
589            ops.GraphKeys.COND_CONTEXT
590        ])
591    return [key for key in collection_keys if key not in collections_to_remove]
592
593  def _convert_saved_model(self):
594    """Convert the input SavedModel."""
595    graph = ops.Graph()
596    with session.Session(graph=graph) as sess:
597      input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
598                                         self._input_saved_model_dir)
599      input_signature_def = input_meta_graph_def.signature_def[
600          self._input_saved_model_signature_key]
601
602      def _gather_names(tensor_info):
603        """Get the node names from a TensorInfo."""
604        return {tensor_info[key].name.split(":")[0] for key in tensor_info}
605
606      # Get input and outputs from all SignatureDef.
607      output_node_names = _gather_names(input_signature_def.inputs).union(
608          _gather_names(input_signature_def.outputs))
609
610      # Preserve nodes in collection
611      for collection_key in self._collections_to_keep(
612          input_meta_graph_def.collection_def):
613        for op in sess.graph.get_collection(collection_key):
614          if isinstance(op, ops.Operation):
615            output_node_names.add(op.name.split(":")[0])
616
617      # Freeze the variables in the SavedModel graph and copy the frozen
618      # graph over.
619      frozen_graph_def = graph_util.convert_variables_to_constants(
620          sess, sess.graph.as_graph_def(add_shapes=True),
621          list(output_node_names))
622      self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
623      self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
624
625      # Copy the collections that are not variables.
626      for collection_key in self._collections_to_keep(
627          input_meta_graph_def.collection_def):
628        self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom(
629            input_meta_graph_def.collection_def[collection_key])
630
631      self._add_nodes_denylist()
632
633      # Copy other information.
634      self._grappler_meta_graph_def.meta_info_def.CopyFrom(
635          input_meta_graph_def.meta_info_def)
636      self._grappler_meta_graph_def.signature_def[
637          self._input_saved_model_signature_key].CopyFrom(input_signature_def)
638      # TODO(laigd): maybe add back AssetFileDef.
639
640    self._run_conversion()
641
642  def convert(self):
643    """Run the TF-TRT conversion.
644
645    Returns:
646      The converted GraphDef for TF 1.x.
647    """
648    assert not self._converted
649    if self._input_graph_def:
650      self._convert_graph_def()
651    else:
652      self._convert_saved_model()
653    return self._converted_graph_def
654
655  def calibrate(self,
656                fetch_names,
657                num_runs,
658                feed_dict_fn=None,
659                input_map_fn=None):
660    """Run the calibration and return the calibrated GraphDef.
661
662    Args:
663      fetch_names: a list of output tensor name to fetch during calibration.
664      num_runs: number of runs of the graph during calibration.
665      feed_dict_fn: a function that returns a dictionary mapping input names (as
666        strings) in the GraphDef to be calibrated to values (e.g. Python list,
667        numpy arrays, etc). One and only one of `feed_dict_fn` and
668        `input_map_fn` should be specified.
669      input_map_fn: a function that returns a dictionary mapping input names (as
670        strings) in the GraphDef to be calibrated to Tensor objects. The values
671        of the named input tensors in the GraphDef to be calibrated will be
672        re-mapped to the respective `Tensor` values during calibration. One and
673        only one of `feed_dict_fn` and `input_map_fn` should be specified.
674
675    Raises:
676      ValueError: if the input combination is invalid.
677      RuntimeError: if this method is called in eager mode.
678
679    Returns:
680      The GraphDef after the calibration.
681    """
682    assert self._converted
683    assert self._need_calibration
684    assert not self._calibration_data_collected
685
686    if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
687                                           not input_map_fn):
688      raise ValueError(
689          "Should specify one and only one of feed_dict_fn and input_map_fn.")
690
691    if input_map_fn:
692      for k, v in input_map_fn().items():
693        if not isinstance(k, str):
694          raise ValueError("Keys of input_map_fn must be of type str")
695        if not isinstance(v, ops.Tensor):
696          raise ValueError("Values of input_map_fn must be of type tf.Tensor")
697
698    self._calibration_graph = ops.Graph()
699    with self._calibration_graph.as_default():
700      fetches = importer.import_graph_def(
701          self._converted_graph_def,
702          input_map=input_map_fn() if input_map_fn else None,
703          return_elements=fetch_names,
704          name="")
705
706    calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
707    if self._test_only_disable_non_trt_optimizers:
708      trt_utils.disable_non_trt_optimizers_in_rewriter_config(
709          calibrate_rewriter_cfg)
710
711    # Set allow_soft_placement=True to run the graph for calibration so that
712    # OPs supported by TensorRT but don't have a GPU implementation are allowed
713    # to execute on CPU.
714    calibrate_config = config_pb2.ConfigProto(
715        allow_soft_placement=True,
716        graph_options=config_pb2.GraphOptions(
717            rewrite_options=calibrate_rewriter_cfg))
718
719    with session.Session(
720        graph=self._calibration_graph,
721        config=calibrate_config) as calibration_sess:
722      for _ in range(num_runs):
723        calibration_sess.run(
724            fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
725
726      # Maps device name to the corresponding get_calibration_data.
727      #
728      # TODO(laigd): a better way would be to use calibration_sess to list
729      # all the devices, add one get_calibration_data for each device, and
730      # fetch each such op for every resource until its found. This can work
731      # even when the device of the TRTEngineOp is empty or not fully specified.
732      device_to_get_resource_op_map = {}
733
734      with self._calibration_graph.as_default():
735        resource_name_input = array_ops.placeholder(dtypes.string)
736
737        for node in self._converted_graph_def.node:
738          if node.op == _TRT_ENGINE_OP_NAME:
739            # Adds the get_calibration_data op for the device if not done
740            # before. We only add one such op for each device.
741            # TODO(laigd): What if the device is empty?????
742            if node.device not in device_to_get_resource_op_map:
743              with self._calibration_graph.device(node.device):
744                serialized_resources_output = (
745                    gen_trt_ops.get_calibration_data_op(resource_name_input))
746              device_to_get_resource_op_map[node.device] = (
747                  serialized_resources_output)
748
749            # Get the calibration resource.
750            calibration_result = calibration_sess.run(
751                device_to_get_resource_op_map[node.device],
752                feed_dict={
753                    resource_name_input: _get_canonical_engine_name(node.name)
754                })
755            node.attr["calibration_data"].s = calibration_result
756
757      self._calibration_data_collected = True
758
759    return self._converted_graph_def
760
761  def save(self, output_saved_model_dir):
762    """Save the converted graph as a SavedModel.
763
764    Args:
765      output_saved_model_dir: construct a SavedModel using the converted
766        GraphDef and save it to the specified directory. This option only works
767        when the input graph is loaded from a SavedModel, i.e. when
768        input_saved_model_dir is specified and input_graph_def is None in
769        __init__().
770
771    Raises:
772      ValueError: if the input to the converter is a GraphDef instead of a
773      SavedModel.
774    """
775    assert self._converted
776    if self._need_calibration:
777      assert self._calibration_data_collected
778    if self._input_graph_def:
779      raise ValueError(
780          "Not able to save to a SavedModel since input is a GraphDef")
781
782    def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
783      """Restores collections that we need to keep."""
784      scope = ""
785      for key in collection_keys:
786        collection_def = src_meta_graph_def.collection_def[key]
787        kind = collection_def.WhichOneof("kind")
788        if kind is None:
789          logging.error(
790              "Cannot identify data type for collection %s. Skipping.", key)
791          continue
792        from_proto = ops.get_from_proto_function(key)
793        if from_proto and kind == "bytes_list":
794          proto_type = ops.get_collection_proto_type(key)
795          # It is assumed that there are no Variables Keys in collections
796          for value in collection_def.bytes_list.value:
797            proto = proto_type()
798            proto.ParseFromString(value)
799            try:
800              new_value = from_proto(proto, import_scope=scope)
801            except:
802              continue
803            dest_graph.add_to_collection(key, new_value)
804        else:
805          field = getattr(collection_def, kind)
806          if kind == "node_list":
807            for value in field.value:
808              name = ops.prepend_name_scope(value, scope)
809              # Since the graph has been optimized, the node may no longer
810              # exists
811              try:
812                col_op = dest_graph.as_graph_element(name)
813              except (TypeError, ValueError, KeyError):
814                continue
815              dest_graph.add_to_collection(key, col_op)
816          elif kind == "int64_list":
817            # NOTE(opensource): This force conversion is to work around the
818            # fact that Python2 distinguishes between int and long, while
819            # Python3 has only int.
820            for value in field.value:
821              dest_graph.add_to_collection(key, int(value))
822          else:
823            for value in field.value:
824              dest_graph.add_to_collection(key,
825                                           ops.prepend_name_scope(value, scope))
826
827    # Write the transformed graphdef as SavedModel.
828    saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
829    with ops.Graph().as_default():
830      importer.import_graph_def(self._converted_graph_def, name="")
831      _restore_collections(
832          ops.get_default_graph(), self._grappler_meta_graph_def,
833          self._collections_to_keep(
834              self._grappler_meta_graph_def.collection_def))
835      # We don't use any specific converter here.
836      with session.Session() as sess:
837        saved_model_builder.add_meta_graph_and_variables(
838            sess,
839            self._input_saved_model_tags,
840            signature_def_map=self._grappler_meta_graph_def.signature_def)
841    # Ignore other meta graphs from the input SavedModel.
842    saved_model_builder.save()
843
844
845def _get_resource_handle(name, device):
846  with ops.device(device):
847    return gen_trt_ops.create_trt_resource_handle(resource_name=name)
848
849
850class _TRTEngineResource(resource.TrackableResource):
851  """Class to track the serialized engines resource."""
852
853  def __init__(self,
854               resource_name,
855               filename,
856               maximum_cached_engines,
857               device="GPU"):
858    super(_TRTEngineResource, self).__init__(device=device)
859    self._resource_name = resource_name
860    # Track the serialized engine file in the SavedModel.
861    self._filename = self._track_trackable(
862        asset.Asset(filename), "_serialized_trt_resource_filename")
863    self._maximum_cached_engines = maximum_cached_engines
864
865  def _create_resource(self):
866    return _get_resource_handle(self._resource_name, self._resource_device)
867
868  def _initialize(self):
869    gen_trt_ops.initialize_trt_resource(
870        self.resource_handle,
871        self._filename,
872        max_cached_engines_count=self._maximum_cached_engines)
873
874  def _destroy_resource(self):
875    handle = _get_resource_handle(self._resource_name, self._resource_device)
876    with ops.device(self._resource_device):
877      gen_resource_variable_ops.destroy_resource_op(
878          handle, ignore_lookup_error=True)
879
880
881def _print_row(fields, positions, print_fn):
882  """Prints a row."""
883  line = ""
884  for i, field in enumerate(fields):
885    field = str(field)
886    end_line_pos = positions[i]
887    if i > 0:
888      line = line + " "
889    line = "{0:{min_length}}".format(line + field, min_length=end_line_pos)
890
891    if len(line) > end_line_pos:
892      line = line[:(end_line_pos - 4)] + " ..."
893
894  print_fn(line)
895
896
897def _construct_function_from_graph_def(func, graph_def, frozen_func=None):
898  """Rebuild function from graph_def."""
899  if frozen_func is None:
900    frozen_func = func
901
902  # If a function is converted, then the TF context contains the original
903  # function while the converted_graph_def contains the converted function.
904  # Remove the original function from the TF context in this case.
905  for f in graph_def.library.function:
906    while context.context().has_function(f.signature.name):
907      context.context().remove_function(f.signature.name)
908
909  # pylint: disable = protected-access
910  captures = {
911      t2.name.split(":")[0]: t1
912      for _, (t1, t2) in frozen_func.graph._captures.items()
913  }
914  new_func = wrap_function.function_from_graph_def(
915      graph_def, [tensor.name for tensor in frozen_func.inputs],
916      [tensor.name for tensor in frozen_func.outputs], captures)
917  new_func.graph.structured_outputs = nest.pack_sequence_as(
918      func.graph.structured_outputs, new_func.graph.structured_outputs)
919
920  # Copy structured input signature from original function (used during
921  # serialization)
922  new_func.graph.structured_input_signature = (func.structured_input_signature)
923
924  return new_func
925
926
927def _apply_inlining(func):
928  """Apply an inlining optimization to the function's graph definition."""
929  graph_def = func.graph.as_graph_def()
930
931  # In some cases, a secondary implementation of the function (e.g. for GPU) is
932  # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
933  # TF2 produces a CuDNN-based RNN for GPU).
934  # This function suppose to inline all functions calls, but "api_implements"
935  # prevents this from happening. Removing the attribute solves the problem.
936  # To learn more about "api_implements", see:
937  #   tensorflow/core/grappler/optimizers/implementation_selector.h
938  for function in graph_def.library.function:
939    if "api_implements" in function.attr:
940      del function.attr["api_implements"]
941
942  meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph)
943
944  # Clear the initializer_name for the variables collections, since they are not
945  # needed after saved to saved_model.
946  for name in [
947      "variables", "model_variables", "trainable_variables", "local_variables"
948  ]:
949    raw_list = []
950    for raw in meta_graph.collection_def["variables"].bytes_list.value:
951      variable = variable_pb2.VariableDef()
952      variable.ParseFromString(raw)
953      variable.ClearField("initializer_name")
954      raw_list.append(variable.SerializeToString())
955    meta_graph.collection_def[name].bytes_list.value[:] = raw_list
956
957  # Add a collection 'train_op' so that Grappler knows the outputs.
958  fetch_collection = meta_graph_pb2.CollectionDef()
959  for array in func.inputs + func.outputs:
960    fetch_collection.node_list.value.append(array.name)
961  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
962
963  # Initialize RewriterConfig with everything disabled except function inlining.
964  config = config_pb2.ConfigProto()
965  rewrite_options = config.graph_options.rewrite_options
966  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
967  rewrite_options.optimizers.append("function")
968
969  new_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)
970
971  return new_graph_def
972
973
974def _annotate_variable_ops(func, graph_def):
975  """Annotates variable operations with custom `_shape` attribute.
976
977  This is required for the converters and shape inference. The graph
978  definition is modified in-place.
979
980  Args:
981    func: Function represented by the graph definition.
982    graph_def: Graph definition to be annotated in-place.
983
984  Raises:
985    RuntimeError: if some shapes cannot be annotated.
986  """
987  ph_shape_map = {}
988  for ph, var in zip(func.graph.internal_captures, func.variables):
989    ph_shape_map[ph.name] = var.shape
990  # Construct a mapping of node names to nodes
991  name_to_node = {node.name: node for node in graph_def.node}
992  # Go through all the ReadVariableOp nodes in the graph def
993  for node in graph_def.node:
994    if node.op == "ReadVariableOp" or node.op == "ResourceGather":
995      node_ = node
996      # Go up the chain of identities to find a placeholder
997      while name_to_node[node_.input[0]].op == "Identity":
998        node_ = name_to_node[node_.input[0]]
999      ph_name = node_.input[0] + ":0"
1000      if ph_name in ph_shape_map:
1001        shape = ph_shape_map[ph_name]
1002        node.attr["_shape"].shape.CopyFrom(shape.as_proto())
1003      else:
1004        raise RuntimeError(
1005            "Not found in the function captures: {}".format(ph_name))
1006
1007
1008def _save_calibration_table(node):
1009  try:
1010    calibration_table = gen_trt_ops.get_calibration_data_op(
1011        _get_canonical_engine_name(node.name))
1012    node.attr["calibration_data"].s = calibration_table.numpy()
1013  except errors.UnknownError:
1014    logging.warning("Warning calibration error for %s", node.name)
1015
1016
1017def _convert_to_tensor(inp):
1018  if isinstance(inp, dict):
1019    args = []
1020    kwargs = {k: ops.convert_to_tensor(v) for k, v in inp.items()}
1021  else:
1022    args = map(ops.convert_to_tensor, inp)
1023    kwargs = {}
1024  return args, kwargs
1025
1026
1027@tf_export("experimental.tensorrt.Converter", v1=[])
1028class TrtGraphConverterV2(object):
1029  """An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
1030
1031  Windows support is provided experimentally. No guarantee is made regarding
1032  functionality or engineering support. Use at your own risk.
1033
1034  There are several ways to run the conversion:
1035
1036  1. FP32/FP16 precision
1037
1038     ```python
1039     params = tf.experimental.tensorrt.ConversionParams(
1040         precision_mode='FP16')
1041     converter = tf.experimental.tensorrt.Converter(
1042         input_saved_model_dir="my_dir", conversion_params=params)
1043     converter.convert()
1044     converter.save(output_saved_model_dir)
1045     ```
1046
1047     In this case, no TRT engines will be built or saved in the converted
1048     SavedModel. But if input data is available during conversion, we can still
1049     build and save the TRT engines to reduce the cost during inference (see
1050     option 2 below).
1051
1052  2. FP32/FP16 precision with pre-built engines
1053
1054     ```python
1055     params = tf.experimental.tensorrt.ConversionParams(
1056         precision_mode='FP16',
1057         # Set this to a large enough number so it can cache all the engines.
1058         maximum_cached_engines=16)
1059     converter = tf.experimental.tensorrt.Converter(
1060         input_saved_model_dir="my_dir", conversion_params=params)
1061     converter.convert()
1062
1063     # Define a generator function that yields input data, and use it to execute
1064     # the graph to build TRT engines.
1065     def my_input_fn():
1066       for _ in range(num_runs):
1067         inp1, inp2 = ...
1068         yield inp1, inp2
1069
1070     converter.build(input_fn=my_input_fn)  # Generate corresponding TRT engines
1071     converter.save(output_saved_model_dir)  # Generated engines will be saved.
1072     ```
1073
1074     In this way, one engine will be built/saved for each unique input shapes of
1075     the TRTEngineOp. This is good for applications that cannot afford building
1076     engines during inference but have access to input data that is similar to
1077     the one used in production (for example, that has the same input shapes).
1078     Also, the generated TRT engines is platform dependent, so we need to run
1079     `build()` in an environment that is similar to production (e.g. with
1080     same type of GPU).
1081
1082  3. INT8 precision and calibration with pre-built engines
1083
1084     ```python
1085     params = tf.experimental.tensorrt.ConversionParams(
1086         precision_mode='INT8',
1087         # Currently only one INT8 engine is supported in this mode.
1088         maximum_cached_engines=1,
1089         use_calibration=True)
1090     converter = tf.experimental.tensorrt.Converter(
1091         input_saved_model_dir="my_dir", conversion_params=params)
1092
1093     # Define a generator function that yields input data, and run INT8
1094     # calibration with the data. All input data should have the same shape.
1095     # At the end of convert(), the calibration stats (e.g. range information)
1096     # will be saved and can be used to generate more TRT engines with different
1097     # shapes. Also, one TRT engine will be generated (with the same shape as
1098     # the calibration data) for save later.
1099     def my_calibration_input_fn():
1100       for _ in range(num_runs):
1101         inp1, inp2 = ...
1102         yield inp1, inp2
1103
1104     converter.convert(calibration_input_fn=my_calibration_input_fn)
1105
1106     # (Optional) Generate more TRT engines offline (same as the previous
1107     # option), to avoid the cost of generating them during inference.
1108     def my_input_fn():
1109       for _ in range(num_runs):
1110         inp1, inp2 = ...
1111         yield inp1, inp2
1112     converter.build(input_fn=my_input_fn)
1113
1114     # Save the TRT engine and the engines.
1115     converter.save(output_saved_model_dir)
1116     ```
1117  4. To use dynamic shape, we need to call the build method with an input
1118     function to generate profiles. This step is similar to the INT8 calibration
1119     step described above. The converter also needs to be created with
1120     use_dynamic_shape=True and one of the following profile_strategies for
1121     creating profiles based on the inputs produced by the input function:
1122     * `Range`: create one profile that works for inputs with dimension values
1123       in the range of [min_dims, max_dims] where min_dims and max_dims are
1124       derived from the provided inputs.
1125     * `Optimal`: create one profile for each input. The profile only works for
1126       inputs with the same dimensions as the input it is created for. The GPU
1127       engine will be run with optimal performance with such inputs.
1128     * `Range+Optimal`: create the profiles for both `Range` and `Optimal`.
1129     * `ImplicitBatchModeCompatible`: create the profiles that will produce the
1130       same GPU engines as the implicit_batch_mode would produce.
1131  """
1132
1133  def _verify_profile_strategy(self, strategy):
1134    supported_strategies = [s.lower() for s in supported_profile_strategies()]
1135    if strategy.lower() not in supported_strategies:
1136      raise ValueError(
1137          ("profile_strategy '{}' is not supported. It should be one of {}"
1138          ).format(strategy, supported_profile_strategies()))
1139
1140  @deprecation.deprecated_args(None,
1141                               "Use individual converter parameters instead",
1142                               "conversion_params")
1143  def __init__(self,
1144               input_saved_model_dir=None,
1145               input_saved_model_tags=None,
1146               input_saved_model_signature_key=None,
1147               use_dynamic_shape=None,
1148               dynamic_shape_profile_strategy=None,
1149               max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
1150               precision_mode=TrtPrecisionMode.FP32,
1151               minimum_segment_size=3,
1152               maximum_cached_engines=1,
1153               use_calibration=True,
1154               allow_build_at_runtime=True,
1155               conversion_params=None):
1156    """Initialize the converter.
1157
1158    Args:
1159      input_saved_model_dir: the directory to load the SavedModel which contains
1160        the input graph to transforms. Required.
1161      input_saved_model_tags: list of tags to load the SavedModel.
1162      input_saved_model_signature_key: the key of the signature to optimize the
1163        graph for.
1164      use_dynamic_shape: whether to enable dynamic shape support. None is
1165        equivalent to False in the current implementation.
1166      dynamic_shape_profile_strategy: one of the strings in
1167        supported_profile_strategies(). None is equivalent to
1168        ImplicitBatchModeCompatible in the current implementation.
1169      max_workspace_size_bytes: the maximum GPU temporary memory that the TRT
1170        engine can use at execution time. This corresponds to the
1171        'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
1172      precision_mode: one of the strings in
1173        TrtPrecisionMode.supported_precision_modes().
1174      minimum_segment_size: the minimum number of nodes required for a subgraph
1175        to be replaced by TRTEngineOp.
1176      maximum_cached_engines: max number of cached TRT engines for dynamic TRT
1177        ops. Created TRT engines for a dynamic dimension are cached. If the
1178        number of cached engines is already at max but none of them supports the
1179        input shapes, the TRTEngineOp will fall back to run the original TF
1180        subgraph that corresponds to the TRTEngineOp.
1181      use_calibration: this argument is ignored if precision_mode is not INT8.
1182        If set to True, a calibration graph will be created to calibrate the
1183        missing ranges. The calibration graph must be converted to an inference
1184        graph by running calibration with calibrate(). If set to False,
1185        quantization nodes will be expected for every tensor in the graph
1186        (excluding those which will be fused). If a range is missing, an error
1187        will occur. Please note that accuracy may be negatively affected if
1188        there is a mismatch between which tensors TRT quantizes and which
1189        tensors were trained with fake quantization.
1190      allow_build_at_runtime: whether to allow building TensorRT engines during
1191        runtime if no prebuilt TensorRT engine can be found that can handle the
1192        given inputs during runtime, then a new TensorRT engine is built at
1193        runtime if allow_build_at_runtime=True, and otherwise native TF is used.
1194      conversion_params: a TrtConversionParams instance (deprecated).
1195
1196    Raises:
1197      ValueError: if the combination of the parameters is invalid.
1198    """
1199    assert context.executing_eagerly()
1200    if conversion_params is None:
1201      conversion_params = TrtConversionParams(
1202          max_workspace_size_bytes=max_workspace_size_bytes,
1203          precision_mode=precision_mode,
1204          minimum_segment_size=minimum_segment_size,
1205          maximum_cached_engines=maximum_cached_engines,
1206          use_calibration=use_calibration,
1207          allow_build_at_runtime=allow_build_at_runtime)
1208
1209    _check_trt_version_compatibility()
1210    _check_conversion_params(conversion_params, is_v2=True)
1211
1212    self._conversion_params = conversion_params
1213    self._input_saved_model_dir = input_saved_model_dir
1214    self._input_saved_model_tags = (
1215        input_saved_model_tags or [tag_constants.SERVING])
1216    self._input_saved_model_signature_key = (
1217        input_saved_model_signature_key or
1218        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
1219    self.freeze = not trt_utils.is_experimental_feature_activated(
1220        "disable_graph_freezing")
1221
1222    self._need_calibration = ((
1223        (conversion_params.precision_mode == TrtPrecisionMode.INT8) or
1224        (conversion_params.precision_mode == TrtPrecisionMode.INT8.lower())) and
1225                              conversion_params.use_calibration)
1226
1227    self._calibration_input_fn = None
1228
1229    self._converted = False
1230    self._device = None
1231    self._build_called_once = False
1232    self._calibrated = False
1233
1234    if use_dynamic_shape is None:
1235      self._use_dynamic_shape = False
1236    else:
1237      self._use_dynamic_shape = use_dynamic_shape
1238
1239    if not self.freeze and not self._use_dynamic_shape:
1240      logging.warn(
1241          "Disabling graph freezing is only possible in dynamic shape mode."
1242          " The graph will be frozen.")
1243      self.freeze = True
1244
1245    self._profile_strategy = "Unknown"
1246    if self._use_dynamic_shape:
1247      if dynamic_shape_profile_strategy is None:
1248        self._profile_strategy = \
1249            PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE
1250      else:
1251        self._verify_profile_strategy(dynamic_shape_profile_strategy)
1252        self._profile_strategy = dynamic_shape_profile_strategy
1253
1254    # Fields to support TF-TRT testing and shouldn't be used for other purpose.
1255    self._test_only_disable_non_trt_optimizers = False
1256
1257  def _need_trt_profiles(self):
1258    return self._use_dynamic_shape
1259
1260  def _run_conversion(self, meta_graph_def):
1261    """Run Grappler's OptimizeGraph() tool to convert the graph.
1262
1263    Args:
1264      meta_graph_def: the MetaGraphDef instance to run the optimizations on.
1265
1266    Returns:
1267      The optimized GraphDef.
1268    """
1269    grappler_session_config = config_pb2.ConfigProto()
1270    # Always set `allow_build_at_runtime` for offline TensorRT engine building.
1271    custom_rewriter_config = _get_tensorrt_rewriter_config(
1272        conversion_params=self._conversion_params._replace(
1273            allow_build_at_runtime=True),
1274        is_dynamic_op=True,
1275        max_batch_size=None,
1276        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
1277        use_implicit_batch=not self._use_dynamic_shape,
1278        profile_strategy=self._profile_strategy)
1279    grappler_session_config.graph_options.rewrite_options.CopyFrom(
1280        custom_rewriter_config)
1281    return tf_optimizer.OptimizeGraph(
1282        grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
1283
1284  def _for_each_trt_node(self, graph_def, fn):
1285    """Helper method to manipulate all TRTEngineOps in a GraphDef."""
1286    for node in graph_def.node:
1287      if node.op == _TRT_ENGINE_OP_NAME:
1288        fn(node)
1289    for func in graph_def.library.function:
1290      for node in func.node_def:
1291        if node.op == _TRT_ENGINE_OP_NAME:
1292          fn(node)
1293
1294  def _execute_calibration(self, calibration_input_fn):
1295    """Run INT8 calibration with the provided input generator function."""
1296    for inp in calibration_input_fn():
1297      args, kwargs = _convert_to_tensor(inp)
1298      self._converted_func(*args, **kwargs)
1299
1300    self._for_each_trt_node(self._converted_graph_def, _save_calibration_table)
1301
1302    # Rebuild the function since calibration has changed the graph.
1303    self._converted_func = _construct_function_from_graph_def(
1304        self._converted_func, self._converted_graph_def)
1305    self._calibrated = True
1306
1307  # TODO(laigd): provide a utility function to optimize a ConcreteFunction and
1308  # use it here (b/124792963).
1309  def convert(self, calibration_input_fn=None):
1310    """Convert the input SavedModel in 2.0 format.
1311
1312    Args:
1313      calibration_input_fn: a generator function that yields input data as a
1314        list or tuple or dict, which will be used to execute the converted
1315        signature for calibration. All the returned input data should have the
1316        same shape. Example: `def input_fn(): yield input1, input2, input3`
1317
1318        If dynamic_shape_mode==False, (or if the graph has static input shapes)
1319        then we run calibration and build the calibrated engine during
1320        conversion.
1321
1322        If dynamic_shape_mode==True (and the graph has any unknown input
1323        shape), then the reference to calibration_input_fn is stored, and the
1324        calibration is actually performed when we build the engine (see
1325        build()).
1326
1327    Raises:
1328      ValueError: if the input combination is invalid.
1329
1330    Returns:
1331      The TF-TRT converted Function.
1332    """
1333    assert not self._converted
1334
1335    # Creating an empty tensor to fetch queried device
1336    device_requested = array_ops.zeros([]).device
1337
1338    if "gpu" not in device_requested.lower():
1339      raise ValueError(f"Specified device is not a GPU: {device_requested}")
1340
1341    if "gpu:0" not in device_requested.lower():
1342      self._device = device_requested
1343      logging.info(f"Placing imported graph from "
1344                   f"`{self._input_saved_model_dir}` on device: {self._device}")
1345
1346    if (self._need_calibration and not calibration_input_fn):
1347      raise ValueError("Should specify calibration_input_fn because INT8 "
1348                       "calibration is needed")
1349    if (not self._need_calibration and calibration_input_fn):
1350      raise ValueError("Should not specify calibration_input_fn because INT8 "
1351                       "calibration is not needed")
1352
1353    self._saved_model = load.load(self._input_saved_model_dir,
1354                                  self._input_saved_model_tags)
1355    func = self._saved_model.signatures[self._input_saved_model_signature_key]
1356    if self.freeze:
1357      frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
1358    else:
1359      inlined_graph_def = _apply_inlining(func)
1360      _annotate_variable_ops(func, inlined_graph_def)
1361      frozen_func = _construct_function_from_graph_def(func, inlined_graph_def)
1362    frozen_graph_def = frozen_func.graph.as_graph_def()
1363
1364    # Clear any prior device assignments
1365    logging.info("Clearing prior device assignments in loaded saved model")
1366    for node in frozen_graph_def.node:
1367      node.device = ""
1368
1369    if self._device is None:
1370      grappler_meta_graph_def = saver.export_meta_graph(
1371          graph_def=frozen_graph_def, graph=frozen_func.graph)
1372    else:
1373      with ops.Graph().as_default() as graph, ops.device(self._device):
1374        importer.import_graph_def(frozen_graph_def, name="")
1375        grappler_meta_graph_def = saver.export_meta_graph(
1376            graph_def=graph.as_graph_def(), graph=graph)
1377
1378    # Add a collection 'train_op' so that Grappler knows the outputs.
1379    fetch_collection = meta_graph_pb2.CollectionDef()
1380    for array in frozen_func.inputs + frozen_func.outputs:
1381      fetch_collection.node_list.value.append(array.name)
1382    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
1383        fetch_collection)
1384
1385    # Run TRT optimizer in Grappler to convert the graph.
1386    self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
1387    self._converted_func = _construct_function_from_graph_def(
1388        func, self._converted_graph_def, frozen_func)
1389
1390    if self._need_calibration:
1391      # Execute calibration here only if not in dynamic shape mode.
1392      if not self._need_trt_profiles():
1393        self._execute_calibration(calibration_input_fn)
1394      else:
1395        self._calibration_input_fn = calibration_input_fn
1396
1397    self._converted = True
1398
1399    graphviz_path = os.environ.get("TF_TRT_EXPORT_GRAPH_VIZ_PATH", default=None)
1400    if graphviz_path is not None:
1401      try:
1402        trt_utils.draw_graphdef_as_graphviz(
1403            graphdef=self._converted_func.graph.as_graph_def(add_shapes=True),
1404            dot_output_filename=graphviz_path)
1405      except Exception as e:
1406        logging.error("An Exception occured during the export of the graph "
1407                      f"visualization: {e}")
1408
1409    return self._converted_func
1410
1411  def build(self, input_fn):
1412    """Run inference with converted graph in order to build TensorRT engines.
1413
1414    If the conversion requires INT8 calibration, then a reference to the
1415    calibration function was stored during the call to convert(). Calibration
1416    will be performed while we build the TensorRT engines.
1417
1418    Args:
1419      input_fn: a generator function that yields input data as a list or tuple
1420        or dict, which will be used to execute the converted signature to
1421        generate TRT engines. Example:
1422        `def input_fn(): # Let's assume a network with 2 input tensors. We
1423          generate 3 sets
1424             # of dummy input data: input_shapes = [[(1, 16), (2, 16)], # 1st
1425               input list [(2, 32), (4, 32)], # 2nd list of two tensors [(4,
1426               32), (8, 32)]] # 3rd input list
1427             for shapes in input_shapes: # return a list of input tensors yield
1428               [np.zeros(x).astype(np.float32) for x in shapes]`
1429
1430    Raises:
1431      NotImplementedError: build() is already called.
1432      RuntimeError: the input_fx is None.
1433    """
1434    if self._build_called_once:
1435      raise NotImplementedError("build() is already called. It is not "
1436                                "supported to call build() more than once.")
1437    if not input_fn:
1438      raise RuntimeError("input_fn is None. Method build() needs input_fn "
1439                         "to be specified in order to build TensorRT engines")
1440    if not self._converted:
1441      raise RuntimeError("Need to call convert() before build()")
1442    if (self._need_calibration and not self._calibrated and
1443        self._calibration_input_fn is None):
1444      raise RuntimeError("Need to provide the calibration_input_fn arg while "
1445                         "calling convert().")
1446
1447    def _set_profile_generation_mode(value, node):
1448      node.attr["_profile_generation_mode"].b = value
1449
1450    if self._need_trt_profiles():
1451      # Enable profile generation.
1452      self._for_each_trt_node(self._converted_graph_def,
1453                              partial(_set_profile_generation_mode, True))
1454      # Profile generation is enabled using the _profile_generation_mode
1455      # attribute of the TRTEngineOps. We need to rebuild the function to
1456      # change this attribute.
1457      func = _construct_function_from_graph_def(self._converted_func,
1458                                                self._converted_graph_def)
1459    else:
1460      func = self._converted_func
1461
1462    first_input = None
1463    # Run inference:
1464    #   Builds TRT engines if self._need_trt_profiles is False.
1465    #   Builds TRT optimization profiles if self._need_trt_profiles is True.
1466    for inp in input_fn():
1467      if not first_input:
1468        first_input = inp
1469      args, kwargs = _convert_to_tensor(inp)
1470      func(*args, **kwargs)
1471
1472    if self._need_trt_profiles():
1473      # Disable profile generation.
1474      self._for_each_trt_node(self._converted_graph_def,
1475                              partial(_set_profile_generation_mode, False))
1476
1477    # Run calibration if required, this would have been skipped in
1478    # the convert step
1479    if self._need_calibration and not self._calibrated:
1480      self._execute_calibration(self._calibration_input_fn)
1481      # calibration also builds the engine
1482    else:
1483      # Use the first input in explicit batch mode to build TensorRT engines
1484      # after generating all the profiles. The first input is used but any of
1485      # the inputs can be used because the shape of this input does not
1486      # determine the engine and instead the shapes collected in profiles
1487      # determine the engine.
1488      if isinstance(first_input, dict):
1489        self._converted_func(
1490            **{k: ops.convert_to_tensor(v) for k, v in first_input.items()})
1491      else:
1492        self._converted_func(*map(ops.convert_to_tensor, first_input))
1493
1494    self._build_called_once = True
1495
1496  def save(self,
1497           output_saved_model_dir,
1498           save_gpu_specific_engines=True,
1499           options=None):
1500    """Save the converted SavedModel.
1501
1502    Args:
1503      output_saved_model_dir: directory to saved the converted SavedModel.
1504      save_gpu_specific_engines: whether to save TRT engines that have been
1505        built. When True, all engines are saved and when False, the engines
1506        are not saved and will be rebuilt at inference time. By using
1507        save_gpu_specific_engines=False after doing INT8 calibration, inference
1508        can be done on different GPUs than the GPU that the model was calibrated
1509        and saved on.
1510      options: `tf.saved_model.SaveOptions` object for configuring save options.
1511    Raises:
1512      RuntimeError: if the needed calibration hasn't been done.
1513    """
1514    assert self._converted
1515    if self._need_calibration and not self._calibrated:
1516      raise RuntimeError("A model that requires INT8 calibration has to be "
1517                         "built before saving it. Call build() to build and "
1518                         "calibrate the TensorRT engines.")
1519    # Serialize the TRT engines in the cache if any, and create trackable
1520    # resource to track them.
1521    engine_asset_dir = tempfile.mkdtemp()
1522    resource_map = {}
1523
1524    def _serialize_and_track_engine(node):
1525      """Serialize TRT engines in the cache and track them."""
1526      # Don't dump the same cache twice.
1527      canonical_engine_name = _get_canonical_engine_name(node.name)
1528      if canonical_engine_name in resource_map:
1529        return
1530
1531      filename = os.path.join(engine_asset_dir,
1532                              "trt-serialized-engine." + canonical_engine_name)
1533
1534      try:
1535        gen_trt_ops.serialize_trt_resource(
1536            resource_name=canonical_engine_name,
1537            filename=filename,
1538            delete_resource=True,
1539            save_gpu_specific_engines=save_gpu_specific_engines)
1540      except errors.NotFoundError:
1541        logging.info(
1542            "Could not find %s in TF-TRT cache. "
1543            "This can happen if build() is not called, "
1544            "which means TensorRT engines will be built "
1545            "and cached at runtime.", canonical_engine_name)
1546        return
1547
1548      # TODO(laigd): add an option for the user to choose the device.
1549      resource_map[canonical_engine_name] = _TRTEngineResource(
1550          canonical_engine_name, filename,
1551          self._conversion_params.maximum_cached_engines)
1552
1553    self._for_each_trt_node(self._converted_graph_def,
1554                            _serialize_and_track_engine)
1555    self._saved_model.trt_engine_resources = resource_map
1556
1557    # Rewrite the signature map using the optimized ConcreteFunction.
1558    signatures = {
1559        key: value for key, value in self._saved_model.signatures.items()
1560    }
1561
1562    # Set allow_build_at_runtime=False if asked by user.
1563    #
1564    # This attribute is set here because build() needs it to be True in order to
1565    # build engines.
1566    if not self._conversion_params.allow_build_at_runtime:
1567
1568      def _reset_allow_build_at_runtime(node):
1569        node.attr["_allow_build_at_runtime"].b = False
1570
1571      self._for_each_trt_node(self._converted_graph_def,
1572                              _reset_allow_build_at_runtime)
1573      # Rebuild the function since a node attribute changed above
1574      reset_converted_func = wrap_function.function_from_graph_def(
1575          self._converted_graph_def,
1576          [tensor.name for tensor in self._converted_func.inputs],
1577          [tensor.name for tensor in self._converted_func.outputs])
1578      reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
1579          self._converted_func.graph.structured_outputs,
1580          reset_converted_func.graph.structured_outputs)
1581      reset_converted_func.graph.structured_input_signature = (
1582          self._converted_func.structured_input_signature)
1583      self._converted_func = reset_converted_func
1584
1585    signatures[self._input_saved_model_signature_key] = self._converted_func
1586    save.save(
1587        self._saved_model, output_saved_model_dir, signatures, options=options)
1588
1589  def summary(self, line_length=160, detailed=True, print_fn=None):
1590    """This method describes the results of the conversion by TF-TRT.
1591
1592    It includes information such as the name of the engine, the number of nodes
1593    per engine, the input and output dtype, along with the input shape of each
1594    TRTEngineOp.
1595
1596    Args:
1597      line_length: Default line length when printing on the console. Minimum 160
1598        characters long.
1599      detailed: Whether or not to show the nodes inside each TRTEngineOp.
1600      print_fn: Print function to use. Defaults to `print`. It will be called on
1601        each line of the summary. You can set it to a custom function in order
1602        to capture the string summary.
1603
1604    Raises:
1605      RuntimeError: if the graph is not converted.
1606    """
1607    if not self._converted:
1608      raise RuntimeError(
1609          f"Impossible to call `{self.__class__.__name__}.summary()` before "
1610          f"calling {self.__class__.__name__}.convert()`.")
1611
1612    if line_length < 160:
1613      raise ValueError(f"Invalid `line_length` value has been received: "
1614                       f"{line_length}. Minimum: 160.")
1615
1616    if print_fn is None:
1617      print_fn = print
1618
1619    # positions are percentage of `line_length`. positions[i]+1 is the starting
1620    # position for (i+1)th field. We also make sure that the last char printed
1621    # for each field is a space.
1622    columns = [
1623        # (column name, column size in % of line)
1624        ("TRTEngineOP Name", .20),  # 20%
1625        ("Device", .09),  # 29%
1626        ("# Nodes", .05),  # 34%
1627        ("# Inputs", .09),  # 43%
1628        ("# Outputs", .09),  # 52%
1629        ("Input DTypes", .12),  # 64%
1630        ("Output Dtypes", .12),  # 76%
1631        ("Input Shapes", .12),  # 88%
1632        ("Output Shapes", .12)  # 100%
1633    ]
1634
1635    positions = [int(line_length * p) for _, p in columns]
1636    positions = np.cumsum(positions).tolist()
1637    headers = [h for h, _ in columns]
1638
1639    _print_row(headers, positions, print_fn=print_fn)
1640    print_fn("=" * line_length)
1641
1642    n_engines = 0
1643    n_ops_converted = 0
1644    n_ops_not_converted = 0
1645
1646    graphdef = self._converted_func.graph.as_graph_def(add_shapes=True)
1647
1648    trtengineops_dict = dict()
1649    for node in graphdef.node:
1650      if node.op != "TRTEngineOp":
1651        n_ops_not_converted += 1
1652        continue
1653      else:
1654        trtengineops_dict[node.name] = node
1655        n_engines += 1
1656
1657    for name, node in sorted(trtengineops_dict.items()):
1658      node_device = node.device.split("/")[-1]
1659      in_shapes = trt_utils.get_node_io_shapes(node, "input_shapes")
1660      out_shapes = trt_utils.get_node_io_shapes(node, "_output_shapes")
1661      in_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "InT")
1662      out_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "OutT")
1663      in_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "InT")
1664      out_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "OutT")
1665      node_count, converted_ops_dict = trt_utils.get_trtengineop_node_op_count(
1666          graphdef, name)
1667
1668      n_ops_converted += node_count
1669
1670      if n_engines != 1:
1671        print_fn(f"\n{'-'*40}\n")
1672
1673      _print_row(
1674          fields=[
1675              name, node_device, node_count, in_nodes_count, out_nodes_count,
1676              in_dtypes, out_dtypes, in_shapes, out_shapes
1677          ],
1678          positions=positions,
1679          print_fn=print_fn)
1680
1681      if detailed:
1682        print_fn()
1683        for key, value in sorted(dict(converted_ops_dict).items()):
1684          print_fn(f"\t- {key}: {value}x")
1685
1686    print_fn(f"\n{'='*line_length}")
1687    print_fn(f"[*] Total number of TensorRT engines: {n_engines}")
1688    total_ops = n_ops_not_converted + n_ops_converted
1689    conversion_ratio = n_ops_converted / total_ops * 100
1690    print_fn(f"[*] % of OPs Converted: {conversion_ratio:.2f}% "
1691             f"[{n_ops_converted}/{total_ops}]\n")
1692
1693
1694# TODO(laigd): use TrtConversionParams here.
1695def create_inference_graph(
1696    input_graph_def,
1697    outputs,
1698    max_batch_size=1,
1699    max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
1700    precision_mode=TrtPrecisionMode.FP32,
1701    minimum_segment_size=3,
1702    is_dynamic_op=False,
1703    maximum_cached_engines=1,
1704    input_saved_model_dir=None,
1705    input_saved_model_tags=None,
1706    input_saved_model_signature_key=None,
1707    output_saved_model_dir=None):
1708  """Python wrapper for the TRT transformation.
1709
1710  Args:
1711    input_graph_def: a GraphDef object containing a model to be transformed. If
1712      set to None, the graph will be read from the SavedModel loaded from
1713      input_saved_model_dir.
1714    outputs: list of tensors or node names for the model outputs. Only used when
1715      input_graph_def is not None.
1716    max_batch_size: max size for the input batch.
1717    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
1718      engine can use at execution time. This corresponds to the 'workspaceSize'
1719      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
1720    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
1721    minimum_segment_size: the minimum number of nodes required for a subgraph to
1722      be replaced by TRTEngineOp.
1723    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
1724      network and engine at run time.
1725    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
1726      If the number of cached engines is already at max but none of them can
1727      serve the input, the TRTEngineOp will fall back to run the TF function
1728      based on which the TRTEngineOp is created.
1729    input_saved_model_dir: the directory to load the SavedModel which contains
1730      the input graph to transforms. Used only when input_graph_def is None.
1731    input_saved_model_tags: list of tags to load the SavedModel.
1732    input_saved_model_signature_key: the key of the signature to optimize the
1733      graph for.
1734    output_saved_model_dir: if not None, construct a SavedModel using the
1735      returned GraphDef and save it to the specified directory. This option only
1736      works when the input graph is loaded from a SavedModel, i.e. when
1737      input_saved_model_dir is specified and input_graph_def is None.
1738
1739  Returns:
1740    A GraphDef transformed from input_graph_def (or the SavedModel graph def
1741    loaded from input_saved_model_dir, if input_graph_def is not present), where
1742    all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
1743    function is added for each of the subgraphs.
1744
1745    If is_dynamic_op is True, each TRTEngineOp will contain a serialized
1746    subgraph GraphDef, which will be converted to a TRT engine at execution time
1747    and the TRT engine will be cached for future usage. A new TRT engine will be
1748    created each time when none of the cached engines match the input shapes. If
1749    it fails to execute the TRT engine or the number of cached engines reaches
1750    maximum_cached_engines, the op will fall back to call the corresponding TF
1751    function.
1752
1753    If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
1754    engine created from the corresponding subgraph. No more engines will be
1755    created on the fly, and the op will fall back to call the corresponding TF
1756    function when it fails to execute the engine.
1757
1758  Raises:
1759    ValueError: if the combination of the parameters is invalid.
1760  """
1761  trt_converter = TrtGraphConverter(
1762      input_saved_model_dir=input_saved_model_dir,
1763      input_saved_model_tags=input_saved_model_tags,
1764      input_saved_model_signature_key=input_saved_model_signature_key,
1765      input_graph_def=input_graph_def,
1766      nodes_denylist=outputs,
1767      max_batch_size=max_batch_size,
1768      max_workspace_size_bytes=max_workspace_size_bytes,
1769      precision_mode=precision_mode,
1770      minimum_segment_size=minimum_segment_size,
1771      is_dynamic_op=is_dynamic_op,
1772      maximum_cached_engines=maximum_cached_engines,
1773      use_calibration=False)
1774  converted_graph_def = trt_converter.convert()
1775  if output_saved_model_dir:
1776    trt_converter.save(output_saved_model_dir)
1777  return converted_graph_def
1778