xref: /aosp_15_r20/external/federated-compute/fcp/demo/plan_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities for working with Plan protos and TensorFlow.
15
16See the field comments in plan.proto for more information about each operation
17and when it should be run.
18"""
19
20import functools
21import tempfile
22from typing import Any, Optional
23import uuid
24
25import tensorflow as tf
26
27from google.protobuf import message
28from fcp.protos import plan_pb2
29from fcp.tensorflow import serve_slices as serve_slices_registry
30
31
32class Session:
33  """A session for performing L2 Plan operations.
34
35  This class only supports loading a single intermediate update.
36  """
37
38  def __init__(self, plan: plan_pb2.Plan, checkpoint: bytes):
39    if len(plan.phase) != 1:
40      raise ValueError('plan must contain exactly 1 phase.')
41    if not plan.phase[0].HasField('server_phase'):
42      raise ValueError('plan.phase[0] is missing server_phase.')
43
44    graph_def = tf.compat.v1.GraphDef()
45    try:
46      plan.server_graph_bytes.Unpack(graph_def)
47    except message.DecodeError as e:
48      raise ValueError('Unable to parse server graph.') from e
49
50    graph = tf.Graph()
51    with graph.as_default():
52      tf.import_graph_def(graph_def, name='')
53    self._session = tf.compat.v1.Session(graph=graph)
54    self._plan = plan
55    self._restore_state(plan.server_savepoint, checkpoint)
56    self._maybe_run(plan.phase[0].server_phase.phase_init_op)
57
58    serve_slices_calls = []
59
60    def record_serve_slices_call(*args):
61      served_at_id = str(uuid.uuid4())
62      serve_slices_calls.append((served_at_id, args))
63      return served_at_id
64
65    with serve_slices_registry.register_serve_slices_callback(
66        record_serve_slices_call
67    ) as token:
68      self._client_checkpoint = self._save_state(
69          plan.phase[0].server_phase.write_client_init, session_token=token
70      )
71    self._slices = {
72        k: self._build_slices(*args) for k, args in serve_slices_calls
73    }
74
75  def __enter__(self) -> 'Session':
76    self._session.__enter__()
77    return self
78
79  def __exit__(self, exc_type, exc_value, tb) -> None:
80    self._session.__exit__(exc_type, exc_value, tb)
81
82  def close(self) -> None:
83    """Closes the session, releasing resources."""
84    self._session.close()
85
86  def _maybe_run(
87      self, op: str, feed_dict: Optional[dict[str, Any]] = None
88  ) -> None:
89    """Runs an operation if it's non-empty."""
90    if op:
91      self._session.run(op, feed_dict=feed_dict)
92
93  def _restore_state(self, checkpoint_op: plan_pb2.CheckpointOp,
94                     checkpoint: bytes) -> None:
95    """Restores state from a TensorFlow checkpoint."""
96    self._maybe_run(checkpoint_op.before_restore_op)
97    if checkpoint_op.HasField('saver_def'):
98      with tempfile.NamedTemporaryFile('wb') as tmpfile:
99        tmpfile.write(checkpoint)
100        tmpfile.flush()
101        self._session.run(
102            checkpoint_op.saver_def.restore_op_name,
103            {checkpoint_op.saver_def.filename_tensor_name: tmpfile.name})
104    self._maybe_run(checkpoint_op.after_restore_op)
105
106  def _save_state(
107      self,
108      checkpoint_op: plan_pb2.CheckpointOp,
109      session_token: Optional[bytes] = None,
110  ) -> bytes:
111    """Saves state to a TensorFlow checkpoint."""
112    before_and_after_inputs = {}
113    if session_token and checkpoint_op.session_token_tensor_name:
114      before_and_after_inputs[checkpoint_op.session_token_tensor_name] = (
115          session_token
116      )
117
118    self._maybe_run(
119        checkpoint_op.before_save_op, feed_dict=before_and_after_inputs
120    )
121    result = b''
122    if checkpoint_op.HasField('saver_def'):
123      with tempfile.NamedTemporaryFile() as tmpfile:
124        save_tensor_inputs = before_and_after_inputs.copy()
125        save_tensor_inputs[checkpoint_op.saver_def.filename_tensor_name] = (
126            tmpfile.name
127        )
128        self._session.run(
129            checkpoint_op.saver_def.save_tensor_name,
130            feed_dict=save_tensor_inputs,
131        )
132        # TensorFlow overwrites (via move) the output file, so the data can't be
133        # read from the filehandle. Deletion still works properly, though.
134        with open(tmpfile.name, 'rb') as f:
135          result = f.read()
136    self._maybe_run(
137        checkpoint_op.after_save_op, feed_dict=before_and_after_inputs
138    )
139    return result
140
141  def _build_slices(
142      self,
143      callback_token: bytes,
144      server_val: list[Any],
145      max_key: int,
146      select_fn_initialize_op: str,
147      select_fn_server_val_input_tensor_names: list[str],
148      select_fn_key_input_tensor_name: str,
149      select_fn_filename_input_tensor_name: str,
150      select_fn_target_tensor_name: str,
151  ):
152    """Builds the slices for a ServeSlices call."""
153    del callback_token
154    slices: list[bytes] = []
155    for i in range(0, max_key + 1):
156      self._maybe_run(select_fn_initialize_op)
157      with tempfile.NamedTemporaryFile() as tmpfile:
158        feed_dict = dict(
159            zip(select_fn_server_val_input_tensor_names, server_val)
160        )
161        feed_dict[select_fn_key_input_tensor_name] = i
162        feed_dict[select_fn_filename_input_tensor_name] = tmpfile.name
163        self._session.run(select_fn_target_tensor_name, feed_dict=feed_dict)
164        # TensorFlow overwrites (via move) the output file, so the data can't be
165        # read from the filehandle. Deletion still works properly, though.
166        with open(tmpfile.name, 'rb') as f:
167          slices.append(f.read())
168    return slices
169
170  @functools.cached_property
171  def client_plan(self) -> bytes:
172    """The serialized ClientOnlyPlan corresponding to the Plan proto."""
173    client_only_plan = plan_pb2.ClientOnlyPlan(
174        phase=self._plan.phase[0].client_phase,
175        graph=self._plan.client_graph_bytes.value,
176        tflite_graph=self._plan.client_tflite_graph_bytes)
177    if self._plan.HasField('tensorflow_config_proto'):
178      client_only_plan.tensorflow_config_proto.CopyFrom(
179          self._plan.tensorflow_config_proto)
180    return client_only_plan.SerializeToString()
181
182  @property
183  def client_checkpoint(self) -> bytes:
184    """The initial checkpoint for use by clients."""
185    return self._client_checkpoint
186
187  def finalize(self, update: bytes) -> bytes:
188    """Loads an intermediate update and return the final result."""
189    self._restore_state(
190        self._plan.phase[0].server_phase.read_intermediate_update, update)
191    self._maybe_run(self._plan.phase[0].server_phase
192                    .intermediate_aggregate_into_accumulators_op)
193    # write_accumulators and metrics are not needed by Federated Program
194    # computations because all results are included in the server savepoint.
195    self._maybe_run(
196        self._plan.phase[0].server_phase.apply_aggregrated_updates_op)
197    return self._save_state(self._plan.server_savepoint)
198
199  @property
200  def slices(self) -> dict[str, list[bytes]]:
201    """The Federated Select slices, keyed by served_at_id."""
202    # Return a copy to prevent mutations.
203    return self._slices.copy()
204