xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/plan_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities related to plan protos."""
15
16from typing import TypeVar
17
18import tensorflow as tf
19from fcp.artifact_building import tensor_utils
20from fcp.protos import plan_pb2
21
22
23_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan)
24
25
26# TODO(team): Remove in favor of save_from_checkpoint_op.
27def write_checkpoint(sess, checkpoint_op, checkpoint_filename):
28  """Writes from a CheckpointOp, without executing before/after restore ops."""
29  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
30    raise ValueError('A CheckpointOp is required.')
31  if (
32      checkpoint_op
33      and checkpoint_op.saver_def
34      and checkpoint_op.saver_def.save_tensor_name
35  ):
36    sess.run(
37        checkpoint_op.saver_def.save_tensor_name,
38        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
39    )
40
41
42# TODO(team): Remove in favor of restore_from_checkpoint_op.
43def read_checkpoint(sess, checkpoint_op, checkpoint_filename):
44  """Reads from a CheckpointOp, without executing before/after restore ops."""
45  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
46    raise ValueError('A CheckpointOp is required.')
47  if (
48      checkpoint_op
49      and checkpoint_op.saver_def
50      and checkpoint_op.saver_def.restore_op_name
51  ):
52    sess.run(
53        checkpoint_op.saver_def.restore_op_name,
54        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
55    )
56
57
58def convert_graphdef_to_flatbuffer(
59    graph: tf.compat.v1.GraphDef,
60    spec: plan_pb2.TensorflowSpec,
61    guarantee_all_funcs_one_use: bool = False,
62):
63  """Converts a tf.Graph to a serialized TFLite model FlatBuffer."""
64
65  def create_input(input_tensor):
66    return (input_tensor.name, [item.size for item in input_tensor.shape.dim])
67
68  inputs = [(spec.dataset_token_tensor_name, [])]
69  for input_tensor in spec.input_tensor_specs:
70    inputs.append(create_input(input_tensor))
71  converter = tf.compat.v1.lite.TFLiteConverter(
72      graph,
73      input_tensors=None,
74      output_tensors=None,
75      input_arrays_with_shape=inputs,
76      output_arrays=[item.name for item in spec.output_tensor_specs],
77  )
78
79  # pylint: disable=protected-access
80  # Sets the control output node names. This is used when converting a tf.Graph
81  # with no output tensors.
82  converter._control_output_arrays = spec.target_node_names
83  # Set this flag to true so that flatbuffer size can be reduced.
84  converter._experimental_unfold_large_splat_constant = True
85  # Exclude conversion metadata generation to reduce conversion time.
86  converter.exclude_conversion_metadata = True
87  converter.target_spec.supported_ops = [
88      tf.lite.OpsSet.TFLITE_BUILTINS,
89      tf.lite.OpsSet.SELECT_TF_OPS,
90  ]
91  converter._experimental_allow_all_select_tf_ops = True
92  converter._experimental_guarantee_all_funcs_one_use = (
93      guarantee_all_funcs_one_use
94  )
95  # Instructs the TF Lite converter to not eliminate Assert ops, since the
96  # client code needs this op to verify result correctness.
97  converter._experimental_preserve_assert_op = True
98  # pylint: enable=protected-access
99  converter.experimental_enable_resource_variables = True
100  return converter.convert()
101
102
103def generate_and_add_flat_buffer_to_plan(
104    plan: _PlanT, forgive_tflite_conversion_failure=True
105) -> _PlanT:
106  """Generates and adds a TFLite model to the specified Plan.
107
108  Note: This method mutates the plan argument.
109
110  Args:
111    plan: An input plan_pb2.Plan object.
112    forgive_tflite_conversion_failure: If True, if TFLite conversion fails no
113      exception will be raised and the Plan will be returned unmutated.
114
115  Returns:
116    The input Plan mutated to include a TFLite model when TFLite conversion
117    succeeds, or the Plan without any mutation if TFLite conversion does not
118    succeed.
119
120  Raises:
121    RuntimeError: if TFLite conversion fails and
122      forgive_tflite_conversion_failure is set to False.
123  """
124
125  def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False):
126    stateful_partitioned_call_err = (
127        "'tf.StatefulPartitionedCall' op is"
128        + ' neither a custom op nor a flex op'
129    )
130    # Pack the TFLite flatbuffer into a BytesValue proto.
131    try:
132      return convert_graphdef_to_flatbuffer(
133          graph_def, tensorflow_spec, guarantee_all_funcs_one_use
134      )
135    except Exception as e:  # pylint: disable=broad-except
136      # Try to handle conversion errors and run converter again.
137      if (
138          stateful_partitioned_call_err in str(e)
139          and not guarantee_all_funcs_one_use
140      ):
141        return convert(graph_def, tensorflow_spec, True)
142      elif forgive_tflite_conversion_failure:
143        return b''
144      else:
145        raise RuntimeError(
146            f'Failure during TFLite conversion of the client graph: {str(e)}'
147        ) from e
148
149  if isinstance(plan, plan_pb2.Plan):
150    client_graph_def = tensor_utils.import_graph_def_from_any(
151        plan.client_graph_bytes
152    )
153    plan.client_tflite_graph_bytes = convert(
154        client_graph_def, plan.phase[0].client_phase.tensorflow_spec
155    )
156  elif isinstance(plan, plan_pb2.ClientOnlyPlan):
157    client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph)
158    plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec)
159  else:
160    raise NotImplementedError(f'Unsupported _PlanT {type(plan)}')
161  return plan
162