1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for the `append_slices` and `merge_appended_slices` custom ops.""" 15 16import os 17import tensorflow as tf 18 19from fcp.tensorflow import append_slices 20from fcp.tensorflow import delete_file 21 22 23class AppendSlicesTest(tf.test.TestCase): 24 25 def new_tempfile_path(self): 26 """Returns a path that can be used to store a new tempfile.""" 27 return os.path.join(self.create_tempdir(), 'checkpoint.ckp') 28 29 def test_converts_single_element_once_appended_file_to_checkpoint(self): 30 checkpoint_path = self.new_tempfile_path() 31 tensor_name = 'a' 32 tensor = tf.constant(42, dtype=tf.int32) 33 append_slices.append_slices( 34 filename=checkpoint_path, 35 tensor_names=[tensor_name], 36 data=[tensor], 37 shapes_and_slices=['']) 38 append_slices.merge_appended_slices(checkpoint_path) 39 restored = tf.raw_ops.RestoreV2( 40 prefix=checkpoint_path, 41 tensor_names=[tensor_name], 42 shape_and_slices=[''], 43 dtypes=[tf.int32]) 44 self.assertEqual(restored[0], 42) 45 46 def test_converts_single_element_twice_appended_file_to_checkpoint(self): 47 checkpoint_path = self.new_tempfile_path() 48 tensor_names = ['a', 'b'] 49 tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)] 50 for (tensor_name, tensor_value) in zip(tensor_names, tensor_values): 51 append_slices.append_slices( 52 filename=checkpoint_path, 53 tensor_names=[tensor_name], 54 data=[tensor_value], 55 shapes_and_slices=['']) 56 append_slices.merge_appended_slices(checkpoint_path) 57 restored = tf.raw_ops.RestoreV2( 58 prefix=checkpoint_path, 59 tensor_names=tensor_names, 60 shape_and_slices=[''] * 2, 61 dtypes=[tf.int32] * 2) 62 self.assertEqual(restored[0], 7) 63 self.assertEqual(restored[1], 11) 64 65 def test_converts_two_element_once_appended_file_to_checkpoint(self): 66 checkpoint_path = self.new_tempfile_path() 67 tensors = [('a', 16), ('b', 17)] 68 append_slices.append_slices( 69 filename=checkpoint_path, 70 tensor_names=[name for (name, value) in tensors], 71 data=[tf.constant(value, tf.int32) for (name, value) in tensors], 72 shapes_and_slices=['' for _ in tensors]) 73 append_slices.merge_appended_slices(checkpoint_path) 74 restored = tf.raw_ops.RestoreV2( 75 prefix=checkpoint_path, 76 tensor_names=['a', 'b'], 77 shape_and_slices=[''] * 2, 78 dtypes=[tf.int32] * 2) 79 self.assertEqual(restored[0], 16) 80 self.assertEqual(restored[1], 17) 81 82 def test_converts_two_element_multi_twice_appended_file_to_checkpoint(self): 83 # Note: the interleaved ordering ensures that the resulting merged 84 # checkpoint is able to mix together the two input checkpoints properly. 85 checkpoint_path = self.new_tempfile_path() 86 tensors = [ 87 [('a', 12), ('c', 55)], 88 [('b', 40), ('d', 88)], 89 ] 90 for tensors_for_checkpoint in tensors: 91 append_slices.append_slices( 92 filename=checkpoint_path, 93 tensor_names=[name for (name, value) in tensors_for_checkpoint], 94 data=[ 95 tf.constant(value, tf.int32) 96 for (name, value) in tensors_for_checkpoint 97 ], 98 shapes_and_slices=['' for _ in tensors_for_checkpoint]) 99 append_slices.merge_appended_slices(checkpoint_path) 100 restored = tf.raw_ops.RestoreV2( 101 prefix=checkpoint_path, 102 tensor_names=['a', 'b', 'c', 'd'], 103 shape_and_slices=[''] * 4, 104 dtypes=[tf.int32] * 4) 105 self.assertEqual(restored[0], 12) 106 self.assertEqual(restored[1], 40) 107 self.assertEqual(restored[2], 55) 108 self.assertEqual(restored[3], 88) 109 110 def test_converts_nonalphabetical_two_element_multi_twice_appended_file_to_checkpoint( 111 self): 112 # Note: the interleaved ordering ensures that the resulting merged 113 # checkpoint is able to mix together the two input checkpoints properly. 114 checkpoint_path = self.new_tempfile_path() 115 tensors = [ 116 [('b', 12), ('a', 55)], 117 [('d', 40), ('c', 88)], 118 ] 119 for tensors_for_checkpoint in tensors: 120 append_slices.append_slices( 121 filename=checkpoint_path, 122 tensor_names=[name for (name, value) in tensors_for_checkpoint], 123 data=[ 124 tf.constant(value, tf.int32) 125 for (name, value) in tensors_for_checkpoint 126 ], 127 shapes_and_slices=['' for _ in tensors_for_checkpoint]) 128 append_slices.merge_appended_slices(checkpoint_path) 129 restored = tf.raw_ops.RestoreV2( 130 prefix=checkpoint_path, 131 tensor_names=['d', 'c', 'b', 'a'], 132 shape_and_slices=[''] * 4, 133 dtypes=[tf.int32] * 4) 134 self.assertEqual(restored[0], 40) 135 self.assertEqual(restored[1], 88) 136 self.assertEqual(restored[2], 12) 137 self.assertEqual(restored[3], 55) 138 139 def test_merge_missing_checkpoint_file_raises(self): 140 checkpoint_path = self.new_tempfile_path() 141 with self.assertRaises(tf.errors.NotFoundError): 142 append_slices.merge_appended_slices(checkpoint_path) 143 144 def test_duplicate_named_tensor_raises(self): 145 checkpoint_path = self.new_tempfile_path() 146 tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)] 147 for tensor_value in tensor_values: 148 append_slices.append_slices( 149 filename=checkpoint_path, 150 tensor_names=['a'], 151 data=[tensor_value], 152 shapes_and_slices=['']) 153 with self.assertRaisesRegex( 154 tf.errors.InvalidArgumentError, 155 'Attempted to merge two checkpoint entries for slice name: `a`'): 156 append_slices.merge_appended_slices(checkpoint_path) 157 158 def test_append_and_merge_using_same_filename(self): 159 checkpoint_path = self.new_tempfile_path() 160 for _ in range(2): 161 # Without calling this we might append to a previously used file. 162 delete_file.delete_file(checkpoint_path) 163 164 tensor_names = ['a', 'b'] 165 tensor_values = [tf.constant(x, dtype=tf.int32) for x in (7, 11)] 166 for (tensor_name, tensor_value) in zip(tensor_names, tensor_values): 167 append_slices.append_slices( 168 filename=checkpoint_path, 169 tensor_names=[tensor_name], 170 data=[tensor_value], 171 shapes_and_slices=['']) 172 append_slices.merge_appended_slices(checkpoint_path) 173 restored = tf.raw_ops.RestoreV2( 174 prefix=checkpoint_path, 175 tensor_names=tensor_names, 176 shape_and_slices=[''] * 2, 177 dtypes=[tf.int32] * 2) 178 self.assertEqual(restored[0], 7) 179 self.assertEqual(restored[1], 11) 180 181 182if __name__ == '__main__': 183 tf.test.main() 184