1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for simple_hash_table.""" 16 17import os.path 18import tempfile 19 20from absl.testing import parameterized 21import tensorflow as tf 22 23from tensorflow.examples.custom_ops_doc.simple_hash_table import simple_hash_table 24from tensorflow.python.eager import def_function 25# This pylint disable is only needed for internal google users 26from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import 27 28 29class SimpleHashTableTest(tf.test.TestCase, parameterized.TestCase): 30 31 # Helper function using "create, find, insert, find, remove, find 32 def _use_table(self, key_dtype, value_dtype): 33 hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype, 111) 34 result1 = hash_table.find(1, -999) 35 hash_table.insert(1, 100) 36 result2 = hash_table.find(1, -999) 37 hash_table.remove(1) 38 result3 = hash_table.find(1, -999) 39 results = tf.stack((result1, result2, result3)) 40 return results # expect [-999, 100, -999] 41 42 # Test of "create, find, insert, find" in eager mode. 43 @parameterized.named_parameters(('int32_float', tf.int32, float), 44 ('int64_int32', tf.int64, tf.int32)) 45 def test_find_insert_find_eager(self, key_dtype, value_dtype): 46 results = self._use_table(key_dtype, value_dtype) 47 self.assertAllClose(results, [-999, 100, -999]) 48 49 # Test of "create, find, insert, find" in a tf.function. Note that the 50 # creation and use of the ref-counted resource occurs inside a single 51 # self.evaluate. 52 @parameterized.named_parameters(('int32_float', tf.int32, float), 53 ('int64_int32', tf.int64, tf.int32)) 54 def test_find_insert_find_tf_function(self, key_dtype, value_dtype): 55 results = def_function.function( 56 lambda: self._use_table(key_dtype, value_dtype)) 57 self.assertAllClose(self.evaluate(results), [-999.0, 100.0, -999.0]) 58 59 # strings for key and value 60 def test_find_insert_find_strings_eager(self): 61 default = 'Default' 62 foo = 'Foo' 63 bar = 'Bar' 64 hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string, 65 default) 66 result1 = hash_table.find(foo, default) 67 self.assertEqual(result1, default) 68 hash_table.insert(foo, bar) 69 result2 = hash_table.find(foo, default) 70 self.assertEqual(result2, bar) 71 72 def test_export(self): 73 table = simple_hash_table.SimpleHashTable( 74 tf.int64, tf.int64, default_value=-1) 75 table.insert(1, 100) 76 table.insert(2, 200) 77 table.insert(3, 300) 78 keys, values = self.evaluate(table.export()) 79 self.assertAllEqual(sorted(keys), [1, 2, 3]) 80 self.assertAllEqual(sorted(values), [100, 200, 300]) 81 82 def test_import(self): 83 table = simple_hash_table.SimpleHashTable( 84 tf.int64, tf.int64, default_value=-1) 85 keys = tf.constant([1, 2, 3], dtype=tf.int64) 86 values = tf.constant([100, 200, 300], dtype=tf.int64) 87 table.do_import(keys, values) 88 self.assertEqual(table.find(1), 100) 89 self.assertEqual(table.find(2), 200) 90 self.assertEqual(table.find(3), 300) 91 self.assertEqual(table.find(9), -1) 92 93 @test_util.run_v2_only 94 def testSavedModelSaveRestore(self): 95 save_dir = os.path.join(self.get_temp_dir(), 'save_restore') 96 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), 'hash') 97 98 # TODO(b/203097231) is there an alternative that is not __internal__? 99 root = tf.__internal__.tracking.AutoTrackable() 100 101 default_value = -1 102 root.table = simple_hash_table.SimpleHashTable( 103 tf.int64, tf.int64, default_value=default_value) 104 105 @def_function.function(input_signature=[tf.TensorSpec((), tf.int64)]) 106 def lookup(key): 107 return root.table.find(key) 108 109 root.lookup = lookup 110 111 root.table.insert(1, 100) 112 root.table.insert(2, 200) 113 root.table.insert(3, 300) 114 self.assertEqual(root.lookup(2), 200) 115 self.assertAllEqual(3, len(self.evaluate(root.table.export()[0]))) 116 tf.saved_model.save(root, save_path) 117 118 del root 119 loaded = tf.saved_model.load(save_path) 120 self.assertEqual(loaded.lookup(2), 200) 121 self.assertEqual(loaded.lookup(10), -1) 122 123 124if __name__ == '__main__': 125 tf.test.main() 126