1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for simple_hash_table."""
16
17import os.path
18import tempfile
19
20from absl.testing import parameterized
21import tensorflow as tf
22
23from tensorflow.examples.custom_ops_doc.simple_hash_table import simple_hash_table
24from tensorflow.python.eager import def_function
25# This pylint disable is only needed for internal google users
26from tensorflow.python.framework import test_util  # pylint: disable=g-direct-tensorflow-import
27
28
29class SimpleHashTableTest(tf.test.TestCase, parameterized.TestCase):
30
31  # Helper function using "create, find, insert, find, remove, find
32  def _use_table(self, key_dtype, value_dtype):
33    hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype, 111)
34    result1 = hash_table.find(1, -999)
35    hash_table.insert(1, 100)
36    result2 = hash_table.find(1, -999)
37    hash_table.remove(1)
38    result3 = hash_table.find(1, -999)
39    results = tf.stack((result1, result2, result3))
40    return results  # expect [-999, 100, -999]
41
42  # Test of "create, find, insert, find" in eager mode.
43  @parameterized.named_parameters(('int32_float', tf.int32, float),
44                                  ('int64_int32', tf.int64, tf.int32))
45  def test_find_insert_find_eager(self, key_dtype, value_dtype):
46    results = self._use_table(key_dtype, value_dtype)
47    self.assertAllClose(results, [-999, 100, -999])
48
49  # Test of "create, find, insert, find" in a tf.function. Note that the
50  # creation and use of the ref-counted resource occurs inside a single
51  # self.evaluate.
52  @parameterized.named_parameters(('int32_float', tf.int32, float),
53                                  ('int64_int32', tf.int64, tf.int32))
54  def test_find_insert_find_tf_function(self, key_dtype, value_dtype):
55    results = def_function.function(
56        lambda: self._use_table(key_dtype, value_dtype))
57    self.assertAllClose(self.evaluate(results), [-999.0, 100.0, -999.0])
58
59  # strings for key and value
60  def test_find_insert_find_strings_eager(self):
61    default = 'Default'
62    foo = 'Foo'
63    bar = 'Bar'
64    hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string,
65                                                   default)
66    result1 = hash_table.find(foo, default)
67    self.assertEqual(result1, default)
68    hash_table.insert(foo, bar)
69    result2 = hash_table.find(foo, default)
70    self.assertEqual(result2, bar)
71
72  def test_export(self):
73    table = simple_hash_table.SimpleHashTable(
74        tf.int64, tf.int64, default_value=-1)
75    table.insert(1, 100)
76    table.insert(2, 200)
77    table.insert(3, 300)
78    keys, values = self.evaluate(table.export())
79    self.assertAllEqual(sorted(keys), [1, 2, 3])
80    self.assertAllEqual(sorted(values), [100, 200, 300])
81
82  def test_import(self):
83    table = simple_hash_table.SimpleHashTable(
84        tf.int64, tf.int64, default_value=-1)
85    keys = tf.constant([1, 2, 3], dtype=tf.int64)
86    values = tf.constant([100, 200, 300], dtype=tf.int64)
87    table.do_import(keys, values)
88    self.assertEqual(table.find(1), 100)
89    self.assertEqual(table.find(2), 200)
90    self.assertEqual(table.find(3), 300)
91    self.assertEqual(table.find(9), -1)
92
93  @test_util.run_v2_only
94  def testSavedModelSaveRestore(self):
95    save_dir = os.path.join(self.get_temp_dir(), 'save_restore')
96    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), 'hash')
97
98    # TODO(b/203097231) is there an alternative that is not __internal__?
99    root = tf.__internal__.tracking.AutoTrackable()
100
101    default_value = -1
102    root.table = simple_hash_table.SimpleHashTable(
103        tf.int64, tf.int64, default_value=default_value)
104
105    @def_function.function(input_signature=[tf.TensorSpec((), tf.int64)])
106    def lookup(key):
107      return root.table.find(key)
108
109    root.lookup = lookup
110
111    root.table.insert(1, 100)
112    root.table.insert(2, 200)
113    root.table.insert(3, 300)
114    self.assertEqual(root.lookup(2), 200)
115    self.assertAllEqual(3, len(self.evaluate(root.table.export()[0])))
116    tf.saved_model.save(root, save_path)
117
118    del root
119    loaded = tf.saved_model.load(save_path)
120    self.assertEqual(loaded.lookup(2), 200)
121    self.assertEqual(loaded.lookup(10), -1)
122
123
124if __name__ == '__main__':
125  tf.test.main()
126