xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/save_util_v1_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for the checkpoint/util.py."""
16
17from tensorflow.python.checkpoint import graph_view
18from tensorflow.python.checkpoint import save_util_v1
19from tensorflow.python.eager import test
20from tensorflow.python.ops import variables
21from tensorflow.python.saved_model import registration
22from tensorflow.python.trackable import autotrackable
23from tensorflow.python.util import object_identity
24
25
26class TrackableWithRegisteredSaver(autotrackable.AutoTrackable):
27  pass
28
29
30registration.register_checkpoint_saver(
31    name="RegisteredSaver",
32    predicate=lambda x: isinstance(x, TrackableWithRegisteredSaver),
33    save_fn=lambda trackables, file_prefix: [],
34    restore_fn=lambda trackables, merged_prefix: None)
35
36
37class SerializationTest(test.TestCase):
38
39  def test_serialize_gathered_objects(self):
40    root = autotrackable.AutoTrackable()
41    root.v = variables.Variable(1.0)
42    root.registered = TrackableWithRegisteredSaver()
43    named_saveable_objects, _, _, registered_savers = (
44        save_util_v1.serialize_gathered_objects(
45            graph_view.ObjectGraphView(root)))
46
47    self.assertLen(named_saveable_objects, 1)
48    self.assertIs(named_saveable_objects[0].op, root.v)
49    self.assertDictEqual(
50        {"Custom.RegisteredSaver": {"registered": root.registered}},
51        registered_savers)
52
53  def test_serialize_gathered_objects_with_map(self):
54    root = autotrackable.AutoTrackable()
55    root.v = variables.Variable(1.0)
56    root.registered = TrackableWithRegisteredSaver()
57
58    copy_of_registered = TrackableWithRegisteredSaver()
59    copy_of_v = variables.Variable(1.0)
60    object_map = object_identity.ObjectIdentityDictionary()
61    object_map[root.registered] = copy_of_registered
62    object_map[root.v] = copy_of_v
63
64    named_saveable_objects, _, _, registered_savers = (
65        save_util_v1.serialize_gathered_objects(
66            graph_view.ObjectGraphView(root), object_map))
67
68    self.assertLen(named_saveable_objects, 1)
69    self.assertIsNot(named_saveable_objects[0].op, root.v)
70    self.assertIs(named_saveable_objects[0].op, copy_of_v)
71
72    ret_value = registered_savers["Custom.RegisteredSaver"]["registered"]
73    self.assertIsNot(root.registered, ret_value)
74    self.assertIs(copy_of_registered, ret_value)
75
76
77if __name__ == "__main__":
78  test.main()
79