xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module contains the implementation of RNN cell wrappers."""
16import hashlib
17import numbers
18import sys
19import types as python_types
20import warnings
21
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.keras.utils import generic_utils
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops import tensor_array_ops
31from tensorflow.python.util import nest
32
33
34class DropoutWrapperBase(object):
35  """Operator adding dropout to inputs and outputs of the given cell."""
36
37  def __init__(self,
38               cell,
39               input_keep_prob=1.0,
40               output_keep_prob=1.0,
41               state_keep_prob=1.0,
42               variational_recurrent=False,
43               input_size=None,
44               dtype=None,
45               seed=None,
46               dropout_state_filter_visitor=None,
47               **kwargs):
48    """Create a cell with added input, state, and/or output dropout.
49
50    If `variational_recurrent` is set to `True` (**NOT** the default behavior),
51    then the same dropout mask is applied at every step, as described in:
52    [A Theoretically Grounded Application of Dropout in Recurrent
53    Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287).
54
55    Otherwise a different dropout mask is applied at every time step.
56
57    Note, by default (unless a custom `dropout_state_filter` is provided),
58    the memory state (`c` component of any `LSTMStateTuple`) passing through
59    a `DropoutWrapper` is never modified.  This behavior is described in the
60    above article.
61
62    Args:
63      cell: an RNNCell, a projection to output_size is added to it.
64      input_keep_prob: unit Tensor or float between 0 and 1, input keep
65        probability; if it is constant and 1, no input dropout will be added.
66      output_keep_prob: unit Tensor or float between 0 and 1, output keep
67        probability; if it is constant and 1, no output dropout will be added.
68      state_keep_prob: unit Tensor or float between 0 and 1, output keep
69        probability; if it is constant and 1, no output dropout will be added.
70        State dropout is performed on the outgoing states of the cell. **Note**
71        the state components to which dropout is applied when `state_keep_prob`
72        is in `(0, 1)` are also determined by the argument
73        `dropout_state_filter_visitor` (e.g. by default dropout is never applied
74        to the `c` component of an `LSTMStateTuple`).
75      variational_recurrent: Python bool.  If `True`, then the same dropout
76        pattern is applied across all time steps per run call. If this parameter
77        is set, `input_size` **must** be provided.
78      input_size: (optional) (possibly nested tuple of) `TensorShape` objects
79        containing the depth(s) of the input tensors expected to be passed in to
80        the `DropoutWrapper`.  Required and used **iff** `variational_recurrent
81        = True` and `input_keep_prob < 1`.
82      dtype: (optional) The `dtype` of the input, state, and output tensors.
83        Required and used **iff** `variational_recurrent = True`.
84      seed: (optional) integer, the randomness seed.
85      dropout_state_filter_visitor: (optional), default: (see below).  Function
86        that takes any hierarchical level of the state and returns a scalar or
87        depth=1 structure of Python booleans describing which terms in the state
88        should be dropped out.  In addition, if the function returns `True`,
89        dropout is applied across this sublevel.  If the function returns
90        `False`, dropout is not applied across this entire sublevel.
91        Default behavior: perform dropout on all terms except the memory (`c`)
92          state of `LSTMCellState` objects, and don't try to apply dropout to
93        `TensorArray` objects: ```
94        def dropout_state_filter_visitor(s):
95          if isinstance(s, LSTMCellState): # Never perform dropout on the c
96            state. return LSTMCellState(c=False, h=True)
97          elif isinstance(s, TensorArray): return False return True ```
98      **kwargs: dict of keyword arguments for base layer.
99
100    Raises:
101      TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided
102        but not `callable`.
103      ValueError: if any of the keep_probs are not between 0 and 1.
104    """
105    super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs)
106
107    if (dropout_state_filter_visitor is not None and
108        not callable(dropout_state_filter_visitor)):
109      raise TypeError("dropout_state_filter_visitor must be callable")
110    self._dropout_state_filter = (
111        dropout_state_filter_visitor or _default_dropout_state_filter_visitor)
112    with ops.name_scope_v2("DropoutWrapperInit"):
113
114      def tensor_and_const_value(v):
115        tensor_value = ops.convert_to_tensor_v2_with_dispatch(v)
116        const_value = tensor_util.constant_value(tensor_value)
117        return (tensor_value, const_value)
118
119      for prob, attr in [(input_keep_prob, "input_keep_prob"),
120                         (state_keep_prob, "state_keep_prob"),
121                         (output_keep_prob, "output_keep_prob")]:
122        tensor_prob, const_prob = tensor_and_const_value(prob)
123        if const_prob is not None:
124          if const_prob < 0 or const_prob > 1:
125            raise ValueError("Parameter %s must be between 0 and 1: %d" %
126                             (attr, const_prob))
127          setattr(self, "_%s" % attr, float(const_prob))
128        else:
129          setattr(self, "_%s" % attr, tensor_prob)
130
131    # Set variational_recurrent, seed before running the code below
132    self._variational_recurrent = variational_recurrent
133    self._input_size = input_size
134    self._seed = seed
135
136    self._recurrent_input_noise = None
137    self._recurrent_state_noise = None
138    self._recurrent_output_noise = None
139
140    if variational_recurrent:
141      if dtype is None:
142        raise ValueError(
143            "When variational_recurrent=True, dtype must be provided")
144
145      def convert_to_batch_shape(s):
146        # Prepend a 1 for the batch dimension; for recurrent
147        # variational dropout we use the same dropout mask for all
148        # batch elements.
149        return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0)
150
151      def batch_noise(s, inner_seed):
152        shape = convert_to_batch_shape(s)
153        return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
154
155      if (not isinstance(self._input_keep_prob, numbers.Real) or
156          self._input_keep_prob < 1.0):
157        if input_size is None:
158          raise ValueError(
159              "When variational_recurrent=True and input_keep_prob < 1.0 or "
160              "is unknown, input_size must be provided")
161        self._recurrent_input_noise = _enumerated_map_structure_up_to(
162            input_size,
163            lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
164            input_size)
165      self._recurrent_state_noise = _enumerated_map_structure_up_to(
166          cell.state_size,
167          lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
168          cell.state_size)
169      self._recurrent_output_noise = _enumerated_map_structure_up_to(
170          cell.output_size,
171          lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
172          cell.output_size)
173
174  def _gen_seed(self, salt_prefix, index):
175    if self._seed is None:
176      return None
177    salt = "%s_%d" % (salt_prefix, index)
178    string = (str(self._seed) + salt).encode("utf-8")
179    return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
180
181  @property
182  def wrapped_cell(self):
183    return self.cell
184
185  @property
186  def state_size(self):
187    return self.cell.state_size
188
189  @property
190  def output_size(self):
191    return self.cell.output_size
192
193  def build(self, inputs_shape):
194    self.cell.build(inputs_shape)
195    self.built = True
196
197  def zero_state(self, batch_size, dtype):
198    with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
199      return self.cell.zero_state(batch_size, dtype)
200
201  def _variational_recurrent_dropout_value(
202      self, unused_index, value, noise, keep_prob):
203    """Performs dropout given the pre-calculated noise tensor."""
204    # uniform [keep_prob, 1.0 + keep_prob)
205    random_tensor = keep_prob + noise
206
207    # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
208    binary_tensor = math_ops.floor(random_tensor)
209    ret = math_ops.divide(value, keep_prob) * binary_tensor
210    ret.set_shape(value.get_shape())
211    return ret
212
213  def _dropout(self,
214               values,
215               salt_prefix,
216               recurrent_noise,
217               keep_prob,
218               shallow_filtered_substructure=None):
219    """Decides whether to perform standard dropout or recurrent dropout."""
220
221    if shallow_filtered_substructure is None:
222      # Put something so we traverse the entire structure; inside the
223      # dropout function we check to see if leafs of this are bool or not.
224      shallow_filtered_substructure = values
225
226    if not self._variational_recurrent:
227
228      def dropout(i, do_dropout, v):
229        if not isinstance(do_dropout, bool) or do_dropout:
230          return nn_ops.dropout_v2(
231              v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i))
232        else:
233          return v
234
235      return _enumerated_map_structure_up_to(
236          shallow_filtered_substructure, dropout,
237          *[shallow_filtered_substructure, values])
238    else:
239
240      def dropout(i, do_dropout, v, n):
241        if not isinstance(do_dropout, bool) or do_dropout:
242          return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
243        else:
244          return v
245
246      return _enumerated_map_structure_up_to(
247          shallow_filtered_substructure, dropout,
248          *[shallow_filtered_substructure, values, recurrent_noise])
249
250  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
251    """Runs the wrapped cell and applies dropout.
252
253    Args:
254      inputs: A tensor with wrapped cell's input.
255      state: A tensor or tuple of tensors with wrapped cell's state.
256      cell_call_fn: Wrapped cell's method to use for step computation (cell's
257        `__call__` or 'call' method).
258      **kwargs: Additional arguments.
259
260    Returns:
261      A pair containing:
262
263      - Output: A tensor with cell's output.
264      - New state: A tensor or tuple of tensors with new wrapped cell's state.
265    """
266
267    def _should_dropout(p):
268      return (not isinstance(p, float)) or p < 1
269
270    if _should_dropout(self._input_keep_prob):
271      inputs = self._dropout(inputs, "input", self._recurrent_input_noise,
272                             self._input_keep_prob)
273    output, new_state = cell_call_fn(inputs, state, **kwargs)
274    if _should_dropout(self._state_keep_prob):
275      # Identify which subsets of the state to perform dropout on and
276      # which ones to keep.
277      shallow_filtered_substructure = nest.get_traverse_shallow_structure(
278          self._dropout_state_filter, new_state)
279      new_state = self._dropout(new_state, "state", self._recurrent_state_noise,
280                                self._state_keep_prob,
281                                shallow_filtered_substructure)
282    if _should_dropout(self._output_keep_prob):
283      output = self._dropout(output, "output", self._recurrent_output_noise,
284                             self._output_keep_prob)
285    return output, new_state
286
287  def get_config(self):
288    """Returns the config of the dropout wrapper."""
289    config = {
290        "input_keep_prob": self._input_keep_prob,
291        "output_keep_prob": self._output_keep_prob,
292        "state_keep_prob": self._state_keep_prob,
293        "variational_recurrent": self._variational_recurrent,
294        "input_size": self._input_size,
295        "seed": self._seed,
296    }
297    if self._dropout_state_filter != _default_dropout_state_filter_visitor:
298      function, function_type, function_module = _serialize_function_to_config(
299          self._dropout_state_filter)
300      config.update({"dropout_fn": function,
301                     "dropout_fn_type": function_type,
302                     "dropout_fn_module": function_module})
303    base_config = super(DropoutWrapperBase, self).get_config()
304    return dict(list(base_config.items()) + list(config.items()))
305
306  @classmethod
307  def from_config(cls, config, custom_objects=None):
308    if "dropout_fn" in config:
309      config = config.copy()
310      dropout_state_filter = _parse_config_to_function(
311          config, custom_objects, "dropout_fn", "dropout_fn_type",
312          "dropout_fn_module")
313      config.pop("dropout_fn")
314      config["dropout_state_filter_visitor"] = dropout_state_filter
315    return super(DropoutWrapperBase, cls).from_config(
316        config, custom_objects=custom_objects)
317
318
319class ResidualWrapperBase(object):
320  """RNNCell wrapper that ensures cell inputs are added to the outputs."""
321
322  def __init__(self, cell, residual_fn=None, **kwargs):
323    """Constructs a `ResidualWrapper` for `cell`.
324
325    Args:
326      cell: An instance of `RNNCell`.
327      residual_fn: (Optional) The function to map raw cell inputs and raw cell
328        outputs to the actual cell outputs of the residual network.
329        Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
330          and outputs.
331      **kwargs: dict of keyword arguments for base layer.
332    """
333    super(ResidualWrapperBase, self).__init__(cell, **kwargs)
334    self._residual_fn = residual_fn
335
336  @property
337  def state_size(self):
338    return self.cell.state_size
339
340  @property
341  def output_size(self):
342    return self.cell.output_size
343
344  def zero_state(self, batch_size, dtype):
345    with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
346      return self.cell.zero_state(batch_size, dtype)
347
348  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
349    """Run the cell and then apply the residual_fn on its inputs to its outputs.
350
351    Args:
352      inputs: cell inputs.
353      state: cell state.
354      cell_call_fn: Wrapped cell's method to use for step computation (cell's
355        `__call__` or 'call' method).
356      **kwargs: Additional arguments passed to the wrapped cell's `call`.
357
358    Returns:
359      Tuple of cell outputs and new state.
360
361    Raises:
362      TypeError: If cell inputs and outputs have different structure (type).
363      ValueError: If cell inputs and outputs have different structure (value).
364    """
365    outputs, new_state = cell_call_fn(inputs, state, **kwargs)
366
367    # Ensure shapes match
368    def assert_shape_match(inp, out):
369      inp.get_shape().assert_is_compatible_with(out.get_shape())
370
371    def default_residual_fn(inputs, outputs):
372      nest.assert_same_structure(inputs, outputs)
373      nest.map_structure(assert_shape_match, inputs, outputs)
374      return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
375
376    res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
377    return (res_outputs, new_state)
378
379  def get_config(self):
380    """Returns the config of the residual wrapper."""
381    if self._residual_fn is not None:
382      function, function_type, function_module = _serialize_function_to_config(
383          self._residual_fn)
384      config = {
385          "residual_fn": function,
386          "residual_fn_type": function_type,
387          "residual_fn_module": function_module
388      }
389    else:
390      config = {}
391    base_config = super(ResidualWrapperBase, self).get_config()
392    return dict(list(base_config.items()) + list(config.items()))
393
394  @classmethod
395  def from_config(cls, config, custom_objects=None):
396    if "residual_fn" in config:
397      config = config.copy()
398      residual_function = _parse_config_to_function(config, custom_objects,
399                                                    "residual_fn",
400                                                    "residual_fn_type",
401                                                    "residual_fn_module")
402      config["residual_fn"] = residual_function
403    return super(ResidualWrapperBase, cls).from_config(
404        config, custom_objects=custom_objects)
405
406
407class DeviceWrapperBase(object):
408  """Operator that ensures an RNNCell runs on a particular device."""
409
410  def __init__(self, cell, device, **kwargs):
411    """Construct a `DeviceWrapper` for `cell` with device `device`.
412
413    Ensures the wrapped `cell` is called with `tf.device(device)`.
414
415    Args:
416      cell: An instance of `RNNCell`.
417      device: A device string or function, for passing to `tf.device`.
418      **kwargs: dict of keyword arguments for base layer.
419    """
420    super(DeviceWrapperBase, self).__init__(cell, **kwargs)
421    self._device = device
422
423  @property
424  def state_size(self):
425    return self.cell.state_size
426
427  @property
428  def output_size(self):
429    return self.cell.output_size
430
431  def zero_state(self, batch_size, dtype):
432    with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
433      with ops.device(self._device):
434        return self.cell.zero_state(batch_size, dtype)
435
436  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
437    """Run the cell on specified device."""
438    with ops.device(self._device):
439      return cell_call_fn(inputs, state, **kwargs)
440
441  def get_config(self):
442    config = {"device": self._device}
443    base_config = super(DeviceWrapperBase, self).get_config()
444    return dict(list(base_config.items()) + list(config.items()))
445
446
447def _serialize_function_to_config(function):
448  """Serialize the function for get_config()."""
449  if isinstance(function, python_types.LambdaType):
450    output = generic_utils.func_dump(function)
451    output_type = "lambda"
452    module = function.__module__
453  elif callable(function):
454    output = function.__name__
455    output_type = "function"
456    module = function.__module__
457  else:
458    raise ValueError("Unrecognized function type for input: {}".format(
459        type(function)))
460
461  return output, output_type, module
462
463
464def _parse_config_to_function(config, custom_objects, func_attr_name,
465                              func_type_attr_name, module_attr_name):
466  """Reconstruct the function from the config."""
467  globs = globals()
468  module = config.pop(module_attr_name, None)
469  if module in sys.modules:
470    globs.update(sys.modules[module].__dict__)
471  elif module is not None:
472    # Note: we don't know the name of the function if it's a lambda.
473    warnings.warn("{} is not loaded, but a layer uses it. "
474                  "It may cause errors.".format(module), UserWarning)
475  if custom_objects:
476    globs.update(custom_objects)
477  function_type = config.pop(func_type_attr_name)
478  if function_type == "function":
479    # Simple lookup in custom objects
480    function = generic_utils.deserialize_keras_object(
481        config[func_attr_name],
482        custom_objects=custom_objects,
483        printable_module_name="function in wrapper")
484  elif function_type == "lambda":
485    # Unsafe deserialization from bytecode
486    function = generic_utils.func_load(
487        config[func_attr_name], globs=globs)
488  else:
489    raise TypeError("Unknown function type:", function_type)
490  return function
491
492
493def _default_dropout_state_filter_visitor(substate):
494  from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple  # pylint: disable=g-import-not-at-top
495  if isinstance(substate, LSTMStateTuple):
496    # Do not perform dropout on the memory state.
497    return LSTMStateTuple(c=False, h=True)
498  elif isinstance(substate, tensor_array_ops.TensorArray):
499    return False
500  return True
501
502
503def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
504  ix = [0]
505
506  def enumerated_fn(*inner_args, **inner_kwargs):
507    r = map_fn(ix[0], *inner_args, **inner_kwargs)
508    ix[0] += 1
509    return r
510
511  return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args,
512                                  **kwargs)
513