xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/checkpoint_management_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.training.saver.py."""
16
17import contextlib
18import os
19import pathlib
20import shutil
21import tempfile
22
23from google.protobuf import text_format
24
25from tensorflow.core.protobuf import saver_pb2
26from tensorflow.python.checkpoint import checkpoint as util
27from tensorflow.python.checkpoint import checkpoint_management
28from tensorflow.python.eager import context
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops as ops_lib
31from tensorflow.python.framework import test_util
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import test
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.training import saver as saver_module
38from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
39
40
41class LatestCheckpointWithRelativePaths(test.TestCase):
42
43  @staticmethod
44  @contextlib.contextmanager
45  def tempWorkingDir(temppath):
46    cwd = os.getcwd()
47    os.chdir(temppath)
48    try:
49      yield
50    finally:
51      os.chdir(cwd)
52
53  @staticmethod
54  @contextlib.contextmanager
55  def tempDir():
56    tempdir = tempfile.mkdtemp()
57    try:
58      yield tempdir
59    finally:
60      shutil.rmtree(tempdir)
61
62  @test_util.run_deprecated_v1
63  def testNameCollision(self):
64    # Make sure we have a clean directory to work in.
65    with self.tempDir() as tempdir:
66      # Jump to that directory until this test is done.
67      with self.tempWorkingDir(tempdir):
68        # Save training snapshots to a relative path.
69        traindir = "train"
70        os.mkdir(traindir)
71        # Collides with the default name of the checkpoint state file.
72        filepath = os.path.join(traindir, "checkpoint")
73
74        with self.cached_session() as sess:
75          unused_a = variables.Variable(0.0)  # So that Saver saves something.
76          self.evaluate(variables.global_variables_initializer())
77
78          # Should fail.
79          saver = saver_module.Saver(sharded=False)
80          with self.assertRaisesRegex(ValueError, "collides with"):
81            saver.save(sess, filepath)
82
83          # Succeeds: the file will be named "checkpoint-<step>".
84          saver.save(sess, filepath, global_step=1)
85          self.assertIsNotNone(
86              checkpoint_management.latest_checkpoint(traindir))
87
88          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
89          saver = saver_module.Saver(sharded=True)
90          saver.save(sess, filepath)
91          self.assertIsNotNone(
92              checkpoint_management.latest_checkpoint(traindir))
93
94          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
95          saver = saver_module.Saver(sharded=True)
96          saver.save(sess, filepath, global_step=1)
97          self.assertIsNotNone(
98              checkpoint_management.latest_checkpoint(traindir))
99
100  @test_util.run_deprecated_v1
101  def testRelativePath(self):
102    # Make sure we have a clean directory to work in.
103    with self.tempDir() as tempdir:
104
105      # Jump to that directory until this test is done.
106      with self.tempWorkingDir(tempdir):
107
108        # Save training snapshots to a relative path.
109        traindir = "train"
110        os.mkdir(traindir)
111
112        filename = "snapshot"
113        filepath = os.path.join(traindir, filename)
114
115        with self.cached_session() as sess:
116          # Build a simple graph.
117          v0 = variables.Variable(0.0)
118          inc = v0.assign_add(1.0)
119
120          save = saver_module.Saver({"v0": v0})
121
122          # Record a short training history.
123          self.evaluate(variables.global_variables_initializer())
124          save.save(sess, filepath, global_step=0)
125          self.evaluate(inc)
126          save.save(sess, filepath, global_step=1)
127          self.evaluate(inc)
128          save.save(sess, filepath, global_step=2)
129
130        with self.cached_session() as sess:
131          # Build a new graph with different initialization.
132          v0 = variables.Variable(-1.0)
133
134          # Create a new saver.
135          save = saver_module.Saver({"v0": v0})
136          self.evaluate(variables.global_variables_initializer())
137
138          # Get the most recent checkpoint name from the training history file.
139          name = checkpoint_management.latest_checkpoint(traindir)
140          self.assertIsNotNone(name)
141
142          # Restore "v0" from that checkpoint.
143          save.restore(sess, name)
144          self.assertEqual(v0.eval(), 2.0)
145
146
147class CheckpointStateTest(test.TestCase):
148
149  def _get_test_dir(self, dirname):
150    test_dir = os.path.join(self.get_temp_dir(), dirname)
151    gfile.MakeDirs(test_dir)
152    return test_dir
153
154  def testAbsPath(self):
155    save_dir = self._get_test_dir("abs_paths")
156    abs_path = os.path.join(save_dir, "model-0")
157    ckpt = checkpoint_management.generate_checkpoint_state_proto(
158        save_dir, abs_path)
159    self.assertEqual(ckpt.model_checkpoint_path, abs_path)
160    self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
161    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
162    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
163
164  def testRelPath(self):
165    train_dir = "train"
166    model = os.path.join(train_dir, "model-0")
167    # model_checkpoint_path should have no "train" directory part.
168    new_rel_path = "model-0"
169    ckpt = checkpoint_management.generate_checkpoint_state_proto(
170        train_dir, model)
171    self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
172    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
173    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
174
175  def testAllModelCheckpointPaths(self):
176    save_dir = self._get_test_dir("all_models_test")
177    abs_path = os.path.join(save_dir, "model-0")
178    for paths in [None, [], ["model-2"]]:
179      ckpt = checkpoint_management.generate_checkpoint_state_proto(
180          save_dir, abs_path, all_model_checkpoint_paths=paths)
181      self.assertEqual(ckpt.model_checkpoint_path, abs_path)
182      self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
183      self.assertEqual(
184          len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
185      self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
186
187  def testUpdateCheckpointState(self):
188    save_dir = self._get_test_dir("update_checkpoint_state")
189    os.chdir(save_dir)
190    # Make a temporary train directory.
191    train_dir = "train"
192    os.mkdir(train_dir)
193    abs_path = os.path.join(save_dir, "model-0")
194    rel_path = os.path.join("train", "model-2")
195    checkpoint_management.update_checkpoint_state(
196        train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
197    ckpt = checkpoint_management.get_checkpoint_state(train_dir)
198    self.assertEqual(ckpt.model_checkpoint_path, rel_path)
199    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
200    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
201    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
202
203  def testFSPath(self):
204    save_dir = self._get_test_dir("fspath")
205    os.chdir(save_dir)
206    # Make a temporary train directory.
207    train_dir = "train"
208    os.mkdir(train_dir)
209    abs_path = os.path.join(save_dir, "model-0")
210    rel_path = os.path.join("train", "model-2")
211    checkpoint_management.update_checkpoint_state(
212        train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
213    ckpt = checkpoint_management.get_checkpoint_state(pathlib.Path(train_dir))
214    self.assertEqual(ckpt.model_checkpoint_path, rel_path)
215
216  def testUpdateCheckpointStateSaveRelativePaths(self):
217    save_dir = self._get_test_dir("update_checkpoint_state")
218    os.chdir(save_dir)
219    abs_path2 = os.path.join(save_dir, "model-2")
220    rel_path2 = "model-2"
221    abs_path0 = os.path.join(save_dir, "model-0")
222    rel_path0 = "model-0"
223    checkpoint_management.update_checkpoint_state_internal(
224        save_dir=save_dir,
225        model_checkpoint_path=abs_path2,
226        all_model_checkpoint_paths=[rel_path0, abs_path2],
227        save_relative_paths=True)
228
229    # File should contain relative paths.
230    file_content = file_io.read_file_to_string(
231        os.path.join(save_dir, "checkpoint"))
232    ckpt = CheckpointState()
233    text_format.Merge(file_content, ckpt)
234    self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
235    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
236    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
237    self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
238
239    # get_checkpoint_state should return absolute paths.
240    ckpt = checkpoint_management.get_checkpoint_state(save_dir)
241    self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
242    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
243    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
244    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
245
246  def testCheckPointStateFailsWhenIncomplete(self):
247    save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
248    os.chdir(save_dir)
249    ckpt_path = os.path.join(save_dir, "checkpoint")
250    ckpt_file = open(ckpt_path, "w")
251    ckpt_file.write("")
252    ckpt_file.close()
253    with self.assertRaises(ValueError):
254      checkpoint_management.get_checkpoint_state(save_dir)
255
256  def testCheckPointCompletesRelativePaths(self):
257    save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
258    os.chdir(save_dir)
259    ckpt_path = os.path.join(save_dir, "checkpoint")
260    ckpt_file = open(ckpt_path, "w")
261    ckpt_file.write("""
262        model_checkpoint_path: "./model.ckpt-687529"
263        all_model_checkpoint_paths: "./model.ckpt-687500"
264        all_model_checkpoint_paths: "./model.ckpt-687529"
265        """)
266    ckpt_file.close()
267    ckpt = checkpoint_management.get_checkpoint_state(save_dir)
268    self.assertEqual(ckpt.model_checkpoint_path,
269                     os.path.join(save_dir, "./model.ckpt-687529"))
270    self.assertEqual(ckpt.all_model_checkpoint_paths[0],
271                     os.path.join(save_dir, "./model.ckpt-687500"))
272    self.assertEqual(ckpt.all_model_checkpoint_paths[1],
273                     os.path.join(save_dir, "./model.ckpt-687529"))
274
275
276class SaverUtilsTest(test.TestCase):
277
278  def setUp(self):
279    self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
280    gfile.MakeDirs(self._base_dir)
281
282  def tearDown(self):
283    gfile.DeleteRecursively(self._base_dir)
284
285  @test_util.run_deprecated_v1
286  def testCheckpointExists(self):
287    for sharded in (False, True):
288      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
289        with self.session(graph=ops_lib.Graph()) as sess:
290          unused_v = variables.Variable(1.0, name="v")
291          self.evaluate(variables.global_variables_initializer())
292          saver = saver_module.Saver(sharded=sharded, write_version=version)
293
294          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
295          self.assertFalse(
296              checkpoint_management.checkpoint_exists(path))  # Not saved yet.
297
298          ckpt_prefix = saver.save(sess, path)
299          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
300
301          ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
302          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
303
304  @test_util.run_deprecated_v1
305  def testGetCheckpointMtimes(self):
306    prefixes = []
307    for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
308      with self.session(graph=ops_lib.Graph()) as sess:
309        unused_v = variables.Variable(1.0, name="v")
310        self.evaluate(variables.global_variables_initializer())
311        saver = saver_module.Saver(write_version=version)
312        prefixes.append(
313            saver.save(sess, os.path.join(self._base_dir, str(version))))
314
315    mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
316    self.assertEqual(2, len(mtimes))
317    self.assertTrue(mtimes[1] >= mtimes[0])
318
319  @test_util.run_deprecated_v1
320  def testRemoveCheckpoint(self):
321    for sharded in (False, True):
322      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
323        with self.session(graph=ops_lib.Graph()) as sess:
324          unused_v = variables.Variable(1.0, name="v")
325          self.evaluate(variables.global_variables_initializer())
326          saver = saver_module.Saver(sharded=sharded, write_version=version)
327
328          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
329          ckpt_prefix = saver.save(sess, path)
330          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
331          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
332          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
333
334
335class CheckpointManagerTest(test.TestCase):
336
337  @test_util.run_in_graph_and_eager_modes
338  def testDeletion(self):
339    checkpoint = util.Checkpoint()
340    manager = checkpoint_management.CheckpointManager(
341        checkpoint, self.get_temp_dir(), max_to_keep=3)
342    first_path = manager.save()
343    second_path = manager.save()
344    third_path = manager.save()
345    fourth_path = manager.save()
346    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
347    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
348    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
349    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
350
351  @test_util.run_in_graph_and_eager_modes
352  def testKeepAll(self):
353    checkpoint = util.Checkpoint()
354    directory = os.path.join(
355        self.get_temp_dir(),
356        # Avoid sharing directories between eager and graph
357        # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
358        str(context.executing_eagerly()))
359    manager = checkpoint_management.CheckpointManager(
360        checkpoint, directory, max_to_keep=None)
361    first_path = manager.save()
362    second_path = manager.save()
363    third_path = manager.save()
364    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
365    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
366    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
367    self.assertEqual(third_path, manager.latest_checkpoint)
368    self.assertEqual([first_path, second_path, third_path],
369                     manager.checkpoints)
370    del manager
371    manager = checkpoint_management.CheckpointManager(
372        checkpoint, directory, max_to_keep=None)
373    fourth_path = manager.save()
374    self.assertEqual([first_path, second_path, third_path, fourth_path],
375                     manager.checkpoints)
376    del manager
377    manager = checkpoint_management.CheckpointManager(
378        checkpoint, directory, max_to_keep=3)
379    self.assertEqual([first_path, second_path, third_path, fourth_path],
380                     manager.checkpoints)
381    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
382    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
383    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
384    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
385    fifth_path = manager.save()
386    self.assertEqual([third_path, fourth_path, fifth_path],
387                     manager.checkpoints)
388    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
389    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
390    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
391    self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
392    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
393
394  @test_util.run_in_graph_and_eager_modes
395  @test.mock.patch.object(checkpoint_management, "time")
396  def testSaveRestoreState(self, mock_time):
397    directory = self.get_temp_dir()
398    mock_time.time.return_value = 3.
399    checkpoint = util.Checkpoint()
400    first_manager = checkpoint_management.CheckpointManager(
401        checkpoint, directory, max_to_keep=2)
402    first_time = 10000.
403    first_name = os.path.join(directory, "ckpt-1")
404    mock_time.time.return_value = first_time
405    first_manager.save()
406    state = checkpoint_management.get_checkpoint_state(directory)
407    second_time = first_time + 3610.
408    second_name = os.path.join(directory, "ckpt-2")
409    mock_time.time.return_value = second_time
410    first_manager.save()
411    state = checkpoint_management.get_checkpoint_state(directory)
412    self.assertEqual([first_time, second_time],
413                     state.all_model_checkpoint_timestamps)
414    self.assertEqual([first_name, second_name], first_manager.checkpoints)
415    self.assertEqual(second_name, first_manager.latest_checkpoint)
416    del first_manager
417
418    second_manager = checkpoint_management.CheckpointManager(
419        checkpoint, directory,
420        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
421    self.assertEqual([first_name, second_name], second_manager.checkpoints)
422    self.assertEqual(second_name, second_manager.latest_checkpoint)
423    third_name = os.path.join(directory, "ckpt-3")
424    third_time = second_time + 3600. * 0.2
425    mock_time.time.return_value = third_time
426    second_manager.save()
427    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
428    self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
429    self.assertEqual([second_name, third_name],
430                     second_manager.checkpoints)
431    state = checkpoint_management.get_checkpoint_state(directory)
432    self.assertEqual(first_time, state.last_preserved_timestamp)
433    fourth_time = third_time + 3600. * 0.5
434    mock_time.time.return_value = fourth_time
435    fourth_name = os.path.join(directory, "ckpt-4")
436    second_manager.save()
437    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
438    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
439    self.assertEqual([third_name, fourth_name],
440                     second_manager.checkpoints)
441    fifth_time = fourth_time + 3600. * 0.5
442    mock_time.time.return_value = fifth_time
443    fifth_name = os.path.join(directory, "ckpt-5")
444    second_manager.save()
445    self.assertEqual([fourth_name, fifth_name],
446                     second_manager.checkpoints)
447    state = checkpoint_management.get_checkpoint_state(directory)
448    self.assertEqual(first_time, state.last_preserved_timestamp)
449    del second_manager
450    third_manager = checkpoint_management.CheckpointManager(
451        checkpoint, directory,
452        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
453    self.assertEqual(fifth_name, third_manager.latest_checkpoint)
454    mock_time.time.return_value += 10.
455    third_manager.save()
456    sixth_name = os.path.join(directory, "ckpt-6")
457    state = checkpoint_management.get_checkpoint_state(directory)
458    self.assertEqual(fourth_time, state.last_preserved_timestamp)
459    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
460    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
461    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
462    self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
463    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
464    self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
465    self.assertEqual([fifth_name, sixth_name],
466                     third_manager.checkpoints)
467
468  @test_util.run_in_graph_and_eager_modes
469  def testContinueFromUnmanaged(self):
470    directory = self.get_temp_dir()
471    prefix = os.path.join(directory, "unusual_prefix")
472    checkpoint = util.Checkpoint()
473    first_path = checkpoint.save(prefix)
474    second_path = checkpoint.save(prefix)
475    del checkpoint
476    checkpoint = util.Checkpoint()
477    manager = checkpoint_management.CheckpointManager(
478        checkpoint, directory, max_to_keep=2)
479    checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
480    self.assertEqual(2, self.evaluate(checkpoint.save_counter))
481    third_path = manager.save()
482    self.assertEqual([third_path], manager.checkpoints)
483    fourth_path = manager.save()
484    self.assertEqual([third_path, fourth_path],
485                     manager.checkpoints)
486    fifth_path = manager.save()
487    self.assertEqual([fourth_path, fifth_path],
488                     manager.checkpoints)
489    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
490    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
491    self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
492    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
493    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
494
495  @test_util.run_in_graph_and_eager_modes
496  @test.mock.patch.object(checkpoint_management, "time")
497  def testClockReset(self, mock_time):
498    directory = self.get_temp_dir()
499    mock_time.time.return_value = 10000.
500    checkpoint = util.Checkpoint()
501    first_manager = checkpoint_management.CheckpointManager(
502        checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
503    first_path = first_manager.save()
504    mock_time.time.return_value += 3600.
505    second_path = first_manager.save()
506    mock_time.time.return_value += 3600.
507    third_path = first_manager.save()
508    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
509    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
510    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
511    self.assertEqual([third_path], first_manager.checkpoints)
512    state = checkpoint_management.get_checkpoint_state(directory)
513    self.assertEqual(13600., state.last_preserved_timestamp)
514    # Set the clock back in time
515    mock_time.time.return_value = 5000.
516    del first_manager
517    with test.mock.patch.object(logging, "warning") as mock_log:
518      second_manager = checkpoint_management.CheckpointManager(
519          checkpoint, directory, max_to_keep=1)
520      self.assertRegex(
521          str(mock_log.call_args),
522          "behind the last preserved checkpoint timestamp")
523    # We should err on the side of keeping checkpoints around when we're not
524    # sure whether they were preserved or not due to clock funkiness.
525    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
526    # We know about the existing checkpoints, but they'll never be deleted and
527    # so won't go in the CheckpointState proto on save.
528    self.assertEqual(third_path, second_manager.latest_checkpoint)
529    self.assertEqual([], second_manager.checkpoints)
530    mock_time.time.return_value += 10.
531    fourth_path = second_manager.save()
532    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
533    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
534    self.assertEqual(fourth_path, second_manager.latest_checkpoint)
535    self.assertEqual([fourth_path], second_manager.checkpoints)
536    mock_time.time.return_value += 10.
537    fifth_path = second_manager.save()
538    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
539    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
540    self.assertEqual([fifth_path], second_manager.checkpoints)
541    state = checkpoint_management.get_checkpoint_state(directory)
542    self.assertEqual(5000., state.last_preserved_timestamp)
543    self.assertEqual([5020.],
544                     state.all_model_checkpoint_timestamps)
545
546  @test_util.run_in_graph_and_eager_modes
547  def testCustomNumbering(self):
548    directory = self.get_temp_dir()
549    step = variables.Variable(0, dtype=dtypes.int64)
550    checkpoint = util.Checkpoint(step=step)
551    manager = checkpoint_management.CheckpointManager(
552        checkpoint, directory, max_to_keep=2)
553    self.evaluate(step.initializer)
554    for i in range(5):
555      path = manager.save(checkpoint_number=step)
556      expected_suffix = "-%d" % (2 * i,)
557      if not path.endswith(expected_suffix):
558        self.fail("%s should have suffix %s" % (path, expected_suffix))
559      self.evaluate(step.assign_add(2))
560    self.assertEqual(5, self.evaluate(checkpoint.save_counter))
561    # Test regular integers
562    last_path = manager.save(checkpoint_number=32)
563    self.assertIn("-32", last_path)
564    self.assertEqual(last_path, manager.latest_checkpoint)
565    self.assertEqual(
566        last_path, checkpoint_management.latest_checkpoint(directory))
567    state = checkpoint_management.get_checkpoint_state(directory)
568    # Only the most recent two checkpoints are saved
569    self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
570
571  @test_util.run_in_graph_and_eager_modes
572  def testCustomCheckpointPrefix(self):
573    directory = self.get_temp_dir()
574    checkpoint = util.Checkpoint()
575    manager = checkpoint_management.CheckpointManager(
576        checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
577    path = manager.save(checkpoint_number=5)
578    self.assertEqual(os.path.basename(path), "ckpt_name-5")
579    manager = checkpoint_management.CheckpointManager(
580        checkpoint, directory, max_to_keep=2)
581    path = manager.save(checkpoint_number=5)
582    self.assertEqual(os.path.basename(path), "ckpt-5")
583
584  @test_util.run_in_graph_and_eager_modes
585  def testRestoreOrInitialize(self):
586    directory = self.get_temp_dir()
587
588    # Create a checkpoint for initializing.
589    init_prefix = os.path.join(directory, "init")
590    init_v = variables.Variable(2.0)
591    init_ckpt = util.Checkpoint(v=init_v)
592    self.evaluate(init_v.initializer)
593    init_path = init_ckpt.save(init_prefix)
594
595    # Create the checkpoint manager.
596    ckpt_dir = os.path.join(directory, "ckpt")
597    v = variables.Variable(1.0)
598    checkpoint = util.Checkpoint(v=v)
599    manager = checkpoint_management.CheckpointManager(
600        checkpoint,
601        ckpt_dir,
602        max_to_keep=None,
603        init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
604    self.evaluate(v.initializer)
605
606    # First call should call `init_fn`.
607    self.assertIsNone(manager.restore_or_initialize())
608    self.assertEqual(2.0, self.evaluate(v))
609
610    # Save a checkpoint and second call should restore from the checkpoints.
611    manager.save()
612    self.assertIsNotNone(manager.restore_or_initialize())
613
614  @test_util.run_in_graph_and_eager_modes
615  def testCheckpointManagerFSpathDirectory(self):
616    directory = pathlib.Path(self.get_temp_dir())
617    v = variables.Variable(0.0)
618    checkpoint = util.Checkpoint(v=v)
619    self.evaluate(v.initializer)
620    manager = checkpoint_management.CheckpointManager(
621        checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
622    save_path = manager.save()
623    expected = str(directory / "ckpt_name-1")
624    self.assertEqual(expected, save_path)
625
626    restore_path = manager.restore_or_initialize()
627    self.assertEqual(str(directory / "ckpt_name-1"), restore_path)
628
629  @test_util.run_in_graph_and_eager_modes
630  def testLatestCheckpointFSpathDirectory(self):
631    directory = pathlib.Path(self.get_temp_dir())
632    checkpoint = util.Checkpoint()
633    manager = checkpoint_management.CheckpointManager(
634        checkpoint, directory, max_to_keep=2, checkpoint_name="ckpt_name")
635    manager.save()
636
637    cp_dir = checkpoint_management.latest_checkpoint(directory)
638    self.assertEqual(str(directory / "ckpt_name-1"), cp_dir)
639
640  @test_util.run_in_graph_and_eager_modes
641  def testCheckpointInterval(self):
642    v = variables.Variable(1.0)
643    step_counter = variables.Variable(0)
644    self.evaluate([v.initializer, step_counter.initializer])
645    checkpoint = util.Checkpoint(v=v)
646    manager = checkpoint_management.CheckpointManager(
647        checkpoint,
648        self.get_temp_dir(),
649        max_to_keep=None,
650        step_counter=step_counter,
651        checkpoint_interval=2)
652
653    # step_counter: 0, save an initial checkpoint.
654    path = manager.save(check_interval=True)
655    self.assertTrue(checkpoint_management.checkpoint_exists(path))
656
657    # step_counter: 1, no checkpoint saved.
658    self.evaluate(step_counter.assign_add(1))
659    path = manager.save(check_interval=True)
660    self.assertIsNone(path)
661
662    # step_counter: 2, checkpoint saved.
663    self.evaluate(step_counter.assign_add(1))
664    path = manager.save(check_interval=True)
665    self.assertTrue(checkpoint_management.checkpoint_exists(path))
666
667    # no checkpoint saved when calling `save` with the same step counter.
668    path = manager.save(check_interval=True)
669    self.assertIsNone(path)
670
671    # step_counter: 3, no checkpoint saved.
672    self.evaluate(step_counter.assign_add(1))
673    path = manager.save(check_interval=True)
674    self.assertIsNone(path)
675
676    # Always save the checkpoint.
677    path = manager.save(check_interval=False)
678    self.assertTrue(checkpoint_management.checkpoint_exists(path))
679
680  @test_util.run_in_graph_and_eager_modes
681  def testCheckpointIntervalWithRestore(self):
682    directory = self.get_temp_dir()
683    v = variables.Variable(1.0)
684    step_counter = variables.Variable(0)
685    self.evaluate([v.initializer, step_counter.initializer])
686
687    # Prepare a checkpoint.
688    checkpoint = util.Checkpoint(v=v)
689    checkpoint.save(os.path.join(directory, "ckpt"))
690
691    manager = checkpoint_management.CheckpointManager(
692        checkpoint,
693        directory,
694        max_to_keep=None,
695        step_counter=step_counter,
696        checkpoint_interval=2)
697
698    # Restore from the checkpoint.
699    self.assertIsNotNone(manager.restore_or_initialize())
700
701    # step_counter: 0, no checkpoint saved because it is restored from the
702    # checkpoint with the same step.
703    path = manager.save()
704    self.assertIsNone(path)
705
706
707if __name__ == "__main__":
708  test.main()
709