xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/checkpoint_utils_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for checkpoint_utils."""
15
16import collections
17import os
18import typing
19from typing import Any
20
21from absl.testing import absltest
22from absl.testing import parameterized
23
24import numpy as np
25import tensorflow as tf
26import tensorflow_federated as tff
27
28from google.protobuf import any_pb2
29from fcp.artifact_building import checkpoint_utils
30from fcp.protos import plan_pb2
31
32
33class CheckpointUtilsTest(tf.test.TestCase, parameterized.TestCase):
34
35  def _assert_variable_functionality(
36      self, test_vars: list[tf.Variable], test_value_to_save: Any = 10
37  ):
38    self.assertIsInstance(test_vars, list)
39    initializer = tf.compat.v1.global_variables_initializer()
40    for test_variable in test_vars:
41      with self.test_session() as session:
42        session.run(initializer)
43        session.run(test_variable.assign(test_value_to_save))
44        self.assertEqual(session.run(test_variable), test_value_to_save)
45
46  def test_create_server_checkpoint_vars_and_savepoint_succeeds_state_vars(
47      self,
48  ):
49    with tf.Graph().as_default():
50      state_vars, _, _, savepoint = (
51          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
52              server_state_type=tff.to_type([('foo1', tf.int32)]),
53              server_metrics_type=tff.to_type([('bar2', tf.int32)]),
54              write_metrics_to_checkpoint=True,
55          )
56      )
57      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
58      self._assert_variable_functionality(state_vars)
59
60  def test_create_server_checkpoint_vars_and_savepoint_succeeds_metadata_vars(
61      self,
62  ):
63    def additional_checkpoint_metadata_var_fn(
64        state_vars, metrics_vars, write_metrics_to_checkpoint
65    ):
66      del state_vars, metrics_vars, write_metrics_to_checkpoint
67      return [tf.Variable(initial_value=b'dog', name='metadata')]
68
69    with tf.Graph().as_default():
70      _, _, metadata_vars, savepoint = (
71          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
72              server_state_type=tff.to_type([('foo3', tf.int32)]),
73              server_metrics_type=tff.to_type([('bar1', tf.int32)]),
74              additional_checkpoint_metadata_var_fn=(
75                  additional_checkpoint_metadata_var_fn
76              ),
77              write_metrics_to_checkpoint=True,
78          )
79      )
80      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
81      self._assert_variable_functionality(
82          metadata_vars, test_value_to_save=b'cat'
83      )
84
85  def test_create_server_checkpoint_vars_and_savepoint_succeeds_metrics_vars(
86      self,
87  ):
88    with tf.Graph().as_default():
89      _, metrics_vars, _, savepoint = (
90          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
91              server_state_type=tff.to_type([('foo2', tf.int32)]),
92              server_metrics_type=tff.to_type([('bar3', tf.int32)]),
93              write_metrics_to_checkpoint=True,
94          )
95      )
96      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
97      self._assert_variable_functionality(metrics_vars)
98
99  def test_tff_type_to_dtype_list_as_expected(self):
100    tff_type = tff.FederatedType(
101        tff.StructType([('foo', tf.int32), ('bar', tf.string)]), tff.SERVER
102    )
103    expected_dtype_list = [tf.int32, tf.string]
104    self.assertEqual(
105        checkpoint_utils.tff_type_to_dtype_list(tff_type), expected_dtype_list
106    )
107
108  def test_tff_type_to_dtype_list_type_error(self):
109    list_type = [tf.int32, tf.string]
110    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
111      checkpoint_utils.tff_type_to_dtype_list(list_type)
112
113  def test_tff_type_to_tensor_spec_list_as_expected(self):
114    tff_type = tff.FederatedType(
115        tff.StructType(
116            [('foo', tf.int32), ('bar', tff.TensorType(tf.string, shape=[1]))]
117        ),
118        tff.SERVER,
119    )
120    expected_tensor_spec_list = [
121        tf.TensorSpec([], tf.int32),
122        tf.TensorSpec([1], tf.string),
123    ]
124    self.assertEqual(
125        checkpoint_utils.tff_type_to_tensor_spec_list(tff_type),
126        expected_tensor_spec_list,
127    )
128
129  def test_tff_type_to_tensor_spec_list_type_error(self):
130    list_type = [tf.int32, tf.string]
131    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
132      checkpoint_utils.tff_type_to_tensor_spec_list(list_type)
133
134  def test_pack_tff_value_with_tensors_as_expected(self):
135    tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
136    value_list = [
137        tf.constant(1, dtype=tf.int32),
138        tf.constant('bla', dtype=tf.string),
139    ]
140    expected_packed_structure = tff.structure.Struct([
141        ('foo', tf.constant(1, dtype=tf.int32)),
142        ('bar', tf.constant('bla', dtype=tf.string)),
143    ])
144    self.assertEqual(
145        checkpoint_utils.pack_tff_value(tff_type, value_list),
146        expected_packed_structure,
147    )
148
149  def test_pack_tff_value_with_federated_server_tensors_as_expected(self):
150    # This test must create a type that has `StructType`s nested under the
151    # `FederatedType` to cover testing that tff.structure.pack_sequence_as
152    # package correctly descends through the entire type tree.
153    tff_type = tff.to_type(
154        collections.OrderedDict(
155            foo=tff.FederatedType(tf.int32, tff.SERVER),
156            # Some arbitrarily deep nesting to ensure full traversals.
157            bar=tff.FederatedType([(), ([tf.int32], tf.int32)], tff.SERVER),
158        )
159    )
160    value_list = [tf.constant(1), tf.constant(2), tf.constant(3)]
161    expected_packed_structure = tff.structure.from_container(
162        collections.OrderedDict(
163            foo=tf.constant(1), bar=[(), ([tf.constant(2)], tf.constant(3))]
164        ),
165        recursive=True,
166    )
167    self.assertEqual(
168        checkpoint_utils.pack_tff_value(tff_type, value_list),
169        expected_packed_structure,
170    )
171
172  def test_pack_tff_value_with_unmatched_input_sizes(self):
173    tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
174    value_list = [tf.constant(1, dtype=tf.int32)]
175    with self.assertRaises(ValueError):
176      checkpoint_utils.pack_tff_value(tff_type, value_list)
177
178  def test_pack_tff_value_with_tff_type_error(self):
179    @tff.federated_computation
180    def fed_comp():
181      return tff.federated_value(0, tff.SERVER)
182
183    tff_function_type = fed_comp.type_signature
184    value_list = [tf.constant(1, dtype=tf.int32)]
185    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
186      checkpoint_utils.pack_tff_value(tff_function_type, value_list)
187
188  def test_variable_names_from_structure_with_tensor_and_no_name(self):
189    names = checkpoint_utils.variable_names_from_structure(tf.constant(1.0))
190    self.assertEqual(names, ['v'])
191
192  def test_variable_names_from_structure_with_tensor(self):
193    names = checkpoint_utils.variable_names_from_structure(
194        tf.constant(1.0), 'test_name'
195    )
196    self.assertEqual(names, ['test_name'])
197
198  def test_variable_names_from_structure_with_named_tuple_type_and_no_name(
199      self,
200  ):
201    names = checkpoint_utils.variable_names_from_structure(
202        tff.structure.Struct([
203            ('a', tf.constant(1.0)),
204            (
205                'b',
206                tff.structure.Struct(
207                    [('c', tf.constant(True)), ('d', tf.constant(0.0))]
208                ),
209            ),
210        ])
211    )
212    self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d'])
213
214  def test_variable_names_from_structure_with_named_struct(self):
215    names = checkpoint_utils.variable_names_from_structure(
216        tff.structure.Struct([
217            ('a', tf.constant(1.0)),
218            (
219                'b',
220                tff.structure.Struct(
221                    [('c', tf.constant(True)), ('d', tf.constant(0.0))]
222                ),
223            ),
224        ]),
225        'test_name',
226    )
227    self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d'])
228
229  def test_variable_names_from_structure_with_named_tuple_type_no_name_field(
230      self,
231  ):
232    names = checkpoint_utils.variable_names_from_structure(
233        tff.structure.Struct([
234            (None, tf.constant(1.0)),
235            (
236                'b',
237                tff.structure.Struct(
238                    [(None, tf.constant(False)), ('d', tf.constant(0.0))]
239                ),
240            ),
241        ]),
242        'test_name',
243    )
244    self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d'])
245
246  def test_save_tf_tensor_to_checkpoint_as_expected(self):
247    temp_dir = self.create_tempdir()
248    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
249
250    tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
251
252    checkpoint_utils.save_tff_structure_to_checkpoint(
253        tensor, ['v'], output_checkpoint_path=output_checkpoint_path
254    )
255
256    reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
257    var_to_shape_map = reader.get_variable_to_shape_map()
258    self.assertLen(var_to_shape_map, 1)
259    self.assertIn('v', var_to_shape_map)
260    np.testing.assert_almost_equal(
261        [[1.0, 2.0], [3.0, 4.0]], reader.get_tensor('v')
262    )
263
264  def test_save_tff_struct_to_checkpoint_as_expected(self):
265    temp_dir = self.create_tempdir()
266    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
267
268    struct = tff.structure.Struct([
269        ('foo', tf.constant(1, dtype=tf.int32)),
270        ('bar', tf.constant('bla', dtype=tf.string)),
271    ])
272
273    checkpoint_utils.save_tff_structure_to_checkpoint(
274        struct,
275        ordered_var_names=['v/foo', 'v/bar'],
276        output_checkpoint_path=output_checkpoint_path,
277    )
278
279    reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
280    var_to_shape_map = reader.get_variable_to_shape_map()
281    self.assertLen(var_to_shape_map, 2)
282    self.assertIn('v/foo', var_to_shape_map)
283    self.assertIn('v/bar', var_to_shape_map)
284    self.assertEqual(1, reader.get_tensor('v/foo'))
285    self.assertEqual(b'bla', reader.get_tensor('v/bar'))
286
287  def test_save_tff_struct_to_checkpoint_fails_if_wrong_num_var_names(self):
288    temp_dir = self.create_tempdir()
289    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')
290
291    struct = tff.structure.Struct([
292        ('foo', tf.constant(1, dtype=tf.int32)),
293        ('bar', tf.constant('bla', dtype=tf.string)),
294    ])
295
296    with self.assertRaisesRegex(ValueError, 'does not match the number'):
297      checkpoint_utils.save_tff_structure_to_checkpoint(
298          struct,
299          ordered_var_names=['v/foo'],
300          output_checkpoint_path=output_checkpoint_path,
301      )
302
303  @parameterized.named_parameters(
304      ('tf.tensor', tf.constant(1.0)),
305      ('ndarray', np.asarray([1.0, 2.0, 3.0])),
306      ('npnumber', np.float64(1.0)),
307      ('int', 1),
308      ('float', 1.0),
309      ('str', 'test'),
310      ('bytes', b'test'),
311  )
312  def test_is_allowed(self, structure):
313    self.assertTrue(checkpoint_utils.is_structure_of_allowed_types(structure))
314
315  @parameterized.named_parameters(
316      ('function', lambda x: x),
317      ('any_proto', any_pb2.Any()),
318  )
319  def test_is_not_allowed(self, structure):
320    self.assertFalse(checkpoint_utils.is_structure_of_allowed_types(structure))
321
322
323class CreateDeterministicSaverTest(tf.test.TestCase):
324
325  def test_failure_unknown_type(self):
326    with self.assertRaisesRegex(ValueError, 'Do not know how to make'):
327      # Using a cast in case the test is being run with static type checking.
328      checkpoint_utils.create_deterministic_saver(
329          typing.cast(list[tf.Variable], 0)
330      )
331
332  def test_creates_saver_for_list(self):
333    with tf.Graph().as_default() as g:
334      saver = checkpoint_utils.create_deterministic_saver([
335          tf.Variable(initial_value=1.0, name='z'),
336          tf.Variable(initial_value=2.0, name='x'),
337          tf.Variable(initial_value=3.0, name='y'),
338      ])
339    self.assertIsInstance(saver, tf.compat.v1.train.Saver)
340    test_filepath = self.create_tempfile().full_path
341    with tf.compat.v1.Session(graph=g) as sess:
342      sess.run(tf.compat.v1.global_variables_initializer())
343      saver.save(sess, save_path=test_filepath)
344    variable_specs = tf.train.list_variables(test_filepath)
345    self.assertEqual([('x', []), ('y', []), ('z', [])], variable_specs)
346
347  def test_creates_saver_for_dict(self):
348    with tf.Graph().as_default() as g:
349      saver = checkpoint_utils.create_deterministic_saver({
350          'foo': tf.Variable(initial_value=1.0, name='z'),
351          'baz': tf.Variable(initial_value=2.0, name='x'),
352          'bar': tf.Variable(initial_value=3.0, name='y'),
353      })
354    self.assertIsInstance(saver, tf.compat.v1.train.Saver)
355    test_filepath = self.create_tempfile().full_path
356    with tf.compat.v1.Session(graph=g) as sess:
357      sess.run(tf.compat.v1.global_variables_initializer())
358      saver.save(sess, save_path=test_filepath)
359    variable_specs = tf.train.list_variables(test_filepath)
360    self.assertEqual([('bar', []), ('baz', []), ('foo', [])], variable_specs)
361
362
363if __name__ == '__main__':
364  absltest.main()
365