1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for checkpoint_utils.""" 15 16import collections 17import os 18import typing 19from typing import Any 20 21from absl.testing import absltest 22from absl.testing import parameterized 23 24import numpy as np 25import tensorflow as tf 26import tensorflow_federated as tff 27 28from google.protobuf import any_pb2 29from fcp.artifact_building import checkpoint_utils 30from fcp.protos import plan_pb2 31 32 33class CheckpointUtilsTest(tf.test.TestCase, parameterized.TestCase): 34 35 def _assert_variable_functionality( 36 self, test_vars: list[tf.Variable], test_value_to_save: Any = 10 37 ): 38 self.assertIsInstance(test_vars, list) 39 initializer = tf.compat.v1.global_variables_initializer() 40 for test_variable in test_vars: 41 with self.test_session() as session: 42 session.run(initializer) 43 session.run(test_variable.assign(test_value_to_save)) 44 self.assertEqual(session.run(test_variable), test_value_to_save) 45 46 def test_create_server_checkpoint_vars_and_savepoint_succeeds_state_vars( 47 self, 48 ): 49 with tf.Graph().as_default(): 50 state_vars, _, _, savepoint = ( 51 checkpoint_utils.create_server_checkpoint_vars_and_savepoint( 52 server_state_type=tff.to_type([('foo1', tf.int32)]), 53 server_metrics_type=tff.to_type([('bar2', tf.int32)]), 54 write_metrics_to_checkpoint=True, 55 ) 56 ) 57 self.assertIsInstance(savepoint, plan_pb2.CheckpointOp) 58 self._assert_variable_functionality(state_vars) 59 60 def test_create_server_checkpoint_vars_and_savepoint_succeeds_metadata_vars( 61 self, 62 ): 63 def additional_checkpoint_metadata_var_fn( 64 state_vars, metrics_vars, write_metrics_to_checkpoint 65 ): 66 del state_vars, metrics_vars, write_metrics_to_checkpoint 67 return [tf.Variable(initial_value=b'dog', name='metadata')] 68 69 with tf.Graph().as_default(): 70 _, _, metadata_vars, savepoint = ( 71 checkpoint_utils.create_server_checkpoint_vars_and_savepoint( 72 server_state_type=tff.to_type([('foo3', tf.int32)]), 73 server_metrics_type=tff.to_type([('bar1', tf.int32)]), 74 additional_checkpoint_metadata_var_fn=( 75 additional_checkpoint_metadata_var_fn 76 ), 77 write_metrics_to_checkpoint=True, 78 ) 79 ) 80 self.assertIsInstance(savepoint, plan_pb2.CheckpointOp) 81 self._assert_variable_functionality( 82 metadata_vars, test_value_to_save=b'cat' 83 ) 84 85 def test_create_server_checkpoint_vars_and_savepoint_succeeds_metrics_vars( 86 self, 87 ): 88 with tf.Graph().as_default(): 89 _, metrics_vars, _, savepoint = ( 90 checkpoint_utils.create_server_checkpoint_vars_and_savepoint( 91 server_state_type=tff.to_type([('foo2', tf.int32)]), 92 server_metrics_type=tff.to_type([('bar3', tf.int32)]), 93 write_metrics_to_checkpoint=True, 94 ) 95 ) 96 self.assertIsInstance(savepoint, plan_pb2.CheckpointOp) 97 self._assert_variable_functionality(metrics_vars) 98 99 def test_tff_type_to_dtype_list_as_expected(self): 100 tff_type = tff.FederatedType( 101 tff.StructType([('foo', tf.int32), ('bar', tf.string)]), tff.SERVER 102 ) 103 expected_dtype_list = [tf.int32, tf.string] 104 self.assertEqual( 105 checkpoint_utils.tff_type_to_dtype_list(tff_type), expected_dtype_list 106 ) 107 108 def test_tff_type_to_dtype_list_type_error(self): 109 list_type = [tf.int32, tf.string] 110 with self.assertRaisesRegex(TypeError, 'to be an instance of type'): 111 checkpoint_utils.tff_type_to_dtype_list(list_type) 112 113 def test_tff_type_to_tensor_spec_list_as_expected(self): 114 tff_type = tff.FederatedType( 115 tff.StructType( 116 [('foo', tf.int32), ('bar', tff.TensorType(tf.string, shape=[1]))] 117 ), 118 tff.SERVER, 119 ) 120 expected_tensor_spec_list = [ 121 tf.TensorSpec([], tf.int32), 122 tf.TensorSpec([1], tf.string), 123 ] 124 self.assertEqual( 125 checkpoint_utils.tff_type_to_tensor_spec_list(tff_type), 126 expected_tensor_spec_list, 127 ) 128 129 def test_tff_type_to_tensor_spec_list_type_error(self): 130 list_type = [tf.int32, tf.string] 131 with self.assertRaisesRegex(TypeError, 'to be an instance of type'): 132 checkpoint_utils.tff_type_to_tensor_spec_list(list_type) 133 134 def test_pack_tff_value_with_tensors_as_expected(self): 135 tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)]) 136 value_list = [ 137 tf.constant(1, dtype=tf.int32), 138 tf.constant('bla', dtype=tf.string), 139 ] 140 expected_packed_structure = tff.structure.Struct([ 141 ('foo', tf.constant(1, dtype=tf.int32)), 142 ('bar', tf.constant('bla', dtype=tf.string)), 143 ]) 144 self.assertEqual( 145 checkpoint_utils.pack_tff_value(tff_type, value_list), 146 expected_packed_structure, 147 ) 148 149 def test_pack_tff_value_with_federated_server_tensors_as_expected(self): 150 # This test must create a type that has `StructType`s nested under the 151 # `FederatedType` to cover testing that tff.structure.pack_sequence_as 152 # package correctly descends through the entire type tree. 153 tff_type = tff.to_type( 154 collections.OrderedDict( 155 foo=tff.FederatedType(tf.int32, tff.SERVER), 156 # Some arbitrarily deep nesting to ensure full traversals. 157 bar=tff.FederatedType([(), ([tf.int32], tf.int32)], tff.SERVER), 158 ) 159 ) 160 value_list = [tf.constant(1), tf.constant(2), tf.constant(3)] 161 expected_packed_structure = tff.structure.from_container( 162 collections.OrderedDict( 163 foo=tf.constant(1), bar=[(), ([tf.constant(2)], tf.constant(3))] 164 ), 165 recursive=True, 166 ) 167 self.assertEqual( 168 checkpoint_utils.pack_tff_value(tff_type, value_list), 169 expected_packed_structure, 170 ) 171 172 def test_pack_tff_value_with_unmatched_input_sizes(self): 173 tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)]) 174 value_list = [tf.constant(1, dtype=tf.int32)] 175 with self.assertRaises(ValueError): 176 checkpoint_utils.pack_tff_value(tff_type, value_list) 177 178 def test_pack_tff_value_with_tff_type_error(self): 179 @tff.federated_computation 180 def fed_comp(): 181 return tff.federated_value(0, tff.SERVER) 182 183 tff_function_type = fed_comp.type_signature 184 value_list = [tf.constant(1, dtype=tf.int32)] 185 with self.assertRaisesRegex(TypeError, 'to be an instance of type'): 186 checkpoint_utils.pack_tff_value(tff_function_type, value_list) 187 188 def test_variable_names_from_structure_with_tensor_and_no_name(self): 189 names = checkpoint_utils.variable_names_from_structure(tf.constant(1.0)) 190 self.assertEqual(names, ['v']) 191 192 def test_variable_names_from_structure_with_tensor(self): 193 names = checkpoint_utils.variable_names_from_structure( 194 tf.constant(1.0), 'test_name' 195 ) 196 self.assertEqual(names, ['test_name']) 197 198 def test_variable_names_from_structure_with_named_tuple_type_and_no_name( 199 self, 200 ): 201 names = checkpoint_utils.variable_names_from_structure( 202 tff.structure.Struct([ 203 ('a', tf.constant(1.0)), 204 ( 205 'b', 206 tff.structure.Struct( 207 [('c', tf.constant(True)), ('d', tf.constant(0.0))] 208 ), 209 ), 210 ]) 211 ) 212 self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d']) 213 214 def test_variable_names_from_structure_with_named_struct(self): 215 names = checkpoint_utils.variable_names_from_structure( 216 tff.structure.Struct([ 217 ('a', tf.constant(1.0)), 218 ( 219 'b', 220 tff.structure.Struct( 221 [('c', tf.constant(True)), ('d', tf.constant(0.0))] 222 ), 223 ), 224 ]), 225 'test_name', 226 ) 227 self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d']) 228 229 def test_variable_names_from_structure_with_named_tuple_type_no_name_field( 230 self, 231 ): 232 names = checkpoint_utils.variable_names_from_structure( 233 tff.structure.Struct([ 234 (None, tf.constant(1.0)), 235 ( 236 'b', 237 tff.structure.Struct( 238 [(None, tf.constant(False)), ('d', tf.constant(0.0))] 239 ), 240 ), 241 ]), 242 'test_name', 243 ) 244 self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d']) 245 246 def test_save_tf_tensor_to_checkpoint_as_expected(self): 247 temp_dir = self.create_tempdir() 248 output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt') 249 250 tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 251 252 checkpoint_utils.save_tff_structure_to_checkpoint( 253 tensor, ['v'], output_checkpoint_path=output_checkpoint_path 254 ) 255 256 reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path) 257 var_to_shape_map = reader.get_variable_to_shape_map() 258 self.assertLen(var_to_shape_map, 1) 259 self.assertIn('v', var_to_shape_map) 260 np.testing.assert_almost_equal( 261 [[1.0, 2.0], [3.0, 4.0]], reader.get_tensor('v') 262 ) 263 264 def test_save_tff_struct_to_checkpoint_as_expected(self): 265 temp_dir = self.create_tempdir() 266 output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt') 267 268 struct = tff.structure.Struct([ 269 ('foo', tf.constant(1, dtype=tf.int32)), 270 ('bar', tf.constant('bla', dtype=tf.string)), 271 ]) 272 273 checkpoint_utils.save_tff_structure_to_checkpoint( 274 struct, 275 ordered_var_names=['v/foo', 'v/bar'], 276 output_checkpoint_path=output_checkpoint_path, 277 ) 278 279 reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path) 280 var_to_shape_map = reader.get_variable_to_shape_map() 281 self.assertLen(var_to_shape_map, 2) 282 self.assertIn('v/foo', var_to_shape_map) 283 self.assertIn('v/bar', var_to_shape_map) 284 self.assertEqual(1, reader.get_tensor('v/foo')) 285 self.assertEqual(b'bla', reader.get_tensor('v/bar')) 286 287 def test_save_tff_struct_to_checkpoint_fails_if_wrong_num_var_names(self): 288 temp_dir = self.create_tempdir() 289 output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt') 290 291 struct = tff.structure.Struct([ 292 ('foo', tf.constant(1, dtype=tf.int32)), 293 ('bar', tf.constant('bla', dtype=tf.string)), 294 ]) 295 296 with self.assertRaisesRegex(ValueError, 'does not match the number'): 297 checkpoint_utils.save_tff_structure_to_checkpoint( 298 struct, 299 ordered_var_names=['v/foo'], 300 output_checkpoint_path=output_checkpoint_path, 301 ) 302 303 @parameterized.named_parameters( 304 ('tf.tensor', tf.constant(1.0)), 305 ('ndarray', np.asarray([1.0, 2.0, 3.0])), 306 ('npnumber', np.float64(1.0)), 307 ('int', 1), 308 ('float', 1.0), 309 ('str', 'test'), 310 ('bytes', b'test'), 311 ) 312 def test_is_allowed(self, structure): 313 self.assertTrue(checkpoint_utils.is_structure_of_allowed_types(structure)) 314 315 @parameterized.named_parameters( 316 ('function', lambda x: x), 317 ('any_proto', any_pb2.Any()), 318 ) 319 def test_is_not_allowed(self, structure): 320 self.assertFalse(checkpoint_utils.is_structure_of_allowed_types(structure)) 321 322 323class CreateDeterministicSaverTest(tf.test.TestCase): 324 325 def test_failure_unknown_type(self): 326 with self.assertRaisesRegex(ValueError, 'Do not know how to make'): 327 # Using a cast in case the test is being run with static type checking. 328 checkpoint_utils.create_deterministic_saver( 329 typing.cast(list[tf.Variable], 0) 330 ) 331 332 def test_creates_saver_for_list(self): 333 with tf.Graph().as_default() as g: 334 saver = checkpoint_utils.create_deterministic_saver([ 335 tf.Variable(initial_value=1.0, name='z'), 336 tf.Variable(initial_value=2.0, name='x'), 337 tf.Variable(initial_value=3.0, name='y'), 338 ]) 339 self.assertIsInstance(saver, tf.compat.v1.train.Saver) 340 test_filepath = self.create_tempfile().full_path 341 with tf.compat.v1.Session(graph=g) as sess: 342 sess.run(tf.compat.v1.global_variables_initializer()) 343 saver.save(sess, save_path=test_filepath) 344 variable_specs = tf.train.list_variables(test_filepath) 345 self.assertEqual([('x', []), ('y', []), ('z', [])], variable_specs) 346 347 def test_creates_saver_for_dict(self): 348 with tf.Graph().as_default() as g: 349 saver = checkpoint_utils.create_deterministic_saver({ 350 'foo': tf.Variable(initial_value=1.0, name='z'), 351 'baz': tf.Variable(initial_value=2.0, name='x'), 352 'bar': tf.Variable(initial_value=3.0, name='y'), 353 }) 354 self.assertIsInstance(saver, tf.compat.v1.train.Saver) 355 test_filepath = self.create_tempfile().full_path 356 with tf.compat.v1.Session(graph=g) as sess: 357 sess.run(tf.compat.v1.global_variables_initializer()) 358 saver.save(sess, save_path=test_filepath) 359 variable_specs = tf.train.list_variables(test_filepath) 360 self.assertEqual([('bar', []), ('baz', []), ('foo', [])], variable_specs) 361 362 363if __name__ == '__main__': 364 absltest.main() 365