xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Module implementing RNN Cells.
17
18This module provides a number of basic commonly used RNN cells, such as LSTM
19(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
20operators that allow adding dropouts, projections, or embeddings for inputs.
21Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
22calling the `rnn` ops several times.
23"""
24import collections
25import warnings
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import config as tf_config
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.keras import activations
35from tensorflow.python.keras import backend
36from tensorflow.python.keras import initializers
37from tensorflow.python.keras.engine import base_layer_utils
38from tensorflow.python.keras.engine import input_spec
39from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl
40from tensorflow.python.keras.legacy_tf_layers import base as base_layer
41from tensorflow.python.keras.utils import tf_utils
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import clip_ops
44from tensorflow.python.ops import init_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import nn_ops
47from tensorflow.python.ops import partitioned_variables
48from tensorflow.python.ops import variable_scope as vs
49from tensorflow.python.ops import variables as tf_variables
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.trackable import base as trackable
52from tensorflow.python.util import nest
53from tensorflow.python.util.tf_export import keras_export
54from tensorflow.python.util.tf_export import tf_export
55
56_BIAS_VARIABLE_NAME = "bias"
57_WEIGHTS_VARIABLE_NAME = "kernel"
58
59# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
60ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
61
62
63def _hasattr(obj, attr_name):
64  try:
65    getattr(obj, attr_name)
66  except AttributeError:
67    return False
68  else:
69    return True
70
71
72def assert_like_rnncell(cell_name, cell):
73  """Raises a TypeError if cell is not like an RNNCell.
74
75  NOTE: Do not rely on the error message (in particular in tests) which can be
76  subject to change to increase readability. Use
77  ASSERT_LIKE_RNNCELL_ERROR_REGEXP.
78
79  Args:
80    cell_name: A string to give a meaningful error referencing to the name of
81      the functionargument.
82    cell: The object which should behave like an RNNCell.
83
84  Raises:
85    TypeError: A human-friendly exception.
86  """
87  conditions = [
88      _hasattr(cell, "output_size"),
89      _hasattr(cell, "state_size"),
90      _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"),
91      callable(cell),
92  ]
93  errors = [
94      "'output_size' property is missing", "'state_size' property is missing",
95      "either 'zero_state' or 'get_initial_state' method is required",
96      "is not callable"
97  ]
98
99  if not all(conditions):
100
101    errors = [error for error, cond in zip(errors, conditions) if not cond]
102    raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format(
103        cell_name, cell, ", ".join(errors)))
104
105
106def _concat(prefix, suffix, static=False):
107  """Concat that enables int, Tensor, or TensorShape values.
108
109  This function takes a size specification, which can be an integer, a
110  TensorShape, or a Tensor, and converts it into a concatenated Tensor
111  (if static = False) or a list of integers (if static = True).
112
113  Args:
114    prefix: The prefix; usually the batch size (and/or time step size).
115      (TensorShape, int, or Tensor.)
116    suffix: TensorShape, int, or Tensor.
117    static: If `True`, return a python list with possibly unknown dimensions.
118      Otherwise return a `Tensor`.
119
120  Returns:
121    shape: the concatenation of prefix and suffix.
122
123  Raises:
124    ValueError: if `suffix` is not a scalar or vector (or TensorShape).
125    ValueError: if prefix or suffix was `None` and asked for dynamic
126      Tensors out.
127  """
128  if isinstance(prefix, ops.Tensor):
129    p = prefix
130    p_static = tensor_util.constant_value(prefix)
131    if p.shape.ndims == 0:
132      p = array_ops.expand_dims(p, 0)
133    elif p.shape.ndims != 1:
134      raise ValueError("prefix tensor must be either a scalar or vector, "
135                       "but saw tensor: %s" % p)
136  else:
137    p = tensor_shape.TensorShape(prefix)
138    p_static = p.as_list() if p.ndims is not None else None
139    p = (
140        constant_op.constant(p.as_list(), dtype=dtypes.int32)
141        if p.is_fully_defined() else None)
142  if isinstance(suffix, ops.Tensor):
143    s = suffix
144    s_static = tensor_util.constant_value(suffix)
145    if s.shape.ndims == 0:
146      s = array_ops.expand_dims(s, 0)
147    elif s.shape.ndims != 1:
148      raise ValueError("suffix tensor must be either a scalar or vector, "
149                       "but saw tensor: %s" % s)
150  else:
151    s = tensor_shape.TensorShape(suffix)
152    s_static = s.as_list() if s.ndims is not None else None
153    s = (
154        constant_op.constant(s.as_list(), dtype=dtypes.int32)
155        if s.is_fully_defined() else None)
156
157  if static:
158    shape = tensor_shape.TensorShape(p_static).concatenate(s_static)
159    shape = shape.as_list() if shape.ndims is not None else None
160  else:
161    if p is None or s is None:
162      raise ValueError("Provided a prefix or suffix of None: %s and %s" %
163                       (prefix, suffix))
164    shape = array_ops.concat((p, s), 0)
165  return shape
166
167
168def _zero_state_tensors(state_size, batch_size, dtype):
169  """Create tensors of zeros based on state_size, batch_size, and dtype."""
170
171  def get_state_shape(s):
172    """Combine s with batch_size to get a proper tensor shape."""
173    c = _concat(batch_size, s)
174    size = array_ops.zeros(c, dtype=dtype)
175    if not context.executing_eagerly():
176      c_static = _concat(batch_size, s, static=True)
177      size.set_shape(c_static)
178    return size
179
180  return nest.map_structure(get_state_shape, state_size)
181
182
183@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"])
184@tf_export(v1=["nn.rnn_cell.RNNCell"])
185class RNNCell(base_layer.Layer):
186  """Abstract object representing an RNN cell.
187
188  Every `RNNCell` must have the properties below and implement `call` with
189  the signature `(output, next_state) = call(input, state)`.  The optional
190  third input argument, `scope`, is allowed for backwards compatibility
191  purposes; but should be left off for new subclasses.
192
193  This definition of cell differs from the definition used in the literature.
194  In the literature, 'cell' refers to an object with a single scalar output.
195  This definition refers to a horizontal array of such units.
196
197  An RNN cell, in the most abstract setting, is anything that has
198  a state and performs some operation that takes a matrix of inputs.
199  This operation results in an output matrix with `self.output_size` columns.
200  If `self.state_size` is an integer, this operation also results in a new
201  state matrix with `self.state_size` columns.  If `self.state_size` is a
202  (possibly nested tuple of) TensorShape object(s), then it should return a
203  matching structure of Tensors having shape `[batch_size].concatenate(s)`
204  for each `s` in `self.batch_size`.
205  """
206
207  def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
208    super(RNNCell, self).__init__(
209        trainable=trainable, name=name, dtype=dtype, **kwargs)
210    # Attribute that indicates whether the cell is a TF RNN cell, due the slight
211    # difference between TF and Keras RNN cell. Notably the state is not wrapped
212    # in a list for TF cell where they are single tensor state, whereas keras
213    # cell will wrap the state into a list, and call() will have to unwrap them.
214    self._is_tf_rnn_cell = True
215
216  def __call__(self, inputs, state, scope=None):
217    """Run this RNN cell on inputs, starting from the given state.
218
219    Args:
220      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
221      state: if `self.state_size` is an integer, this should be a `2-D Tensor`
222        with shape `[batch_size, self.state_size]`.  Otherwise, if
223        `self.state_size` is a tuple of integers, this should be a tuple with
224        shapes `[batch_size, s] for s in self.state_size`.
225      scope: VariableScope for the created subgraph; defaults to class name.
226
227    Returns:
228      A pair containing:
229
230      - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
231      - New state: Either a single `2-D` tensor, or a tuple of tensors matching
232        the arity and shapes of `state`.
233    """
234    if scope is not None:
235      with vs.variable_scope(
236          scope, custom_getter=self._rnn_get_variable) as scope:
237        return super(RNNCell, self).__call__(inputs, state, scope=scope)
238    else:
239      scope_attrname = "rnncell_scope"
240      scope = getattr(self, scope_attrname, None)
241      if scope is None:
242        scope = vs.variable_scope(
243            vs.get_variable_scope(), custom_getter=self._rnn_get_variable)
244        setattr(self, scope_attrname, scope)
245      with scope:
246        return super(RNNCell, self).__call__(inputs, state)
247
248  def _rnn_get_variable(self, getter, *args, **kwargs):
249    variable = getter(*args, **kwargs)
250    if ops.executing_eagerly_outside_functions():
251      trainable = variable.trainable
252    else:
253      trainable = (
254          variable in tf_variables.trainable_variables() or
255          (base_layer_utils.is_split_variable(variable) and
256           list(variable)[0] in tf_variables.trainable_variables()))
257    if trainable and all(variable is not v for v in self._trainable_weights):
258      self._trainable_weights.append(variable)
259    elif not trainable and all(
260        variable is not v for v in self._non_trainable_weights):
261      self._non_trainable_weights.append(variable)
262    return variable
263
264  @property
265  def state_size(self):
266    """size(s) of state(s) used by this cell.
267
268    It can be represented by an Integer, a TensorShape or a tuple of Integers
269    or TensorShapes.
270    """
271    raise NotImplementedError("Abstract method")
272
273  @property
274  def output_size(self):
275    """Integer or TensorShape: size of outputs produced by this cell."""
276    raise NotImplementedError("Abstract method")
277
278  def build(self, _):
279    # This tells the parent Layer object that it's OK to call
280    # self.add_variable() inside the call() method.
281    pass
282
283  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
284    if inputs is not None:
285      # Validate the given batch_size and dtype against inputs if provided.
286      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, name="inputs")
287      if batch_size is not None:
288        if tensor_util.is_tf_type(batch_size):
289          static_batch_size = tensor_util.constant_value(
290              batch_size, partial=True)
291        else:
292          static_batch_size = batch_size
293        if inputs.shape.dims[0].value != static_batch_size:
294          raise ValueError(
295              "batch size from input tensor is different from the "
296              "input param. Input tensor batch: {}, batch_size: {}".format(
297                  inputs.shape.dims[0].value, batch_size))
298
299      if dtype is not None and inputs.dtype != dtype:
300        raise ValueError(
301            "dtype from input tensor is different from the "
302            "input param. Input tensor dtype: {}, dtype: {}".format(
303                inputs.dtype, dtype))
304
305      batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0]
306      dtype = inputs.dtype
307    if batch_size is None or dtype is None:
308      raise ValueError(
309          "batch_size and dtype cannot be None while constructing initial "
310          "state: batch_size={}, dtype={}".format(batch_size, dtype))
311    return self.zero_state(batch_size, dtype)
312
313  def zero_state(self, batch_size, dtype):
314    """Return zero-filled state tensor(s).
315
316    Args:
317      batch_size: int, float, or unit Tensor representing the batch size.
318      dtype: the data type to use for the state.
319
320    Returns:
321      If `state_size` is an int or TensorShape, then the return value is a
322      `N-D` tensor of shape `[batch_size, state_size]` filled with zeros.
323
324      If `state_size` is a nested list or tuple, then the return value is
325      a nested list or tuple (of the same structure) of `2-D` tensors with
326      the shapes `[batch_size, s]` for each s in `state_size`.
327    """
328    # Try to use the last cached zero_state. This is done to avoid recreating
329    # zeros, especially when eager execution is enabled.
330    state_size = self.state_size
331    is_eager = context.executing_eagerly()
332    if is_eager and _hasattr(self, "_last_zero_state"):
333      (last_state_size, last_batch_size, last_dtype,
334       last_output) = getattr(self, "_last_zero_state")
335      if (last_batch_size == batch_size and last_dtype == dtype and
336          last_state_size == state_size):
337        return last_output
338    with backend.name_scope(type(self).__name__ + "ZeroState"):
339      output = _zero_state_tensors(state_size, batch_size, dtype)
340    if is_eager:
341      self._last_zero_state = (state_size, batch_size, dtype, output)
342    return output
343
344  # TODO(b/134773139): Remove when contrib RNN cells implement `get_config`
345  def get_config(self):  # pylint: disable=useless-super-delegation
346    return super(RNNCell, self).get_config()
347
348  @property
349  def _use_input_spec_as_call_signature(self):
350    # We do not store the shape information for the state argument in the call
351    # function for legacy RNN cells, so do not generate an input signature.
352    return False
353
354
355class LayerRNNCell(RNNCell):
356  """Subclass of RNNCells that act like proper `tf.Layer` objects.
357
358  For backwards compatibility purposes, most `RNNCell` instances allow their
359  `call` methods to instantiate variables via `tf.compat.v1.get_variable`.  The
360  underlying
361  variable scope thus keeps track of any variables, and returning cached
362  versions.  This is atypical of `tf.layer` objects, which separate this
363  part of layer building into a `build` method that is only called once.
364
365  Here we provide a subclass for `RNNCell` objects that act exactly as
366  `Layer` objects do.  They must provide a `build` method and their
367  `call` methods do not access Variables `tf.compat.v1.get_variable`.
368  """
369
370  def __call__(self, inputs, state, scope=None, *args, **kwargs):
371    """Run this RNN cell on inputs, starting from the given state.
372
373    Args:
374      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
375      state: if `self.state_size` is an integer, this should be a `2-D Tensor`
376        with shape `[batch_size, self.state_size]`.  Otherwise, if
377        `self.state_size` is a tuple of integers, this should be a tuple with
378        shapes `[batch_size, s] for s in self.state_size`.
379      scope: optional cell scope.
380      *args: Additional positional arguments.
381      **kwargs: Additional keyword arguments.
382
383    Returns:
384      A pair containing:
385
386      - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
387      - New state: Either a single `2-D` tensor, or a tuple of tensors matching
388        the arity and shapes of `state`.
389    """
390    # Bypass RNNCell's variable capturing semantics for LayerRNNCell.
391    # Instead, it is up to subclasses to provide a proper build
392    # method.  See the class docstring for more details.
393    return base_layer.Layer.__call__(
394        self, inputs, state, scope=scope, *args, **kwargs)
395
396
397@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"])
398@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
399class BasicRNNCell(LayerRNNCell):
400  """The most basic RNN cell.
401
402  Note that this cell is not optimized for performance. Please use
403  `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
404
405  Args:
406    num_units: int, The number of units in the RNN cell.
407    activation: Nonlinearity to use.  Default: `tanh`. It could also be string
408      that is within Keras activation function names.
409    reuse: (optional) Python boolean describing whether to reuse variables in an
410      existing scope.  If not `True`, and the existing scope already has the
411      given variables, an error is raised.
412    name: String, the name of the layer. Layers with the same name will share
413      weights, but to avoid mistakes we require reuse=True in such cases.
414    dtype: Default dtype of the layer (default of `None` means use the type of
415      the first input). Required when `build` is called before `call`.
416    **kwargs: Dict, keyword named properties for common layer attributes, like
417      `trainable` etc when constructing the cell from configs of get_config().
418  """
419
420  def __init__(self,
421               num_units,
422               activation=None,
423               reuse=None,
424               name=None,
425               dtype=None,
426               **kwargs):
427    warnings.warn("`tf.nn.rnn_cell.BasicRNNCell` is deprecated and will be "
428                  "removed in a future version. This class "
429                  "is equivalent as `tf.keras.layers.SimpleRNNCell`, "
430                  "and will be replaced by that in Tensorflow 2.0.")
431    super(BasicRNNCell, self).__init__(
432        _reuse=reuse, name=name, dtype=dtype, **kwargs)
433    _check_supported_dtypes(self.dtype)
434    if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
435      logging.warning(
436          "%s: Note that this cell is not optimized for performance. "
437          "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better "
438          "performance on GPU.", self)
439
440    # Inputs must be 2-dimensional.
441    self.input_spec = input_spec.InputSpec(ndim=2)
442
443    self._num_units = num_units
444    if activation:
445      self._activation = activations.get(activation)
446    else:
447      self._activation = math_ops.tanh
448
449  @property
450  def state_size(self):
451    return self._num_units
452
453  @property
454  def output_size(self):
455    return self._num_units
456
457  @tf_utils.shape_type_conversion
458  def build(self, inputs_shape):
459    if inputs_shape[-1] is None:
460      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
461                       str(inputs_shape))
462    _check_supported_dtypes(self.dtype)
463
464    input_depth = inputs_shape[-1]
465    self._kernel = self.add_variable(
466        _WEIGHTS_VARIABLE_NAME,
467        shape=[input_depth + self._num_units, self._num_units])
468    self._bias = self.add_variable(
469        _BIAS_VARIABLE_NAME,
470        shape=[self._num_units],
471        initializer=init_ops.zeros_initializer(dtype=self.dtype))
472
473    self.built = True
474
475  def call(self, inputs, state):
476    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
477    _check_rnn_cell_input_dtypes([inputs, state])
478    gate_inputs = math_ops.matmul(
479        array_ops.concat([inputs, state], 1), self._kernel)
480    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
481    output = self._activation(gate_inputs)
482    return output, output
483
484  def get_config(self):
485    config = {
486        "num_units": self._num_units,
487        "activation": activations.serialize(self._activation),
488        "reuse": self._reuse,
489    }
490    base_config = super(BasicRNNCell, self).get_config()
491    return dict(list(base_config.items()) + list(config.items()))
492
493
494@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"])
495@tf_export(v1=["nn.rnn_cell.GRUCell"])
496class GRUCell(LayerRNNCell):
497  """Gated Recurrent Unit cell.
498
499  Note that this cell is not optimized for performance. Please use
500  `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
501  `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU.
502
503  Args:
504    num_units: int, The number of units in the GRU cell.
505    activation: Nonlinearity to use.  Default: `tanh`.
506    reuse: (optional) Python boolean describing whether to reuse variables in an
507      existing scope.  If not `True`, and the existing scope already has the
508      given variables, an error is raised.
509    kernel_initializer: (optional) The initializer to use for the weight and
510      projection matrices.
511    bias_initializer: (optional) The initializer to use for the bias.
512    name: String, the name of the layer. Layers with the same name will share
513      weights, but to avoid mistakes we require reuse=True in such cases.
514    dtype: Default dtype of the layer (default of `None` means use the type of
515      the first input). Required when `build` is called before `call`.
516    **kwargs: Dict, keyword named properties for common layer attributes, like
517      `trainable` etc when constructing the cell from configs of get_config().
518
519      References:
520    Learning Phrase Representations using RNN Encoder Decoder for Statistical
521    Machine Translation:
522      [Cho et al., 2014]
523      (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179)
524      ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf))
525  """
526
527  def __init__(self,
528               num_units,
529               activation=None,
530               reuse=None,
531               kernel_initializer=None,
532               bias_initializer=None,
533               name=None,
534               dtype=None,
535               **kwargs):
536    warnings.warn("`tf.nn.rnn_cell.GRUCell` is deprecated and will be removed "
537                  "in a future version. This class "
538                  "is equivalent as `tf.keras.layers.GRUCell`, "
539                  "and will be replaced by that in Tensorflow 2.0.")
540    super(GRUCell, self).__init__(
541        _reuse=reuse, name=name, dtype=dtype, **kwargs)
542    _check_supported_dtypes(self.dtype)
543
544    if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
545      logging.warning(
546          "%s: Note that this cell is not optimized for performance. "
547          "Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
548          "performance on GPU.", self)
549    # Inputs must be 2-dimensional.
550    self.input_spec = input_spec.InputSpec(ndim=2)
551
552    self._num_units = num_units
553    if activation:
554      self._activation = activations.get(activation)
555    else:
556      self._activation = math_ops.tanh
557    self._kernel_initializer = initializers.get(kernel_initializer)
558    self._bias_initializer = initializers.get(bias_initializer)
559
560  @property
561  def state_size(self):
562    return self._num_units
563
564  @property
565  def output_size(self):
566    return self._num_units
567
568  @tf_utils.shape_type_conversion
569  def build(self, inputs_shape):
570    if inputs_shape[-1] is None:
571      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
572                       str(inputs_shape))
573    _check_supported_dtypes(self.dtype)
574    input_depth = inputs_shape[-1]
575    self._gate_kernel = self.add_variable(
576        "gates/%s" % _WEIGHTS_VARIABLE_NAME,
577        shape=[input_depth + self._num_units, 2 * self._num_units],
578        initializer=self._kernel_initializer)
579    self._gate_bias = self.add_variable(
580        "gates/%s" % _BIAS_VARIABLE_NAME,
581        shape=[2 * self._num_units],
582        initializer=(self._bias_initializer
583                     if self._bias_initializer is not None else
584                     init_ops.constant_initializer(1.0, dtype=self.dtype)))
585    self._candidate_kernel = self.add_variable(
586        "candidate/%s" % _WEIGHTS_VARIABLE_NAME,
587        shape=[input_depth + self._num_units, self._num_units],
588        initializer=self._kernel_initializer)
589    self._candidate_bias = self.add_variable(
590        "candidate/%s" % _BIAS_VARIABLE_NAME,
591        shape=[self._num_units],
592        initializer=(self._bias_initializer
593                     if self._bias_initializer is not None else
594                     init_ops.zeros_initializer(dtype=self.dtype)))
595
596    self.built = True
597
598  def call(self, inputs, state):
599    """Gated recurrent unit (GRU) with nunits cells."""
600    _check_rnn_cell_input_dtypes([inputs, state])
601
602    gate_inputs = math_ops.matmul(
603        array_ops.concat([inputs, state], 1), self._gate_kernel)
604    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
605
606    value = math_ops.sigmoid(gate_inputs)
607    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
608
609    r_state = r * state
610
611    candidate = math_ops.matmul(
612        array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
613    candidate = nn_ops.bias_add(candidate, self._candidate_bias)
614
615    c = self._activation(candidate)
616    new_h = u * state + (1 - u) * c
617    return new_h, new_h
618
619  def get_config(self):
620    config = {
621        "num_units": self._num_units,
622        "kernel_initializer": initializers.serialize(self._kernel_initializer),
623        "bias_initializer": initializers.serialize(self._bias_initializer),
624        "activation": activations.serialize(self._activation),
625        "reuse": self._reuse,
626    }
627    base_config = super(GRUCell, self).get_config()
628    return dict(list(base_config.items()) + list(config.items()))
629
630
631_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
632
633
634@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"])
635@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"])
636class LSTMStateTuple(_LSTMStateTuple):
637  """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
638
639  Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
640  and `h` is the output.
641
642  Only used when `state_is_tuple=True`.
643  """
644  __slots__ = ()
645
646  @property
647  def dtype(self):
648    (c, h) = self
649    if c.dtype != h.dtype:
650      raise TypeError("Inconsistent internal state: %s vs %s" %
651                      (str(c.dtype), str(h.dtype)))
652    return c.dtype
653
654
655@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"])
656@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"])
657class BasicLSTMCell(LayerRNNCell):
658  """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead.
659
660  Basic LSTM recurrent network cell.
661
662  The implementation is based on
663
664  We add forget_bias (default: 1) to the biases of the forget gate in order to
665  reduce the scale of forgetting in the beginning of the training.
666
667  It does not allow cell clipping, a projection layer, and does not
668  use peep-hole connections: it is the basic baseline.
669
670  For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell`
671  that follows.
672
673  Note that this cell is not optimized for performance. Please use
674  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
675  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
676  better performance on CPU.
677  """
678
679  def __init__(self,
680               num_units,
681               forget_bias=1.0,
682               state_is_tuple=True,
683               activation=None,
684               reuse=None,
685               name=None,
686               dtype=None,
687               **kwargs):
688    """Initialize the basic LSTM cell.
689
690    Args:
691      num_units: int, The number of units in the LSTM cell.
692      forget_bias: float, The bias added to forget gates (see above). Must set
693        to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
694      state_is_tuple: If True, accepted and returned states are 2-tuples of the
695        `c_state` and `m_state`.  If False, they are concatenated along the
696        column axis.  The latter behavior will soon be deprecated.
697      activation: Activation function of the inner states.  Default: `tanh`. It
698        could also be string that is within Keras activation function names.
699      reuse: (optional) Python boolean describing whether to reuse variables in
700        an existing scope.  If not `True`, and the existing scope already has
701        the given variables, an error is raised.
702      name: String, the name of the layer. Layers with the same name will share
703        weights, but to avoid mistakes we require reuse=True in such cases.
704      dtype: Default dtype of the layer (default of `None` means use the type of
705        the first input). Required when `build` is called before `call`.
706      **kwargs: Dict, keyword named properties for common layer attributes, like
707        `trainable` etc when constructing the cell from configs of get_config().
708        When restoring from CudnnLSTM-trained checkpoints, must use
709        `CudnnCompatibleLSTMCell` instead.
710    """
711    warnings.warn("`tf.nn.rnn_cell.BasicLSTMCell` is deprecated and will be "
712                  "removed in a future version. This class "
713                  "is equivalent as `tf.keras.layers.LSTMCell`, "
714                  "and will be replaced by that in Tensorflow 2.0.")
715    super(BasicLSTMCell, self).__init__(
716        _reuse=reuse, name=name, dtype=dtype, **kwargs)
717    _check_supported_dtypes(self.dtype)
718    if not state_is_tuple:
719      logging.warning(
720          "%s: Using a concatenated state is slower and will soon be "
721          "deprecated.  Use state_is_tuple=True.", self)
722    if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
723      logging.warning(
724          "%s: Note that this cell is not optimized for performance. "
725          "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
726          "performance on GPU.", self)
727
728    # Inputs must be 2-dimensional.
729    self.input_spec = input_spec.InputSpec(ndim=2)
730
731    self._num_units = num_units
732    self._forget_bias = forget_bias
733    self._state_is_tuple = state_is_tuple
734    if activation:
735      self._activation = activations.get(activation)
736    else:
737      self._activation = math_ops.tanh
738
739  @property
740  def state_size(self):
741    return (LSTMStateTuple(self._num_units, self._num_units)
742            if self._state_is_tuple else 2 * self._num_units)
743
744  @property
745  def output_size(self):
746    return self._num_units
747
748  @tf_utils.shape_type_conversion
749  def build(self, inputs_shape):
750    if inputs_shape[-1] is None:
751      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
752                       str(inputs_shape))
753    _check_supported_dtypes(self.dtype)
754    input_depth = inputs_shape[-1]
755    h_depth = self._num_units
756    self._kernel = self.add_variable(
757        _WEIGHTS_VARIABLE_NAME,
758        shape=[input_depth + h_depth, 4 * self._num_units])
759    self._bias = self.add_variable(
760        _BIAS_VARIABLE_NAME,
761        shape=[4 * self._num_units],
762        initializer=init_ops.zeros_initializer(dtype=self.dtype))
763
764    self.built = True
765
766  def call(self, inputs, state):
767    """Long short-term memory cell (LSTM).
768
769    Args:
770      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
771      state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size,
772        num_units]`, if `state_is_tuple` has been set to `True`.  Otherwise, a
773        `Tensor` shaped `[batch_size, 2 * num_units]`.
774
775    Returns:
776      A pair containing the new hidden state, and the new state (either a
777        `LSTMStateTuple` or a concatenated state, depending on
778        `state_is_tuple`).
779    """
780    _check_rnn_cell_input_dtypes([inputs, state])
781
782    sigmoid = math_ops.sigmoid
783    one = constant_op.constant(1, dtype=dtypes.int32)
784    # Parameters of gates are concatenated into one multiply for efficiency.
785    if self._state_is_tuple:
786      c, h = state
787    else:
788      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
789
790    gate_inputs = math_ops.matmul(
791        array_ops.concat([inputs, h], 1), self._kernel)
792    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
793
794    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
795    i, j, f, o = array_ops.split(
796        value=gate_inputs, num_or_size_splits=4, axis=one)
797
798    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
799    # Note that using `add` and `multiply` instead of `+` and `*` gives a
800    # performance improvement. So using those at the cost of readability.
801    add = math_ops.add
802    multiply = math_ops.multiply
803    new_c = add(
804        multiply(c, sigmoid(add(f, forget_bias_tensor))),
805        multiply(sigmoid(i), self._activation(j)))
806    new_h = multiply(self._activation(new_c), sigmoid(o))
807
808    if self._state_is_tuple:
809      new_state = LSTMStateTuple(new_c, new_h)
810    else:
811      new_state = array_ops.concat([new_c, new_h], 1)
812    return new_h, new_state
813
814  def get_config(self):
815    config = {
816        "num_units": self._num_units,
817        "forget_bias": self._forget_bias,
818        "state_is_tuple": self._state_is_tuple,
819        "activation": activations.serialize(self._activation),
820        "reuse": self._reuse,
821    }
822    base_config = super(BasicLSTMCell, self).get_config()
823    return dict(list(base_config.items()) + list(config.items()))
824
825
826@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"])
827@tf_export(v1=["nn.rnn_cell.LSTMCell"])
828class LSTMCell(LayerRNNCell):
829  """Long short-term memory unit (LSTM) recurrent network cell.
830
831  The default non-peephole implementation is based on (Gers et al., 1999).
832  The peephole implementation is based on (Sak et al., 2014).
833
834  The class uses optional peep-hole connections, optional cell clipping, and
835  an optional projection layer.
836
837  Note that this cell is not optimized for performance. Please use
838  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
839  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
840  better performance on CPU.
841  References:
842    Long short-term memory recurrent neural network architectures for large
843    scale acoustic modeling:
844      [Sak et al., 2014]
845      (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html)
846      ([pdf]
847      (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf))
848    Learning to forget:
849      [Gers et al., 1999]
850      (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218)
851      ([pdf](https://arxiv.org/pdf/1409.2329.pdf))
852    Long Short-Term Memory:
853      [Hochreiter et al., 1997]
854      (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735)
855      ([pdf](http://ml.jku.at/publications/older/3504.pdf))
856  """
857
858  def __init__(self,
859               num_units,
860               use_peepholes=False,
861               cell_clip=None,
862               initializer=None,
863               num_proj=None,
864               proj_clip=None,
865               num_unit_shards=None,
866               num_proj_shards=None,
867               forget_bias=1.0,
868               state_is_tuple=True,
869               activation=None,
870               reuse=None,
871               name=None,
872               dtype=None,
873               **kwargs):
874    """Initialize the parameters for an LSTM cell.
875
876    Args:
877      num_units: int, The number of units in the LSTM cell.
878      use_peepholes: bool, set True to enable diagonal/peephole connections.
879      cell_clip: (optional) A float value, if provided the cell state is clipped
880        by this value prior to the cell output activation.
881      initializer: (optional) The initializer to use for the weight and
882        projection matrices.
883      num_proj: (optional) int, The output dimensionality for the projection
884        matrices.  If None, no projection is performed.
885      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
886        provided, then the projected values are clipped elementwise to within
887        `[-proj_clip, proj_clip]`.
888      num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
889        variable_scope partitioner instead.
890      num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
891        variable_scope partitioner instead.
892      forget_bias: Biases of the forget gate are initialized by default to 1 in
893        order to reduce the scale of forgetting at the beginning of the
894        training. Must set it manually to `0.0` when restoring from CudnnLSTM
895        trained checkpoints.
896      state_is_tuple: If True, accepted and returned states are 2-tuples of the
897        `c_state` and `m_state`.  If False, they are concatenated along the
898        column axis.  This latter behavior will soon be deprecated.
899      activation: Activation function of the inner states.  Default: `tanh`. It
900        could also be string that is within Keras activation function names.
901      reuse: (optional) Python boolean describing whether to reuse variables in
902        an existing scope.  If not `True`, and the existing scope already has
903        the given variables, an error is raised.
904      name: String, the name of the layer. Layers with the same name will share
905        weights, but to avoid mistakes we require reuse=True in such cases.
906      dtype: Default dtype of the layer (default of `None` means use the type of
907        the first input). Required when `build` is called before `call`.
908      **kwargs: Dict, keyword named properties for common layer attributes, like
909        `trainable` etc when constructing the cell from configs of get_config().
910        When restoring from CudnnLSTM-trained checkpoints, use
911        `CudnnCompatibleLSTMCell` instead.
912    """
913    warnings.warn("`tf.nn.rnn_cell.LSTMCell` is deprecated and will be "
914                  "removed in a future version. This class "
915                  "is equivalent as `tf.keras.layers.LSTMCell`, "
916                  "and will be replaced by that in Tensorflow 2.0.")
917    super(LSTMCell, self).__init__(
918        _reuse=reuse, name=name, dtype=dtype, **kwargs)
919    _check_supported_dtypes(self.dtype)
920    if not state_is_tuple:
921      logging.warning(
922          "%s: Using a concatenated state is slower and will soon be "
923          "deprecated.  Use state_is_tuple=True.", self)
924    if num_unit_shards is not None or num_proj_shards is not None:
925      logging.warning(
926          "%s: The num_unit_shards and proj_unit_shards parameters are "
927          "deprecated and will be removed in Jan 2017.  "
928          "Use a variable scope with a partitioner instead.", self)
929    if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
930      logging.warning(
931          "%s: Note that this cell is not optimized for performance. "
932          "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
933          "performance on GPU.", self)
934
935    # Inputs must be 2-dimensional.
936    self.input_spec = input_spec.InputSpec(ndim=2)
937
938    self._num_units = num_units
939    self._use_peepholes = use_peepholes
940    self._cell_clip = cell_clip
941    self._initializer = initializers.get(initializer)
942    self._num_proj = num_proj
943    self._proj_clip = proj_clip
944    self._num_unit_shards = num_unit_shards
945    self._num_proj_shards = num_proj_shards
946    self._forget_bias = forget_bias
947    self._state_is_tuple = state_is_tuple
948    if activation:
949      self._activation = activations.get(activation)
950    else:
951      self._activation = math_ops.tanh
952
953    if num_proj:
954      self._state_size = (
955          LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units +
956          num_proj)
957      self._output_size = num_proj
958    else:
959      self._state_size = (
960          LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 *
961          num_units)
962      self._output_size = num_units
963
964  @property
965  def state_size(self):
966    return self._state_size
967
968  @property
969  def output_size(self):
970    return self._output_size
971
972  @tf_utils.shape_type_conversion
973  def build(self, inputs_shape):
974    if inputs_shape[-1] is None:
975      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
976                       str(inputs_shape))
977    _check_supported_dtypes(self.dtype)
978    input_depth = inputs_shape[-1]
979    h_depth = self._num_units if self._num_proj is None else self._num_proj
980    maybe_partitioner = (
981        partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
982        if self._num_unit_shards is not None else None)
983    self._kernel = self.add_variable(
984        _WEIGHTS_VARIABLE_NAME,
985        shape=[input_depth + h_depth, 4 * self._num_units],
986        initializer=self._initializer,
987        partitioner=maybe_partitioner)
988    if self.dtype is None:
989      initializer = init_ops.zeros_initializer
990    else:
991      initializer = init_ops.zeros_initializer(dtype=self.dtype)
992    self._bias = self.add_variable(
993        _BIAS_VARIABLE_NAME,
994        shape=[4 * self._num_units],
995        initializer=initializer)
996    if self._use_peepholes:
997      self._w_f_diag = self.add_variable(
998          "w_f_diag", shape=[self._num_units], initializer=self._initializer)
999      self._w_i_diag = self.add_variable(
1000          "w_i_diag", shape=[self._num_units], initializer=self._initializer)
1001      self._w_o_diag = self.add_variable(
1002          "w_o_diag", shape=[self._num_units], initializer=self._initializer)
1003
1004    if self._num_proj is not None:
1005      maybe_proj_partitioner = (
1006          partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
1007          if self._num_proj_shards is not None else None)
1008      self._proj_kernel = self.add_variable(
1009          "projection/%s" % _WEIGHTS_VARIABLE_NAME,
1010          shape=[self._num_units, self._num_proj],
1011          initializer=self._initializer,
1012          partitioner=maybe_proj_partitioner)
1013
1014    self.built = True
1015
1016  def call(self, inputs, state):
1017    """Run one step of LSTM.
1018
1019    Args:
1020      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
1021      state: if `state_is_tuple` is False, this must be a state Tensor, `2-D,
1022        [batch, state_size]`.  If `state_is_tuple` is True, this must be a tuple
1023        of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`.
1024
1025    Returns:
1026      A tuple containing:
1027
1028      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
1029        LSTM after reading `inputs` when previous state was `state`.
1030        Here output_dim is:
1031           num_proj if num_proj was set,
1032           num_units otherwise.
1033      - Tensor(s) representing the new state of LSTM after reading `inputs` when
1034        the previous state was `state`.  Same type and shape(s) as `state`.
1035
1036    Raises:
1037      ValueError: If input size cannot be inferred from inputs via
1038        static shape inference.
1039    """
1040    _check_rnn_cell_input_dtypes([inputs, state])
1041
1042    num_proj = self._num_units if self._num_proj is None else self._num_proj
1043    sigmoid = math_ops.sigmoid
1044
1045    if self._state_is_tuple:
1046      (c_prev, m_prev) = state
1047    else:
1048      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
1049      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
1050
1051    input_size = inputs.get_shape().with_rank(2).dims[1].value
1052    if input_size is None:
1053      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1054
1055    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
1056    lstm_matrix = math_ops.matmul(
1057        array_ops.concat([inputs, m_prev], 1), self._kernel)
1058    lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
1059
1060    i, j, f, o = array_ops.split(
1061        value=lstm_matrix, num_or_size_splits=4, axis=1)
1062    # Diagonal connections
1063    if self._use_peepholes:
1064      c = (
1065          sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
1066          sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
1067    else:
1068      c = (
1069          sigmoid(f + self._forget_bias) * c_prev +
1070          sigmoid(i) * self._activation(j))
1071
1072    if self._cell_clip is not None:
1073      # pylint: disable=invalid-unary-operand-type
1074      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
1075      # pylint: enable=invalid-unary-operand-type
1076    if self._use_peepholes:
1077      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
1078    else:
1079      m = sigmoid(o) * self._activation(c)
1080
1081    if self._num_proj is not None:
1082      m = math_ops.matmul(m, self._proj_kernel)
1083
1084      if self._proj_clip is not None:
1085        # pylint: disable=invalid-unary-operand-type
1086        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
1087        # pylint: enable=invalid-unary-operand-type
1088
1089    new_state = (
1090        LSTMStateTuple(c, m)
1091        if self._state_is_tuple else array_ops.concat([c, m], 1))
1092    return m, new_state
1093
1094  def get_config(self):
1095    config = {
1096        "num_units": self._num_units,
1097        "use_peepholes": self._use_peepholes,
1098        "cell_clip": self._cell_clip,
1099        "initializer": initializers.serialize(self._initializer),
1100        "num_proj": self._num_proj,
1101        "proj_clip": self._proj_clip,
1102        "num_unit_shards": self._num_unit_shards,
1103        "num_proj_shards": self._num_proj_shards,
1104        "forget_bias": self._forget_bias,
1105        "state_is_tuple": self._state_is_tuple,
1106        "activation": activations.serialize(self._activation),
1107        "reuse": self._reuse,
1108    }
1109    base_config = super(LSTMCell, self).get_config()
1110    return dict(list(base_config.items()) + list(config.items()))
1111
1112
1113class _RNNCellWrapperV1(RNNCell):
1114  """Base class for cells wrappers V1 compatibility.
1115
1116  This class along with `_RNNCellWrapperV2` allows to define cells wrappers that
1117  are compatible with V1 and V2, and defines helper methods for this purpose.
1118  """
1119
1120  def __init__(self, cell, *args, **kwargs):
1121    super(_RNNCellWrapperV1, self).__init__(*args, **kwargs)
1122    assert_like_rnncell("cell", cell)
1123    self.cell = cell
1124    if isinstance(cell, trackable.Trackable):
1125      self._track_trackable(self.cell, name="cell")
1126
1127  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1128    """Calls the wrapped cell and performs the wrapping logic.
1129
1130    This method is called from the wrapper's `call` or `__call__` methods.
1131
1132    Args:
1133      inputs: A tensor with wrapped cell's input.
1134      state: A tensor or tuple of tensors with wrapped cell's state.
1135      cell_call_fn: Wrapped cell's method to use for step computation (cell's
1136        `__call__` or 'call' method).
1137      **kwargs: Additional arguments.
1138
1139    Returns:
1140      A pair containing:
1141      - Output: A tensor with cell's output.
1142      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1143    """
1144    raise NotImplementedError
1145
1146  def __call__(self, inputs, state, scope=None):
1147    """Runs the RNN cell step computation.
1148
1149    We assume that the wrapped RNNCell is being built within its `__call__`
1150    method. We directly use the wrapped cell's `__call__` in the overridden
1151    wrapper `__call__` method.
1152
1153    This allows to use the wrapped cell and the non-wrapped cell equivalently
1154    when using `__call__`.
1155
1156    Args:
1157      inputs: A tensor with wrapped cell's input.
1158      state: A tensor or tuple of tensors with wrapped cell's state.
1159      scope: VariableScope for the subgraph created in the wrapped cells'
1160        `__call__`.
1161
1162    Returns:
1163      A pair containing:
1164
1165      - Output: A tensor with cell's output.
1166      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1167    """
1168    return self._call_wrapped_cell(
1169        inputs, state, cell_call_fn=self.cell.__call__, scope=scope)
1170
1171  def get_config(self):
1172    config = {
1173        "cell": {
1174            "class_name": self.cell.__class__.__name__,
1175            "config": self.cell.get_config()
1176        },
1177    }
1178    base_config = super(_RNNCellWrapperV1, self).get_config()
1179    return dict(list(base_config.items()) + list(config.items()))
1180
1181  @classmethod
1182  def from_config(cls, config, custom_objects=None):
1183    config = config.copy()
1184    cell = config.pop("cell")
1185    try:
1186      assert_like_rnncell("cell", cell)
1187      return cls(cell, **config)
1188    except TypeError:
1189      raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. "
1190                       "Please overwrite the cell in the config with a RNNCell "
1191                       "instance.")
1192
1193
1194@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DropoutWrapper"])
1195@tf_export(v1=["nn.rnn_cell.DropoutWrapper"])
1196class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
1197                     _RNNCellWrapperV1):
1198  """Operator adding dropout to inputs and outputs of the given cell."""
1199
1200  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
1201    super(DropoutWrapper, self).__init__(*args, **kwargs)
1202
1203  __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
1204
1205
1206@keras_export(v1=["keras.__internal__.legacy.rnn_cell.ResidualWrapper"])
1207@tf_export(v1=["nn.rnn_cell.ResidualWrapper"])
1208class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
1209                      _RNNCellWrapperV1):
1210  """RNNCell wrapper that ensures cell inputs are added to the outputs."""
1211
1212  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
1213    super(ResidualWrapper, self).__init__(*args, **kwargs)
1214
1215  __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
1216
1217
1218@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DeviceWrapper"])
1219@tf_export(v1=["nn.rnn_cell.DeviceWrapper"])
1220class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
1221                    _RNNCellWrapperV1):
1222
1223  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
1224    super(DeviceWrapper, self).__init__(*args, **kwargs)
1225
1226  __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__
1227
1228
1229@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"])
1230@tf_export(v1=["nn.rnn_cell.MultiRNNCell"])
1231class MultiRNNCell(RNNCell):
1232  """RNN cell composed sequentially of multiple simple cells.
1233
1234  Example:
1235
1236  ```python
1237  num_units = [128, 64]
1238  cells = [BasicLSTMCell(num_units=n) for n in num_units]
1239  stacked_rnn_cell = MultiRNNCell(cells)
1240  ```
1241  """
1242
1243  def __init__(self, cells, state_is_tuple=True):
1244    """Create a RNN cell composed sequentially of a number of RNNCells.
1245
1246    Args:
1247      cells: list of RNNCells that will be composed in this order.
1248      state_is_tuple: If True, accepted and returned states are n-tuples, where
1249        `n = len(cells)`.  If False, the states are all concatenated along the
1250        column axis.  This latter behavior will soon be deprecated.
1251
1252    Raises:
1253      ValueError: if cells is empty (not allowed), or at least one of the cells
1254        returns a state tuple but the flag `state_is_tuple` is `False`.
1255    """
1256    logging.warning("`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class "
1257                    "is equivalent as `tf.keras.layers.StackedRNNCells`, "
1258                    "and will be replaced by that in Tensorflow 2.0.")
1259    super(MultiRNNCell, self).__init__()
1260    if not cells:
1261      raise ValueError("Must specify at least one cell for MultiRNNCell.")
1262    if not nest.is_nested(cells):
1263      raise TypeError("cells must be a list or tuple, but saw: %s." % cells)
1264
1265    if len(set(id(cell) for cell in cells)) < len(cells):
1266      logging.log_first_n(
1267          logging.WARN, "At least two cells provided to MultiRNNCell "
1268          "are the same object and will share weights.", 1)
1269
1270    self._cells = cells
1271    for cell_number, cell in enumerate(self._cells):
1272      # Add Trackable dependencies on these cells so their variables get
1273      # saved with this object when using object-based saving.
1274      if isinstance(cell, trackable.Trackable):
1275        # TODO(allenl): Track down non-Trackable callers.
1276        self._track_trackable(cell, name="cell-%d" % (cell_number,))
1277    self._state_is_tuple = state_is_tuple
1278    if not state_is_tuple:
1279      if any(nest.is_nested(c.state_size) for c in self._cells):
1280        raise ValueError("Some cells return tuples of states, but the flag "
1281                         "state_is_tuple is not set.  State sizes are: %s" %
1282                         str([c.state_size for c in self._cells]))
1283
1284  @property
1285  def state_size(self):
1286    if self._state_is_tuple:
1287      return tuple(cell.state_size for cell in self._cells)
1288    else:
1289      return sum(cell.state_size for cell in self._cells)
1290
1291  @property
1292  def output_size(self):
1293    return self._cells[-1].output_size
1294
1295  def zero_state(self, batch_size, dtype):
1296    with backend.name_scope(type(self).__name__ + "ZeroState"):
1297      if self._state_is_tuple:
1298        return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
1299      else:
1300        # We know here that state_size of each cell is not a tuple and
1301        # presumably does not contain TensorArrays or anything else fancy
1302        return super(MultiRNNCell, self).zero_state(batch_size, dtype)
1303
1304  @property
1305  def trainable_weights(self):
1306    if not self.trainable:
1307      return []
1308    weights = []
1309    for cell in self._cells:
1310      if isinstance(cell, base_layer.Layer):
1311        weights += cell.trainable_weights
1312    return weights
1313
1314  @property
1315  def non_trainable_weights(self):
1316    weights = []
1317    for cell in self._cells:
1318      if isinstance(cell, base_layer.Layer):
1319        weights += cell.non_trainable_weights
1320    if not self.trainable:
1321      trainable_weights = []
1322      for cell in self._cells:
1323        if isinstance(cell, base_layer.Layer):
1324          trainable_weights += cell.trainable_weights
1325      return trainable_weights + weights
1326    return weights
1327
1328  def call(self, inputs, state):
1329    """Run this multi-layer cell on inputs, starting from state."""
1330    cur_state_pos = 0
1331    cur_inp = inputs
1332    new_states = []
1333    for i, cell in enumerate(self._cells):
1334      with vs.variable_scope("cell_%d" % i):
1335        if self._state_is_tuple:
1336          if not nest.is_nested(state):
1337            raise ValueError(
1338                "Expected state to be a tuple of length %d, but received: %s" %
1339                (len(self.state_size), state))
1340          cur_state = state[i]
1341        else:
1342          cur_state = array_ops.slice(state, [0, cur_state_pos],
1343                                      [-1, cell.state_size])
1344          cur_state_pos += cell.state_size
1345        cur_inp, new_state = cell(cur_inp, cur_state)
1346        new_states.append(new_state)
1347
1348    new_states = (
1349        tuple(new_states) if self._state_is_tuple else array_ops.concat(
1350            new_states, 1))
1351
1352    return cur_inp, new_states
1353
1354
1355def _check_rnn_cell_input_dtypes(inputs):
1356  """Check whether the input tensors are with supported dtypes.
1357
1358  Default RNN cells only support floats and complex as its dtypes since the
1359  activation function (tanh and sigmoid) only allow those types. This function
1360  will throw a proper error message if the inputs is not in a supported type.
1361
1362  Args:
1363    inputs: tensor or nested structure of tensors that are feed to RNN cell as
1364      input or state.
1365
1366  Raises:
1367    ValueError: if any of the input tensor are not having dtypes of float or
1368      complex.
1369  """
1370  for t in nest.flatten(inputs):
1371    _check_supported_dtypes(t.dtype)
1372
1373
1374def _check_supported_dtypes(dtype):
1375  if dtype is None:
1376    return
1377  dtype = dtypes.as_dtype(dtype)
1378  if not (dtype.is_floating or dtype.is_complex):
1379    raise ValueError("RNN cell only supports floating point inputs, "
1380                     "but saw dtype: %s" % dtype)
1381