xref: /aosp_15_r20/external/federated-compute/fcp/demo/plan_utils_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Tests for plan_utils."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport functools
17*14675a02SAndroid Build Coastguard Workerimport tempfile
18*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest
21*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import plan_utils
24*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils
25*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
26*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import serve_slices
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard WorkerDEFAULT_INITIAL_CHECKPOINT = b'initial'
29*14675a02SAndroid Build Coastguard WorkerCHECKPOINT_TENSOR_NAME = 'checkpoint'
30*14675a02SAndroid Build Coastguard WorkerINTERMEDIATE_TENSOR_NAME = 'intermediate_value'
31*14675a02SAndroid Build Coastguard WorkerFINAL_TENSOR_NAME = 'final_value'
32*14675a02SAndroid Build Coastguard WorkerNUM_SLICES = 3
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Workerdef create_plan(log_file: Optional[str] = None) -> plan_pb2.Plan:
36*14675a02SAndroid Build Coastguard Worker  """Creates a test Plan that sums inputs."""
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker  def log_op(name: str) -> tf.Operation:
39*14675a02SAndroid Build Coastguard Worker    """Helper function to log op invocations to a file."""
40*14675a02SAndroid Build Coastguard Worker    if log_file:
41*14675a02SAndroid Build Coastguard Worker      return tf.print(name, output_stream=f'file://{log_file}')
42*14675a02SAndroid Build Coastguard Worker    return tf.raw_ops.NoOp()
43*14675a02SAndroid Build Coastguard Worker
44*14675a02SAndroid Build Coastguard Worker  def create_checkpoint_op(
45*14675a02SAndroid Build Coastguard Worker      name: str,
46*14675a02SAndroid Build Coastguard Worker      filename_op: Any,
47*14675a02SAndroid Build Coastguard Worker      save_op: Any = None,
48*14675a02SAndroid Build Coastguard Worker      restore_op: Any = None,
49*14675a02SAndroid Build Coastguard Worker      session_token_tensor_name: Optional[str] = None,
50*14675a02SAndroid Build Coastguard Worker  ) -> plan_pb2.CheckpointOp:
51*14675a02SAndroid Build Coastguard Worker    before_restore = log_op(f'{name}/before_restore')
52*14675a02SAndroid Build Coastguard Worker    after_restore = log_op(f'{name}/after_restore')
53*14675a02SAndroid Build Coastguard Worker    before_save = log_op(f'{name}/before_save')
54*14675a02SAndroid Build Coastguard Worker    after_save = log_op(f'{name}/after_save')
55*14675a02SAndroid Build Coastguard Worker    with tf.control_dependencies(
56*14675a02SAndroid Build Coastguard Worker        [save_op if save_op is not None else tf.raw_ops.NoOp()]):
57*14675a02SAndroid Build Coastguard Worker      save_op = log_op(f'{name}/save')
58*14675a02SAndroid Build Coastguard Worker    with tf.control_dependencies(
59*14675a02SAndroid Build Coastguard Worker        [restore_op if restore_op is not None else tf.raw_ops.NoOp()]):
60*14675a02SAndroid Build Coastguard Worker      restore_op = log_op(f'{name}/restore')
61*14675a02SAndroid Build Coastguard Worker    return plan_pb2.CheckpointOp(
62*14675a02SAndroid Build Coastguard Worker        saver_def=tf.compat.v1.train.SaverDef(
63*14675a02SAndroid Build Coastguard Worker            filename_tensor_name=filename_op.name,
64*14675a02SAndroid Build Coastguard Worker            restore_op_name=restore_op.name,
65*14675a02SAndroid Build Coastguard Worker            save_tensor_name=save_op.name,
66*14675a02SAndroid Build Coastguard Worker            version=tf.compat.v1.train.SaverDef.V1,
67*14675a02SAndroid Build Coastguard Worker        ),
68*14675a02SAndroid Build Coastguard Worker        before_restore_op=before_restore.name,
69*14675a02SAndroid Build Coastguard Worker        after_restore_op=after_restore.name,
70*14675a02SAndroid Build Coastguard Worker        before_save_op=before_save.name,
71*14675a02SAndroid Build Coastguard Worker        after_save_op=after_save.name,
72*14675a02SAndroid Build Coastguard Worker        session_token_tensor_name=session_token_tensor_name,
73*14675a02SAndroid Build Coastguard Worker    )
74*14675a02SAndroid Build Coastguard Worker
75*14675a02SAndroid Build Coastguard Worker  with tf.compat.v1.Graph().as_default() as client_graph:
76*14675a02SAndroid Build Coastguard Worker    tf.constant(0)
77*14675a02SAndroid Build Coastguard Worker
78*14675a02SAndroid Build Coastguard Worker  with tf.compat.v1.Graph().as_default() as server_graph:
79*14675a02SAndroid Build Coastguard Worker    # Initialization:
80*14675a02SAndroid Build Coastguard Worker    last_client_update = tf.Variable(0, dtype=tf.int32)
81*14675a02SAndroid Build Coastguard Worker    intermediate_acc = tf.Variable(0, dtype=tf.int32)
82*14675a02SAndroid Build Coastguard Worker    last_intermediate_update = tf.Variable(0, dtype=tf.int32)
83*14675a02SAndroid Build Coastguard Worker    final_acc = tf.Variable(0, dtype=tf.int32)
84*14675a02SAndroid Build Coastguard Worker    with tf.control_dependencies([
85*14675a02SAndroid Build Coastguard Worker        last_client_update.initializer, intermediate_acc.initializer,
86*14675a02SAndroid Build Coastguard Worker        last_intermediate_update.initializer, final_acc.initializer
87*14675a02SAndroid Build Coastguard Worker    ]):
88*14675a02SAndroid Build Coastguard Worker      phase_init_op = log_op('phase_init')
89*14675a02SAndroid Build Coastguard Worker
90*14675a02SAndroid Build Coastguard Worker    # Ops for Federated Select:
91*14675a02SAndroid Build Coastguard Worker    select_fn_initialize_op = log_op('slices/initialize')
92*14675a02SAndroid Build Coastguard Worker    select_fn_server_vals = [
93*14675a02SAndroid Build Coastguard Worker        tf.constant(1234),
94*14675a02SAndroid Build Coastguard Worker        tf.constant('asdf'),
95*14675a02SAndroid Build Coastguard Worker        tf.constant([1, 2, 3]),
96*14675a02SAndroid Build Coastguard Worker    ]
97*14675a02SAndroid Build Coastguard Worker    select_fn_server_val_inputs = [
98*14675a02SAndroid Build Coastguard Worker        tf.compat.v1.placeholder(v.dtype) for v in select_fn_server_vals
99*14675a02SAndroid Build Coastguard Worker    ]
100*14675a02SAndroid Build Coastguard Worker    select_fn_key_input = tf.compat.v1.placeholder(tf.int32, shape=())
101*14675a02SAndroid Build Coastguard Worker    select_fn_filename_input = tf.compat.v1.placeholder(tf.string, shape=())
102*14675a02SAndroid Build Coastguard Worker    assertions = [
103*14675a02SAndroid Build Coastguard Worker        tf.debugging.assert_equal(placeholder, constant)
104*14675a02SAndroid Build Coastguard Worker        for placeholder, constant in zip(
105*14675a02SAndroid Build Coastguard Worker            select_fn_server_val_inputs, select_fn_server_vals
106*14675a02SAndroid Build Coastguard Worker        )
107*14675a02SAndroid Build Coastguard Worker    ]
108*14675a02SAndroid Build Coastguard Worker    with tf.control_dependencies([log_op('slices/save_slice')] + assertions):
109*14675a02SAndroid Build Coastguard Worker      select_fn_save_op = tf.io.write_file(
110*14675a02SAndroid Build Coastguard Worker          select_fn_filename_input, tf.strings.as_string(select_fn_key_input)
111*14675a02SAndroid Build Coastguard Worker      )
112*14675a02SAndroid Build Coastguard Worker    # Some tests disable passing the callback token; set `served_at_id` to '-'
113*14675a02SAndroid Build Coastguard Worker    # in that case.
114*14675a02SAndroid Build Coastguard Worker    callback_token = tf.compat.v1.placeholder_with_default('', shape=())
115*14675a02SAndroid Build Coastguard Worker    served_at_id = tf.cond(
116*14675a02SAndroid Build Coastguard Worker        tf.equal(callback_token, ''),
117*14675a02SAndroid Build Coastguard Worker        lambda: '-',
118*14675a02SAndroid Build Coastguard Worker        functools.partial(
119*14675a02SAndroid Build Coastguard Worker            serve_slices.serve_slices,
120*14675a02SAndroid Build Coastguard Worker            callback_token=callback_token,
121*14675a02SAndroid Build Coastguard Worker            server_val=select_fn_server_vals,
122*14675a02SAndroid Build Coastguard Worker            max_key=NUM_SLICES - 1,
123*14675a02SAndroid Build Coastguard Worker            select_fn_initialize_op=select_fn_initialize_op.name,
124*14675a02SAndroid Build Coastguard Worker            select_fn_server_val_input_tensor_names=[
125*14675a02SAndroid Build Coastguard Worker                v.name for v in select_fn_server_val_inputs
126*14675a02SAndroid Build Coastguard Worker            ],
127*14675a02SAndroid Build Coastguard Worker            select_fn_key_input_tensor_name=select_fn_key_input.name,
128*14675a02SAndroid Build Coastguard Worker            select_fn_filename_input_tensor_name=select_fn_filename_input.name,
129*14675a02SAndroid Build Coastguard Worker            select_fn_target_tensor_name=select_fn_save_op.name,
130*14675a02SAndroid Build Coastguard Worker        ),
131*14675a02SAndroid Build Coastguard Worker    )
132*14675a02SAndroid Build Coastguard Worker
133*14675a02SAndroid Build Coastguard Worker    # Ops for L2 Aggregation:
134*14675a02SAndroid Build Coastguard Worker    client_checkpoint_data = tf.Variable(
135*14675a02SAndroid Build Coastguard Worker        DEFAULT_INITIAL_CHECKPOINT, dtype=tf.string)
136*14675a02SAndroid Build Coastguard Worker
137*14675a02SAndroid Build Coastguard Worker    write_client_init_filename = tf.compat.v1.placeholder(tf.string, shape=())
138*14675a02SAndroid Build Coastguard Worker    client_checkpoint_data_value = tf.cond(
139*14675a02SAndroid Build Coastguard Worker        tf.compat.v1.is_variable_initialized(client_checkpoint_data),
140*14675a02SAndroid Build Coastguard Worker        client_checkpoint_data.read_value,
141*14675a02SAndroid Build Coastguard Worker        lambda: client_checkpoint_data.initial_value,
142*14675a02SAndroid Build Coastguard Worker    )
143*14675a02SAndroid Build Coastguard Worker    write_client_init_op = create_checkpoint_op(
144*14675a02SAndroid Build Coastguard Worker        'write_client_init',
145*14675a02SAndroid Build Coastguard Worker        write_client_init_filename,
146*14675a02SAndroid Build Coastguard Worker        save_op=tf.io.write_file(
147*14675a02SAndroid Build Coastguard Worker            write_client_init_filename,
148*14675a02SAndroid Build Coastguard Worker            tf.strings.join(
149*14675a02SAndroid Build Coastguard Worker                [client_checkpoint_data_value, served_at_id], separator=' '
150*14675a02SAndroid Build Coastguard Worker            ),
151*14675a02SAndroid Build Coastguard Worker        ),
152*14675a02SAndroid Build Coastguard Worker        session_token_tensor_name=callback_token.name,
153*14675a02SAndroid Build Coastguard Worker    )
154*14675a02SAndroid Build Coastguard Worker
155*14675a02SAndroid Build Coastguard Worker    read_intermediate_update_filename = tf.compat.v1.placeholder(
156*14675a02SAndroid Build Coastguard Worker        tf.string, shape=())
157*14675a02SAndroid Build Coastguard Worker    read_intermediate_update_op = create_checkpoint_op(
158*14675a02SAndroid Build Coastguard Worker        'read_intermediate_update',
159*14675a02SAndroid Build Coastguard Worker        read_intermediate_update_filename,
160*14675a02SAndroid Build Coastguard Worker        restore_op=last_intermediate_update.assign(
161*14675a02SAndroid Build Coastguard Worker            tf.raw_ops.Restore(
162*14675a02SAndroid Build Coastguard Worker                file_pattern=read_intermediate_update_filename,
163*14675a02SAndroid Build Coastguard Worker                tensor_name=INTERMEDIATE_TENSOR_NAME,
164*14675a02SAndroid Build Coastguard Worker                dt=tf.int32)))
165*14675a02SAndroid Build Coastguard Worker
166*14675a02SAndroid Build Coastguard Worker    with tf.control_dependencies([log_op('apply_aggregated_updates')]):
167*14675a02SAndroid Build Coastguard Worker      apply_aggregated_updates_op = final_acc.assign_add(
168*14675a02SAndroid Build Coastguard Worker          last_intermediate_update)
169*14675a02SAndroid Build Coastguard Worker
170*14675a02SAndroid Build Coastguard Worker    server_savepoint_filename = tf.compat.v1.placeholder(tf.string, shape=())
171*14675a02SAndroid Build Coastguard Worker    server_savepoint_op = create_checkpoint_op(
172*14675a02SAndroid Build Coastguard Worker        'server_savepoint',
173*14675a02SAndroid Build Coastguard Worker        server_savepoint_filename,
174*14675a02SAndroid Build Coastguard Worker        save_op=tf.raw_ops.Save(
175*14675a02SAndroid Build Coastguard Worker            filename=server_savepoint_filename,
176*14675a02SAndroid Build Coastguard Worker            tensor_names=[FINAL_TENSOR_NAME],
177*14675a02SAndroid Build Coastguard Worker            data=[final_acc]),
178*14675a02SAndroid Build Coastguard Worker        restore_op=client_checkpoint_data.assign(
179*14675a02SAndroid Build Coastguard Worker            tf.raw_ops.Restore(
180*14675a02SAndroid Build Coastguard Worker                file_pattern=server_savepoint_filename,
181*14675a02SAndroid Build Coastguard Worker                tensor_name=CHECKPOINT_TENSOR_NAME,
182*14675a02SAndroid Build Coastguard Worker                dt=tf.string)))
183*14675a02SAndroid Build Coastguard Worker
184*14675a02SAndroid Build Coastguard Worker  config_proto = tf.compat.v1.ConfigProto(operation_timeout_in_ms=1234)
185*14675a02SAndroid Build Coastguard Worker
186*14675a02SAndroid Build Coastguard Worker  plan = plan_pb2.Plan(
187*14675a02SAndroid Build Coastguard Worker      phase=[
188*14675a02SAndroid Build Coastguard Worker          plan_pb2.Plan.Phase(
189*14675a02SAndroid Build Coastguard Worker              client_phase=plan_pb2.ClientPhase(name='ClientPhase'),
190*14675a02SAndroid Build Coastguard Worker              server_phase=plan_pb2.ServerPhase(
191*14675a02SAndroid Build Coastguard Worker                  phase_init_op=phase_init_op.name,
192*14675a02SAndroid Build Coastguard Worker                  write_client_init=write_client_init_op,
193*14675a02SAndroid Build Coastguard Worker                  read_intermediate_update=read_intermediate_update_op,
194*14675a02SAndroid Build Coastguard Worker                  apply_aggregrated_updates_op=(
195*14675a02SAndroid Build Coastguard Worker                      apply_aggregated_updates_op.name)))
196*14675a02SAndroid Build Coastguard Worker      ],
197*14675a02SAndroid Build Coastguard Worker      server_savepoint=server_savepoint_op,
198*14675a02SAndroid Build Coastguard Worker      client_tflite_graph_bytes=b'tflite-graph',
199*14675a02SAndroid Build Coastguard Worker      version=1)
200*14675a02SAndroid Build Coastguard Worker  plan.client_graph_bytes.Pack(client_graph.as_graph_def())
201*14675a02SAndroid Build Coastguard Worker  plan.server_graph_bytes.Pack(server_graph.as_graph_def())
202*14675a02SAndroid Build Coastguard Worker  plan.tensorflow_config_proto.Pack(config_proto)
203*14675a02SAndroid Build Coastguard Worker  return plan
204*14675a02SAndroid Build Coastguard Worker
205*14675a02SAndroid Build Coastguard Worker
206*14675a02SAndroid Build Coastguard Workerdef create_checkpoint(tensor_name=b'test'):
207*14675a02SAndroid Build Coastguard Worker  """Creates a test initial checkpoint."""
208*14675a02SAndroid Build Coastguard Worker  return test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: tensor_name})
209*14675a02SAndroid Build Coastguard Worker
210*14675a02SAndroid Build Coastguard Worker
211*14675a02SAndroid Build Coastguard Workerclass PlanUtilsTest(absltest.TestCase):
212*14675a02SAndroid Build Coastguard Worker
213*14675a02SAndroid Build Coastguard Worker  def test_session_enter_exit(self):
214*14675a02SAndroid Build Coastguard Worker    self.assertIsNone(tf.compat.v1.get_default_session())
215*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(create_plan(), create_checkpoint()):
216*14675a02SAndroid Build Coastguard Worker      self.assertIsNotNone(tf.compat.v1.get_default_session())
217*14675a02SAndroid Build Coastguard Worker    self.assertIsNone(tf.compat.v1.get_default_session())
218*14675a02SAndroid Build Coastguard Worker
219*14675a02SAndroid Build Coastguard Worker  def test_session_without_phase(self):
220*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
221*14675a02SAndroid Build Coastguard Worker    plan.ClearField('phase')
222*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(ValueError):
223*14675a02SAndroid Build Coastguard Worker      plan_utils.Session(plan, create_checkpoint())
224*14675a02SAndroid Build Coastguard Worker
225*14675a02SAndroid Build Coastguard Worker  def test_session_without_server_phase(self):
226*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
227*14675a02SAndroid Build Coastguard Worker    plan.phase[0].ClearField('server_phase')
228*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(ValueError):
229*14675a02SAndroid Build Coastguard Worker      plan_utils.Session(plan, create_checkpoint())
230*14675a02SAndroid Build Coastguard Worker
231*14675a02SAndroid Build Coastguard Worker  def test_session_with_multiple_phases(self):
232*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
233*14675a02SAndroid Build Coastguard Worker    plan.phase.append(plan.phase[0])
234*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(ValueError):
235*14675a02SAndroid Build Coastguard Worker      plan_utils.Session(plan, create_checkpoint())
236*14675a02SAndroid Build Coastguard Worker
237*14675a02SAndroid Build Coastguard Worker  def test_session_client_plan(self):
238*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
239*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(plan, create_checkpoint()) as session:
240*14675a02SAndroid Build Coastguard Worker      self.assertEqual(
241*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
242*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan(
243*14675a02SAndroid Build Coastguard Worker              phase=plan.phase[0].client_phase,
244*14675a02SAndroid Build Coastguard Worker              graph=plan.client_graph_bytes.value,
245*14675a02SAndroid Build Coastguard Worker              tflite_graph=plan.client_tflite_graph_bytes,
246*14675a02SAndroid Build Coastguard Worker              tensorflow_config_proto=plan.tensorflow_config_proto))
247*14675a02SAndroid Build Coastguard Worker
248*14675a02SAndroid Build Coastguard Worker  def test_session_client_plan_without_tensorflow_config(self):
249*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
250*14675a02SAndroid Build Coastguard Worker    plan.ClearField('tensorflow_config_proto')
251*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(plan, create_checkpoint()) as session:
252*14675a02SAndroid Build Coastguard Worker      self.assertEqual(
253*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
254*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan(
255*14675a02SAndroid Build Coastguard Worker              phase=plan.phase[0].client_phase,
256*14675a02SAndroid Build Coastguard Worker              graph=plan.client_graph_bytes.value,
257*14675a02SAndroid Build Coastguard Worker              tflite_graph=plan.client_tflite_graph_bytes))
258*14675a02SAndroid Build Coastguard Worker
259*14675a02SAndroid Build Coastguard Worker  def test_session_client_plan_without_tflite_graph(self):
260*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
261*14675a02SAndroid Build Coastguard Worker    plan.ClearField('client_tflite_graph_bytes')
262*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(plan, create_checkpoint()) as session:
263*14675a02SAndroid Build Coastguard Worker      self.assertEqual(
264*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
265*14675a02SAndroid Build Coastguard Worker          plan_pb2.ClientOnlyPlan(
266*14675a02SAndroid Build Coastguard Worker              phase=plan.phase[0].client_phase,
267*14675a02SAndroid Build Coastguard Worker              graph=plan.client_graph_bytes.value,
268*14675a02SAndroid Build Coastguard Worker              tensorflow_config_proto=plan.tensorflow_config_proto))
269*14675a02SAndroid Build Coastguard Worker
270*14675a02SAndroid Build Coastguard Worker  def test_session_client_checkpoint(self):
271*14675a02SAndroid Build Coastguard Worker    expected = b'test-client-checkpoint'
272*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(
273*14675a02SAndroid Build Coastguard Worker        create_plan(),
274*14675a02SAndroid Build Coastguard Worker        test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: expected
275*14675a02SAndroid Build Coastguard Worker                                     })) as session:
276*14675a02SAndroid Build Coastguard Worker      self.assertEqual(
277*14675a02SAndroid Build Coastguard Worker          session.client_checkpoint,
278*14675a02SAndroid Build Coastguard Worker          expected + b' ' + next(iter(session.slices)).encode(),
279*14675a02SAndroid Build Coastguard Worker      )
280*14675a02SAndroid Build Coastguard Worker
281*14675a02SAndroid Build Coastguard Worker  def test_session_client_checkpoint_without_server_savepoint(self):
282*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
283*14675a02SAndroid Build Coastguard Worker    # If server_savepoint isn't set, the checkpoint shouldn't be loaded.
284*14675a02SAndroid Build Coastguard Worker    plan.ClearField('server_savepoint')
285*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(plan, create_checkpoint()) as session:
286*14675a02SAndroid Build Coastguard Worker      self.assertStartsWith(
287*14675a02SAndroid Build Coastguard Worker          session.client_checkpoint, DEFAULT_INITIAL_CHECKPOINT + b' '
288*14675a02SAndroid Build Coastguard Worker      )
289*14675a02SAndroid Build Coastguard Worker
290*14675a02SAndroid Build Coastguard Worker  def test_session_finalize(self):
291*14675a02SAndroid Build Coastguard Worker    with tempfile.NamedTemporaryFile('r') as tmpfile:
292*14675a02SAndroid Build Coastguard Worker      with plan_utils.Session(create_plan(tmpfile.name),
293*14675a02SAndroid Build Coastguard Worker                              create_checkpoint()) as session:
294*14675a02SAndroid Build Coastguard Worker        checkpoint = session.finalize(
295*14675a02SAndroid Build Coastguard Worker            test_utils.create_checkpoint({INTERMEDIATE_TENSOR_NAME: 3}))
296*14675a02SAndroid Build Coastguard Worker      self.assertSequenceEqual(
297*14675a02SAndroid Build Coastguard Worker          tmpfile.read().splitlines(),
298*14675a02SAndroid Build Coastguard Worker          [
299*14675a02SAndroid Build Coastguard Worker              'server_savepoint/before_restore',
300*14675a02SAndroid Build Coastguard Worker              'server_savepoint/restore',
301*14675a02SAndroid Build Coastguard Worker              'server_savepoint/after_restore',
302*14675a02SAndroid Build Coastguard Worker              'phase_init',
303*14675a02SAndroid Build Coastguard Worker              'write_client_init/before_save',
304*14675a02SAndroid Build Coastguard Worker              'write_client_init/save',
305*14675a02SAndroid Build Coastguard Worker              'write_client_init/after_save',
306*14675a02SAndroid Build Coastguard Worker          ]
307*14675a02SAndroid Build Coastguard Worker          + ['slices/initialize', 'slices/save_slice'] * NUM_SLICES
308*14675a02SAndroid Build Coastguard Worker          + [
309*14675a02SAndroid Build Coastguard Worker              'read_intermediate_update/before_restore',
310*14675a02SAndroid Build Coastguard Worker              'read_intermediate_update/restore',
311*14675a02SAndroid Build Coastguard Worker              'read_intermediate_update/after_restore',
312*14675a02SAndroid Build Coastguard Worker              'apply_aggregated_updates',
313*14675a02SAndroid Build Coastguard Worker              'server_savepoint/before_save',
314*14675a02SAndroid Build Coastguard Worker              'server_savepoint/save',
315*14675a02SAndroid Build Coastguard Worker              'server_savepoint/after_save',
316*14675a02SAndroid Build Coastguard Worker          ],
317*14675a02SAndroid Build Coastguard Worker      )
318*14675a02SAndroid Build Coastguard Worker
319*14675a02SAndroid Build Coastguard Worker    result = test_utils.read_tensor_from_checkpoint(checkpoint,
320*14675a02SAndroid Build Coastguard Worker                                                    FINAL_TENSOR_NAME, tf.int32)
321*14675a02SAndroid Build Coastguard Worker    # The value should be propagated from the intermediate aggregate.
322*14675a02SAndroid Build Coastguard Worker    self.assertEqual(result, 3)
323*14675a02SAndroid Build Coastguard Worker
324*14675a02SAndroid Build Coastguard Worker  def test_session_with_tensorflow_error(self):
325*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
326*14675a02SAndroid Build Coastguard Worker    plan.phase[0].server_phase.phase_init_op = 'does-not-exist'
327*14675a02SAndroid Build Coastguard Worker    with self.assertRaises(ValueError):
328*14675a02SAndroid Build Coastguard Worker      plan_utils.Session(plan, create_checkpoint())
329*14675a02SAndroid Build Coastguard Worker
330*14675a02SAndroid Build Coastguard Worker  def test_session_slices(self):
331*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(create_plan(), create_checkpoint()) as session:
332*14675a02SAndroid Build Coastguard Worker      # The served_at_id should match the value in the client checkpoint.
333*14675a02SAndroid Build Coastguard Worker      served_at_id = session.client_checkpoint.split(b' ')[1].decode()
334*14675a02SAndroid Build Coastguard Worker      self.assertSameElements(session.slices.keys(), [served_at_id])
335*14675a02SAndroid Build Coastguard Worker      self.assertListEqual(
336*14675a02SAndroid Build Coastguard Worker          session.slices[served_at_id],
337*14675a02SAndroid Build Coastguard Worker          [str(i).encode() for i in range(NUM_SLICES)],
338*14675a02SAndroid Build Coastguard Worker      )
339*14675a02SAndroid Build Coastguard Worker
340*14675a02SAndroid Build Coastguard Worker  def test_session_without_slices(self):
341*14675a02SAndroid Build Coastguard Worker    plan = create_plan()
342*14675a02SAndroid Build Coastguard Worker    plan.phase[0].server_phase.write_client_init.ClearField(
343*14675a02SAndroid Build Coastguard Worker        'session_token_tensor_name'
344*14675a02SAndroid Build Coastguard Worker    )
345*14675a02SAndroid Build Coastguard Worker    with plan_utils.Session(plan, create_checkpoint()) as session:
346*14675a02SAndroid Build Coastguard Worker      self.assertEmpty(session.slices)
347*14675a02SAndroid Build Coastguard Worker
348*14675a02SAndroid Build Coastguard Worker
349*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__':
350*14675a02SAndroid Build Coastguard Worker  absltest.main()
351