1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for FlatBuffers. 16 17All functions that are commonly used to work with FlatBuffers. 18 19Refer to the tensorflow lite flatbuffer schema here: 20tensorflow/lite/schema/schema.fbs 21 22""" 23 24import copy 25import random 26import re 27 28import flatbuffers 29from tensorflow.lite.python import schema_py_generated as schema_fb 30from tensorflow.python.platform import gfile 31 32_TFLITE_FILE_IDENTIFIER = b'TFL3' 33 34 35def convert_bytearray_to_object(model_bytearray): 36 """Converts a tflite model from a bytearray to an object for parsing.""" 37 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 38 return schema_fb.ModelT.InitFromObj(model_object) 39 40 41def read_model(input_tflite_file): 42 """Reads a tflite model as a python object. 43 44 Args: 45 input_tflite_file: Full path name to the input tflite file 46 47 Raises: 48 RuntimeError: If input_tflite_file path is invalid. 49 IOError: If input_tflite_file cannot be opened. 50 51 Returns: 52 A python object corresponding to the input tflite file. 53 """ 54 if not gfile.Exists(input_tflite_file): 55 raise RuntimeError('Input file not found at %r\n' % input_tflite_file) 56 with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: 57 model_bytearray = bytearray(input_file_handle.read()) 58 return convert_bytearray_to_object(model_bytearray) 59 60 61def read_model_with_mutable_tensors(input_tflite_file): 62 """Reads a tflite model as a python object with mutable tensors. 63 64 Similar to read_model() with the addition that the returned object has 65 mutable tensors (read_model() returns an object with immutable tensors). 66 67 Args: 68 input_tflite_file: Full path name to the input tflite file 69 70 Raises: 71 RuntimeError: If input_tflite_file path is invalid. 72 IOError: If input_tflite_file cannot be opened. 73 74 Returns: 75 A mutable python object corresponding to the input tflite file. 76 """ 77 return copy.deepcopy(read_model(input_tflite_file)) 78 79 80def convert_object_to_bytearray(model_object): 81 """Converts a tflite model from an object to a immutable bytearray.""" 82 # Initial size of the buffer, which will grow automatically if needed 83 builder = flatbuffers.Builder(1024) 84 model_offset = model_object.Pack(builder) 85 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 86 model_bytearray = bytes(builder.Output()) 87 return model_bytearray 88 89 90def write_model(model_object, output_tflite_file): 91 """Writes the tflite model, a python object, into the output file. 92 93 Args: 94 model_object: A tflite model as a python object 95 output_tflite_file: Full path name to the output tflite file. 96 97 Raises: 98 IOError: If output_tflite_file path is invalid or cannot be opened. 99 """ 100 model_bytearray = convert_object_to_bytearray(model_object) 101 with gfile.GFile(output_tflite_file, 'wb') as output_file_handle: 102 output_file_handle.write(model_bytearray) 103 104 105def strip_strings(model): 106 """Strips all nonessential strings from the model to reduce model size. 107 108 We remove the following strings: 109 (find strings by searching ":string" in the tensorflow lite flatbuffer schema) 110 1. Model description 111 2. SubGraph name 112 3. Tensor names 113 We retain OperatorCode custom_code and Metadata name. 114 115 Args: 116 model: The model from which to remove nonessential strings. 117 """ 118 119 model.description = None 120 for subgraph in model.subgraphs: 121 subgraph.name = None 122 for tensor in subgraph.tensors: 123 tensor.name = None 124 # We clear all signature_def structure, since without names it is useless. 125 model.signatureDefs = None 126 127 128def randomize_weights(model, random_seed=0, buffers_to_skip=None): 129 """Randomize weights in a model. 130 131 Args: 132 model: The model in which to randomize weights. 133 random_seed: The input to the random number generator (default value is 0). 134 buffers_to_skip: The list of buffer indices to skip. The weights in these 135 buffers are left unmodified. 136 """ 137 138 # The input to the random seed generator. The default value is 0. 139 random.seed(random_seed) 140 141 # Parse model buffers which store the model weights 142 buffers = model.buffers 143 buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None 144 if buffers_to_skip is not None: 145 buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip] 146 147 for i in buffer_ids: 148 buffer_i_data = buffers[i].data 149 buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size 150 151 # Raw data buffers are of type ubyte (or uint8) whose values lie in the 152 # range [0, 255]. Those ubytes (or unint8s) are the underlying 153 # representation of each datatype. For example, a bias tensor of type 154 # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). 155 # TODO(b/152324470): This does not work for float as randomized weights may 156 # end up as denormalized or NaN/Inf floating point numbers. 157 for j in range(buffer_i_size): 158 buffer_i_data[j] = random.randint(0, 255) 159 160 161def rename_custom_ops(model, map_custom_op_renames): 162 """Rename custom ops so they use the same naming style as builtin ops. 163 164 Args: 165 model: The input tflite model. 166 map_custom_op_renames: A mapping from old to new custom op names. 167 """ 168 for op_code in model.operatorCodes: 169 if op_code.customCode: 170 op_code_str = op_code.customCode.decode('ascii') 171 if op_code_str in map_custom_op_renames: 172 op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii') 173 174 175def xxd_output_to_bytes(input_cc_file): 176 """Converts xxd output C++ source file to bytes (immutable). 177 178 Args: 179 input_cc_file: Full path name to th C++ source file dumped by xxd 180 181 Raises: 182 RuntimeError: If input_cc_file path is invalid. 183 IOError: If input_cc_file cannot be opened. 184 185 Returns: 186 A bytearray corresponding to the input cc file array. 187 """ 188 # Match hex values in the string with comma as separator 189 pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*') 190 191 model_bytearray = bytearray() 192 193 with open(input_cc_file) as file_handle: 194 for line in file_handle: 195 values_match = pattern.match(line) 196 197 if values_match is None: 198 continue 199 200 # Match in the parentheses (hex array only) 201 list_text = values_match.group(1) 202 203 # Extract hex values (text) from the line 204 # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 205 values_text = filter(None, list_text.split(',')) 206 207 # Convert to hex 208 values = [int(x, base=16) for x in values_text] 209 model_bytearray.extend(values) 210 211 return bytes(model_bytearray) 212 213 214def xxd_output_to_object(input_cc_file): 215 """Converts xxd output C++ source file to object. 216 217 Args: 218 input_cc_file: Full path name to th C++ source file dumped by xxd 219 220 Raises: 221 RuntimeError: If input_cc_file path is invalid. 222 IOError: If input_cc_file cannot be opened. 223 224 Returns: 225 A python object corresponding to the input tflite file. 226 """ 227 model_bytes = xxd_output_to_bytes(input_cc_file) 228 return convert_bytearray_to_object(model_bytes) 229