1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utilities related to plan protos.""" 15 16from typing import TypeVar 17 18import tensorflow as tf 19from fcp.artifact_building import tensor_utils 20from fcp.protos import plan_pb2 21 22 23_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan) 24 25 26# TODO(team): Remove in favor of save_from_checkpoint_op. 27def write_checkpoint(sess, checkpoint_op, checkpoint_filename): 28 """Writes from a CheckpointOp, without executing before/after restore ops.""" 29 if not isinstance(checkpoint_op, plan_pb2.CheckpointOp): 30 raise ValueError('A CheckpointOp is required.') 31 if ( 32 checkpoint_op 33 and checkpoint_op.saver_def 34 and checkpoint_op.saver_def.save_tensor_name 35 ): 36 sess.run( 37 checkpoint_op.saver_def.save_tensor_name, 38 {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename}, 39 ) 40 41 42# TODO(team): Remove in favor of restore_from_checkpoint_op. 43def read_checkpoint(sess, checkpoint_op, checkpoint_filename): 44 """Reads from a CheckpointOp, without executing before/after restore ops.""" 45 if not isinstance(checkpoint_op, plan_pb2.CheckpointOp): 46 raise ValueError('A CheckpointOp is required.') 47 if ( 48 checkpoint_op 49 and checkpoint_op.saver_def 50 and checkpoint_op.saver_def.restore_op_name 51 ): 52 sess.run( 53 checkpoint_op.saver_def.restore_op_name, 54 {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename}, 55 ) 56 57 58def convert_graphdef_to_flatbuffer( 59 graph: tf.compat.v1.GraphDef, 60 spec: plan_pb2.TensorflowSpec, 61 guarantee_all_funcs_one_use: bool = False, 62): 63 """Converts a tf.Graph to a serialized TFLite model FlatBuffer.""" 64 65 def create_input(input_tensor): 66 return (input_tensor.name, [item.size for item in input_tensor.shape.dim]) 67 68 inputs = [(spec.dataset_token_tensor_name, [])] 69 for input_tensor in spec.input_tensor_specs: 70 inputs.append(create_input(input_tensor)) 71 converter = tf.compat.v1.lite.TFLiteConverter( 72 graph, 73 input_tensors=None, 74 output_tensors=None, 75 input_arrays_with_shape=inputs, 76 output_arrays=[item.name for item in spec.output_tensor_specs], 77 ) 78 79 # pylint: disable=protected-access 80 # Sets the control output node names. This is used when converting a tf.Graph 81 # with no output tensors. 82 converter._control_output_arrays = spec.target_node_names 83 # Set this flag to true so that flatbuffer size can be reduced. 84 converter._experimental_unfold_large_splat_constant = True 85 # Exclude conversion metadata generation to reduce conversion time. 86 converter.exclude_conversion_metadata = True 87 converter.target_spec.supported_ops = [ 88 tf.lite.OpsSet.TFLITE_BUILTINS, 89 tf.lite.OpsSet.SELECT_TF_OPS, 90 ] 91 converter._experimental_allow_all_select_tf_ops = True 92 converter._experimental_guarantee_all_funcs_one_use = ( 93 guarantee_all_funcs_one_use 94 ) 95 # Instructs the TF Lite converter to not eliminate Assert ops, since the 96 # client code needs this op to verify result correctness. 97 converter._experimental_preserve_assert_op = True 98 # pylint: enable=protected-access 99 converter.experimental_enable_resource_variables = True 100 return converter.convert() 101 102 103def generate_and_add_flat_buffer_to_plan( 104 plan: _PlanT, forgive_tflite_conversion_failure=True 105) -> _PlanT: 106 """Generates and adds a TFLite model to the specified Plan. 107 108 Note: This method mutates the plan argument. 109 110 Args: 111 plan: An input plan_pb2.Plan object. 112 forgive_tflite_conversion_failure: If True, if TFLite conversion fails no 113 exception will be raised and the Plan will be returned unmutated. 114 115 Returns: 116 The input Plan mutated to include a TFLite model when TFLite conversion 117 succeeds, or the Plan without any mutation if TFLite conversion does not 118 succeed. 119 120 Raises: 121 RuntimeError: if TFLite conversion fails and 122 forgive_tflite_conversion_failure is set to False. 123 """ 124 125 def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False): 126 stateful_partitioned_call_err = ( 127 "'tf.StatefulPartitionedCall' op is" 128 + ' neither a custom op nor a flex op' 129 ) 130 # Pack the TFLite flatbuffer into a BytesValue proto. 131 try: 132 return convert_graphdef_to_flatbuffer( 133 graph_def, tensorflow_spec, guarantee_all_funcs_one_use 134 ) 135 except Exception as e: # pylint: disable=broad-except 136 # Try to handle conversion errors and run converter again. 137 if ( 138 stateful_partitioned_call_err in str(e) 139 and not guarantee_all_funcs_one_use 140 ): 141 return convert(graph_def, tensorflow_spec, True) 142 elif forgive_tflite_conversion_failure: 143 return b'' 144 else: 145 raise RuntimeError( 146 f'Failure during TFLite conversion of the client graph: {str(e)}' 147 ) from e 148 149 if isinstance(plan, plan_pb2.Plan): 150 client_graph_def = tensor_utils.import_graph_def_from_any( 151 plan.client_graph_bytes 152 ) 153 plan.client_tflite_graph_bytes = convert( 154 client_graph_def, plan.phase[0].client_phase.tensorflow_spec 155 ) 156 elif isinstance(plan, plan_pb2.ClientOnlyPlan): 157 client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph) 158 plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec) 159 else: 160 raise NotImplementedError(f'Unsupported _PlanT {type(plan)}') 161 return plan 162