xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/flatbuffer_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for FlatBuffers.
16
17All functions that are commonly used to work with FlatBuffers.
18
19Refer to the tensorflow lite flatbuffer schema here:
20tensorflow/lite/schema/schema.fbs
21
22"""
23
24import copy
25import random
26import re
27
28import flatbuffers
29from tensorflow.lite.python import schema_py_generated as schema_fb
30from tensorflow.python.platform import gfile
31
32_TFLITE_FILE_IDENTIFIER = b'TFL3'
33
34
35def convert_bytearray_to_object(model_bytearray):
36  """Converts a tflite model from a bytearray to an object for parsing."""
37  model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
38  return schema_fb.ModelT.InitFromObj(model_object)
39
40
41def read_model(input_tflite_file):
42  """Reads a tflite model as a python object.
43
44  Args:
45    input_tflite_file: Full path name to the input tflite file
46
47  Raises:
48    RuntimeError: If input_tflite_file path is invalid.
49    IOError: If input_tflite_file cannot be opened.
50
51  Returns:
52    A python object corresponding to the input tflite file.
53  """
54  if not gfile.Exists(input_tflite_file):
55    raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
56  with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
57    model_bytearray = bytearray(input_file_handle.read())
58  return convert_bytearray_to_object(model_bytearray)
59
60
61def read_model_with_mutable_tensors(input_tflite_file):
62  """Reads a tflite model as a python object with mutable tensors.
63
64  Similar to read_model() with the addition that the returned object has
65  mutable tensors (read_model() returns an object with immutable tensors).
66
67  Args:
68    input_tflite_file: Full path name to the input tflite file
69
70  Raises:
71    RuntimeError: If input_tflite_file path is invalid.
72    IOError: If input_tflite_file cannot be opened.
73
74  Returns:
75    A mutable python object corresponding to the input tflite file.
76  """
77  return copy.deepcopy(read_model(input_tflite_file))
78
79
80def convert_object_to_bytearray(model_object):
81  """Converts a tflite model from an object to a immutable bytearray."""
82  # Initial size of the buffer, which will grow automatically if needed
83  builder = flatbuffers.Builder(1024)
84  model_offset = model_object.Pack(builder)
85  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
86  model_bytearray = bytes(builder.Output())
87  return model_bytearray
88
89
90def write_model(model_object, output_tflite_file):
91  """Writes the tflite model, a python object, into the output file.
92
93  Args:
94    model_object: A tflite model as a python object
95    output_tflite_file: Full path name to the output tflite file.
96
97  Raises:
98    IOError: If output_tflite_file path is invalid or cannot be opened.
99  """
100  model_bytearray = convert_object_to_bytearray(model_object)
101  with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
102    output_file_handle.write(model_bytearray)
103
104
105def strip_strings(model):
106  """Strips all nonessential strings from the model to reduce model size.
107
108  We remove the following strings:
109  (find strings by searching ":string" in the tensorflow lite flatbuffer schema)
110  1. Model description
111  2. SubGraph name
112  3. Tensor names
113  We retain OperatorCode custom_code and Metadata name.
114
115  Args:
116    model: The model from which to remove nonessential strings.
117  """
118
119  model.description = None
120  for subgraph in model.subgraphs:
121    subgraph.name = None
122    for tensor in subgraph.tensors:
123      tensor.name = None
124  # We clear all signature_def structure, since without names it is useless.
125  model.signatureDefs = None
126
127
128def randomize_weights(model, random_seed=0, buffers_to_skip=None):
129  """Randomize weights in a model.
130
131  Args:
132    model: The model in which to randomize weights.
133    random_seed: The input to the random number generator (default value is 0).
134    buffers_to_skip: The list of buffer indices to skip. The weights in these
135                     buffers are left unmodified.
136  """
137
138  # The input to the random seed generator. The default value is 0.
139  random.seed(random_seed)
140
141  # Parse model buffers which store the model weights
142  buffers = model.buffers
143  buffer_ids = range(1, len(buffers))  # ignore index 0 as it's always None
144  if buffers_to_skip is not None:
145    buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
146
147  for i in buffer_ids:
148    buffer_i_data = buffers[i].data
149    buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
150
151    # Raw data buffers are of type ubyte (or uint8) whose values lie in the
152    # range [0, 255]. Those ubytes (or unint8s) are the underlying
153    # representation of each datatype. For example, a bias tensor of type
154    # int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
155    # TODO(b/152324470): This does not work for float as randomized weights may
156    # end up as denormalized or NaN/Inf floating point numbers.
157    for j in range(buffer_i_size):
158      buffer_i_data[j] = random.randint(0, 255)
159
160
161def rename_custom_ops(model, map_custom_op_renames):
162  """Rename custom ops so they use the same naming style as builtin ops.
163
164  Args:
165    model: The input tflite model.
166    map_custom_op_renames: A mapping from old to new custom op names.
167  """
168  for op_code in model.operatorCodes:
169    if op_code.customCode:
170      op_code_str = op_code.customCode.decode('ascii')
171      if op_code_str in map_custom_op_renames:
172        op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
173
174
175def xxd_output_to_bytes(input_cc_file):
176  """Converts xxd output C++ source file to bytes (immutable).
177
178  Args:
179    input_cc_file: Full path name to th C++ source file dumped by xxd
180
181  Raises:
182    RuntimeError: If input_cc_file path is invalid.
183    IOError: If input_cc_file cannot be opened.
184
185  Returns:
186    A bytearray corresponding to the input cc file array.
187  """
188  # Match hex values in the string with comma as separator
189  pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
190
191  model_bytearray = bytearray()
192
193  with open(input_cc_file) as file_handle:
194    for line in file_handle:
195      values_match = pattern.match(line)
196
197      if values_match is None:
198        continue
199
200      # Match in the parentheses (hex array only)
201      list_text = values_match.group(1)
202
203      # Extract hex values (text) from the line
204      # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
205      values_text = filter(None, list_text.split(','))
206
207      # Convert to hex
208      values = [int(x, base=16) for x in values_text]
209      model_bytearray.extend(values)
210
211  return bytes(model_bytearray)
212
213
214def xxd_output_to_object(input_cc_file):
215  """Converts xxd output C++ source file to object.
216
217  Args:
218    input_cc_file: Full path name to th C++ source file dumped by xxd
219
220  Raises:
221    RuntimeError: If input_cc_file path is invalid.
222    IOError: If input_cc_file cannot be opened.
223
224  Returns:
225    A python object corresponding to the input tflite file.
226  """
227  model_bytes = xxd_output_to_bytes(input_cc_file)
228  return convert_bytearray_to_object(model_bytes)
229