xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/append_slices_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for the `append_slices` and `merge_appended_slices` custom ops."""
15
16import os
17import tensorflow as tf
18
19from fcp.tensorflow import append_slices
20from fcp.tensorflow import delete_file
21
22
23class AppendSlicesTest(tf.test.TestCase):
24
25  def new_tempfile_path(self):
26    """Returns a path that can be used to store a new tempfile."""
27    return os.path.join(self.create_tempdir(), 'checkpoint.ckp')
28
29  def test_converts_single_element_once_appended_file_to_checkpoint(self):
30    checkpoint_path = self.new_tempfile_path()
31    tensor_name = 'a'
32    tensor = tf.constant(42, dtype=tf.int32)
33    append_slices.append_slices(
34        filename=checkpoint_path,
35        tensor_names=[tensor_name],
36        data=[tensor],
37        shapes_and_slices=[''])
38    append_slices.merge_appended_slices(checkpoint_path)
39    restored = tf.raw_ops.RestoreV2(
40        prefix=checkpoint_path,
41        tensor_names=[tensor_name],
42        shape_and_slices=[''],
43        dtypes=[tf.int32])
44    self.assertEqual(restored[0], 42)
45
46  def test_converts_single_element_twice_appended_file_to_checkpoint(self):
47    checkpoint_path = self.new_tempfile_path()
48    tensor_names = ['a', 'b']
49    tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
50    for (tensor_name, tensor_value) in zip(tensor_names, tensor_values):
51      append_slices.append_slices(
52          filename=checkpoint_path,
53          tensor_names=[tensor_name],
54          data=[tensor_value],
55          shapes_and_slices=[''])
56    append_slices.merge_appended_slices(checkpoint_path)
57    restored = tf.raw_ops.RestoreV2(
58        prefix=checkpoint_path,
59        tensor_names=tensor_names,
60        shape_and_slices=[''] * 2,
61        dtypes=[tf.int32] * 2)
62    self.assertEqual(restored[0], 7)
63    self.assertEqual(restored[1], 11)
64
65  def test_converts_two_element_once_appended_file_to_checkpoint(self):
66    checkpoint_path = self.new_tempfile_path()
67    tensors = [('a', 16), ('b', 17)]
68    append_slices.append_slices(
69        filename=checkpoint_path,
70        tensor_names=[name for (name, value) in tensors],
71        data=[tf.constant(value, tf.int32) for (name, value) in tensors],
72        shapes_and_slices=['' for _ in tensors])
73    append_slices.merge_appended_slices(checkpoint_path)
74    restored = tf.raw_ops.RestoreV2(
75        prefix=checkpoint_path,
76        tensor_names=['a', 'b'],
77        shape_and_slices=[''] * 2,
78        dtypes=[tf.int32] * 2)
79    self.assertEqual(restored[0], 16)
80    self.assertEqual(restored[1], 17)
81
82  def test_converts_two_element_multi_twice_appended_file_to_checkpoint(self):
83    # Note: the interleaved ordering ensures that the resulting merged
84    # checkpoint is able to mix together the two input checkpoints properly.
85    checkpoint_path = self.new_tempfile_path()
86    tensors = [
87        [('a', 12), ('c', 55)],
88        [('b', 40), ('d', 88)],
89    ]
90    for tensors_for_checkpoint in tensors:
91      append_slices.append_slices(
92          filename=checkpoint_path,
93          tensor_names=[name for (name, value) in tensors_for_checkpoint],
94          data=[
95              tf.constant(value, tf.int32)
96              for (name, value) in tensors_for_checkpoint
97          ],
98          shapes_and_slices=['' for _ in tensors_for_checkpoint])
99    append_slices.merge_appended_slices(checkpoint_path)
100    restored = tf.raw_ops.RestoreV2(
101        prefix=checkpoint_path,
102        tensor_names=['a', 'b', 'c', 'd'],
103        shape_and_slices=[''] * 4,
104        dtypes=[tf.int32] * 4)
105    self.assertEqual(restored[0], 12)
106    self.assertEqual(restored[1], 40)
107    self.assertEqual(restored[2], 55)
108    self.assertEqual(restored[3], 88)
109
110  def test_converts_nonalphabetical_two_element_multi_twice_appended_file_to_checkpoint(
111      self):
112    # Note: the interleaved ordering ensures that the resulting merged
113    # checkpoint is able to mix together the two input checkpoints properly.
114    checkpoint_path = self.new_tempfile_path()
115    tensors = [
116        [('b', 12), ('a', 55)],
117        [('d', 40), ('c', 88)],
118    ]
119    for tensors_for_checkpoint in tensors:
120      append_slices.append_slices(
121          filename=checkpoint_path,
122          tensor_names=[name for (name, value) in tensors_for_checkpoint],
123          data=[
124              tf.constant(value, tf.int32)
125              for (name, value) in tensors_for_checkpoint
126          ],
127          shapes_and_slices=['' for _ in tensors_for_checkpoint])
128    append_slices.merge_appended_slices(checkpoint_path)
129    restored = tf.raw_ops.RestoreV2(
130        prefix=checkpoint_path,
131        tensor_names=['d', 'c', 'b', 'a'],
132        shape_and_slices=[''] * 4,
133        dtypes=[tf.int32] * 4)
134    self.assertEqual(restored[0], 40)
135    self.assertEqual(restored[1], 88)
136    self.assertEqual(restored[2], 12)
137    self.assertEqual(restored[3], 55)
138
139  def test_merge_missing_checkpoint_file_raises(self):
140    checkpoint_path = self.new_tempfile_path()
141    with self.assertRaises(tf.errors.NotFoundError):
142      append_slices.merge_appended_slices(checkpoint_path)
143
144  def test_duplicate_named_tensor_raises(self):
145    checkpoint_path = self.new_tempfile_path()
146    tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
147    for tensor_value in tensor_values:
148      append_slices.append_slices(
149          filename=checkpoint_path,
150          tensor_names=['a'],
151          data=[tensor_value],
152          shapes_and_slices=[''])
153    with self.assertRaisesRegex(
154        tf.errors.InvalidArgumentError,
155        'Attempted to merge two checkpoint entries for slice name: `a`'):
156      append_slices.merge_appended_slices(checkpoint_path)
157
158  def test_append_and_merge_using_same_filename(self):
159    checkpoint_path = self.new_tempfile_path()
160    for _ in range(2):
161      # Without calling this we might append to a previously used file.
162      delete_file.delete_file(checkpoint_path)
163
164      tensor_names = ['a', 'b']
165      tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)]
166      for (tensor_name, tensor_value) in zip(tensor_names, tensor_values):
167        append_slices.append_slices(
168            filename=checkpoint_path,
169            tensor_names=[tensor_name],
170            data=[tensor_value],
171            shapes_and_slices=[''])
172      append_slices.merge_appended_slices(checkpoint_path)
173      restored = tf.raw_ops.RestoreV2(
174          prefix=checkpoint_path,
175          tensor_names=tensor_names,
176          shape_and_slices=[''] * 2,
177          dtypes=[tf.int32] * 2)
178      self.assertEqual(restored[0], 7)
179      self.assertEqual(restored[1], 11)
180
181
182if __name__ == '__main__':
183  tf.test.main()
184