xref: /aosp_15_r20/external/federated-compute/fcp/demo/plan_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Utilities for working with Plan protos and TensorFlow.
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard WorkerSee the field comments in plan.proto for more information about each operation
17*14675a02SAndroid Build Coastguard Workerand when it should be run.
18*14675a02SAndroid Build Coastguard Worker"""
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerimport functools
21*14675a02SAndroid Build Coastguard Workerimport tempfile
22*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional
23*14675a02SAndroid Build Coastguard Workerimport uuid
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard Workerfrom google.protobuf import message
28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
29*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import serve_slices as serve_slices_registry
30*14675a02SAndroid Build Coastguard Worker
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Workerclass Session:
33*14675a02SAndroid Build Coastguard Worker  """A session for performing L2 Plan operations.
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Worker  This class only supports loading a single intermediate update.
36*14675a02SAndroid Build Coastguard Worker  """
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker  def __init__(self, plan: plan_pb2.Plan, checkpoint: bytes):
39*14675a02SAndroid Build Coastguard Worker    if len(plan.phase) != 1:
40*14675a02SAndroid Build Coastguard Worker      raise ValueError('plan must contain exactly 1 phase.')
41*14675a02SAndroid Build Coastguard Worker    if not plan.phase[0].HasField('server_phase'):
42*14675a02SAndroid Build Coastguard Worker      raise ValueError('plan.phase[0] is missing server_phase.')
43*14675a02SAndroid Build Coastguard Worker
44*14675a02SAndroid Build Coastguard Worker    graph_def = tf.compat.v1.GraphDef()
45*14675a02SAndroid Build Coastguard Worker    try:
46*14675a02SAndroid Build Coastguard Worker      plan.server_graph_bytes.Unpack(graph_def)
47*14675a02SAndroid Build Coastguard Worker    except message.DecodeError as e:
48*14675a02SAndroid Build Coastguard Worker      raise ValueError('Unable to parse server graph.') from e
49*14675a02SAndroid Build Coastguard Worker
50*14675a02SAndroid Build Coastguard Worker    graph = tf.Graph()
51*14675a02SAndroid Build Coastguard Worker    with graph.as_default():
52*14675a02SAndroid Build Coastguard Worker      tf.import_graph_def(graph_def, name='')
53*14675a02SAndroid Build Coastguard Worker    self._session = tf.compat.v1.Session(graph=graph)
54*14675a02SAndroid Build Coastguard Worker    self._plan = plan
55*14675a02SAndroid Build Coastguard Worker    self._restore_state(plan.server_savepoint, checkpoint)
56*14675a02SAndroid Build Coastguard Worker    self._maybe_run(plan.phase[0].server_phase.phase_init_op)
57*14675a02SAndroid Build Coastguard Worker
58*14675a02SAndroid Build Coastguard Worker    serve_slices_calls = []
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Worker    def record_serve_slices_call(*args):
61*14675a02SAndroid Build Coastguard Worker      served_at_id = str(uuid.uuid4())
62*14675a02SAndroid Build Coastguard Worker      serve_slices_calls.append((served_at_id, args))
63*14675a02SAndroid Build Coastguard Worker      return served_at_id
64*14675a02SAndroid Build Coastguard Worker
65*14675a02SAndroid Build Coastguard Worker    with serve_slices_registry.register_serve_slices_callback(
66*14675a02SAndroid Build Coastguard Worker        record_serve_slices_call
67*14675a02SAndroid Build Coastguard Worker    ) as token:
68*14675a02SAndroid Build Coastguard Worker      self._client_checkpoint = self._save_state(
69*14675a02SAndroid Build Coastguard Worker          plan.phase[0].server_phase.write_client_init, session_token=token
70*14675a02SAndroid Build Coastguard Worker      )
71*14675a02SAndroid Build Coastguard Worker    self._slices = {
72*14675a02SAndroid Build Coastguard Worker        k: self._build_slices(*args) for k, args in serve_slices_calls
73*14675a02SAndroid Build Coastguard Worker    }
74*14675a02SAndroid Build Coastguard Worker
75*14675a02SAndroid Build Coastguard Worker  def __enter__(self) -> 'Session':
76*14675a02SAndroid Build Coastguard Worker    self._session.__enter__()
77*14675a02SAndroid Build Coastguard Worker    return self
78*14675a02SAndroid Build Coastguard Worker
79*14675a02SAndroid Build Coastguard Worker  def __exit__(self, exc_type, exc_value, tb) -> None:
80*14675a02SAndroid Build Coastguard Worker    self._session.__exit__(exc_type, exc_value, tb)
81*14675a02SAndroid Build Coastguard Worker
82*14675a02SAndroid Build Coastguard Worker  def close(self) -> None:
83*14675a02SAndroid Build Coastguard Worker    """Closes the session, releasing resources."""
84*14675a02SAndroid Build Coastguard Worker    self._session.close()
85*14675a02SAndroid Build Coastguard Worker
86*14675a02SAndroid Build Coastguard Worker  def _maybe_run(
87*14675a02SAndroid Build Coastguard Worker      self, op: str, feed_dict: Optional[dict[str, Any]] = None
88*14675a02SAndroid Build Coastguard Worker  ) -> None:
89*14675a02SAndroid Build Coastguard Worker    """Runs an operation if it's non-empty."""
90*14675a02SAndroid Build Coastguard Worker    if op:
91*14675a02SAndroid Build Coastguard Worker      self._session.run(op, feed_dict=feed_dict)
92*14675a02SAndroid Build Coastguard Worker
93*14675a02SAndroid Build Coastguard Worker  def _restore_state(self, checkpoint_op: plan_pb2.CheckpointOp,
94*14675a02SAndroid Build Coastguard Worker                     checkpoint: bytes) -> None:
95*14675a02SAndroid Build Coastguard Worker    """Restores state from a TensorFlow checkpoint."""
96*14675a02SAndroid Build Coastguard Worker    self._maybe_run(checkpoint_op.before_restore_op)
97*14675a02SAndroid Build Coastguard Worker    if checkpoint_op.HasField('saver_def'):
98*14675a02SAndroid Build Coastguard Worker      with tempfile.NamedTemporaryFile('wb') as tmpfile:
99*14675a02SAndroid Build Coastguard Worker        tmpfile.write(checkpoint)
100*14675a02SAndroid Build Coastguard Worker        tmpfile.flush()
101*14675a02SAndroid Build Coastguard Worker        self._session.run(
102*14675a02SAndroid Build Coastguard Worker            checkpoint_op.saver_def.restore_op_name,
103*14675a02SAndroid Build Coastguard Worker            {checkpoint_op.saver_def.filename_tensor_name: tmpfile.name})
104*14675a02SAndroid Build Coastguard Worker    self._maybe_run(checkpoint_op.after_restore_op)
105*14675a02SAndroid Build Coastguard Worker
106*14675a02SAndroid Build Coastguard Worker  def _save_state(
107*14675a02SAndroid Build Coastguard Worker      self,
108*14675a02SAndroid Build Coastguard Worker      checkpoint_op: plan_pb2.CheckpointOp,
109*14675a02SAndroid Build Coastguard Worker      session_token: Optional[bytes] = None,
110*14675a02SAndroid Build Coastguard Worker  ) -> bytes:
111*14675a02SAndroid Build Coastguard Worker    """Saves state to a TensorFlow checkpoint."""
112*14675a02SAndroid Build Coastguard Worker    before_and_after_inputs = {}
113*14675a02SAndroid Build Coastguard Worker    if session_token and checkpoint_op.session_token_tensor_name:
114*14675a02SAndroid Build Coastguard Worker      before_and_after_inputs[checkpoint_op.session_token_tensor_name] = (
115*14675a02SAndroid Build Coastguard Worker          session_token
116*14675a02SAndroid Build Coastguard Worker      )
117*14675a02SAndroid Build Coastguard Worker
118*14675a02SAndroid Build Coastguard Worker    self._maybe_run(
119*14675a02SAndroid Build Coastguard Worker        checkpoint_op.before_save_op, feed_dict=before_and_after_inputs
120*14675a02SAndroid Build Coastguard Worker    )
121*14675a02SAndroid Build Coastguard Worker    result = b''
122*14675a02SAndroid Build Coastguard Worker    if checkpoint_op.HasField('saver_def'):
123*14675a02SAndroid Build Coastguard Worker      with tempfile.NamedTemporaryFile() as tmpfile:
124*14675a02SAndroid Build Coastguard Worker        save_tensor_inputs = before_and_after_inputs.copy()
125*14675a02SAndroid Build Coastguard Worker        save_tensor_inputs[checkpoint_op.saver_def.filename_tensor_name] = (
126*14675a02SAndroid Build Coastguard Worker            tmpfile.name
127*14675a02SAndroid Build Coastguard Worker        )
128*14675a02SAndroid Build Coastguard Worker        self._session.run(
129*14675a02SAndroid Build Coastguard Worker            checkpoint_op.saver_def.save_tensor_name,
130*14675a02SAndroid Build Coastguard Worker            feed_dict=save_tensor_inputs,
131*14675a02SAndroid Build Coastguard Worker        )
132*14675a02SAndroid Build Coastguard Worker        # TensorFlow overwrites (via move) the output file, so the data can't be
133*14675a02SAndroid Build Coastguard Worker        # read from the filehandle. Deletion still works properly, though.
134*14675a02SAndroid Build Coastguard Worker        with open(tmpfile.name, 'rb') as f:
135*14675a02SAndroid Build Coastguard Worker          result = f.read()
136*14675a02SAndroid Build Coastguard Worker    self._maybe_run(
137*14675a02SAndroid Build Coastguard Worker        checkpoint_op.after_save_op, feed_dict=before_and_after_inputs
138*14675a02SAndroid Build Coastguard Worker    )
139*14675a02SAndroid Build Coastguard Worker    return result
140*14675a02SAndroid Build Coastguard Worker
141*14675a02SAndroid Build Coastguard Worker  def _build_slices(
142*14675a02SAndroid Build Coastguard Worker      self,
143*14675a02SAndroid Build Coastguard Worker      callback_token: bytes,
144*14675a02SAndroid Build Coastguard Worker      server_val: list[Any],
145*14675a02SAndroid Build Coastguard Worker      max_key: int,
146*14675a02SAndroid Build Coastguard Worker      select_fn_initialize_op: str,
147*14675a02SAndroid Build Coastguard Worker      select_fn_server_val_input_tensor_names: list[str],
148*14675a02SAndroid Build Coastguard Worker      select_fn_key_input_tensor_name: str,
149*14675a02SAndroid Build Coastguard Worker      select_fn_filename_input_tensor_name: str,
150*14675a02SAndroid Build Coastguard Worker      select_fn_target_tensor_name: str,
151*14675a02SAndroid Build Coastguard Worker  ):
152*14675a02SAndroid Build Coastguard Worker    """Builds the slices for a ServeSlices call."""
153*14675a02SAndroid Build Coastguard Worker    del callback_token
154*14675a02SAndroid Build Coastguard Worker    slices: list[bytes] = []
155*14675a02SAndroid Build Coastguard Worker    for i in range(0, max_key + 1):
156*14675a02SAndroid Build Coastguard Worker      self._maybe_run(select_fn_initialize_op)
157*14675a02SAndroid Build Coastguard Worker      with tempfile.NamedTemporaryFile() as tmpfile:
158*14675a02SAndroid Build Coastguard Worker        feed_dict = dict(
159*14675a02SAndroid Build Coastguard Worker            zip(select_fn_server_val_input_tensor_names, server_val)
160*14675a02SAndroid Build Coastguard Worker        )
161*14675a02SAndroid Build Coastguard Worker        feed_dict[select_fn_key_input_tensor_name] = i
162*14675a02SAndroid Build Coastguard Worker        feed_dict[select_fn_filename_input_tensor_name] = tmpfile.name
163*14675a02SAndroid Build Coastguard Worker        self._session.run(select_fn_target_tensor_name, feed_dict=feed_dict)
164*14675a02SAndroid Build Coastguard Worker        # TensorFlow overwrites (via move) the output file, so the data can't be
165*14675a02SAndroid Build Coastguard Worker        # read from the filehandle. Deletion still works properly, though.
166*14675a02SAndroid Build Coastguard Worker        with open(tmpfile.name, 'rb') as f:
167*14675a02SAndroid Build Coastguard Worker          slices.append(f.read())
168*14675a02SAndroid Build Coastguard Worker    return slices
169*14675a02SAndroid Build Coastguard Worker
170*14675a02SAndroid Build Coastguard Worker  @functools.cached_property
171*14675a02SAndroid Build Coastguard Worker  def client_plan(self) -> bytes:
172*14675a02SAndroid Build Coastguard Worker    """The serialized ClientOnlyPlan corresponding to the Plan proto."""
173*14675a02SAndroid Build Coastguard Worker    client_only_plan = plan_pb2.ClientOnlyPlan(
174*14675a02SAndroid Build Coastguard Worker        phase=self._plan.phase[0].client_phase,
175*14675a02SAndroid Build Coastguard Worker        graph=self._plan.client_graph_bytes.value,
176*14675a02SAndroid Build Coastguard Worker        tflite_graph=self._plan.client_tflite_graph_bytes)
177*14675a02SAndroid Build Coastguard Worker    if self._plan.HasField('tensorflow_config_proto'):
178*14675a02SAndroid Build Coastguard Worker      client_only_plan.tensorflow_config_proto.CopyFrom(
179*14675a02SAndroid Build Coastguard Worker          self._plan.tensorflow_config_proto)
180*14675a02SAndroid Build Coastguard Worker    return client_only_plan.SerializeToString()
181*14675a02SAndroid Build Coastguard Worker
182*14675a02SAndroid Build Coastguard Worker  @property
183*14675a02SAndroid Build Coastguard Worker  def client_checkpoint(self) -> bytes:
184*14675a02SAndroid Build Coastguard Worker    """The initial checkpoint for use by clients."""
185*14675a02SAndroid Build Coastguard Worker    return self._client_checkpoint
186*14675a02SAndroid Build Coastguard Worker
187*14675a02SAndroid Build Coastguard Worker  def finalize(self, update: bytes) -> bytes:
188*14675a02SAndroid Build Coastguard Worker    """Loads an intermediate update and return the final result."""
189*14675a02SAndroid Build Coastguard Worker    self._restore_state(
190*14675a02SAndroid Build Coastguard Worker        self._plan.phase[0].server_phase.read_intermediate_update, update)
191*14675a02SAndroid Build Coastguard Worker    self._maybe_run(self._plan.phase[0].server_phase
192*14675a02SAndroid Build Coastguard Worker                    .intermediate_aggregate_into_accumulators_op)
193*14675a02SAndroid Build Coastguard Worker    # write_accumulators and metrics are not needed by Federated Program
194*14675a02SAndroid Build Coastguard Worker    # computations because all results are included in the server savepoint.
195*14675a02SAndroid Build Coastguard Worker    self._maybe_run(
196*14675a02SAndroid Build Coastguard Worker        self._plan.phase[0].server_phase.apply_aggregrated_updates_op)
197*14675a02SAndroid Build Coastguard Worker    return self._save_state(self._plan.server_savepoint)
198*14675a02SAndroid Build Coastguard Worker
199*14675a02SAndroid Build Coastguard Worker  @property
200*14675a02SAndroid Build Coastguard Worker  def slices(self) -> dict[str, list[bytes]]:
201*14675a02SAndroid Build Coastguard Worker    """The Federated Select slices, keyed by served_at_id."""
202*14675a02SAndroid Build Coastguard Worker    # Return a copy to prevent mutations.
203*14675a02SAndroid Build Coastguard Worker    return self._slices.copy()
204