1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Utilities related to plan protos.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom typing import TypeVar 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 19*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils 20*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan) 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Worker# TODO(team): Remove in favor of save_from_checkpoint_op. 27*14675a02SAndroid Build Coastguard Workerdef write_checkpoint(sess, checkpoint_op, checkpoint_filename): 28*14675a02SAndroid Build Coastguard Worker """Writes from a CheckpointOp, without executing before/after restore ops.""" 29*14675a02SAndroid Build Coastguard Worker if not isinstance(checkpoint_op, plan_pb2.CheckpointOp): 30*14675a02SAndroid Build Coastguard Worker raise ValueError('A CheckpointOp is required.') 31*14675a02SAndroid Build Coastguard Worker if ( 32*14675a02SAndroid Build Coastguard Worker checkpoint_op 33*14675a02SAndroid Build Coastguard Worker and checkpoint_op.saver_def 34*14675a02SAndroid Build Coastguard Worker and checkpoint_op.saver_def.save_tensor_name 35*14675a02SAndroid Build Coastguard Worker ): 36*14675a02SAndroid Build Coastguard Worker sess.run( 37*14675a02SAndroid Build Coastguard Worker checkpoint_op.saver_def.save_tensor_name, 38*14675a02SAndroid Build Coastguard Worker {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename}, 39*14675a02SAndroid Build Coastguard Worker ) 40*14675a02SAndroid Build Coastguard Worker 41*14675a02SAndroid Build Coastguard Worker 42*14675a02SAndroid Build Coastguard Worker# TODO(team): Remove in favor of restore_from_checkpoint_op. 43*14675a02SAndroid Build Coastguard Workerdef read_checkpoint(sess, checkpoint_op, checkpoint_filename): 44*14675a02SAndroid Build Coastguard Worker """Reads from a CheckpointOp, without executing before/after restore ops.""" 45*14675a02SAndroid Build Coastguard Worker if not isinstance(checkpoint_op, plan_pb2.CheckpointOp): 46*14675a02SAndroid Build Coastguard Worker raise ValueError('A CheckpointOp is required.') 47*14675a02SAndroid Build Coastguard Worker if ( 48*14675a02SAndroid Build Coastguard Worker checkpoint_op 49*14675a02SAndroid Build Coastguard Worker and checkpoint_op.saver_def 50*14675a02SAndroid Build Coastguard Worker and checkpoint_op.saver_def.restore_op_name 51*14675a02SAndroid Build Coastguard Worker ): 52*14675a02SAndroid Build Coastguard Worker sess.run( 53*14675a02SAndroid Build Coastguard Worker checkpoint_op.saver_def.restore_op_name, 54*14675a02SAndroid Build Coastguard Worker {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename}, 55*14675a02SAndroid Build Coastguard Worker ) 56*14675a02SAndroid Build Coastguard Worker 57*14675a02SAndroid Build Coastguard Worker 58*14675a02SAndroid Build Coastguard Workerdef convert_graphdef_to_flatbuffer( 59*14675a02SAndroid Build Coastguard Worker graph: tf.compat.v1.GraphDef, 60*14675a02SAndroid Build Coastguard Worker spec: plan_pb2.TensorflowSpec, 61*14675a02SAndroid Build Coastguard Worker guarantee_all_funcs_one_use: bool = False, 62*14675a02SAndroid Build Coastguard Worker): 63*14675a02SAndroid Build Coastguard Worker """Converts a tf.Graph to a serialized TFLite model FlatBuffer.""" 64*14675a02SAndroid Build Coastguard Worker 65*14675a02SAndroid Build Coastguard Worker def create_input(input_tensor): 66*14675a02SAndroid Build Coastguard Worker return (input_tensor.name, [item.size for item in input_tensor.shape.dim]) 67*14675a02SAndroid Build Coastguard Worker 68*14675a02SAndroid Build Coastguard Worker inputs = [(spec.dataset_token_tensor_name, [])] 69*14675a02SAndroid Build Coastguard Worker for input_tensor in spec.input_tensor_specs: 70*14675a02SAndroid Build Coastguard Worker inputs.append(create_input(input_tensor)) 71*14675a02SAndroid Build Coastguard Worker converter = tf.compat.v1.lite.TFLiteConverter( 72*14675a02SAndroid Build Coastguard Worker graph, 73*14675a02SAndroid Build Coastguard Worker input_tensors=None, 74*14675a02SAndroid Build Coastguard Worker output_tensors=None, 75*14675a02SAndroid Build Coastguard Worker input_arrays_with_shape=inputs, 76*14675a02SAndroid Build Coastguard Worker output_arrays=[item.name for item in spec.output_tensor_specs], 77*14675a02SAndroid Build Coastguard Worker ) 78*14675a02SAndroid Build Coastguard Worker 79*14675a02SAndroid Build Coastguard Worker # pylint: disable=protected-access 80*14675a02SAndroid Build Coastguard Worker # Sets the control output node names. This is used when converting a tf.Graph 81*14675a02SAndroid Build Coastguard Worker # with no output tensors. 82*14675a02SAndroid Build Coastguard Worker converter._control_output_arrays = spec.target_node_names 83*14675a02SAndroid Build Coastguard Worker # Set this flag to true so that flatbuffer size can be reduced. 84*14675a02SAndroid Build Coastguard Worker converter._experimental_unfold_large_splat_constant = True 85*14675a02SAndroid Build Coastguard Worker # Exclude conversion metadata generation to reduce conversion time. 86*14675a02SAndroid Build Coastguard Worker converter.exclude_conversion_metadata = True 87*14675a02SAndroid Build Coastguard Worker converter.target_spec.supported_ops = [ 88*14675a02SAndroid Build Coastguard Worker tf.lite.OpsSet.TFLITE_BUILTINS, 89*14675a02SAndroid Build Coastguard Worker tf.lite.OpsSet.SELECT_TF_OPS, 90*14675a02SAndroid Build Coastguard Worker ] 91*14675a02SAndroid Build Coastguard Worker converter._experimental_allow_all_select_tf_ops = True 92*14675a02SAndroid Build Coastguard Worker converter._experimental_guarantee_all_funcs_one_use = ( 93*14675a02SAndroid Build Coastguard Worker guarantee_all_funcs_one_use 94*14675a02SAndroid Build Coastguard Worker ) 95*14675a02SAndroid Build Coastguard Worker # Instructs the TF Lite converter to not eliminate Assert ops, since the 96*14675a02SAndroid Build Coastguard Worker # client code needs this op to verify result correctness. 97*14675a02SAndroid Build Coastguard Worker converter._experimental_preserve_assert_op = True 98*14675a02SAndroid Build Coastguard Worker # pylint: enable=protected-access 99*14675a02SAndroid Build Coastguard Worker converter.experimental_enable_resource_variables = True 100*14675a02SAndroid Build Coastguard Worker return converter.convert() 101*14675a02SAndroid Build Coastguard Worker 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Workerdef generate_and_add_flat_buffer_to_plan( 104*14675a02SAndroid Build Coastguard Worker plan: _PlanT, forgive_tflite_conversion_failure=True 105*14675a02SAndroid Build Coastguard Worker) -> _PlanT: 106*14675a02SAndroid Build Coastguard Worker """Generates and adds a TFLite model to the specified Plan. 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker Note: This method mutates the plan argument. 109*14675a02SAndroid Build Coastguard Worker 110*14675a02SAndroid Build Coastguard Worker Args: 111*14675a02SAndroid Build Coastguard Worker plan: An input plan_pb2.Plan object. 112*14675a02SAndroid Build Coastguard Worker forgive_tflite_conversion_failure: If True, if TFLite conversion fails no 113*14675a02SAndroid Build Coastguard Worker exception will be raised and the Plan will be returned unmutated. 114*14675a02SAndroid Build Coastguard Worker 115*14675a02SAndroid Build Coastguard Worker Returns: 116*14675a02SAndroid Build Coastguard Worker The input Plan mutated to include a TFLite model when TFLite conversion 117*14675a02SAndroid Build Coastguard Worker succeeds, or the Plan without any mutation if TFLite conversion does not 118*14675a02SAndroid Build Coastguard Worker succeed. 119*14675a02SAndroid Build Coastguard Worker 120*14675a02SAndroid Build Coastguard Worker Raises: 121*14675a02SAndroid Build Coastguard Worker RuntimeError: if TFLite conversion fails and 122*14675a02SAndroid Build Coastguard Worker forgive_tflite_conversion_failure is set to False. 123*14675a02SAndroid Build Coastguard Worker """ 124*14675a02SAndroid Build Coastguard Worker 125*14675a02SAndroid Build Coastguard Worker def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False): 126*14675a02SAndroid Build Coastguard Worker stateful_partitioned_call_err = ( 127*14675a02SAndroid Build Coastguard Worker "'tf.StatefulPartitionedCall' op is" 128*14675a02SAndroid Build Coastguard Worker + ' neither a custom op nor a flex op' 129*14675a02SAndroid Build Coastguard Worker ) 130*14675a02SAndroid Build Coastguard Worker # Pack the TFLite flatbuffer into a BytesValue proto. 131*14675a02SAndroid Build Coastguard Worker try: 132*14675a02SAndroid Build Coastguard Worker return convert_graphdef_to_flatbuffer( 133*14675a02SAndroid Build Coastguard Worker graph_def, tensorflow_spec, guarantee_all_funcs_one_use 134*14675a02SAndroid Build Coastguard Worker ) 135*14675a02SAndroid Build Coastguard Worker except Exception as e: # pylint: disable=broad-except 136*14675a02SAndroid Build Coastguard Worker # Try to handle conversion errors and run converter again. 137*14675a02SAndroid Build Coastguard Worker if ( 138*14675a02SAndroid Build Coastguard Worker stateful_partitioned_call_err in str(e) 139*14675a02SAndroid Build Coastguard Worker and not guarantee_all_funcs_one_use 140*14675a02SAndroid Build Coastguard Worker ): 141*14675a02SAndroid Build Coastguard Worker return convert(graph_def, tensorflow_spec, True) 142*14675a02SAndroid Build Coastguard Worker elif forgive_tflite_conversion_failure: 143*14675a02SAndroid Build Coastguard Worker return b'' 144*14675a02SAndroid Build Coastguard Worker else: 145*14675a02SAndroid Build Coastguard Worker raise RuntimeError( 146*14675a02SAndroid Build Coastguard Worker f'Failure during TFLite conversion of the client graph: {str(e)}' 147*14675a02SAndroid Build Coastguard Worker ) from e 148*14675a02SAndroid Build Coastguard Worker 149*14675a02SAndroid Build Coastguard Worker if isinstance(plan, plan_pb2.Plan): 150*14675a02SAndroid Build Coastguard Worker client_graph_def = tensor_utils.import_graph_def_from_any( 151*14675a02SAndroid Build Coastguard Worker plan.client_graph_bytes 152*14675a02SAndroid Build Coastguard Worker ) 153*14675a02SAndroid Build Coastguard Worker plan.client_tflite_graph_bytes = convert( 154*14675a02SAndroid Build Coastguard Worker client_graph_def, plan.phase[0].client_phase.tensorflow_spec 155*14675a02SAndroid Build Coastguard Worker ) 156*14675a02SAndroid Build Coastguard Worker elif isinstance(plan, plan_pb2.ClientOnlyPlan): 157*14675a02SAndroid Build Coastguard Worker client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph) 158*14675a02SAndroid Build Coastguard Worker plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec) 159*14675a02SAndroid Build Coastguard Worker else: 160*14675a02SAndroid Build Coastguard Worker raise NotImplementedError(f'Unsupported _PlanT {type(plan)}') 161*14675a02SAndroid Build Coastguard Worker return plan 162