xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/testing/zip_test_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for make_zip tests."""
16import functools
17import io
18import itertools
19import operator
20import os
21import re
22import string
23import tempfile
24import traceback
25import zipfile
26
27import numpy as np
28import tensorflow.compat.v1 as tf
29
30from google.protobuf import text_format
31from tensorflow.lite.testing import _pywrap_string_util
32from tensorflow.lite.testing import generate_examples_report as report_lib
33from tensorflow.python.framework import graph_util as tf_graph_util
34from tensorflow.python.saved_model import signature_constants
35
36# pylint: disable=g-import-not-at-top
37
38# A map from names to functions which make test cases.
39_MAKE_TEST_FUNCTIONS_MAP = {}
40
41
42# A decorator to register the make test functions.
43# Usage:
44# All the make_*_test should be registered. Example:
45#   @register_make_test_function()
46#   def make_conv_tests(options):
47#     # ...
48# If a function is decorated by other decorators, it's required to specify the
49# name explicitly. Example:
50#   @register_make_test_function(name="make_unidirectional_sequence_lstm_tests")
51#   @test_util.enable_control_flow_v2
52#   def make_unidirectional_sequence_lstm_tests(options):
53#     # ...
54def register_make_test_function(name=None):
55
56  def decorate(function, name=name):
57    if name is None:
58      name = function.__name__
59    _MAKE_TEST_FUNCTIONS_MAP[name] = function
60
61  return decorate
62
63
64def get_test_function(test_function_name):
65  """Get the test function according to the test function name."""
66
67  if test_function_name not in _MAKE_TEST_FUNCTIONS_MAP:
68    return None
69  return _MAKE_TEST_FUNCTIONS_MAP[test_function_name]
70
71
72RANDOM_SEED = 342
73
74MAP_TF_TO_NUMPY_TYPE = {
75    tf.float32: np.float32,
76    tf.float16: np.float16,
77    tf.float64: np.float64,
78    tf.complex64: np.complex64,
79    tf.complex128: np.complex128,
80    tf.int32: np.int32,
81    tf.uint32: np.uint32,
82    tf.uint8: np.uint8,
83    tf.int8: np.int8,
84    tf.uint16: np.uint16,
85    tf.int16: np.int16,
86    tf.int64: np.int64,
87    tf.bool: np.bool_,
88    tf.string: np.string_,
89}
90
91
92class ExtraConvertOptions:
93  """Additional options for conversion, besides input, output, shape."""
94
95  def __init__(self):
96    # Whether to ignore control dependency nodes.
97    self.drop_control_dependency = False
98    # Allow custom ops in the conversion.
99    self.allow_custom_ops = False
100    # Rnn states that are used to support rnn / lstm cells.
101    self.rnn_states = None
102    # Split the LSTM inputs from 5 inputs to 18 inputs for TFLite.
103    self.split_tflite_lstm_inputs = None
104    # The inference input type passed to TFLiteConvert.
105    self.inference_input_type = None
106    # The inference output type passed to TFLiteConvert.
107    self.inference_output_type = None
108
109
110def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
111  """Build tensor data spreading the range [min_value, max_value)."""
112
113  if dtype in MAP_TF_TO_NUMPY_TYPE:
114    dtype = MAP_TF_TO_NUMPY_TYPE[dtype]
115
116  if dtype in (tf.float32, tf.float16, tf.float64):
117    value = (max_value - min_value) * np.random.random_sample(shape) + min_value
118  elif dtype in (tf.complex64, tf.complex128):
119    real = (max_value - min_value) * np.random.random_sample(shape) + min_value
120    imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
121    value = real + imag * 1j
122  elif dtype in (tf.uint32, tf.int32, tf.uint8, tf.int8, tf.int64, tf.uint16,
123                 tf.int16):
124    value = np.random.randint(min_value, max_value + 1, shape)
125  elif dtype == tf.bool:
126    value = np.random.choice([True, False], size=shape)
127  elif dtype == np.string_:
128    # Not the best strings, but they will do for some basic testing.
129    letters = list(string.ascii_uppercase)
130    return np.random.choice(letters, size=shape).astype(dtype)
131  return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
132      dtype)
133
134
135def create_scalar_data(dtype, min_value=-100, max_value=100):
136  """Build scalar tensor data range from min_value to max_value exclusively."""
137
138  if dtype in MAP_TF_TO_NUMPY_TYPE:
139    dtype = MAP_TF_TO_NUMPY_TYPE[dtype]
140
141  if dtype in (tf.float32, tf.float16, tf.float64):
142    value = (max_value - min_value) * np.random.random() + min_value
143  elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
144    value = np.random.randint(min_value, max_value + 1)
145  elif dtype == tf.bool:
146    value = np.random.choice([True, False])
147  elif dtype == np.string_:
148    l = np.random.randint(1, 6)
149    value = "".join(np.random.choice(list(string.ascii_uppercase), size=l))
150  return np.array(value, dtype=dtype)
151
152
153def freeze_graph(session, outputs):
154  """Freeze the current graph.
155
156  Args:
157    session: Tensorflow sessions containing the graph
158    outputs: List of output tensors
159
160  Returns:
161    The frozen graph_def.
162  """
163  return tf_graph_util.convert_variables_to_constants(
164      session, session.graph.as_graph_def(), [x.op.name for x in outputs])
165
166
167def format_result(t):
168  """Convert a tensor to a format that can be used in test specs."""
169  if t.dtype.kind not in [np.dtype(np.string_).kind, np.dtype(np.object_).kind]:
170    # Output 9 digits after the point to ensure the precision is good enough.
171    values = ["{:.9f}".format(value) for value in list(t.flatten())]
172    return ",".join(values)
173  else:
174    # SerializeAsHexString returns bytes in PY3, so decode if appropriate.
175    return _pywrap_string_util.SerializeAsHexString(t.flatten()).decode("utf-8")
176
177
178def write_examples(fp, examples):
179  """Given a list `examples`, write a text format representation.
180
181  The file format is csv like with a simple repeated pattern. We would ike
182  to use proto here, but we can't yet due to interfacing with the Android
183  team using this format.
184
185  Args:
186    fp: File-like object to write to.
187    examples: Example dictionary consisting of keys "inputs" and "outputs"
188  """
189
190  def write_tensor(fp, name, x):
191    """Write tensor in file format supported by TFLITE example."""
192    fp.write("name,%s\n" % name)
193    fp.write("dtype,%s\n" % x.dtype)
194    fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
195    fp.write("values," + format_result(x) + "\n")
196
197  fp.write("test_cases,%d\n" % len(examples))
198  for example in examples:
199    fp.write("inputs,%d\n" % len(example["inputs"]))
200    for name, value in example["inputs"].items():
201      if value is not None:
202        write_tensor(fp, name, value)
203    fp.write("outputs,%d\n" % len(example["outputs"]))
204    for name, value in example["outputs"].items():
205      write_tensor(fp, name, value)
206
207
208class TextFormatWriter:
209  """Utility class for writing ProtoBuf like messages."""
210
211  def __init__(self, fp, name=None, parent=None):
212    self.fp = fp
213    self.indent = parent.indent if parent else 0
214    self.name = name
215
216  def __enter__(self):
217    if self.name:
218      self.write(self.name + " {")
219      self.indent += 2
220    return self
221
222  def __exit__(self, *exc_info):
223    if self.name:
224      self.indent -= 2
225      self.write("}")
226    return True
227
228  def write(self, data):
229    self.fp.write(" " * self.indent + data + "\n")
230
231  def write_field(self, key, val):
232    self.write(key + ": \"" + val + "\"")
233
234  def sub_message(self, name):
235    return TextFormatWriter(self.fp, name, self)
236
237
238def write_test_cases(fp, model_name, examples):
239  """Given a dictionary of `examples`, write a text format representation.
240
241  The file format is protocol-buffer-like, even though we don't use proto due
242  to the needs of the Android team.
243
244  Args:
245    fp: File-like object to write to.
246    model_name: Filename where the model was written to, relative to filename.
247    examples: Example dictionary consisting of keys "inputs" and "outputs"
248
249  Raises:
250    RuntimeError: Example dictionary does not have input / output names.
251  """
252
253  writer = TextFormatWriter(fp)
254  writer.write_field("load_model", os.path.basename(model_name))
255  for example in examples:
256    inputs = []
257    for name in example["inputs"].keys():
258      if name:
259        inputs.append(name)
260    outputs = []
261    for name in example["outputs"].keys():
262      if name:
263        outputs.append(name)
264    if not (inputs and outputs):
265      raise RuntimeError("Empty input / output names.")
266
267    # Reshape message
268    with writer.sub_message("reshape") as reshape:
269      for name, value in example["inputs"].items():
270        with reshape.sub_message("input") as input_msg:
271          input_msg.write_field("key", name)
272          input_msg.write_field("value", ",".join(map(str, value.shape)))
273
274    # Invoke message
275    with writer.sub_message("invoke") as invoke:
276      for name, value in example["inputs"].items():
277        with invoke.sub_message("input") as input_msg:
278          input_msg.write_field("key", name)
279          input_msg.write_field("value", format_result(value))
280      # Expectations
281      for name, value in example["outputs"].items():
282        with invoke.sub_message("output") as output_msg:
283          output_msg.write_field("key", name)
284          output_msg.write_field("value", format_result(value))
285        with invoke.sub_message("output_shape") as output_shape:
286          output_shape.write_field("key", name)
287          output_shape.write_field("value",
288                                   ",".join([str(dim) for dim in value.shape]))
289
290
291def get_input_shapes_map(input_tensors):
292  """Gets a map of input names to shapes.
293
294  Args:
295    input_tensors: List of input tensor tuples `(name, shape, type)`.
296
297  Returns:
298    {string : list of integers}.
299  """
300  input_arrays = [tensor[0] for tensor in input_tensors]
301  input_shapes_list = []
302
303  for _, shape, _ in input_tensors:
304    dims = None
305    if shape:
306      dims = [dim.value for dim in shape.dims]
307    input_shapes_list.append(dims)
308
309  input_shapes = {
310      name: shape
311      for name, shape in zip(input_arrays, input_shapes_list)
312      if shape
313  }
314  return input_shapes
315
316
317def _normalize_input_name(input_name):
318  """Remove :i suffix from input tensor names."""
319  return input_name.split(":")[0]
320
321
322def _normalize_output_name(output_name):
323  """Remove :0 suffix from output tensor names."""
324  return output_name.split(":")[0] if output_name.endswith(
325      ":0") else output_name
326
327
328def _get_tensor_info(tensors, default_name_prefix, normalize_func):
329  """Get the list of tensor name and info."""
330  tensor_names = []
331  tensor_info_map = {}
332  for idx, tensor in enumerate(tensors):
333    if not tensor.name:
334      tensor.name = default_name_prefix + str(idx)
335    tensor_info = tf.saved_model.utils.build_tensor_info(tensor)
336    tensor_name = normalize_func(tensor.name)
337    tensor_info_map[tensor_name] = tensor_info
338    tensor_names.append(tensor_name)
339  return tensor_names, tensor_info_map
340
341
342# How many test cases we may have in a zip file. Too many test cases will
343# slow down the test data generation process.
344_MAX_TESTS_PER_ZIP = 500
345
346
347def make_zip_of_tests(options,
348                      test_parameters,
349                      make_graph,
350                      make_test_inputs,
351                      extra_convert_options=ExtraConvertOptions(),
352                      use_frozen_graph=False,
353                      expected_tf_failures=0):
354  """Helper to make a zip file of a bunch of TensorFlow models.
355
356  This does a cartesian product of the dictionary of test_parameters and
357  calls make_graph() for each item in the cartesian product set.
358  If the graph is built successfully, then make_test_inputs() is called to
359  build expected input/output value pairs. The model is then converted to
360  tflite, and the examples are serialized with the tflite model into a zip
361  file (2 files per item in the cartesian product set).
362
363  Args:
364    options: An Options instance.
365    test_parameters: Dictionary mapping to lists for each parameter.
366      e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
367    make_graph: function that takes current parameters and returns tuple
368      `[input1, input2, ...], [output1, output2, ...]`
369    make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
370      `output_tensors` and returns tuple `(input_values, output_values)`.
371    extra_convert_options: Additional convert options.
372    use_frozen_graph: Whether or not freeze graph before convertion.
373    expected_tf_failures: Number of times tensorflow is expected to fail in
374      executing the input graphs. In some cases it is OK for TensorFlow to fail
375      because the one or more combination of parameters is invalid.
376
377  Raises:
378    RuntimeError: if there are converter errors that can't be ignored.
379  """
380  zip_path = os.path.join(options.output_path, options.zip_to_output)
381  parameter_count = 0
382  for parameters in test_parameters:
383    parameter_count += functools.reduce(
384        operator.mul, [len(values) for values in parameters.values()])
385
386  all_parameter_count = parameter_count
387  if options.multi_gen_state:
388    all_parameter_count += options.multi_gen_state.parameter_count
389  if not options.no_tests_limit and all_parameter_count > _MAX_TESTS_PER_ZIP:
390    raise RuntimeError(
391        "Too many parameter combinations for generating '%s'.\n"
392        "There are at least %d combinations while the upper limit is %d.\n"
393        "Having too many combinations will slow down the tests.\n"
394        "Please consider splitting the test into multiple functions.\n" %
395        (zip_path, all_parameter_count, _MAX_TESTS_PER_ZIP))
396  if options.multi_gen_state:
397    options.multi_gen_state.parameter_count = all_parameter_count
398
399  # TODO(aselle): Make this allow multiple inputs outputs.
400  if options.multi_gen_state:
401    archive = options.multi_gen_state.archive
402  else:
403    archive = zipfile.PyZipFile(zip_path, "w")
404  zip_manifest = []
405  convert_report = []
406  converter_errors = 0
407
408  processed_labels = set()
409
410  if options.make_tf_ptq_tests:
411    # For cases with fully_quantize is True, also generates a case with
412    # fully_quantize is False. Marks these cases as suitable for PTQ tests.
413    parameter_count = 0
414    for parameters in test_parameters:
415      if True in parameters.get("fully_quantize", []):
416        parameters.update({"fully_quantize": [True, False], "tf_ptq": [True]})
417        # TODO(b/199054047): Support 16x8 quantization in TF Quantization.
418        parameters.update({"quant_16x8": [False]})
419        parameter_count += functools.reduce(
420            operator.mul, [len(values) for values in parameters.values()])
421
422  if options.make_edgetpu_tests:
423    extra_convert_options.inference_input_type = tf.uint8
424    extra_convert_options.inference_output_type = tf.uint8
425    # Only count parameters when fully_quantize is True.
426    parameter_count = 0
427    for parameters in test_parameters:
428      if True in parameters.get("fully_quantize",
429                                []) and False in parameters.get(
430                                    "quant_16x8", [False]):
431        parameter_count += functools.reduce(operator.mul, [
432            len(values)
433            for key, values in parameters.items()
434            if key != "fully_quantize" and key != "quant_16x8"
435        ])
436
437  label_base_path = zip_path
438  if options.multi_gen_state:
439    label_base_path = options.multi_gen_state.label_base_path
440
441  i = 1
442  for parameters in test_parameters:
443    keys = parameters.keys()
444    for curr in itertools.product(*parameters.values()):
445      label = label_base_path.replace(".zip", "_") + (",".join(
446          "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
447      if label[0] == "/":
448        label = label[1:]
449
450      zip_path_label = label
451      if len(os.path.basename(zip_path_label)) > 245:
452        zip_path_label = label_base_path.replace(".zip", "_") + str(i)
453
454      i += 1
455      if label in processed_labels:
456        # Do not populate data for the same label more than once. It will cause
457        # errors when unzipping.
458        continue
459      processed_labels.add(label)
460
461      param_dict = dict(zip(keys, curr))
462
463      if options.make_tf_ptq_tests and not param_dict.get("tf_ptq", False):
464        continue
465
466      if options.make_edgetpu_tests and (not param_dict.get(
467          "fully_quantize", False) or param_dict.get("quant_16x8", False)):
468        continue
469
470      def generate_inputs_outputs(tflite_model_binary,
471                                  min_value=0,
472                                  max_value=255):
473        """Generate input values and output values of the given tflite model.
474
475        Args:
476          tflite_model_binary: A serialized flatbuffer as a string.
477          min_value: min value for the input tensor.
478          max_value: max value for the input tensor.
479
480        Returns:
481          (input_values, output_values): Maps of input values and output values
482          built.
483        """
484        interpreter = tf.lite.Interpreter(model_content=tflite_model_binary)
485        interpreter.allocate_tensors()
486
487        input_details = interpreter.get_input_details()
488        input_values = {}
489        for input_detail in input_details:
490          input_value = create_tensor_data(
491              input_detail["dtype"],
492              input_detail["shape"],
493              min_value=min_value,
494              max_value=max_value)
495          interpreter.set_tensor(input_detail["index"], input_value)
496          input_values.update(
497              {_normalize_input_name(input_detail["name"]): input_value})
498
499        interpreter.invoke()
500
501        output_details = interpreter.get_output_details()
502        output_values = {}
503        for output_detail in output_details:
504          output_values.update({
505              _normalize_output_name(output_detail["name"]):
506                  interpreter.get_tensor(output_detail["index"])
507          })
508
509        return input_values, output_values
510
511      def build_example(label, param_dict_real, zip_path_label):
512        """Build the model with parameter values set in param_dict_real.
513
514        Args:
515          label: Label of the model
516          param_dict_real: Parameter dictionary (arguments to the factories
517            make_graph and make_test_inputs)
518          zip_path_label: Filename in the zip
519
520        Returns:
521          (tflite_model_binary, report) where tflite_model_binary is the
522          serialized flatbuffer as a string and report is a dictionary with
523          keys `tflite_converter_log` (log of conversion), `tf_log` (log of tf
524          conversion), `converter` (a string of success status of the
525          conversion), `tf` (a string success status of the conversion).
526        """
527
528        np.random.seed(RANDOM_SEED)
529        report = {
530            "tflite_converter": report_lib.NOTRUN,
531            "tf": report_lib.FAILED
532        }
533
534        # Build graph
535        report["tf_log"] = ""
536        report["tflite_converter_log"] = ""
537        tf.reset_default_graph()
538
539        with tf.Graph().as_default():
540          with tf.device("/cpu:0"):
541            try:
542              inputs, outputs = make_graph(param_dict_real)
543              inputs = [x for x in inputs if x is not None]
544            except (tf.errors.UnimplementedError,
545                    tf.errors.InvalidArgumentError, ValueError):
546              report["tf_log"] += traceback.format_exc()
547              return None, report
548
549          sess = tf.Session()
550          try:
551            baseline_inputs, baseline_outputs = (
552                make_test_inputs(param_dict_real, sess, inputs, outputs))
553            baseline_inputs = [x for x in baseline_inputs if x is not None]
554            # Converts baseline inputs/outputs to maps. The signature input and
555            # output names are set to be the same as the tensor names.
556            input_names = [_normalize_input_name(x.name) for x in inputs]
557            output_names = [_normalize_output_name(x.name) for x in outputs]
558            baseline_input_map = dict(zip(input_names, baseline_inputs))
559            baseline_output_map = dict(zip(output_names, baseline_outputs))
560          except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
561                  ValueError):
562            report["tf_log"] += traceback.format_exc()
563            return None, report
564          report["tflite_converter"] = report_lib.FAILED
565          report["tf"] = report_lib.SUCCESS
566
567          # Builds a saved model with the default signature key.
568          input_names, tensor_info_inputs = _get_tensor_info(
569              inputs, "input_", _normalize_input_name)
570          output_tensors, tensor_info_outputs = _get_tensor_info(
571              outputs, "output_", _normalize_output_name)
572          input_tensors = [
573              (name, t.shape, t.dtype) for name, t in zip(input_names, inputs)
574          ]
575
576          inference_signature = (
577              tf.saved_model.signature_def_utils.build_signature_def(
578                  inputs=tensor_info_inputs,
579                  outputs=tensor_info_outputs,
580                  method_name="op_test"))
581          saved_model_dir = tempfile.mkdtemp("op_test")
582          saved_model_tags = [tf.saved_model.tag_constants.SERVING]
583          signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
584          builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
585          builder.add_meta_graph_and_variables(
586              sess,
587              saved_model_tags,
588              signature_def_map={
589                  signature_key: inference_signature,
590              },
591              strip_default_attrs=True)
592          builder.save(as_text=False)
593          # pylint: disable=g-long-ternary
594          graph_def = freeze_graph(
595              sess,
596              tf.global_variables() + inputs +
597              outputs) if use_frozen_graph else sess.graph_def
598
599        if "split_tflite_lstm_inputs" in param_dict_real:
600          extra_convert_options.split_tflite_lstm_inputs = param_dict_real[
601              "split_tflite_lstm_inputs"]
602        tflite_model_binary, converter_log = options.tflite_convert_function(
603            options,
604            saved_model_dir,
605            input_tensors,
606            output_tensors,
607            extra_convert_options=extra_convert_options,
608            test_params=param_dict_real)
609        report["tflite_converter"] = (
610            report_lib.SUCCESS
611            if tflite_model_binary is not None else report_lib.FAILED)
612        report["tflite_converter_log"] = converter_log
613
614        if options.save_graphdefs:
615          zipinfo = zipfile.ZipInfo(zip_path_label + ".pbtxt")
616          archive.writestr(zipinfo, text_format.MessageToString(graph_def),
617                           zipfile.ZIP_DEFLATED)
618
619        if tflite_model_binary:
620          if options.make_edgetpu_tests:
621            # Set proper min max values according to input dtype.
622            baseline_input_map, baseline_output_map = generate_inputs_outputs(
623                tflite_model_binary, min_value=0, max_value=255)
624          zipinfo = zipfile.ZipInfo(zip_path_label + ".bin")
625          archive.writestr(zipinfo, tflite_model_binary, zipfile.ZIP_DEFLATED)
626
627          example = {
628              "inputs": baseline_input_map,
629              "outputs": baseline_output_map
630          }
631
632          example_fp = io.StringIO()
633          write_examples(example_fp, [example])
634          zipinfo = zipfile.ZipInfo(zip_path_label + ".inputs")
635          archive.writestr(zipinfo, example_fp.getvalue(), zipfile.ZIP_DEFLATED)
636
637          example_fp2 = io.StringIO()
638          write_test_cases(example_fp2, zip_path_label + ".bin", [example])
639          zipinfo = zipfile.ZipInfo(zip_path_label + "_tests.txt")
640          archive.writestr(zipinfo, example_fp2.getvalue(),
641                           zipfile.ZIP_DEFLATED)
642
643          zip_manifest_label = zip_path_label + " " + label
644          if zip_path_label == label:
645            zip_manifest_label = zip_path_label
646
647          zip_manifest.append(zip_manifest_label + "\n")
648
649        return tflite_model_binary, report
650
651      _, report = build_example(label, param_dict, zip_path_label)
652
653      if report["tflite_converter"] == report_lib.FAILED:
654        ignore_error = False
655        if not options.known_bugs_are_errors:
656          for pattern, bug_number in options.known_bugs.items():
657            if re.search(pattern, label):
658              print("Ignored converter error due to bug %s" % bug_number)
659              ignore_error = True
660        if not ignore_error:
661          converter_errors += 1
662          print("-----------------\nconverter error!\n%s\n-----------------\n" %
663                report["tflite_converter_log"])
664
665      convert_report.append((param_dict, report))
666
667  if not options.no_conversion_report:
668    report_io = io.StringIO()
669    report_lib.make_report_table(report_io, zip_path, convert_report)
670    if options.multi_gen_state:
671      zipinfo = zipfile.ZipInfo("report_" + options.multi_gen_state.test_name +
672                                ".html")
673      archive.writestr(zipinfo, report_io.getvalue())
674    else:
675      zipinfo = zipfile.ZipInfo("report.html")
676      archive.writestr(zipinfo, report_io.getvalue())
677
678  if options.multi_gen_state:
679    options.multi_gen_state.zip_manifest.extend(zip_manifest)
680  else:
681    zipinfo = zipfile.ZipInfo("manifest.txt")
682    archive.writestr(zipinfo, "".join(zip_manifest), zipfile.ZIP_DEFLATED)
683
684  # Log statistics of what succeeded
685  total_conversions = len(convert_report)
686  tf_success = sum(
687      1 for x in convert_report if x[1]["tf"] == report_lib.SUCCESS)
688  converter_success = sum(1 for x in convert_report
689                          if x[1]["tflite_converter"] == report_lib.SUCCESS)
690  percent = 0
691  if tf_success > 0:
692    percent = float(converter_success) / float(tf_success) * 100.
693  tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
694                   " and %d converted graphs (%.1f%%"), zip_path,
695                  total_conversions, tf_success, converter_success, percent)
696
697  tf_failures = parameter_count - tf_success
698
699  if tf_failures / parameter_count > 0.8:
700    raise RuntimeError(("Test for '%s' is not very useful. "
701                        "TensorFlow fails in %d percent of the cases.") %
702                       (zip_path, int(100 * tf_failures / parameter_count)))
703
704  if tf_failures != expected_tf_failures and not (options.make_edgetpu_tests or
705                                                  options.make_tf_ptq_tests):
706    raise RuntimeError(("Expected TF to fail %d times while generating '%s', "
707                        "but that happened %d times") %
708                       (expected_tf_failures, zip_path, tf_failures))
709
710  if not options.ignore_converter_errors and converter_errors > 0:
711    raise RuntimeError("Found %d errors while generating models" %
712                       converter_errors)
713