1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Utilities for working with Plan protos and TensorFlow. 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard WorkerSee the field comments in plan.proto for more information about each operation 17*14675a02SAndroid Build Coastguard Workerand when it should be run. 18*14675a02SAndroid Build Coastguard Worker""" 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerimport functools 21*14675a02SAndroid Build Coastguard Workerimport tempfile 22*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional 23*14675a02SAndroid Build Coastguard Workerimport uuid 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 26*14675a02SAndroid Build Coastguard Worker 27*14675a02SAndroid Build Coastguard Workerfrom google.protobuf import message 28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 29*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import serve_slices as serve_slices_registry 30*14675a02SAndroid Build Coastguard Worker 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Workerclass Session: 33*14675a02SAndroid Build Coastguard Worker """A session for performing L2 Plan operations. 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Worker This class only supports loading a single intermediate update. 36*14675a02SAndroid Build Coastguard Worker """ 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker def __init__(self, plan: plan_pb2.Plan, checkpoint: bytes): 39*14675a02SAndroid Build Coastguard Worker if len(plan.phase) != 1: 40*14675a02SAndroid Build Coastguard Worker raise ValueError('plan must contain exactly 1 phase.') 41*14675a02SAndroid Build Coastguard Worker if not plan.phase[0].HasField('server_phase'): 42*14675a02SAndroid Build Coastguard Worker raise ValueError('plan.phase[0] is missing server_phase.') 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker graph_def = tf.compat.v1.GraphDef() 45*14675a02SAndroid Build Coastguard Worker try: 46*14675a02SAndroid Build Coastguard Worker plan.server_graph_bytes.Unpack(graph_def) 47*14675a02SAndroid Build Coastguard Worker except message.DecodeError as e: 48*14675a02SAndroid Build Coastguard Worker raise ValueError('Unable to parse server graph.') from e 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker graph = tf.Graph() 51*14675a02SAndroid Build Coastguard Worker with graph.as_default(): 52*14675a02SAndroid Build Coastguard Worker tf.import_graph_def(graph_def, name='') 53*14675a02SAndroid Build Coastguard Worker self._session = tf.compat.v1.Session(graph=graph) 54*14675a02SAndroid Build Coastguard Worker self._plan = plan 55*14675a02SAndroid Build Coastguard Worker self._restore_state(plan.server_savepoint, checkpoint) 56*14675a02SAndroid Build Coastguard Worker self._maybe_run(plan.phase[0].server_phase.phase_init_op) 57*14675a02SAndroid Build Coastguard Worker 58*14675a02SAndroid Build Coastguard Worker serve_slices_calls = [] 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Worker def record_serve_slices_call(*args): 61*14675a02SAndroid Build Coastguard Worker served_at_id = str(uuid.uuid4()) 62*14675a02SAndroid Build Coastguard Worker serve_slices_calls.append((served_at_id, args)) 63*14675a02SAndroid Build Coastguard Worker return served_at_id 64*14675a02SAndroid Build Coastguard Worker 65*14675a02SAndroid Build Coastguard Worker with serve_slices_registry.register_serve_slices_callback( 66*14675a02SAndroid Build Coastguard Worker record_serve_slices_call 67*14675a02SAndroid Build Coastguard Worker ) as token: 68*14675a02SAndroid Build Coastguard Worker self._client_checkpoint = self._save_state( 69*14675a02SAndroid Build Coastguard Worker plan.phase[0].server_phase.write_client_init, session_token=token 70*14675a02SAndroid Build Coastguard Worker ) 71*14675a02SAndroid Build Coastguard Worker self._slices = { 72*14675a02SAndroid Build Coastguard Worker k: self._build_slices(*args) for k, args in serve_slices_calls 73*14675a02SAndroid Build Coastguard Worker } 74*14675a02SAndroid Build Coastguard Worker 75*14675a02SAndroid Build Coastguard Worker def __enter__(self) -> 'Session': 76*14675a02SAndroid Build Coastguard Worker self._session.__enter__() 77*14675a02SAndroid Build Coastguard Worker return self 78*14675a02SAndroid Build Coastguard Worker 79*14675a02SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_value, tb) -> None: 80*14675a02SAndroid Build Coastguard Worker self._session.__exit__(exc_type, exc_value, tb) 81*14675a02SAndroid Build Coastguard Worker 82*14675a02SAndroid Build Coastguard Worker def close(self) -> None: 83*14675a02SAndroid Build Coastguard Worker """Closes the session, releasing resources.""" 84*14675a02SAndroid Build Coastguard Worker self._session.close() 85*14675a02SAndroid Build Coastguard Worker 86*14675a02SAndroid Build Coastguard Worker def _maybe_run( 87*14675a02SAndroid Build Coastguard Worker self, op: str, feed_dict: Optional[dict[str, Any]] = None 88*14675a02SAndroid Build Coastguard Worker ) -> None: 89*14675a02SAndroid Build Coastguard Worker """Runs an operation if it's non-empty.""" 90*14675a02SAndroid Build Coastguard Worker if op: 91*14675a02SAndroid Build Coastguard Worker self._session.run(op, feed_dict=feed_dict) 92*14675a02SAndroid Build Coastguard Worker 93*14675a02SAndroid Build Coastguard Worker def _restore_state(self, checkpoint_op: plan_pb2.CheckpointOp, 94*14675a02SAndroid Build Coastguard Worker checkpoint: bytes) -> None: 95*14675a02SAndroid Build Coastguard Worker """Restores state from a TensorFlow checkpoint.""" 96*14675a02SAndroid Build Coastguard Worker self._maybe_run(checkpoint_op.before_restore_op) 97*14675a02SAndroid Build Coastguard Worker if checkpoint_op.HasField('saver_def'): 98*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile('wb') as tmpfile: 99*14675a02SAndroid Build Coastguard Worker tmpfile.write(checkpoint) 100*14675a02SAndroid Build Coastguard Worker tmpfile.flush() 101*14675a02SAndroid Build Coastguard Worker self._session.run( 102*14675a02SAndroid Build Coastguard Worker checkpoint_op.saver_def.restore_op_name, 103*14675a02SAndroid Build Coastguard Worker {checkpoint_op.saver_def.filename_tensor_name: tmpfile.name}) 104*14675a02SAndroid Build Coastguard Worker self._maybe_run(checkpoint_op.after_restore_op) 105*14675a02SAndroid Build Coastguard Worker 106*14675a02SAndroid Build Coastguard Worker def _save_state( 107*14675a02SAndroid Build Coastguard Worker self, 108*14675a02SAndroid Build Coastguard Worker checkpoint_op: plan_pb2.CheckpointOp, 109*14675a02SAndroid Build Coastguard Worker session_token: Optional[bytes] = None, 110*14675a02SAndroid Build Coastguard Worker ) -> bytes: 111*14675a02SAndroid Build Coastguard Worker """Saves state to a TensorFlow checkpoint.""" 112*14675a02SAndroid Build Coastguard Worker before_and_after_inputs = {} 113*14675a02SAndroid Build Coastguard Worker if session_token and checkpoint_op.session_token_tensor_name: 114*14675a02SAndroid Build Coastguard Worker before_and_after_inputs[checkpoint_op.session_token_tensor_name] = ( 115*14675a02SAndroid Build Coastguard Worker session_token 116*14675a02SAndroid Build Coastguard Worker ) 117*14675a02SAndroid Build Coastguard Worker 118*14675a02SAndroid Build Coastguard Worker self._maybe_run( 119*14675a02SAndroid Build Coastguard Worker checkpoint_op.before_save_op, feed_dict=before_and_after_inputs 120*14675a02SAndroid Build Coastguard Worker ) 121*14675a02SAndroid Build Coastguard Worker result = b'' 122*14675a02SAndroid Build Coastguard Worker if checkpoint_op.HasField('saver_def'): 123*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as tmpfile: 124*14675a02SAndroid Build Coastguard Worker save_tensor_inputs = before_and_after_inputs.copy() 125*14675a02SAndroid Build Coastguard Worker save_tensor_inputs[checkpoint_op.saver_def.filename_tensor_name] = ( 126*14675a02SAndroid Build Coastguard Worker tmpfile.name 127*14675a02SAndroid Build Coastguard Worker ) 128*14675a02SAndroid Build Coastguard Worker self._session.run( 129*14675a02SAndroid Build Coastguard Worker checkpoint_op.saver_def.save_tensor_name, 130*14675a02SAndroid Build Coastguard Worker feed_dict=save_tensor_inputs, 131*14675a02SAndroid Build Coastguard Worker ) 132*14675a02SAndroid Build Coastguard Worker # TensorFlow overwrites (via move) the output file, so the data can't be 133*14675a02SAndroid Build Coastguard Worker # read from the filehandle. Deletion still works properly, though. 134*14675a02SAndroid Build Coastguard Worker with open(tmpfile.name, 'rb') as f: 135*14675a02SAndroid Build Coastguard Worker result = f.read() 136*14675a02SAndroid Build Coastguard Worker self._maybe_run( 137*14675a02SAndroid Build Coastguard Worker checkpoint_op.after_save_op, feed_dict=before_and_after_inputs 138*14675a02SAndroid Build Coastguard Worker ) 139*14675a02SAndroid Build Coastguard Worker return result 140*14675a02SAndroid Build Coastguard Worker 141*14675a02SAndroid Build Coastguard Worker def _build_slices( 142*14675a02SAndroid Build Coastguard Worker self, 143*14675a02SAndroid Build Coastguard Worker callback_token: bytes, 144*14675a02SAndroid Build Coastguard Worker server_val: list[Any], 145*14675a02SAndroid Build Coastguard Worker max_key: int, 146*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op: str, 147*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names: list[str], 148*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name: str, 149*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name: str, 150*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name: str, 151*14675a02SAndroid Build Coastguard Worker ): 152*14675a02SAndroid Build Coastguard Worker """Builds the slices for a ServeSlices call.""" 153*14675a02SAndroid Build Coastguard Worker del callback_token 154*14675a02SAndroid Build Coastguard Worker slices: list[bytes] = [] 155*14675a02SAndroid Build Coastguard Worker for i in range(0, max_key + 1): 156*14675a02SAndroid Build Coastguard Worker self._maybe_run(select_fn_initialize_op) 157*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as tmpfile: 158*14675a02SAndroid Build Coastguard Worker feed_dict = dict( 159*14675a02SAndroid Build Coastguard Worker zip(select_fn_server_val_input_tensor_names, server_val) 160*14675a02SAndroid Build Coastguard Worker ) 161*14675a02SAndroid Build Coastguard Worker feed_dict[select_fn_key_input_tensor_name] = i 162*14675a02SAndroid Build Coastguard Worker feed_dict[select_fn_filename_input_tensor_name] = tmpfile.name 163*14675a02SAndroid Build Coastguard Worker self._session.run(select_fn_target_tensor_name, feed_dict=feed_dict) 164*14675a02SAndroid Build Coastguard Worker # TensorFlow overwrites (via move) the output file, so the data can't be 165*14675a02SAndroid Build Coastguard Worker # read from the filehandle. Deletion still works properly, though. 166*14675a02SAndroid Build Coastguard Worker with open(tmpfile.name, 'rb') as f: 167*14675a02SAndroid Build Coastguard Worker slices.append(f.read()) 168*14675a02SAndroid Build Coastguard Worker return slices 169*14675a02SAndroid Build Coastguard Worker 170*14675a02SAndroid Build Coastguard Worker @functools.cached_property 171*14675a02SAndroid Build Coastguard Worker def client_plan(self) -> bytes: 172*14675a02SAndroid Build Coastguard Worker """The serialized ClientOnlyPlan corresponding to the Plan proto.""" 173*14675a02SAndroid Build Coastguard Worker client_only_plan = plan_pb2.ClientOnlyPlan( 174*14675a02SAndroid Build Coastguard Worker phase=self._plan.phase[0].client_phase, 175*14675a02SAndroid Build Coastguard Worker graph=self._plan.client_graph_bytes.value, 176*14675a02SAndroid Build Coastguard Worker tflite_graph=self._plan.client_tflite_graph_bytes) 177*14675a02SAndroid Build Coastguard Worker if self._plan.HasField('tensorflow_config_proto'): 178*14675a02SAndroid Build Coastguard Worker client_only_plan.tensorflow_config_proto.CopyFrom( 179*14675a02SAndroid Build Coastguard Worker self._plan.tensorflow_config_proto) 180*14675a02SAndroid Build Coastguard Worker return client_only_plan.SerializeToString() 181*14675a02SAndroid Build Coastguard Worker 182*14675a02SAndroid Build Coastguard Worker @property 183*14675a02SAndroid Build Coastguard Worker def client_checkpoint(self) -> bytes: 184*14675a02SAndroid Build Coastguard Worker """The initial checkpoint for use by clients.""" 185*14675a02SAndroid Build Coastguard Worker return self._client_checkpoint 186*14675a02SAndroid Build Coastguard Worker 187*14675a02SAndroid Build Coastguard Worker def finalize(self, update: bytes) -> bytes: 188*14675a02SAndroid Build Coastguard Worker """Loads an intermediate update and return the final result.""" 189*14675a02SAndroid Build Coastguard Worker self._restore_state( 190*14675a02SAndroid Build Coastguard Worker self._plan.phase[0].server_phase.read_intermediate_update, update) 191*14675a02SAndroid Build Coastguard Worker self._maybe_run(self._plan.phase[0].server_phase 192*14675a02SAndroid Build Coastguard Worker .intermediate_aggregate_into_accumulators_op) 193*14675a02SAndroid Build Coastguard Worker # write_accumulators and metrics are not needed by Federated Program 194*14675a02SAndroid Build Coastguard Worker # computations because all results are included in the server savepoint. 195*14675a02SAndroid Build Coastguard Worker self._maybe_run( 196*14675a02SAndroid Build Coastguard Worker self._plan.phase[0].server_phase.apply_aggregrated_updates_op) 197*14675a02SAndroid Build Coastguard Worker return self._save_state(self._plan.server_savepoint) 198*14675a02SAndroid Build Coastguard Worker 199*14675a02SAndroid Build Coastguard Worker @property 200*14675a02SAndroid Build Coastguard Worker def slices(self) -> dict[str, list[bytes]]: 201*14675a02SAndroid Build Coastguard Worker """The Federated Select slices, keyed by served_at_id.""" 202*14675a02SAndroid Build Coastguard Worker # Return a copy to prevent mutations. 203*14675a02SAndroid Build Coastguard Worker return self._slices.copy() 204