xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/delete_file_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for the `delete_file` custom op."""
15
16import os
17
18import tensorflow as tf
19
20from fcp.tensorflow import delete_file
21
22
23class DeleteOpTest(tf.test.TestCase):
24
25  def setup_temp_dir(self) -> tuple[str, str]:
26    """Sets up a temporary directory suitable for testing.
27
28    The filesystem consist of directory with one file inside.
29
30    Returns:
31      Tuple of directory and checkpoint paths.
32    """
33    temp_dir = self.create_tempdir().full_path
34    temp_file = os.path.join(temp_dir, 'checkpoint.ckp')
35
36    expected_content = 'content'
37    tf.io.write_file(temp_file, expected_content)
38    read_content = tf.io.read_file(temp_file)
39    self.assertEqual(expected_content, read_content)
40
41    self.assertTrue(os.path.isdir(temp_dir))
42    self.assertTrue(os.path.exists(temp_file))
43    return temp_dir, temp_file
44
45  def test_delete_file_op(self):
46    _, temp_file = self.setup_temp_dir()
47
48    delete_file.delete_file(temp_file)
49    # Delete one more time to make sure no error when the file doesn't exist.
50    delete_file.delete_file(temp_file)
51    self.assertFalse(os.path.exists(temp_file))
52
53  def test_delete_file_op_exceptions(self):
54    with self.subTest(name='non_string_dtype'):
55      with self.assertRaises(TypeError):
56        delete_file.delete_file(1.0)
57    with self.subTest(name='non_scalar'):
58      with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
59                                  '.*must be a string scalar.*'):
60        _, checkpoint_path = self.setup_temp_dir()
61        delete_file.delete_file([checkpoint_path, checkpoint_path])
62
63  def test_delete_file_and_dir_succeeds(self):
64    temp_dir, temp_file = self.setup_temp_dir()
65    delete_file.delete_file(temp_file)
66    self.assertFalse(os.path.exists(temp_file))
67
68    delete_file.delete_dir(temp_dir)
69    # Delete dir more time to make sure no error when the dir doesn't exist.
70    delete_file.delete_dir(temp_dir)
71    self.assertFalse(os.path.isdir(temp_dir))
72
73  def test_delete_non_empty_dir_fails(self):
74    temp_dir, temp_file = self.setup_temp_dir()
75
76    delete_file.delete_dir(temp_dir)
77    self.assertTrue(os.path.isdir(temp_dir))
78    self.assertTrue(os.path.exists(temp_file))
79
80  def test_recursive_delete_non_empty_dir_succeeds(self):
81    temp_dir, temp_file = self.setup_temp_dir()
82
83    delete_file.delete_dir(temp_dir, recursively=True)
84    self.assertFalse(os.path.isdir(temp_dir))
85    self.assertFalse(os.path.exists(temp_file))
86
87  def test_delete_dir_op_exceptions(self):
88    with self.subTest(name='non_string_dtype'):
89      with self.assertRaises(TypeError):
90        delete_file.delete_dir(1.0)
91    with self.subTest(name='non_scalar'):
92      with self.assertRaisesRegex(
93          tf.errors.InvalidArgumentError, '.*must be a string scalar.*'
94      ):
95        temp_dir, _ = self.setup_temp_dir()
96        delete_file.delete_dir([temp_dir, temp_dir])
97
98
99if __name__ == '__main__':
100  tf.test.main()
101