1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for plan_utils.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport functools 17*14675a02SAndroid Build Coastguard Workerimport tempfile 18*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 21*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import plan_utils 24*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils 25*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 26*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import serve_slices 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard WorkerDEFAULT_INITIAL_CHECKPOINT = b'initial' 29*14675a02SAndroid Build Coastguard WorkerCHECKPOINT_TENSOR_NAME = 'checkpoint' 30*14675a02SAndroid Build Coastguard WorkerINTERMEDIATE_TENSOR_NAME = 'intermediate_value' 31*14675a02SAndroid Build Coastguard WorkerFINAL_TENSOR_NAME = 'final_value' 32*14675a02SAndroid Build Coastguard WorkerNUM_SLICES = 3 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Workerdef create_plan(log_file: Optional[str] = None) -> plan_pb2.Plan: 36*14675a02SAndroid Build Coastguard Worker """Creates a test Plan that sums inputs.""" 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker def log_op(name: str) -> tf.Operation: 39*14675a02SAndroid Build Coastguard Worker """Helper function to log op invocations to a file.""" 40*14675a02SAndroid Build Coastguard Worker if log_file: 41*14675a02SAndroid Build Coastguard Worker return tf.print(name, output_stream=f'file://{log_file}') 42*14675a02SAndroid Build Coastguard Worker return tf.raw_ops.NoOp() 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker def create_checkpoint_op( 45*14675a02SAndroid Build Coastguard Worker name: str, 46*14675a02SAndroid Build Coastguard Worker filename_op: Any, 47*14675a02SAndroid Build Coastguard Worker save_op: Any = None, 48*14675a02SAndroid Build Coastguard Worker restore_op: Any = None, 49*14675a02SAndroid Build Coastguard Worker session_token_tensor_name: Optional[str] = None, 50*14675a02SAndroid Build Coastguard Worker ) -> plan_pb2.CheckpointOp: 51*14675a02SAndroid Build Coastguard Worker before_restore = log_op(f'{name}/before_restore') 52*14675a02SAndroid Build Coastguard Worker after_restore = log_op(f'{name}/after_restore') 53*14675a02SAndroid Build Coastguard Worker before_save = log_op(f'{name}/before_save') 54*14675a02SAndroid Build Coastguard Worker after_save = log_op(f'{name}/after_save') 55*14675a02SAndroid Build Coastguard Worker with tf.control_dependencies( 56*14675a02SAndroid Build Coastguard Worker [save_op if save_op is not None else tf.raw_ops.NoOp()]): 57*14675a02SAndroid Build Coastguard Worker save_op = log_op(f'{name}/save') 58*14675a02SAndroid Build Coastguard Worker with tf.control_dependencies( 59*14675a02SAndroid Build Coastguard Worker [restore_op if restore_op is not None else tf.raw_ops.NoOp()]): 60*14675a02SAndroid Build Coastguard Worker restore_op = log_op(f'{name}/restore') 61*14675a02SAndroid Build Coastguard Worker return plan_pb2.CheckpointOp( 62*14675a02SAndroid Build Coastguard Worker saver_def=tf.compat.v1.train.SaverDef( 63*14675a02SAndroid Build Coastguard Worker filename_tensor_name=filename_op.name, 64*14675a02SAndroid Build Coastguard Worker restore_op_name=restore_op.name, 65*14675a02SAndroid Build Coastguard Worker save_tensor_name=save_op.name, 66*14675a02SAndroid Build Coastguard Worker version=tf.compat.v1.train.SaverDef.V1, 67*14675a02SAndroid Build Coastguard Worker ), 68*14675a02SAndroid Build Coastguard Worker before_restore_op=before_restore.name, 69*14675a02SAndroid Build Coastguard Worker after_restore_op=after_restore.name, 70*14675a02SAndroid Build Coastguard Worker before_save_op=before_save.name, 71*14675a02SAndroid Build Coastguard Worker after_save_op=after_save.name, 72*14675a02SAndroid Build Coastguard Worker session_token_tensor_name=session_token_tensor_name, 73*14675a02SAndroid Build Coastguard Worker ) 74*14675a02SAndroid Build Coastguard Worker 75*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Graph().as_default() as client_graph: 76*14675a02SAndroid Build Coastguard Worker tf.constant(0) 77*14675a02SAndroid Build Coastguard Worker 78*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Graph().as_default() as server_graph: 79*14675a02SAndroid Build Coastguard Worker # Initialization: 80*14675a02SAndroid Build Coastguard Worker last_client_update = tf.Variable(0, dtype=tf.int32) 81*14675a02SAndroid Build Coastguard Worker intermediate_acc = tf.Variable(0, dtype=tf.int32) 82*14675a02SAndroid Build Coastguard Worker last_intermediate_update = tf.Variable(0, dtype=tf.int32) 83*14675a02SAndroid Build Coastguard Worker final_acc = tf.Variable(0, dtype=tf.int32) 84*14675a02SAndroid Build Coastguard Worker with tf.control_dependencies([ 85*14675a02SAndroid Build Coastguard Worker last_client_update.initializer, intermediate_acc.initializer, 86*14675a02SAndroid Build Coastguard Worker last_intermediate_update.initializer, final_acc.initializer 87*14675a02SAndroid Build Coastguard Worker ]): 88*14675a02SAndroid Build Coastguard Worker phase_init_op = log_op('phase_init') 89*14675a02SAndroid Build Coastguard Worker 90*14675a02SAndroid Build Coastguard Worker # Ops for Federated Select: 91*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op = log_op('slices/initialize') 92*14675a02SAndroid Build Coastguard Worker select_fn_server_vals = [ 93*14675a02SAndroid Build Coastguard Worker tf.constant(1234), 94*14675a02SAndroid Build Coastguard Worker tf.constant('asdf'), 95*14675a02SAndroid Build Coastguard Worker tf.constant([1, 2, 3]), 96*14675a02SAndroid Build Coastguard Worker ] 97*14675a02SAndroid Build Coastguard Worker select_fn_server_val_inputs = [ 98*14675a02SAndroid Build Coastguard Worker tf.compat.v1.placeholder(v.dtype) for v in select_fn_server_vals 99*14675a02SAndroid Build Coastguard Worker ] 100*14675a02SAndroid Build Coastguard Worker select_fn_key_input = tf.compat.v1.placeholder(tf.int32, shape=()) 101*14675a02SAndroid Build Coastguard Worker select_fn_filename_input = tf.compat.v1.placeholder(tf.string, shape=()) 102*14675a02SAndroid Build Coastguard Worker assertions = [ 103*14675a02SAndroid Build Coastguard Worker tf.debugging.assert_equal(placeholder, constant) 104*14675a02SAndroid Build Coastguard Worker for placeholder, constant in zip( 105*14675a02SAndroid Build Coastguard Worker select_fn_server_val_inputs, select_fn_server_vals 106*14675a02SAndroid Build Coastguard Worker ) 107*14675a02SAndroid Build Coastguard Worker ] 108*14675a02SAndroid Build Coastguard Worker with tf.control_dependencies([log_op('slices/save_slice')] + assertions): 109*14675a02SAndroid Build Coastguard Worker select_fn_save_op = tf.io.write_file( 110*14675a02SAndroid Build Coastguard Worker select_fn_filename_input, tf.strings.as_string(select_fn_key_input) 111*14675a02SAndroid Build Coastguard Worker ) 112*14675a02SAndroid Build Coastguard Worker # Some tests disable passing the callback token; set `served_at_id` to '-' 113*14675a02SAndroid Build Coastguard Worker # in that case. 114*14675a02SAndroid Build Coastguard Worker callback_token = tf.compat.v1.placeholder_with_default('', shape=()) 115*14675a02SAndroid Build Coastguard Worker served_at_id = tf.cond( 116*14675a02SAndroid Build Coastguard Worker tf.equal(callback_token, ''), 117*14675a02SAndroid Build Coastguard Worker lambda: '-', 118*14675a02SAndroid Build Coastguard Worker functools.partial( 119*14675a02SAndroid Build Coastguard Worker serve_slices.serve_slices, 120*14675a02SAndroid Build Coastguard Worker callback_token=callback_token, 121*14675a02SAndroid Build Coastguard Worker server_val=select_fn_server_vals, 122*14675a02SAndroid Build Coastguard Worker max_key=NUM_SLICES - 1, 123*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op=select_fn_initialize_op.name, 124*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names=[ 125*14675a02SAndroid Build Coastguard Worker v.name for v in select_fn_server_val_inputs 126*14675a02SAndroid Build Coastguard Worker ], 127*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name=select_fn_key_input.name, 128*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name=select_fn_filename_input.name, 129*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name=select_fn_save_op.name, 130*14675a02SAndroid Build Coastguard Worker ), 131*14675a02SAndroid Build Coastguard Worker ) 132*14675a02SAndroid Build Coastguard Worker 133*14675a02SAndroid Build Coastguard Worker # Ops for L2 Aggregation: 134*14675a02SAndroid Build Coastguard Worker client_checkpoint_data = tf.Variable( 135*14675a02SAndroid Build Coastguard Worker DEFAULT_INITIAL_CHECKPOINT, dtype=tf.string) 136*14675a02SAndroid Build Coastguard Worker 137*14675a02SAndroid Build Coastguard Worker write_client_init_filename = tf.compat.v1.placeholder(tf.string, shape=()) 138*14675a02SAndroid Build Coastguard Worker client_checkpoint_data_value = tf.cond( 139*14675a02SAndroid Build Coastguard Worker tf.compat.v1.is_variable_initialized(client_checkpoint_data), 140*14675a02SAndroid Build Coastguard Worker client_checkpoint_data.read_value, 141*14675a02SAndroid Build Coastguard Worker lambda: client_checkpoint_data.initial_value, 142*14675a02SAndroid Build Coastguard Worker ) 143*14675a02SAndroid Build Coastguard Worker write_client_init_op = create_checkpoint_op( 144*14675a02SAndroid Build Coastguard Worker 'write_client_init', 145*14675a02SAndroid Build Coastguard Worker write_client_init_filename, 146*14675a02SAndroid Build Coastguard Worker save_op=tf.io.write_file( 147*14675a02SAndroid Build Coastguard Worker write_client_init_filename, 148*14675a02SAndroid Build Coastguard Worker tf.strings.join( 149*14675a02SAndroid Build Coastguard Worker [client_checkpoint_data_value, served_at_id], separator=' ' 150*14675a02SAndroid Build Coastguard Worker ), 151*14675a02SAndroid Build Coastguard Worker ), 152*14675a02SAndroid Build Coastguard Worker session_token_tensor_name=callback_token.name, 153*14675a02SAndroid Build Coastguard Worker ) 154*14675a02SAndroid Build Coastguard Worker 155*14675a02SAndroid Build Coastguard Worker read_intermediate_update_filename = tf.compat.v1.placeholder( 156*14675a02SAndroid Build Coastguard Worker tf.string, shape=()) 157*14675a02SAndroid Build Coastguard Worker read_intermediate_update_op = create_checkpoint_op( 158*14675a02SAndroid Build Coastguard Worker 'read_intermediate_update', 159*14675a02SAndroid Build Coastguard Worker read_intermediate_update_filename, 160*14675a02SAndroid Build Coastguard Worker restore_op=last_intermediate_update.assign( 161*14675a02SAndroid Build Coastguard Worker tf.raw_ops.Restore( 162*14675a02SAndroid Build Coastguard Worker file_pattern=read_intermediate_update_filename, 163*14675a02SAndroid Build Coastguard Worker tensor_name=INTERMEDIATE_TENSOR_NAME, 164*14675a02SAndroid Build Coastguard Worker dt=tf.int32))) 165*14675a02SAndroid Build Coastguard Worker 166*14675a02SAndroid Build Coastguard Worker with tf.control_dependencies([log_op('apply_aggregated_updates')]): 167*14675a02SAndroid Build Coastguard Worker apply_aggregated_updates_op = final_acc.assign_add( 168*14675a02SAndroid Build Coastguard Worker last_intermediate_update) 169*14675a02SAndroid Build Coastguard Worker 170*14675a02SAndroid Build Coastguard Worker server_savepoint_filename = tf.compat.v1.placeholder(tf.string, shape=()) 171*14675a02SAndroid Build Coastguard Worker server_savepoint_op = create_checkpoint_op( 172*14675a02SAndroid Build Coastguard Worker 'server_savepoint', 173*14675a02SAndroid Build Coastguard Worker server_savepoint_filename, 174*14675a02SAndroid Build Coastguard Worker save_op=tf.raw_ops.Save( 175*14675a02SAndroid Build Coastguard Worker filename=server_savepoint_filename, 176*14675a02SAndroid Build Coastguard Worker tensor_names=[FINAL_TENSOR_NAME], 177*14675a02SAndroid Build Coastguard Worker data=[final_acc]), 178*14675a02SAndroid Build Coastguard Worker restore_op=client_checkpoint_data.assign( 179*14675a02SAndroid Build Coastguard Worker tf.raw_ops.Restore( 180*14675a02SAndroid Build Coastguard Worker file_pattern=server_savepoint_filename, 181*14675a02SAndroid Build Coastguard Worker tensor_name=CHECKPOINT_TENSOR_NAME, 182*14675a02SAndroid Build Coastguard Worker dt=tf.string))) 183*14675a02SAndroid Build Coastguard Worker 184*14675a02SAndroid Build Coastguard Worker config_proto = tf.compat.v1.ConfigProto(operation_timeout_in_ms=1234) 185*14675a02SAndroid Build Coastguard Worker 186*14675a02SAndroid Build Coastguard Worker plan = plan_pb2.Plan( 187*14675a02SAndroid Build Coastguard Worker phase=[ 188*14675a02SAndroid Build Coastguard Worker plan_pb2.Plan.Phase( 189*14675a02SAndroid Build Coastguard Worker client_phase=plan_pb2.ClientPhase(name='ClientPhase'), 190*14675a02SAndroid Build Coastguard Worker server_phase=plan_pb2.ServerPhase( 191*14675a02SAndroid Build Coastguard Worker phase_init_op=phase_init_op.name, 192*14675a02SAndroid Build Coastguard Worker write_client_init=write_client_init_op, 193*14675a02SAndroid Build Coastguard Worker read_intermediate_update=read_intermediate_update_op, 194*14675a02SAndroid Build Coastguard Worker apply_aggregrated_updates_op=( 195*14675a02SAndroid Build Coastguard Worker apply_aggregated_updates_op.name))) 196*14675a02SAndroid Build Coastguard Worker ], 197*14675a02SAndroid Build Coastguard Worker server_savepoint=server_savepoint_op, 198*14675a02SAndroid Build Coastguard Worker client_tflite_graph_bytes=b'tflite-graph', 199*14675a02SAndroid Build Coastguard Worker version=1) 200*14675a02SAndroid Build Coastguard Worker plan.client_graph_bytes.Pack(client_graph.as_graph_def()) 201*14675a02SAndroid Build Coastguard Worker plan.server_graph_bytes.Pack(server_graph.as_graph_def()) 202*14675a02SAndroid Build Coastguard Worker plan.tensorflow_config_proto.Pack(config_proto) 203*14675a02SAndroid Build Coastguard Worker return plan 204*14675a02SAndroid Build Coastguard Worker 205*14675a02SAndroid Build Coastguard Worker 206*14675a02SAndroid Build Coastguard Workerdef create_checkpoint(tensor_name=b'test'): 207*14675a02SAndroid Build Coastguard Worker """Creates a test initial checkpoint.""" 208*14675a02SAndroid Build Coastguard Worker return test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: tensor_name}) 209*14675a02SAndroid Build Coastguard Worker 210*14675a02SAndroid Build Coastguard Worker 211*14675a02SAndroid Build Coastguard Workerclass PlanUtilsTest(absltest.TestCase): 212*14675a02SAndroid Build Coastguard Worker 213*14675a02SAndroid Build Coastguard Worker def test_session_enter_exit(self): 214*14675a02SAndroid Build Coastguard Worker self.assertIsNone(tf.compat.v1.get_default_session()) 215*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(create_plan(), create_checkpoint()): 216*14675a02SAndroid Build Coastguard Worker self.assertIsNotNone(tf.compat.v1.get_default_session()) 217*14675a02SAndroid Build Coastguard Worker self.assertIsNone(tf.compat.v1.get_default_session()) 218*14675a02SAndroid Build Coastguard Worker 219*14675a02SAndroid Build Coastguard Worker def test_session_without_phase(self): 220*14675a02SAndroid Build Coastguard Worker plan = create_plan() 221*14675a02SAndroid Build Coastguard Worker plan.ClearField('phase') 222*14675a02SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 223*14675a02SAndroid Build Coastguard Worker plan_utils.Session(plan, create_checkpoint()) 224*14675a02SAndroid Build Coastguard Worker 225*14675a02SAndroid Build Coastguard Worker def test_session_without_server_phase(self): 226*14675a02SAndroid Build Coastguard Worker plan = create_plan() 227*14675a02SAndroid Build Coastguard Worker plan.phase[0].ClearField('server_phase') 228*14675a02SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 229*14675a02SAndroid Build Coastguard Worker plan_utils.Session(plan, create_checkpoint()) 230*14675a02SAndroid Build Coastguard Worker 231*14675a02SAndroid Build Coastguard Worker def test_session_with_multiple_phases(self): 232*14675a02SAndroid Build Coastguard Worker plan = create_plan() 233*14675a02SAndroid Build Coastguard Worker plan.phase.append(plan.phase[0]) 234*14675a02SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 235*14675a02SAndroid Build Coastguard Worker plan_utils.Session(plan, create_checkpoint()) 236*14675a02SAndroid Build Coastguard Worker 237*14675a02SAndroid Build Coastguard Worker def test_session_client_plan(self): 238*14675a02SAndroid Build Coastguard Worker plan = create_plan() 239*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, create_checkpoint()) as session: 240*14675a02SAndroid Build Coastguard Worker self.assertEqual( 241*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 242*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan( 243*14675a02SAndroid Build Coastguard Worker phase=plan.phase[0].client_phase, 244*14675a02SAndroid Build Coastguard Worker graph=plan.client_graph_bytes.value, 245*14675a02SAndroid Build Coastguard Worker tflite_graph=plan.client_tflite_graph_bytes, 246*14675a02SAndroid Build Coastguard Worker tensorflow_config_proto=plan.tensorflow_config_proto)) 247*14675a02SAndroid Build Coastguard Worker 248*14675a02SAndroid Build Coastguard Worker def test_session_client_plan_without_tensorflow_config(self): 249*14675a02SAndroid Build Coastguard Worker plan = create_plan() 250*14675a02SAndroid Build Coastguard Worker plan.ClearField('tensorflow_config_proto') 251*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, create_checkpoint()) as session: 252*14675a02SAndroid Build Coastguard Worker self.assertEqual( 253*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 254*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan( 255*14675a02SAndroid Build Coastguard Worker phase=plan.phase[0].client_phase, 256*14675a02SAndroid Build Coastguard Worker graph=plan.client_graph_bytes.value, 257*14675a02SAndroid Build Coastguard Worker tflite_graph=plan.client_tflite_graph_bytes)) 258*14675a02SAndroid Build Coastguard Worker 259*14675a02SAndroid Build Coastguard Worker def test_session_client_plan_without_tflite_graph(self): 260*14675a02SAndroid Build Coastguard Worker plan = create_plan() 261*14675a02SAndroid Build Coastguard Worker plan.ClearField('client_tflite_graph_bytes') 262*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, create_checkpoint()) as session: 263*14675a02SAndroid Build Coastguard Worker self.assertEqual( 264*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 265*14675a02SAndroid Build Coastguard Worker plan_pb2.ClientOnlyPlan( 266*14675a02SAndroid Build Coastguard Worker phase=plan.phase[0].client_phase, 267*14675a02SAndroid Build Coastguard Worker graph=plan.client_graph_bytes.value, 268*14675a02SAndroid Build Coastguard Worker tensorflow_config_proto=plan.tensorflow_config_proto)) 269*14675a02SAndroid Build Coastguard Worker 270*14675a02SAndroid Build Coastguard Worker def test_session_client_checkpoint(self): 271*14675a02SAndroid Build Coastguard Worker expected = b'test-client-checkpoint' 272*14675a02SAndroid Build Coastguard Worker with plan_utils.Session( 273*14675a02SAndroid Build Coastguard Worker create_plan(), 274*14675a02SAndroid Build Coastguard Worker test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: expected 275*14675a02SAndroid Build Coastguard Worker })) as session: 276*14675a02SAndroid Build Coastguard Worker self.assertEqual( 277*14675a02SAndroid Build Coastguard Worker session.client_checkpoint, 278*14675a02SAndroid Build Coastguard Worker expected + b' ' + next(iter(session.slices)).encode(), 279*14675a02SAndroid Build Coastguard Worker ) 280*14675a02SAndroid Build Coastguard Worker 281*14675a02SAndroid Build Coastguard Worker def test_session_client_checkpoint_without_server_savepoint(self): 282*14675a02SAndroid Build Coastguard Worker plan = create_plan() 283*14675a02SAndroid Build Coastguard Worker # If server_savepoint isn't set, the checkpoint shouldn't be loaded. 284*14675a02SAndroid Build Coastguard Worker plan.ClearField('server_savepoint') 285*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, create_checkpoint()) as session: 286*14675a02SAndroid Build Coastguard Worker self.assertStartsWith( 287*14675a02SAndroid Build Coastguard Worker session.client_checkpoint, DEFAULT_INITIAL_CHECKPOINT + b' ' 288*14675a02SAndroid Build Coastguard Worker ) 289*14675a02SAndroid Build Coastguard Worker 290*14675a02SAndroid Build Coastguard Worker def test_session_finalize(self): 291*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile('r') as tmpfile: 292*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(create_plan(tmpfile.name), 293*14675a02SAndroid Build Coastguard Worker create_checkpoint()) as session: 294*14675a02SAndroid Build Coastguard Worker checkpoint = session.finalize( 295*14675a02SAndroid Build Coastguard Worker test_utils.create_checkpoint({INTERMEDIATE_TENSOR_NAME: 3})) 296*14675a02SAndroid Build Coastguard Worker self.assertSequenceEqual( 297*14675a02SAndroid Build Coastguard Worker tmpfile.read().splitlines(), 298*14675a02SAndroid Build Coastguard Worker [ 299*14675a02SAndroid Build Coastguard Worker 'server_savepoint/before_restore', 300*14675a02SAndroid Build Coastguard Worker 'server_savepoint/restore', 301*14675a02SAndroid Build Coastguard Worker 'server_savepoint/after_restore', 302*14675a02SAndroid Build Coastguard Worker 'phase_init', 303*14675a02SAndroid Build Coastguard Worker 'write_client_init/before_save', 304*14675a02SAndroid Build Coastguard Worker 'write_client_init/save', 305*14675a02SAndroid Build Coastguard Worker 'write_client_init/after_save', 306*14675a02SAndroid Build Coastguard Worker ] 307*14675a02SAndroid Build Coastguard Worker + ['slices/initialize', 'slices/save_slice'] * NUM_SLICES 308*14675a02SAndroid Build Coastguard Worker + [ 309*14675a02SAndroid Build Coastguard Worker 'read_intermediate_update/before_restore', 310*14675a02SAndroid Build Coastguard Worker 'read_intermediate_update/restore', 311*14675a02SAndroid Build Coastguard Worker 'read_intermediate_update/after_restore', 312*14675a02SAndroid Build Coastguard Worker 'apply_aggregated_updates', 313*14675a02SAndroid Build Coastguard Worker 'server_savepoint/before_save', 314*14675a02SAndroid Build Coastguard Worker 'server_savepoint/save', 315*14675a02SAndroid Build Coastguard Worker 'server_savepoint/after_save', 316*14675a02SAndroid Build Coastguard Worker ], 317*14675a02SAndroid Build Coastguard Worker ) 318*14675a02SAndroid Build Coastguard Worker 319*14675a02SAndroid Build Coastguard Worker result = test_utils.read_tensor_from_checkpoint(checkpoint, 320*14675a02SAndroid Build Coastguard Worker FINAL_TENSOR_NAME, tf.int32) 321*14675a02SAndroid Build Coastguard Worker # The value should be propagated from the intermediate aggregate. 322*14675a02SAndroid Build Coastguard Worker self.assertEqual(result, 3) 323*14675a02SAndroid Build Coastguard Worker 324*14675a02SAndroid Build Coastguard Worker def test_session_with_tensorflow_error(self): 325*14675a02SAndroid Build Coastguard Worker plan = create_plan() 326*14675a02SAndroid Build Coastguard Worker plan.phase[0].server_phase.phase_init_op = 'does-not-exist' 327*14675a02SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 328*14675a02SAndroid Build Coastguard Worker plan_utils.Session(plan, create_checkpoint()) 329*14675a02SAndroid Build Coastguard Worker 330*14675a02SAndroid Build Coastguard Worker def test_session_slices(self): 331*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(create_plan(), create_checkpoint()) as session: 332*14675a02SAndroid Build Coastguard Worker # The served_at_id should match the value in the client checkpoint. 333*14675a02SAndroid Build Coastguard Worker served_at_id = session.client_checkpoint.split(b' ')[1].decode() 334*14675a02SAndroid Build Coastguard Worker self.assertSameElements(session.slices.keys(), [served_at_id]) 335*14675a02SAndroid Build Coastguard Worker self.assertListEqual( 336*14675a02SAndroid Build Coastguard Worker session.slices[served_at_id], 337*14675a02SAndroid Build Coastguard Worker [str(i).encode() for i in range(NUM_SLICES)], 338*14675a02SAndroid Build Coastguard Worker ) 339*14675a02SAndroid Build Coastguard Worker 340*14675a02SAndroid Build Coastguard Worker def test_session_without_slices(self): 341*14675a02SAndroid Build Coastguard Worker plan = create_plan() 342*14675a02SAndroid Build Coastguard Worker plan.phase[0].server_phase.write_client_init.ClearField( 343*14675a02SAndroid Build Coastguard Worker 'session_token_tensor_name' 344*14675a02SAndroid Build Coastguard Worker ) 345*14675a02SAndroid Build Coastguard Worker with plan_utils.Session(plan, create_checkpoint()) as session: 346*14675a02SAndroid Build Coastguard Worker self.assertEmpty(session.slices) 347*14675a02SAndroid Build Coastguard Worker 348*14675a02SAndroid Build Coastguard Worker 349*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 350*14675a02SAndroid Build Coastguard Worker absltest.main() 351