1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tensorflow.python.training.saver.py.""" 16 17import contextlib 18import os 19import pathlib 20import shutil 21import tempfile 22 23from google.protobuf import text_format 24 25from tensorflow.core.protobuf import saver_pb2 26from tensorflow.python.checkpoint import checkpoint as util 27from tensorflow.python.checkpoint import checkpoint_management 28from tensorflow.python.eager import context 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops as ops_lib 31from tensorflow.python.framework import test_util 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import test 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.training import saver as saver_module 38from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 39 40 41class LatestCheckpointWithRelativePaths(test.TestCase): 42 43 @staticmethod 44 @contextlib.contextmanager 45 def tempWorkingDir(temppath): 46 cwd = os.getcwd() 47 os.chdir(temppath) 48 try: 49 yield 50 finally: 51 os.chdir(cwd) 52 53 @staticmethod 54 @contextlib.contextmanager 55 def tempDir(): 56 tempdir = tempfile.mkdtemp() 57 try: 58 yield tempdir 59 finally: 60 shutil.rmtree(tempdir) 61 62 @test_util.run_deprecated_v1 63 def testNameCollision(self): 64 # Make sure we have a clean directory to work in. 65 with self.tempDir() as tempdir: 66 # Jump to that directory until this test is done. 67 with self.tempWorkingDir(tempdir): 68 # Save training snapshots to a relative path. 69 traindir = "train" 70 os.mkdir(traindir) 71 # Collides with the default name of the checkpoint state file. 72 filepath = os.path.join(traindir, "checkpoint") 73 74 with self.cached_session() as sess: 75 unused_a = variables.Variable(0.0) # So that Saver saves something. 76 self.evaluate(variables.global_variables_initializer()) 77 78 # Should fail. 79 saver = saver_module.Saver(sharded=False) 80 with self.assertRaisesRegex(ValueError, "collides with"): 81 saver.save(sess, filepath) 82 83 # Succeeds: the file will be named "checkpoint-<step>". 84 saver.save(sess, filepath, global_step=1) 85 self.assertIsNotNone( 86 checkpoint_management.latest_checkpoint(traindir)) 87 88 # Succeeds: the file will be named "checkpoint-<i>-of-<n>". 89 saver = saver_module.Saver(sharded=True) 90 saver.save(sess, filepath) 91 self.assertIsNotNone( 92 checkpoint_management.latest_checkpoint(traindir)) 93 94 # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". 95 saver = saver_module.Saver(sharded=True) 96 saver.save(sess, filepath, global_step=1) 97 self.assertIsNotNone( 98 checkpoint_management.latest_checkpoint(traindir)) 99 100 @test_util.run_deprecated_v1 101 def testRelativePath(self): 102 # Make sure we have a clean directory to work in. 103 with self.tempDir() as tempdir: 104 105 # Jump to that directory until this test is done. 106 with self.tempWorkingDir(tempdir): 107 108 # Save training snapshots to a relative path. 109 traindir = "train" 110 os.mkdir(traindir) 111 112 filename = "snapshot" 113 filepath = os.path.join(traindir, filename) 114 115 with self.cached_session() as sess: 116 # Build a simple graph. 117 v0 = variables.Variable(0.0) 118 inc = v0.assign_add(1.0) 119 120 save = saver_module.Saver({"v0": v0}) 121 122 # Record a short training history. 123 self.evaluate(variables.global_variables_initializer()) 124 save.save(sess, filepath, global_step=0) 125 self.evaluate(inc) 126 save.save(sess, filepath, global_step=1) 127 self.evaluate(inc) 128 save.save(sess, filepath, global_step=2) 129 130 with self.cached_session() as sess: 131 # Build a new graph with different initialization. 132 v0 = variables.Variable(-1.0) 133 134 # Create a new saver. 135 save = saver_module.Saver({"v0": v0}) 136 self.evaluate(variables.global_variables_initializer()) 137 138 # Get the most recent checkpoint name from the training history file. 139 name = checkpoint_management.latest_checkpoint(traindir) 140 self.assertIsNotNone(name) 141 142 # Restore "v0" from that checkpoint. 143 save.restore(sess, name) 144 self.assertEqual(v0.eval(), 2.0) 145 146 147class CheckpointStateTest(test.TestCase): 148 149 def _get_test_dir(self, dirname): 150 test_dir = os.path.join(self.get_temp_dir(), dirname) 151 gfile.MakeDirs(test_dir) 152 return test_dir 153 154 def testAbsPath(self): 155 save_dir = self._get_test_dir("abs_paths") 156 abs_path = os.path.join(save_dir, "model-0") 157 ckpt = checkpoint_management.generate_checkpoint_state_proto( 158 save_dir, abs_path) 159 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 160 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 161 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 162 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 163 164 def testRelPath(self): 165 train_dir = "train" 166 model = os.path.join(train_dir, "model-0") 167 # model_checkpoint_path should have no "train" directory part. 168 new_rel_path = "model-0" 169 ckpt = checkpoint_management.generate_checkpoint_state_proto( 170 train_dir, model) 171 self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) 172 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 173 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) 174 175 def testAllModelCheckpointPaths(self): 176 save_dir = self._get_test_dir("all_models_test") 177 abs_path = os.path.join(save_dir, "model-0") 178 for paths in [None, [], ["model-2"]]: 179 ckpt = checkpoint_management.generate_checkpoint_state_proto( 180 save_dir, abs_path, all_model_checkpoint_paths=paths) 181 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 182 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 183 self.assertEqual( 184 len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) 185 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 186 187 def testUpdateCheckpointState(self): 188 save_dir = self._get_test_dir("update_checkpoint_state") 189 os.chdir(save_dir) 190 # Make a temporary train directory. 191 train_dir = "train" 192 os.mkdir(train_dir) 193 abs_path = os.path.join(save_dir, "model-0") 194 rel_path = os.path.join("train", "model-2") 195 checkpoint_management.update_checkpoint_state( 196 train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) 197 ckpt = checkpoint_management.get_checkpoint_state(train_dir) 198 self.assertEqual(ckpt.model_checkpoint_path, rel_path) 199 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 200 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) 201 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) 202 203 def testFSPath(self): 204 save_dir = self._get_test_dir("fspath") 205 os.chdir(save_dir) 206 # Make a temporary train directory. 207 train_dir = "train" 208 os.mkdir(train_dir) 209 abs_path = os.path.join(save_dir, "model-0") 210 rel_path = os.path.join("train", "model-2") 211 checkpoint_management.update_checkpoint_state( 212 train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) 213 ckpt = checkpoint_management.get_checkpoint_state(pathlib.Path(train_dir)) 214 self.assertEqual(ckpt.model_checkpoint_path, rel_path) 215 216 def testUpdateCheckpointStateSaveRelativePaths(self): 217 save_dir = self._get_test_dir("update_checkpoint_state") 218 os.chdir(save_dir) 219 abs_path2 = os.path.join(save_dir, "model-2") 220 rel_path2 = "model-2" 221 abs_path0 = os.path.join(save_dir, "model-0") 222 rel_path0 = "model-0" 223 checkpoint_management.update_checkpoint_state_internal( 224 save_dir=save_dir, 225 model_checkpoint_path=abs_path2, 226 all_model_checkpoint_paths=[rel_path0, abs_path2], 227 save_relative_paths=True) 228 229 # File should contain relative paths. 230 file_content = file_io.read_file_to_string( 231 os.path.join(save_dir, "checkpoint")) 232 ckpt = CheckpointState() 233 text_format.Merge(file_content, ckpt) 234 self.assertEqual(ckpt.model_checkpoint_path, rel_path2) 235 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 236 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) 237 self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) 238 239 # get_checkpoint_state should return absolute paths. 240 ckpt = checkpoint_management.get_checkpoint_state(save_dir) 241 self.assertEqual(ckpt.model_checkpoint_path, abs_path2) 242 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 243 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) 244 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) 245 246 def testCheckPointStateFailsWhenIncomplete(self): 247 save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") 248 os.chdir(save_dir) 249 ckpt_path = os.path.join(save_dir, "checkpoint") 250 ckpt_file = open(ckpt_path, "w") 251 ckpt_file.write("") 252 ckpt_file.close() 253 with self.assertRaises(ValueError): 254 checkpoint_management.get_checkpoint_state(save_dir) 255 256 def testCheckPointCompletesRelativePaths(self): 257 save_dir = self._get_test_dir("checkpoint_completes_relative_paths") 258 os.chdir(save_dir) 259 ckpt_path = os.path.join(save_dir, "checkpoint") 260 ckpt_file = open(ckpt_path, "w") 261 ckpt_file.write(""" 262 model_checkpoint_path: "./model.ckpt-687529" 263 all_model_checkpoint_paths: "./model.ckpt-687500" 264 all_model_checkpoint_paths: "./model.ckpt-687529" 265 """) 266 ckpt_file.close() 267 ckpt = checkpoint_management.get_checkpoint_state(save_dir) 268 self.assertEqual(ckpt.model_checkpoint_path, 269 os.path.join(save_dir, "./model.ckpt-687529")) 270 self.assertEqual(ckpt.all_model_checkpoint_paths[0], 271 os.path.join(save_dir, "./model.ckpt-687500")) 272 self.assertEqual(ckpt.all_model_checkpoint_paths[1], 273 os.path.join(save_dir, "./model.ckpt-687529")) 274 275 276class SaverUtilsTest(test.TestCase): 277 278 def setUp(self): 279 self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") 280 gfile.MakeDirs(self._base_dir) 281 282 def tearDown(self): 283 gfile.DeleteRecursively(self._base_dir) 284 285 @test_util.run_deprecated_v1 286 def testCheckpointExists(self): 287 for sharded in (False, True): 288 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 289 with self.session(graph=ops_lib.Graph()) as sess: 290 unused_v = variables.Variable(1.0, name="v") 291 self.evaluate(variables.global_variables_initializer()) 292 saver = saver_module.Saver(sharded=sharded, write_version=version) 293 294 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) 295 self.assertFalse( 296 checkpoint_management.checkpoint_exists(path)) # Not saved yet. 297 298 ckpt_prefix = saver.save(sess, path) 299 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 300 301 ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir) 302 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 303 304 @test_util.run_deprecated_v1 305 def testGetCheckpointMtimes(self): 306 prefixes = [] 307 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 308 with self.session(graph=ops_lib.Graph()) as sess: 309 unused_v = variables.Variable(1.0, name="v") 310 self.evaluate(variables.global_variables_initializer()) 311 saver = saver_module.Saver(write_version=version) 312 prefixes.append( 313 saver.save(sess, os.path.join(self._base_dir, str(version)))) 314 315 mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes) 316 self.assertEqual(2, len(mtimes)) 317 self.assertTrue(mtimes[1] >= mtimes[0]) 318 319 @test_util.run_deprecated_v1 320 def testRemoveCheckpoint(self): 321 for sharded in (False, True): 322 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 323 with self.session(graph=ops_lib.Graph()) as sess: 324 unused_v = variables.Variable(1.0, name="v") 325 self.evaluate(variables.global_variables_initializer()) 326 saver = saver_module.Saver(sharded=sharded, write_version=version) 327 328 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) 329 ckpt_prefix = saver.save(sess, path) 330 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 331 checkpoint_management.remove_checkpoint(ckpt_prefix, version) 332 self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix)) 333 334 335class CheckpointManagerTest(test.TestCase): 336 337 @test_util.run_in_graph_and_eager_modes 338 def testDeletion(self): 339 checkpoint = util.Checkpoint() 340 manager = checkpoint_management.CheckpointManager( 341 checkpoint, self.get_temp_dir(), max_to_keep=3) 342 first_path = manager.save() 343 second_path = manager.save() 344 third_path = manager.save() 345 fourth_path = manager.save() 346 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 347 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 348 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 349 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 350 351 @test_util.run_in_graph_and_eager_modes 352 def testKeepAll(self): 353 checkpoint = util.Checkpoint() 354 directory = os.path.join( 355 self.get_temp_dir(), 356 # Avoid sharing directories between eager and graph 357 # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories 358 str(context.executing_eagerly())) 359 manager = checkpoint_management.CheckpointManager( 360 checkpoint, directory, max_to_keep=None) 361 first_path = manager.save() 362 second_path = manager.save() 363 third_path = manager.save() 364 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 365 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 366 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 367 self.assertEqual(third_path, manager.latest_checkpoint) 368 self.assertEqual([first_path, second_path, third_path], 369 manager.checkpoints) 370 del manager 371 manager = checkpoint_management.CheckpointManager( 372 checkpoint, directory, max_to_keep=None) 373 fourth_path = manager.save() 374 self.assertEqual([first_path, second_path, third_path, fourth_path], 375 manager.checkpoints) 376 del manager 377 manager = checkpoint_management.CheckpointManager( 378 checkpoint, directory, max_to_keep=3) 379 self.assertEqual([first_path, second_path, third_path, fourth_path], 380 manager.checkpoints) 381 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 382 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 383 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 384 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 385 fifth_path = manager.save() 386 self.assertEqual([third_path, fourth_path, fifth_path], 387 manager.checkpoints) 388 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) 389 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 390 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 391 self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) 392 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 393 394 @test_util.run_in_graph_and_eager_modes 395 @test.mock.patch.object(checkpoint_management, "time") 396 def testSaveRestoreState(self, mock_time): 397 directory = self.get_temp_dir() 398 mock_time.time.return_value = 3. 399 checkpoint = util.Checkpoint() 400 first_manager = checkpoint_management.CheckpointManager( 401 checkpoint, directory, max_to_keep=2) 402 first_time = 10000. 403 first_name = os.path.join(directory, "ckpt-1") 404 mock_time.time.return_value = first_time 405 first_manager.save() 406 state = checkpoint_management.get_checkpoint_state(directory) 407 second_time = first_time + 3610. 408 second_name = os.path.join(directory, "ckpt-2") 409 mock_time.time.return_value = second_time 410 first_manager.save() 411 state = checkpoint_management.get_checkpoint_state(directory) 412 self.assertEqual([first_time, second_time], 413 state.all_model_checkpoint_timestamps) 414 self.assertEqual([first_name, second_name], first_manager.checkpoints) 415 self.assertEqual(second_name, first_manager.latest_checkpoint) 416 del first_manager 417 418 second_manager = checkpoint_management.CheckpointManager( 419 checkpoint, directory, 420 max_to_keep=2, keep_checkpoint_every_n_hours=1.5) 421 self.assertEqual([first_name, second_name], second_manager.checkpoints) 422 self.assertEqual(second_name, second_manager.latest_checkpoint) 423 third_name = os.path.join(directory, "ckpt-3") 424 third_time = second_time + 3600. * 0.2 425 mock_time.time.return_value = third_time 426 second_manager.save() 427 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 428 self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) 429 self.assertEqual([second_name, third_name], 430 second_manager.checkpoints) 431 state = checkpoint_management.get_checkpoint_state(directory) 432 self.assertEqual(first_time, state.last_preserved_timestamp) 433 fourth_time = third_time + 3600. * 0.5 434 mock_time.time.return_value = fourth_time 435 fourth_name = os.path.join(directory, "ckpt-4") 436 second_manager.save() 437 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 438 self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) 439 self.assertEqual([third_name, fourth_name], 440 second_manager.checkpoints) 441 fifth_time = fourth_time + 3600. * 0.5 442 mock_time.time.return_value = fifth_time 443 fifth_name = os.path.join(directory, "ckpt-5") 444 second_manager.save() 445 self.assertEqual([fourth_name, fifth_name], 446 second_manager.checkpoints) 447 state = checkpoint_management.get_checkpoint_state(directory) 448 self.assertEqual(first_time, state.last_preserved_timestamp) 449 del second_manager 450 third_manager = checkpoint_management.CheckpointManager( 451 checkpoint, directory, 452 max_to_keep=2, keep_checkpoint_every_n_hours=1.5) 453 self.assertEqual(fifth_name, third_manager.latest_checkpoint) 454 mock_time.time.return_value += 10. 455 third_manager.save() 456 sixth_name = os.path.join(directory, "ckpt-6") 457 state = checkpoint_management.get_checkpoint_state(directory) 458 self.assertEqual(fourth_time, state.last_preserved_timestamp) 459 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 460 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) 461 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) 462 self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) 463 self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) 464 self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) 465 self.assertEqual([fifth_name, sixth_name], 466 third_manager.checkpoints) 467 468 @test_util.run_in_graph_and_eager_modes 469 def testContinueFromUnmanaged(self): 470 directory = self.get_temp_dir() 471 prefix = os.path.join(directory, "unusual_prefix") 472 checkpoint = util.Checkpoint() 473 first_path = checkpoint.save(prefix) 474 second_path = checkpoint.save(prefix) 475 del checkpoint 476 checkpoint = util.Checkpoint() 477 manager = checkpoint_management.CheckpointManager( 478 checkpoint, directory, max_to_keep=2) 479 checkpoint.restore(manager.latest_checkpoint).run_restore_ops() 480 self.assertEqual(2, self.evaluate(checkpoint.save_counter)) 481 third_path = manager.save() 482 self.assertEqual([third_path], manager.checkpoints) 483 fourth_path = manager.save() 484 self.assertEqual([third_path, fourth_path], 485 manager.checkpoints) 486 fifth_path = manager.save() 487 self.assertEqual([fourth_path, fifth_path], 488 manager.checkpoints) 489 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 490 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 491 self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) 492 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 493 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) 494 495 @test_util.run_in_graph_and_eager_modes 496 @test.mock.patch.object(checkpoint_management, "time") 497 def testClockReset(self, mock_time): 498 directory = self.get_temp_dir() 499 mock_time.time.return_value = 10000. 500 checkpoint = util.Checkpoint() 501 first_manager = checkpoint_management.CheckpointManager( 502 checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) 503 first_path = first_manager.save() 504 mock_time.time.return_value += 3600. 505 second_path = first_manager.save() 506 mock_time.time.return_value += 3600. 507 third_path = first_manager.save() 508 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 509 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 510 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 511 self.assertEqual([third_path], first_manager.checkpoints) 512 state = checkpoint_management.get_checkpoint_state(directory) 513 self.assertEqual(13600., state.last_preserved_timestamp) 514 # Set the clock back in time 515 mock_time.time.return_value = 5000. 516 del first_manager 517 with test.mock.patch.object(logging, "warning") as mock_log: 518 second_manager = checkpoint_management.CheckpointManager( 519 checkpoint, directory, max_to_keep=1) 520 self.assertRegex( 521 str(mock_log.call_args), 522 "behind the last preserved checkpoint timestamp") 523 # We should err on the side of keeping checkpoints around when we're not 524 # sure whether they were preserved or not due to clock funkiness. 525 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 526 # We know about the existing checkpoints, but they'll never be deleted and 527 # so won't go in the CheckpointState proto on save. 528 self.assertEqual(third_path, second_manager.latest_checkpoint) 529 self.assertEqual([], second_manager.checkpoints) 530 mock_time.time.return_value += 10. 531 fourth_path = second_manager.save() 532 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 533 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 534 self.assertEqual(fourth_path, second_manager.latest_checkpoint) 535 self.assertEqual([fourth_path], second_manager.checkpoints) 536 mock_time.time.return_value += 10. 537 fifth_path = second_manager.save() 538 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 539 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 540 self.assertEqual([fifth_path], second_manager.checkpoints) 541 state = checkpoint_management.get_checkpoint_state(directory) 542 self.assertEqual(5000., state.last_preserved_timestamp) 543 self.assertEqual([5020.], 544 state.all_model_checkpoint_timestamps) 545 546 @test_util.run_in_graph_and_eager_modes 547 def testCustomNumbering(self): 548 directory = self.get_temp_dir() 549 step = variables.Variable(0, dtype=dtypes.int64) 550 checkpoint = util.Checkpoint(step=step) 551 manager = checkpoint_management.CheckpointManager( 552 checkpoint, directory, max_to_keep=2) 553 self.evaluate(step.initializer) 554 for i in range(5): 555 path = manager.save(checkpoint_number=step) 556 expected_suffix = "-%d" % (2 * i,) 557 if not path.endswith(expected_suffix): 558 self.fail("%s should have suffix %s" % (path, expected_suffix)) 559 self.evaluate(step.assign_add(2)) 560 self.assertEqual(5, self.evaluate(checkpoint.save_counter)) 561 # Test regular integers 562 last_path = manager.save(checkpoint_number=32) 563 self.assertIn("-32", last_path) 564 self.assertEqual(last_path, manager.latest_checkpoint) 565 self.assertEqual( 566 last_path, checkpoint_management.latest_checkpoint(directory)) 567 state = checkpoint_management.get_checkpoint_state(directory) 568 # Only the most recent two checkpoints are saved 569 self.assertEqual([path, last_path], state.all_model_checkpoint_paths) 570 571 @test_util.run_in_graph_and_eager_modes 572 def testCustomCheckpointPrefix(self): 573 directory = self.get_temp_dir() 574 checkpoint = util.Checkpoint() 575 manager = checkpoint_management.CheckpointManager( 576 checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") 577 path = manager.save(checkpoint_number=5) 578 self.assertEqual(os.path.basename(path), "ckpt_name-5") 579 manager = checkpoint_management.CheckpointManager( 580 checkpoint, directory, max_to_keep=2) 581 path = manager.save(checkpoint_number=5) 582 self.assertEqual(os.path.basename(path), "ckpt-5") 583 584 @test_util.run_in_graph_and_eager_modes 585 def testRestoreOrInitialize(self): 586 directory = self.get_temp_dir() 587 588 # Create a checkpoint for initializing. 589 init_prefix = os.path.join(directory, "init") 590 init_v = variables.Variable(2.0) 591 init_ckpt = util.Checkpoint(v=init_v) 592 self.evaluate(init_v.initializer) 593 init_path = init_ckpt.save(init_prefix) 594 595 # Create the checkpoint manager. 596 ckpt_dir = os.path.join(directory, "ckpt") 597 v = variables.Variable(1.0) 598 checkpoint = util.Checkpoint(v=v) 599 manager = checkpoint_management.CheckpointManager( 600 checkpoint, 601 ckpt_dir, 602 max_to_keep=None, 603 init_fn=lambda: checkpoint.restore(init_path).run_restore_ops()) 604 self.evaluate(v.initializer) 605 606 # First call should call `init_fn`. 607 self.assertIsNone(manager.restore_or_initialize()) 608 self.assertEqual(2.0, self.evaluate(v)) 609 610 # Save a checkpoint and second call should restore from the checkpoints. 611 manager.save() 612 self.assertIsNotNone(manager.restore_or_initialize()) 613 614 @test_util.run_in_graph_and_eager_modes 615 def testCheckpointManagerFSpathDirectory(self): 616 directory = pathlib.Path(self.get_temp_dir()) 617 v = variables.Variable(0.0) 618 checkpoint = util.Checkpoint(v=v) 619 self.evaluate(v.initializer) 620 manager = checkpoint_management.CheckpointManager( 621 checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") 622 save_path = manager.save() 623 expected = str(directory / "ckpt_name-1") 624 self.assertEqual(expected, save_path) 625 626 restore_path = manager.restore_or_initialize() 627 self.assertEqual(str(directory / "ckpt_name-1"), restore_path) 628 629 @test_util.run_in_graph_and_eager_modes 630 def testLatestCheckpointFSpathDirectory(self): 631 directory = pathlib.Path(self.get_temp_dir()) 632 checkpoint = util.Checkpoint() 633 manager = checkpoint_management.CheckpointManager( 634 checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name") 635 manager.save() 636 637 cp_dir = checkpoint_management.latest_checkpoint(directory) 638 self.assertEqual(str(directory / "ckpt_name-1"), cp_dir) 639 640 @test_util.run_in_graph_and_eager_modes 641 def testCheckpointInterval(self): 642 v = variables.Variable(1.0) 643 step_counter = variables.Variable(0) 644 self.evaluate([v.initializer, step_counter.initializer]) 645 checkpoint = util.Checkpoint(v=v) 646 manager = checkpoint_management.CheckpointManager( 647 checkpoint, 648 self.get_temp_dir(), 649 max_to_keep=None, 650 step_counter=step_counter, 651 checkpoint_interval=2) 652 653 # step_counter: 0, save an initial checkpoint. 654 path = manager.save(check_interval=True) 655 self.assertTrue(checkpoint_management.checkpoint_exists(path)) 656 657 # step_counter: 1, no checkpoint saved. 658 self.evaluate(step_counter.assign_add(1)) 659 path = manager.save(check_interval=True) 660 self.assertIsNone(path) 661 662 # step_counter: 2, checkpoint saved. 663 self.evaluate(step_counter.assign_add(1)) 664 path = manager.save(check_interval=True) 665 self.assertTrue(checkpoint_management.checkpoint_exists(path)) 666 667 # no checkpoint saved when calling `save` with the same step counter. 668 path = manager.save(check_interval=True) 669 self.assertIsNone(path) 670 671 # step_counter: 3, no checkpoint saved. 672 self.evaluate(step_counter.assign_add(1)) 673 path = manager.save(check_interval=True) 674 self.assertIsNone(path) 675 676 # Always save the checkpoint. 677 path = manager.save(check_interval=False) 678 self.assertTrue(checkpoint_management.checkpoint_exists(path)) 679 680 @test_util.run_in_graph_and_eager_modes 681 def testCheckpointIntervalWithRestore(self): 682 directory = self.get_temp_dir() 683 v = variables.Variable(1.0) 684 step_counter = variables.Variable(0) 685 self.evaluate([v.initializer, step_counter.initializer]) 686 687 # Prepare a checkpoint. 688 checkpoint = util.Checkpoint(v=v) 689 checkpoint.save(os.path.join(directory, "ckpt")) 690 691 manager = checkpoint_management.CheckpointManager( 692 checkpoint, 693 directory, 694 max_to_keep=None, 695 step_counter=step_counter, 696 checkpoint_interval=2) 697 698 # Restore from the checkpoint. 699 self.assertIsNotNone(manager.restore_or_initialize()) 700 701 # step_counter: 0, no checkpoint saved because it is restored from the 702 # checkpoint with the same step. 703 path = manager.save() 704 self.assertIsNone(path) 705 706 707if __name__ == "__main__": 708 test.main() 709