1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utilities for working with Plan protos and TensorFlow. 15 16See the field comments in plan.proto for more information about each operation 17and when it should be run. 18""" 19 20import functools 21import tempfile 22from typing import Any, Optional 23import uuid 24 25import tensorflow as tf 26 27from google.protobuf import message 28from fcp.protos import plan_pb2 29from fcp.tensorflow import serve_slices as serve_slices_registry 30 31 32class Session: 33 """A session for performing L2 Plan operations. 34 35 This class only supports loading a single intermediate update. 36 """ 37 38 def __init__(self, plan: plan_pb2.Plan, checkpoint: bytes): 39 if len(plan.phase) != 1: 40 raise ValueError('plan must contain exactly 1 phase.') 41 if not plan.phase[0].HasField('server_phase'): 42 raise ValueError('plan.phase[0] is missing server_phase.') 43 44 graph_def = tf.compat.v1.GraphDef() 45 try: 46 plan.server_graph_bytes.Unpack(graph_def) 47 except message.DecodeError as e: 48 raise ValueError('Unable to parse server graph.') from e 49 50 graph = tf.Graph() 51 with graph.as_default(): 52 tf.import_graph_def(graph_def, name='') 53 self._session = tf.compat.v1.Session(graph=graph) 54 self._plan = plan 55 self._restore_state(plan.server_savepoint, checkpoint) 56 self._maybe_run(plan.phase[0].server_phase.phase_init_op) 57 58 serve_slices_calls = [] 59 60 def record_serve_slices_call(*args): 61 served_at_id = str(uuid.uuid4()) 62 serve_slices_calls.append((served_at_id, args)) 63 return served_at_id 64 65 with serve_slices_registry.register_serve_slices_callback( 66 record_serve_slices_call 67 ) as token: 68 self._client_checkpoint = self._save_state( 69 plan.phase[0].server_phase.write_client_init, session_token=token 70 ) 71 self._slices = { 72 k: self._build_slices(*args) for k, args in serve_slices_calls 73 } 74 75 def __enter__(self) -> 'Session': 76 self._session.__enter__() 77 return self 78 79 def __exit__(self, exc_type, exc_value, tb) -> None: 80 self._session.__exit__(exc_type, exc_value, tb) 81 82 def close(self) -> None: 83 """Closes the session, releasing resources.""" 84 self._session.close() 85 86 def _maybe_run( 87 self, op: str, feed_dict: Optional[dict[str, Any]] = None 88 ) -> None: 89 """Runs an operation if it's non-empty.""" 90 if op: 91 self._session.run(op, feed_dict=feed_dict) 92 93 def _restore_state(self, checkpoint_op: plan_pb2.CheckpointOp, 94 checkpoint: bytes) -> None: 95 """Restores state from a TensorFlow checkpoint.""" 96 self._maybe_run(checkpoint_op.before_restore_op) 97 if checkpoint_op.HasField('saver_def'): 98 with tempfile.NamedTemporaryFile('wb') as tmpfile: 99 tmpfile.write(checkpoint) 100 tmpfile.flush() 101 self._session.run( 102 checkpoint_op.saver_def.restore_op_name, 103 {checkpoint_op.saver_def.filename_tensor_name: tmpfile.name}) 104 self._maybe_run(checkpoint_op.after_restore_op) 105 106 def _save_state( 107 self, 108 checkpoint_op: plan_pb2.CheckpointOp, 109 session_token: Optional[bytes] = None, 110 ) -> bytes: 111 """Saves state to a TensorFlow checkpoint.""" 112 before_and_after_inputs = {} 113 if session_token and checkpoint_op.session_token_tensor_name: 114 before_and_after_inputs[checkpoint_op.session_token_tensor_name] = ( 115 session_token 116 ) 117 118 self._maybe_run( 119 checkpoint_op.before_save_op, feed_dict=before_and_after_inputs 120 ) 121 result = b'' 122 if checkpoint_op.HasField('saver_def'): 123 with tempfile.NamedTemporaryFile() as tmpfile: 124 save_tensor_inputs = before_and_after_inputs.copy() 125 save_tensor_inputs[checkpoint_op.saver_def.filename_tensor_name] = ( 126 tmpfile.name 127 ) 128 self._session.run( 129 checkpoint_op.saver_def.save_tensor_name, 130 feed_dict=save_tensor_inputs, 131 ) 132 # TensorFlow overwrites (via move) the output file, so the data can't be 133 # read from the filehandle. Deletion still works properly, though. 134 with open(tmpfile.name, 'rb') as f: 135 result = f.read() 136 self._maybe_run( 137 checkpoint_op.after_save_op, feed_dict=before_and_after_inputs 138 ) 139 return result 140 141 def _build_slices( 142 self, 143 callback_token: bytes, 144 server_val: list[Any], 145 max_key: int, 146 select_fn_initialize_op: str, 147 select_fn_server_val_input_tensor_names: list[str], 148 select_fn_key_input_tensor_name: str, 149 select_fn_filename_input_tensor_name: str, 150 select_fn_target_tensor_name: str, 151 ): 152 """Builds the slices for a ServeSlices call.""" 153 del callback_token 154 slices: list[bytes] = [] 155 for i in range(0, max_key + 1): 156 self._maybe_run(select_fn_initialize_op) 157 with tempfile.NamedTemporaryFile() as tmpfile: 158 feed_dict = dict( 159 zip(select_fn_server_val_input_tensor_names, server_val) 160 ) 161 feed_dict[select_fn_key_input_tensor_name] = i 162 feed_dict[select_fn_filename_input_tensor_name] = tmpfile.name 163 self._session.run(select_fn_target_tensor_name, feed_dict=feed_dict) 164 # TensorFlow overwrites (via move) the output file, so the data can't be 165 # read from the filehandle. Deletion still works properly, though. 166 with open(tmpfile.name, 'rb') as f: 167 slices.append(f.read()) 168 return slices 169 170 @functools.cached_property 171 def client_plan(self) -> bytes: 172 """The serialized ClientOnlyPlan corresponding to the Plan proto.""" 173 client_only_plan = plan_pb2.ClientOnlyPlan( 174 phase=self._plan.phase[0].client_phase, 175 graph=self._plan.client_graph_bytes.value, 176 tflite_graph=self._plan.client_tflite_graph_bytes) 177 if self._plan.HasField('tensorflow_config_proto'): 178 client_only_plan.tensorflow_config_proto.CopyFrom( 179 self._plan.tensorflow_config_proto) 180 return client_only_plan.SerializeToString() 181 182 @property 183 def client_checkpoint(self) -> bytes: 184 """The initial checkpoint for use by clients.""" 185 return self._client_checkpoint 186 187 def finalize(self, update: bytes) -> bytes: 188 """Loads an intermediate update and return the final result.""" 189 self._restore_state( 190 self._plan.phase[0].server_phase.read_intermediate_update, update) 191 self._maybe_run(self._plan.phase[0].server_phase 192 .intermediate_aggregate_into_accumulators_op) 193 # write_accumulators and metrics are not needed by Federated Program 194 # computations because all results are included in the server savepoint. 195 self._maybe_run( 196 self._plan.phase[0].server_phase.apply_aggregrated_updates_op) 197 return self._save_state(self._plan.server_savepoint) 198 199 @property 200 def slices(self) -> dict[str, list[bytes]]: 201 """The Federated Select slices, keyed by served_at_id.""" 202 # Return a copy to prevent mutations. 203 return self._slices.copy() 204