1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for the `delete_file` custom op.""" 15 16import os 17 18import tensorflow as tf 19 20from fcp.tensorflow import delete_file 21 22 23class DeleteOpTest(tf.test.TestCase): 24 25 def setup_temp_dir(self) -> tuple[str, str]: 26 """Sets up a temporary directory suitable for testing. 27 28 The filesystem consist of directory with one file inside. 29 30 Returns: 31 Tuple of directory and checkpoint paths. 32 """ 33 temp_dir = self.create_tempdir().full_path 34 temp_file = os.path.join(temp_dir, 'checkpoint.ckp') 35 36 expected_content = 'content' 37 tf.io.write_file(temp_file, expected_content) 38 read_content = tf.io.read_file(temp_file) 39 self.assertEqual(expected_content, read_content) 40 41 self.assertTrue(os.path.isdir(temp_dir)) 42 self.assertTrue(os.path.exists(temp_file)) 43 return temp_dir, temp_file 44 45 def test_delete_file_op(self): 46 _, temp_file = self.setup_temp_dir() 47 48 delete_file.delete_file(temp_file) 49 # Delete one more time to make sure no error when the file doesn't exist. 50 delete_file.delete_file(temp_file) 51 self.assertFalse(os.path.exists(temp_file)) 52 53 def test_delete_file_op_exceptions(self): 54 with self.subTest(name='non_string_dtype'): 55 with self.assertRaises(TypeError): 56 delete_file.delete_file(1.0) 57 with self.subTest(name='non_scalar'): 58 with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 59 '.*must be a string scalar.*'): 60 _, checkpoint_path = self.setup_temp_dir() 61 delete_file.delete_file([checkpoint_path, checkpoint_path]) 62 63 def test_delete_file_and_dir_succeeds(self): 64 temp_dir, temp_file = self.setup_temp_dir() 65 delete_file.delete_file(temp_file) 66 self.assertFalse(os.path.exists(temp_file)) 67 68 delete_file.delete_dir(temp_dir) 69 # Delete dir more time to make sure no error when the dir doesn't exist. 70 delete_file.delete_dir(temp_dir) 71 self.assertFalse(os.path.isdir(temp_dir)) 72 73 def test_delete_non_empty_dir_fails(self): 74 temp_dir, temp_file = self.setup_temp_dir() 75 76 delete_file.delete_dir(temp_dir) 77 self.assertTrue(os.path.isdir(temp_dir)) 78 self.assertTrue(os.path.exists(temp_file)) 79 80 def test_recursive_delete_non_empty_dir_succeeds(self): 81 temp_dir, temp_file = self.setup_temp_dir() 82 83 delete_file.delete_dir(temp_dir, recursively=True) 84 self.assertFalse(os.path.isdir(temp_dir)) 85 self.assertFalse(os.path.exists(temp_file)) 86 87 def test_delete_dir_op_exceptions(self): 88 with self.subTest(name='non_string_dtype'): 89 with self.assertRaises(TypeError): 90 delete_file.delete_dir(1.0) 91 with self.subTest(name='non_scalar'): 92 with self.assertRaisesRegex( 93 tf.errors.InvalidArgumentError, '.*must be a string scalar.*' 94 ): 95 temp_dir, _ = self.setup_temp_dir() 96 delete_file.delete_dir([temp_dir, temp_dir]) 97 98 99if __name__ == '__main__': 100 tf.test.main() 101