xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/plan_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Utilities related to plan protos."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom typing import TypeVar
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
19*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils
20*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan)
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Worker# TODO(team): Remove in favor of save_from_checkpoint_op.
27*14675a02SAndroid Build Coastguard Workerdef write_checkpoint(sess, checkpoint_op, checkpoint_filename):
28*14675a02SAndroid Build Coastguard Worker  """Writes from a CheckpointOp, without executing before/after restore ops."""
29*14675a02SAndroid Build Coastguard Worker  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
30*14675a02SAndroid Build Coastguard Worker    raise ValueError('A CheckpointOp is required.')
31*14675a02SAndroid Build Coastguard Worker  if (
32*14675a02SAndroid Build Coastguard Worker      checkpoint_op
33*14675a02SAndroid Build Coastguard Worker      and checkpoint_op.saver_def
34*14675a02SAndroid Build Coastguard Worker      and checkpoint_op.saver_def.save_tensor_name
35*14675a02SAndroid Build Coastguard Worker  ):
36*14675a02SAndroid Build Coastguard Worker    sess.run(
37*14675a02SAndroid Build Coastguard Worker        checkpoint_op.saver_def.save_tensor_name,
38*14675a02SAndroid Build Coastguard Worker        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
39*14675a02SAndroid Build Coastguard Worker    )
40*14675a02SAndroid Build Coastguard Worker
41*14675a02SAndroid Build Coastguard Worker
42*14675a02SAndroid Build Coastguard Worker# TODO(team): Remove in favor of restore_from_checkpoint_op.
43*14675a02SAndroid Build Coastguard Workerdef read_checkpoint(sess, checkpoint_op, checkpoint_filename):
44*14675a02SAndroid Build Coastguard Worker  """Reads from a CheckpointOp, without executing before/after restore ops."""
45*14675a02SAndroid Build Coastguard Worker  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
46*14675a02SAndroid Build Coastguard Worker    raise ValueError('A CheckpointOp is required.')
47*14675a02SAndroid Build Coastguard Worker  if (
48*14675a02SAndroid Build Coastguard Worker      checkpoint_op
49*14675a02SAndroid Build Coastguard Worker      and checkpoint_op.saver_def
50*14675a02SAndroid Build Coastguard Worker      and checkpoint_op.saver_def.restore_op_name
51*14675a02SAndroid Build Coastguard Worker  ):
52*14675a02SAndroid Build Coastguard Worker    sess.run(
53*14675a02SAndroid Build Coastguard Worker        checkpoint_op.saver_def.restore_op_name,
54*14675a02SAndroid Build Coastguard Worker        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
55*14675a02SAndroid Build Coastguard Worker    )
56*14675a02SAndroid Build Coastguard Worker
57*14675a02SAndroid Build Coastguard Worker
58*14675a02SAndroid Build Coastguard Workerdef convert_graphdef_to_flatbuffer(
59*14675a02SAndroid Build Coastguard Worker    graph: tf.compat.v1.GraphDef,
60*14675a02SAndroid Build Coastguard Worker    spec: plan_pb2.TensorflowSpec,
61*14675a02SAndroid Build Coastguard Worker    guarantee_all_funcs_one_use: bool = False,
62*14675a02SAndroid Build Coastguard Worker):
63*14675a02SAndroid Build Coastguard Worker  """Converts a tf.Graph to a serialized TFLite model FlatBuffer."""
64*14675a02SAndroid Build Coastguard Worker
65*14675a02SAndroid Build Coastguard Worker  def create_input(input_tensor):
66*14675a02SAndroid Build Coastguard Worker    return (input_tensor.name, [item.size for item in input_tensor.shape.dim])
67*14675a02SAndroid Build Coastguard Worker
68*14675a02SAndroid Build Coastguard Worker  inputs = [(spec.dataset_token_tensor_name, [])]
69*14675a02SAndroid Build Coastguard Worker  for input_tensor in spec.input_tensor_specs:
70*14675a02SAndroid Build Coastguard Worker    inputs.append(create_input(input_tensor))
71*14675a02SAndroid Build Coastguard Worker  converter = tf.compat.v1.lite.TFLiteConverter(
72*14675a02SAndroid Build Coastguard Worker      graph,
73*14675a02SAndroid Build Coastguard Worker      input_tensors=None,
74*14675a02SAndroid Build Coastguard Worker      output_tensors=None,
75*14675a02SAndroid Build Coastguard Worker      input_arrays_with_shape=inputs,
76*14675a02SAndroid Build Coastguard Worker      output_arrays=[item.name for item in spec.output_tensor_specs],
77*14675a02SAndroid Build Coastguard Worker  )
78*14675a02SAndroid Build Coastguard Worker
79*14675a02SAndroid Build Coastguard Worker  # pylint: disable=protected-access
80*14675a02SAndroid Build Coastguard Worker  # Sets the control output node names. This is used when converting a tf.Graph
81*14675a02SAndroid Build Coastguard Worker  # with no output tensors.
82*14675a02SAndroid Build Coastguard Worker  converter._control_output_arrays = spec.target_node_names
83*14675a02SAndroid Build Coastguard Worker  # Set this flag to true so that flatbuffer size can be reduced.
84*14675a02SAndroid Build Coastguard Worker  converter._experimental_unfold_large_splat_constant = True
85*14675a02SAndroid Build Coastguard Worker  # Exclude conversion metadata generation to reduce conversion time.
86*14675a02SAndroid Build Coastguard Worker  converter.exclude_conversion_metadata = True
87*14675a02SAndroid Build Coastguard Worker  converter.target_spec.supported_ops = [
88*14675a02SAndroid Build Coastguard Worker      tf.lite.OpsSet.TFLITE_BUILTINS,
89*14675a02SAndroid Build Coastguard Worker      tf.lite.OpsSet.SELECT_TF_OPS,
90*14675a02SAndroid Build Coastguard Worker  ]
91*14675a02SAndroid Build Coastguard Worker  converter._experimental_allow_all_select_tf_ops = True
92*14675a02SAndroid Build Coastguard Worker  converter._experimental_guarantee_all_funcs_one_use = (
93*14675a02SAndroid Build Coastguard Worker      guarantee_all_funcs_one_use
94*14675a02SAndroid Build Coastguard Worker  )
95*14675a02SAndroid Build Coastguard Worker  # Instructs the TF Lite converter to not eliminate Assert ops, since the
96*14675a02SAndroid Build Coastguard Worker  # client code needs this op to verify result correctness.
97*14675a02SAndroid Build Coastguard Worker  converter._experimental_preserve_assert_op = True
98*14675a02SAndroid Build Coastguard Worker  # pylint: enable=protected-access
99*14675a02SAndroid Build Coastguard Worker  converter.experimental_enable_resource_variables = True
100*14675a02SAndroid Build Coastguard Worker  return converter.convert()
101*14675a02SAndroid Build Coastguard Worker
102*14675a02SAndroid Build Coastguard Worker
103*14675a02SAndroid Build Coastguard Workerdef generate_and_add_flat_buffer_to_plan(
104*14675a02SAndroid Build Coastguard Worker    plan: _PlanT, forgive_tflite_conversion_failure=True
105*14675a02SAndroid Build Coastguard Worker) -> _PlanT:
106*14675a02SAndroid Build Coastguard Worker  """Generates and adds a TFLite model to the specified Plan.
107*14675a02SAndroid Build Coastguard Worker
108*14675a02SAndroid Build Coastguard Worker  Note: This method mutates the plan argument.
109*14675a02SAndroid Build Coastguard Worker
110*14675a02SAndroid Build Coastguard Worker  Args:
111*14675a02SAndroid Build Coastguard Worker    plan: An input plan_pb2.Plan object.
112*14675a02SAndroid Build Coastguard Worker    forgive_tflite_conversion_failure: If True, if TFLite conversion fails no
113*14675a02SAndroid Build Coastguard Worker      exception will be raised and the Plan will be returned unmutated.
114*14675a02SAndroid Build Coastguard Worker
115*14675a02SAndroid Build Coastguard Worker  Returns:
116*14675a02SAndroid Build Coastguard Worker    The input Plan mutated to include a TFLite model when TFLite conversion
117*14675a02SAndroid Build Coastguard Worker    succeeds, or the Plan without any mutation if TFLite conversion does not
118*14675a02SAndroid Build Coastguard Worker    succeed.
119*14675a02SAndroid Build Coastguard Worker
120*14675a02SAndroid Build Coastguard Worker  Raises:
121*14675a02SAndroid Build Coastguard Worker    RuntimeError: if TFLite conversion fails and
122*14675a02SAndroid Build Coastguard Worker      forgive_tflite_conversion_failure is set to False.
123*14675a02SAndroid Build Coastguard Worker  """
124*14675a02SAndroid Build Coastguard Worker
125*14675a02SAndroid Build Coastguard Worker  def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False):
126*14675a02SAndroid Build Coastguard Worker    stateful_partitioned_call_err = (
127*14675a02SAndroid Build Coastguard Worker        "'tf.StatefulPartitionedCall' op is"
128*14675a02SAndroid Build Coastguard Worker        + ' neither a custom op nor a flex op'
129*14675a02SAndroid Build Coastguard Worker    )
130*14675a02SAndroid Build Coastguard Worker    # Pack the TFLite flatbuffer into a BytesValue proto.
131*14675a02SAndroid Build Coastguard Worker    try:
132*14675a02SAndroid Build Coastguard Worker      return convert_graphdef_to_flatbuffer(
133*14675a02SAndroid Build Coastguard Worker          graph_def, tensorflow_spec, guarantee_all_funcs_one_use
134*14675a02SAndroid Build Coastguard Worker      )
135*14675a02SAndroid Build Coastguard Worker    except Exception as e:  # pylint: disable=broad-except
136*14675a02SAndroid Build Coastguard Worker      # Try to handle conversion errors and run converter again.
137*14675a02SAndroid Build Coastguard Worker      if (
138*14675a02SAndroid Build Coastguard Worker          stateful_partitioned_call_err in str(e)
139*14675a02SAndroid Build Coastguard Worker          and not guarantee_all_funcs_one_use
140*14675a02SAndroid Build Coastguard Worker      ):
141*14675a02SAndroid Build Coastguard Worker        return convert(graph_def, tensorflow_spec, True)
142*14675a02SAndroid Build Coastguard Worker      elif forgive_tflite_conversion_failure:
143*14675a02SAndroid Build Coastguard Worker        return b''
144*14675a02SAndroid Build Coastguard Worker      else:
145*14675a02SAndroid Build Coastguard Worker        raise RuntimeError(
146*14675a02SAndroid Build Coastguard Worker            f'Failure during TFLite conversion of the client graph: {str(e)}'
147*14675a02SAndroid Build Coastguard Worker        ) from e
148*14675a02SAndroid Build Coastguard Worker
149*14675a02SAndroid Build Coastguard Worker  if isinstance(plan, plan_pb2.Plan):
150*14675a02SAndroid Build Coastguard Worker    client_graph_def = tensor_utils.import_graph_def_from_any(
151*14675a02SAndroid Build Coastguard Worker        plan.client_graph_bytes
152*14675a02SAndroid Build Coastguard Worker    )
153*14675a02SAndroid Build Coastguard Worker    plan.client_tflite_graph_bytes = convert(
154*14675a02SAndroid Build Coastguard Worker        client_graph_def, plan.phase[0].client_phase.tensorflow_spec
155*14675a02SAndroid Build Coastguard Worker    )
156*14675a02SAndroid Build Coastguard Worker  elif isinstance(plan, plan_pb2.ClientOnlyPlan):
157*14675a02SAndroid Build Coastguard Worker    client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph)
158*14675a02SAndroid Build Coastguard Worker    plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec)
159*14675a02SAndroid Build Coastguard Worker  else:
160*14675a02SAndroid Build Coastguard Worker    raise NotImplementedError(f'Unsupported _PlanT {type(plan)}')
161*14675a02SAndroid Build Coastguard Worker  return plan
162