1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A wrapper for gen_simple_hash_table_op.py.
16
17This defines a public API and provides a docstring for the C++ Op defined by
18simple_hash_table_kernel.cc
19"""
20
21import tensorflow as tf
22from tensorflow.examples.custom_ops_doc.simple_hash_table.simple_hash_table_op import gen_simple_hash_table_op
23
24
25class SimpleHashTable(tf.saved_model.experimental.TrackableResource):
26  """A simple mutable hash table implementation.
27
28  Implement a simple hash table as a Resource using ref-counting.
29  This demonstrates a Stateful Op for a general Create/Read/Update/Delete
30  (CRUD) style use case.  To instead make an op for a specific lookup table
31  case, it is preferable to follow the implementation style of
32  TensorFlow's internal ops, e.g. use LookupInterface.
33
34  Data can be inserted by calling the `insert` method and removed by calling
35  the `remove` method. It does not support initialization via the init method.
36
37  The `import` and `export` methods allow loading and restoring all of
38  the key, value pairs. These methods (or their corresponding kernels)
39  are intended to be used for supporting SavedModel.
40
41  Example usage:
42    hash_table = simple_hash_table_op.SimpleHashTable(key_dtype, value_dtype,
43                                                      111)
44    result1 = hash_table.find(1, -999)  # -999
45    hash_table.insert(1, 100)
46    result2 = hash_table.find(1, -999)  # 100
47    hash_table.remove(1)
48    result3 = hash_table.find(1, -999)  # -999
49  """
50
51  def __init__(self,
52               key_dtype,
53               value_dtype,
54               default_value,
55               name="SimpleHashTable"):
56    """Creates an empty `SimpleHashTable` object.
57
58    Creates a table, the type of its keys and values are specified by key_dtype
59    and value_dtype, respectively.
60
61    Args:
62      key_dtype: the type of the key tensors.
63      value_dtype: the type of the value tensors.
64      default_value: The value to use if a key is missing in the table.
65      name: A name for the operation (optional).
66
67    Returns:
68      A `SimpleHashTable` object.
69    """
70    super(SimpleHashTable, self).__init__()
71    self._default_value = tf.convert_to_tensor(default_value, dtype=value_dtype)
72    self._value_shape = self._default_value.get_shape()
73    self._key_dtype = key_dtype
74    self._value_dtype = value_dtype
75    self._name = name
76    self._resource_handle = self._create_resource()
77    # Methods that use the Resource get its handle using the
78    # public self.resource_handle property (defined by TrackableResource).
79    # This property calls self._create_resource() the first time
80    # if private self._resource_handle is not preemptively initialized.
81
82  def _create_resource(self):
83    """Create the resource tensor handle.
84
85    `_create_resource` is an override of a method in base class
86    `TrackableResource` that is required for SavedModel support. It can be
87    called by the `resource_handle` property defined by `TrackableResource`.
88
89    Returns:
90      A tensor handle to the lookup table.
91    """
92    assert self._default_value.get_shape().ndims == 0
93    table_ref = gen_simple_hash_table_op.examples_simple_hash_table_create(
94        key_dtype=self._key_dtype,
95        value_dtype=self._value_dtype,
96        name=self._name)
97    return table_ref
98
99  def _serialize_to_tensors(self):
100    """Implements checkpointing protocols for `Trackable`."""
101    tensors = self.export()
102    return {"table-keys": tensors[0], "table-values": tensors[1]}
103
104  def _restore_from_tensors(self, restored_tensors):
105    """Implements checkpointing protocols for `Trackable`."""
106    return gen_simple_hash_table_op.examples_simple_hash_table_import(
107        self.resource_handle, restored_tensors["table-keys"],
108        restored_tensors["table-values"])
109
110  @property
111  def key_dtype(self):
112    """The table key dtype."""
113    return self._key_dtype
114
115  @property
116  def value_dtype(self):
117    """The table value dtype."""
118    return self._value_dtype
119
120  def find(self, key, dynamic_default_value=None, name=None):
121    """Looks up `key` in a table, outputs the corresponding value.
122
123    The `default_value` is used if key not present in the table.
124
125    Args:
126      key: Key to look up. Must match the table's key_dtype.
127      dynamic_default_value: The value to use if the key is missing in the
128        table. If None (by default), the `table.default_value` will be used.
129      name: A name for the operation (optional).
130
131    Returns:
132      A tensor containing the value in the same shape as `key` using the
133        table's value type.
134
135    Raises:
136      TypeError: when `key` do not match the table data types.
137    """
138    with tf.name_scope(name or "%s_lookup_table_find" % self._name):
139      key = tf.convert_to_tensor(key, dtype=self._key_dtype, name="key")
140      if dynamic_default_value is not None:
141        dynamic_default_value = tf.convert_to_tensor(
142            dynamic_default_value,
143            dtype=self._value_dtype,
144            name="default_value")
145      value = gen_simple_hash_table_op.examples_simple_hash_table_find(
146          self.resource_handle, key, dynamic_default_value
147          if dynamic_default_value is not None else self._default_value)
148    return value
149
150  def insert(self, key, value, name=None):
151    """Associates `key` with `value`.
152
153    Args:
154      key: Scalar key to insert.
155      value: Scalar value to be associated with key.
156      name: A name for the operation (optional).
157
158    Returns:
159      The created Operation.
160
161    Raises:
162      TypeError: when `key` or `value` doesn't match the table data
163        types.
164    """
165    with tf.name_scope(name or "%s_lookup_table_insert" % self._name):
166      key = tf.convert_to_tensor(key, self._key_dtype, name="key")
167      value = tf.convert_to_tensor(value, self._value_dtype, name="value")
168      # pylint: disable=protected-access
169      op = gen_simple_hash_table_op.examples_simple_hash_table_insert(
170          self.resource_handle, key, value)
171      return op
172
173  def remove(self, key, name=None):
174    """Remove `key`.
175
176    Args:
177      key: Scalar key to remove.
178      name: A name for the operation (optional).
179
180    Returns:
181      The created Operation.
182
183    Raises:
184      TypeError: when `key` doesn't match the table data type.
185    """
186    with tf.name_scope(name or "%s_lookup_table_remove" % self._name):
187      key = tf.convert_to_tensor(key, self._key_dtype, name="key")
188
189      # For remove, just the key is used by the kernel; no value is used.
190      # But the kernel is specifc to key_dtype and value_dtype
191      # (i.e. it uses a <key_dtype, value_dtype> template).
192      # So value_dtype is passed in explicitly. (While
193      # key_dtype is specificed implicitly by the dtype of key.)
194
195      # pylint: disable=protected-access
196      op = gen_simple_hash_table_op.examples_simple_hash_table_remove(
197          self.resource_handle, key, value_dtype=self._value_dtype)
198      return op
199
200  def export(self, name=None):
201    """Export all `key` and `value` pairs.
202
203    Args:
204      name: A name for the operation (optional).
205
206    Returns:
207      A tuple of two tensors, the first with the `keys` and the second with
208      the `values`.
209    """
210    with tf.name_scope(name or "%s_lookup_table_export" % self._name):
211      # pylint: disable=protected-access
212      keys, values = gen_simple_hash_table_op.examples_simple_hash_table_export(
213          self.resource_handle,
214          key_dtype=self._key_dtype,
215          value_dtype=self._value_dtype)
216      return keys, values
217
218  def do_import(self, keys, values, name=None):
219    """Import all `key` and `value` pairs.
220
221    (Note that "import" is a python reserved word, so it cannot be the name of
222    a method.)
223
224    Args:
225      keys: Tensor of all keys.
226      values: Tensor of all values.
227      name: A name for the operation (optional).
228
229    Returns:
230      A tuple of two tensors, the first with the `keys` and the second with
231      the `values`.
232    """
233    with tf.name_scope(name or "%s_lookup_table_import" % self._name):
234      # pylint: disable=protected-access
235      op = gen_simple_hash_table_op.examples_simple_hash_table_import(
236          self.resource_handle, keys, values)
237      return op
238
239
240tf.no_gradient("Examples>SimpleHashTableCreate")
241tf.no_gradient("Examples>SimpleHashTableFind")
242tf.no_gradient("Examples>SimpleHashTableInsert")
243tf.no_gradient("Examples>SimpleHashTableRemove")
244