1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A wrapper for gen_simple_hash_table_op.py. 16 17This defines a public API and provides a docstring for the C++ Op defined by 18simple_hash_table_kernel.cc 19""" 20 21import tensorflow as tf 22from tensorflow.examples.custom_ops_doc.simple_hash_table.simple_hash_table_op import gen_simple_hash_table_op 23 24 25class SimpleHashTable(tf.saved_model.experimental.TrackableResource): 26 """A simple mutable hash table implementation. 27 28 Implement a simple hash table as a Resource using ref-counting. 29 This demonstrates a Stateful Op for a general Create/Read/Update/Delete 30 (CRUD) style use case. To instead make an op for a specific lookup table 31 case, it is preferable to follow the implementation style of 32 TensorFlow's internal ops, e.g. use LookupInterface. 33 34 Data can be inserted by calling the `insert` method and removed by calling 35 the `remove` method. It does not support initialization via the init method. 36 37 The `import` and `export` methods allow loading and restoring all of 38 the key, value pairs. These methods (or their corresponding kernels) 39 are intended to be used for supporting SavedModel. 40 41 Example usage: 42 hash_table = simple_hash_table_op.SimpleHashTable(key_dtype, value_dtype, 43 111) 44 result1 = hash_table.find(1, -999) # -999 45 hash_table.insert(1, 100) 46 result2 = hash_table.find(1, -999) # 100 47 hash_table.remove(1) 48 result3 = hash_table.find(1, -999) # -999 49 """ 50 51 def __init__(self, 52 key_dtype, 53 value_dtype, 54 default_value, 55 name="SimpleHashTable"): 56 """Creates an empty `SimpleHashTable` object. 57 58 Creates a table, the type of its keys and values are specified by key_dtype 59 and value_dtype, respectively. 60 61 Args: 62 key_dtype: the type of the key tensors. 63 value_dtype: the type of the value tensors. 64 default_value: The value to use if a key is missing in the table. 65 name: A name for the operation (optional). 66 67 Returns: 68 A `SimpleHashTable` object. 69 """ 70 super(SimpleHashTable, self).__init__() 71 self._default_value = tf.convert_to_tensor(default_value, dtype=value_dtype) 72 self._value_shape = self._default_value.get_shape() 73 self._key_dtype = key_dtype 74 self._value_dtype = value_dtype 75 self._name = name 76 self._resource_handle = self._create_resource() 77 # Methods that use the Resource get its handle using the 78 # public self.resource_handle property (defined by TrackableResource). 79 # This property calls self._create_resource() the first time 80 # if private self._resource_handle is not preemptively initialized. 81 82 def _create_resource(self): 83 """Create the resource tensor handle. 84 85 `_create_resource` is an override of a method in base class 86 `TrackableResource` that is required for SavedModel support. It can be 87 called by the `resource_handle` property defined by `TrackableResource`. 88 89 Returns: 90 A tensor handle to the lookup table. 91 """ 92 assert self._default_value.get_shape().ndims == 0 93 table_ref = gen_simple_hash_table_op.examples_simple_hash_table_create( 94 key_dtype=self._key_dtype, 95 value_dtype=self._value_dtype, 96 name=self._name) 97 return table_ref 98 99 def _serialize_to_tensors(self): 100 """Implements checkpointing protocols for `Trackable`.""" 101 tensors = self.export() 102 return {"table-keys": tensors[0], "table-values": tensors[1]} 103 104 def _restore_from_tensors(self, restored_tensors): 105 """Implements checkpointing protocols for `Trackable`.""" 106 return gen_simple_hash_table_op.examples_simple_hash_table_import( 107 self.resource_handle, restored_tensors["table-keys"], 108 restored_tensors["table-values"]) 109 110 @property 111 def key_dtype(self): 112 """The table key dtype.""" 113 return self._key_dtype 114 115 @property 116 def value_dtype(self): 117 """The table value dtype.""" 118 return self._value_dtype 119 120 def find(self, key, dynamic_default_value=None, name=None): 121 """Looks up `key` in a table, outputs the corresponding value. 122 123 The `default_value` is used if key not present in the table. 124 125 Args: 126 key: Key to look up. Must match the table's key_dtype. 127 dynamic_default_value: The value to use if the key is missing in the 128 table. If None (by default), the `table.default_value` will be used. 129 name: A name for the operation (optional). 130 131 Returns: 132 A tensor containing the value in the same shape as `key` using the 133 table's value type. 134 135 Raises: 136 TypeError: when `key` do not match the table data types. 137 """ 138 with tf.name_scope(name or "%s_lookup_table_find" % self._name): 139 key = tf.convert_to_tensor(key, dtype=self._key_dtype, name="key") 140 if dynamic_default_value is not None: 141 dynamic_default_value = tf.convert_to_tensor( 142 dynamic_default_value, 143 dtype=self._value_dtype, 144 name="default_value") 145 value = gen_simple_hash_table_op.examples_simple_hash_table_find( 146 self.resource_handle, key, dynamic_default_value 147 if dynamic_default_value is not None else self._default_value) 148 return value 149 150 def insert(self, key, value, name=None): 151 """Associates `key` with `value`. 152 153 Args: 154 key: Scalar key to insert. 155 value: Scalar value to be associated with key. 156 name: A name for the operation (optional). 157 158 Returns: 159 The created Operation. 160 161 Raises: 162 TypeError: when `key` or `value` doesn't match the table data 163 types. 164 """ 165 with tf.name_scope(name or "%s_lookup_table_insert" % self._name): 166 key = tf.convert_to_tensor(key, self._key_dtype, name="key") 167 value = tf.convert_to_tensor(value, self._value_dtype, name="value") 168 # pylint: disable=protected-access 169 op = gen_simple_hash_table_op.examples_simple_hash_table_insert( 170 self.resource_handle, key, value) 171 return op 172 173 def remove(self, key, name=None): 174 """Remove `key`. 175 176 Args: 177 key: Scalar key to remove. 178 name: A name for the operation (optional). 179 180 Returns: 181 The created Operation. 182 183 Raises: 184 TypeError: when `key` doesn't match the table data type. 185 """ 186 with tf.name_scope(name or "%s_lookup_table_remove" % self._name): 187 key = tf.convert_to_tensor(key, self._key_dtype, name="key") 188 189 # For remove, just the key is used by the kernel; no value is used. 190 # But the kernel is specifc to key_dtype and value_dtype 191 # (i.e. it uses a <key_dtype, value_dtype> template). 192 # So value_dtype is passed in explicitly. (While 193 # key_dtype is specificed implicitly by the dtype of key.) 194 195 # pylint: disable=protected-access 196 op = gen_simple_hash_table_op.examples_simple_hash_table_remove( 197 self.resource_handle, key, value_dtype=self._value_dtype) 198 return op 199 200 def export(self, name=None): 201 """Export all `key` and `value` pairs. 202 203 Args: 204 name: A name for the operation (optional). 205 206 Returns: 207 A tuple of two tensors, the first with the `keys` and the second with 208 the `values`. 209 """ 210 with tf.name_scope(name or "%s_lookup_table_export" % self._name): 211 # pylint: disable=protected-access 212 keys, values = gen_simple_hash_table_op.examples_simple_hash_table_export( 213 self.resource_handle, 214 key_dtype=self._key_dtype, 215 value_dtype=self._value_dtype) 216 return keys, values 217 218 def do_import(self, keys, values, name=None): 219 """Import all `key` and `value` pairs. 220 221 (Note that "import" is a python reserved word, so it cannot be the name of 222 a method.) 223 224 Args: 225 keys: Tensor of all keys. 226 values: Tensor of all values. 227 name: A name for the operation (optional). 228 229 Returns: 230 A tuple of two tensors, the first with the `keys` and the second with 231 the `values`. 232 """ 233 with tf.name_scope(name or "%s_lookup_table_import" % self._name): 234 # pylint: disable=protected-access 235 op = gen_simple_hash_table_op.examples_simple_hash_table_import( 236 self.resource_handle, keys, values) 237 return op 238 239 240tf.no_gradient("Examples>SimpleHashTableCreate") 241tf.no_gradient("Examples>SimpleHashTableFind") 242tf.no_gradient("Examples>SimpleHashTableInsert") 243tf.no_gradient("Examples>SimpleHashTableRemove") 244