1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-classes-have-attributes 16"""Module implementing RNN Cells. 17 18This module provides a number of basic commonly used RNN cells, such as LSTM 19(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of 20operators that allow adding dropouts, projections, or embeddings for inputs. 21Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by 22calling the `rnn` ops several times. 23""" 24import collections 25import warnings 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import config as tf_config 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.keras import activations 35from tensorflow.python.keras import backend 36from tensorflow.python.keras import initializers 37from tensorflow.python.keras.engine import base_layer_utils 38from tensorflow.python.keras.engine import input_spec 39from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl 40from tensorflow.python.keras.legacy_tf_layers import base as base_layer 41from tensorflow.python.keras.utils import tf_utils 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import clip_ops 44from tensorflow.python.ops import init_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import nn_ops 47from tensorflow.python.ops import partitioned_variables 48from tensorflow.python.ops import variable_scope as vs 49from tensorflow.python.ops import variables as tf_variables 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.trackable import base as trackable 52from tensorflow.python.util import nest 53from tensorflow.python.util.tf_export import keras_export 54from tensorflow.python.util.tf_export import tf_export 55 56_BIAS_VARIABLE_NAME = "bias" 57_WEIGHTS_VARIABLE_NAME = "kernel" 58 59# This can be used with self.assertRaisesRegexp for assert_like_rnncell. 60ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" 61 62 63def _hasattr(obj, attr_name): 64 try: 65 getattr(obj, attr_name) 66 except AttributeError: 67 return False 68 else: 69 return True 70 71 72def assert_like_rnncell(cell_name, cell): 73 """Raises a TypeError if cell is not like an RNNCell. 74 75 NOTE: Do not rely on the error message (in particular in tests) which can be 76 subject to change to increase readability. Use 77 ASSERT_LIKE_RNNCELL_ERROR_REGEXP. 78 79 Args: 80 cell_name: A string to give a meaningful error referencing to the name of 81 the functionargument. 82 cell: The object which should behave like an RNNCell. 83 84 Raises: 85 TypeError: A human-friendly exception. 86 """ 87 conditions = [ 88 _hasattr(cell, "output_size"), 89 _hasattr(cell, "state_size"), 90 _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), 91 callable(cell), 92 ] 93 errors = [ 94 "'output_size' property is missing", "'state_size' property is missing", 95 "either 'zero_state' or 'get_initial_state' method is required", 96 "is not callable" 97 ] 98 99 if not all(conditions): 100 101 errors = [error for error, cond in zip(errors, conditions) if not cond] 102 raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( 103 cell_name, cell, ", ".join(errors))) 104 105 106def _concat(prefix, suffix, static=False): 107 """Concat that enables int, Tensor, or TensorShape values. 108 109 This function takes a size specification, which can be an integer, a 110 TensorShape, or a Tensor, and converts it into a concatenated Tensor 111 (if static = False) or a list of integers (if static = True). 112 113 Args: 114 prefix: The prefix; usually the batch size (and/or time step size). 115 (TensorShape, int, or Tensor.) 116 suffix: TensorShape, int, or Tensor. 117 static: If `True`, return a python list with possibly unknown dimensions. 118 Otherwise return a `Tensor`. 119 120 Returns: 121 shape: the concatenation of prefix and suffix. 122 123 Raises: 124 ValueError: if `suffix` is not a scalar or vector (or TensorShape). 125 ValueError: if prefix or suffix was `None` and asked for dynamic 126 Tensors out. 127 """ 128 if isinstance(prefix, ops.Tensor): 129 p = prefix 130 p_static = tensor_util.constant_value(prefix) 131 if p.shape.ndims == 0: 132 p = array_ops.expand_dims(p, 0) 133 elif p.shape.ndims != 1: 134 raise ValueError("prefix tensor must be either a scalar or vector, " 135 "but saw tensor: %s" % p) 136 else: 137 p = tensor_shape.TensorShape(prefix) 138 p_static = p.as_list() if p.ndims is not None else None 139 p = ( 140 constant_op.constant(p.as_list(), dtype=dtypes.int32) 141 if p.is_fully_defined() else None) 142 if isinstance(suffix, ops.Tensor): 143 s = suffix 144 s_static = tensor_util.constant_value(suffix) 145 if s.shape.ndims == 0: 146 s = array_ops.expand_dims(s, 0) 147 elif s.shape.ndims != 1: 148 raise ValueError("suffix tensor must be either a scalar or vector, " 149 "but saw tensor: %s" % s) 150 else: 151 s = tensor_shape.TensorShape(suffix) 152 s_static = s.as_list() if s.ndims is not None else None 153 s = ( 154 constant_op.constant(s.as_list(), dtype=dtypes.int32) 155 if s.is_fully_defined() else None) 156 157 if static: 158 shape = tensor_shape.TensorShape(p_static).concatenate(s_static) 159 shape = shape.as_list() if shape.ndims is not None else None 160 else: 161 if p is None or s is None: 162 raise ValueError("Provided a prefix or suffix of None: %s and %s" % 163 (prefix, suffix)) 164 shape = array_ops.concat((p, s), 0) 165 return shape 166 167 168def _zero_state_tensors(state_size, batch_size, dtype): 169 """Create tensors of zeros based on state_size, batch_size, and dtype.""" 170 171 def get_state_shape(s): 172 """Combine s with batch_size to get a proper tensor shape.""" 173 c = _concat(batch_size, s) 174 size = array_ops.zeros(c, dtype=dtype) 175 if not context.executing_eagerly(): 176 c_static = _concat(batch_size, s, static=True) 177 size.set_shape(c_static) 178 return size 179 180 return nest.map_structure(get_state_shape, state_size) 181 182 183@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"]) 184@tf_export(v1=["nn.rnn_cell.RNNCell"]) 185class RNNCell(base_layer.Layer): 186 """Abstract object representing an RNN cell. 187 188 Every `RNNCell` must have the properties below and implement `call` with 189 the signature `(output, next_state) = call(input, state)`. The optional 190 third input argument, `scope`, is allowed for backwards compatibility 191 purposes; but should be left off for new subclasses. 192 193 This definition of cell differs from the definition used in the literature. 194 In the literature, 'cell' refers to an object with a single scalar output. 195 This definition refers to a horizontal array of such units. 196 197 An RNN cell, in the most abstract setting, is anything that has 198 a state and performs some operation that takes a matrix of inputs. 199 This operation results in an output matrix with `self.output_size` columns. 200 If `self.state_size` is an integer, this operation also results in a new 201 state matrix with `self.state_size` columns. If `self.state_size` is a 202 (possibly nested tuple of) TensorShape object(s), then it should return a 203 matching structure of Tensors having shape `[batch_size].concatenate(s)` 204 for each `s` in `self.batch_size`. 205 """ 206 207 def __init__(self, trainable=True, name=None, dtype=None, **kwargs): 208 super(RNNCell, self).__init__( 209 trainable=trainable, name=name, dtype=dtype, **kwargs) 210 # Attribute that indicates whether the cell is a TF RNN cell, due the slight 211 # difference between TF and Keras RNN cell. Notably the state is not wrapped 212 # in a list for TF cell where they are single tensor state, whereas keras 213 # cell will wrap the state into a list, and call() will have to unwrap them. 214 self._is_tf_rnn_cell = True 215 216 def __call__(self, inputs, state, scope=None): 217 """Run this RNN cell on inputs, starting from the given state. 218 219 Args: 220 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 221 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 222 with shape `[batch_size, self.state_size]`. Otherwise, if 223 `self.state_size` is a tuple of integers, this should be a tuple with 224 shapes `[batch_size, s] for s in self.state_size`. 225 scope: VariableScope for the created subgraph; defaults to class name. 226 227 Returns: 228 A pair containing: 229 230 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 231 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 232 the arity and shapes of `state`. 233 """ 234 if scope is not None: 235 with vs.variable_scope( 236 scope, custom_getter=self._rnn_get_variable) as scope: 237 return super(RNNCell, self).__call__(inputs, state, scope=scope) 238 else: 239 scope_attrname = "rnncell_scope" 240 scope = getattr(self, scope_attrname, None) 241 if scope is None: 242 scope = vs.variable_scope( 243 vs.get_variable_scope(), custom_getter=self._rnn_get_variable) 244 setattr(self, scope_attrname, scope) 245 with scope: 246 return super(RNNCell, self).__call__(inputs, state) 247 248 def _rnn_get_variable(self, getter, *args, **kwargs): 249 variable = getter(*args, **kwargs) 250 if ops.executing_eagerly_outside_functions(): 251 trainable = variable.trainable 252 else: 253 trainable = ( 254 variable in tf_variables.trainable_variables() or 255 (base_layer_utils.is_split_variable(variable) and 256 list(variable)[0] in tf_variables.trainable_variables())) 257 if trainable and all(variable is not v for v in self._trainable_weights): 258 self._trainable_weights.append(variable) 259 elif not trainable and all( 260 variable is not v for v in self._non_trainable_weights): 261 self._non_trainable_weights.append(variable) 262 return variable 263 264 @property 265 def state_size(self): 266 """size(s) of state(s) used by this cell. 267 268 It can be represented by an Integer, a TensorShape or a tuple of Integers 269 or TensorShapes. 270 """ 271 raise NotImplementedError("Abstract method") 272 273 @property 274 def output_size(self): 275 """Integer or TensorShape: size of outputs produced by this cell.""" 276 raise NotImplementedError("Abstract method") 277 278 def build(self, _): 279 # This tells the parent Layer object that it's OK to call 280 # self.add_variable() inside the call() method. 281 pass 282 283 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 284 if inputs is not None: 285 # Validate the given batch_size and dtype against inputs if provided. 286 inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, name="inputs") 287 if batch_size is not None: 288 if tensor_util.is_tf_type(batch_size): 289 static_batch_size = tensor_util.constant_value( 290 batch_size, partial=True) 291 else: 292 static_batch_size = batch_size 293 if inputs.shape.dims[0].value != static_batch_size: 294 raise ValueError( 295 "batch size from input tensor is different from the " 296 "input param. Input tensor batch: {}, batch_size: {}".format( 297 inputs.shape.dims[0].value, batch_size)) 298 299 if dtype is not None and inputs.dtype != dtype: 300 raise ValueError( 301 "dtype from input tensor is different from the " 302 "input param. Input tensor dtype: {}, dtype: {}".format( 303 inputs.dtype, dtype)) 304 305 batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] 306 dtype = inputs.dtype 307 if batch_size is None or dtype is None: 308 raise ValueError( 309 "batch_size and dtype cannot be None while constructing initial " 310 "state: batch_size={}, dtype={}".format(batch_size, dtype)) 311 return self.zero_state(batch_size, dtype) 312 313 def zero_state(self, batch_size, dtype): 314 """Return zero-filled state tensor(s). 315 316 Args: 317 batch_size: int, float, or unit Tensor representing the batch size. 318 dtype: the data type to use for the state. 319 320 Returns: 321 If `state_size` is an int or TensorShape, then the return value is a 322 `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. 323 324 If `state_size` is a nested list or tuple, then the return value is 325 a nested list or tuple (of the same structure) of `2-D` tensors with 326 the shapes `[batch_size, s]` for each s in `state_size`. 327 """ 328 # Try to use the last cached zero_state. This is done to avoid recreating 329 # zeros, especially when eager execution is enabled. 330 state_size = self.state_size 331 is_eager = context.executing_eagerly() 332 if is_eager and _hasattr(self, "_last_zero_state"): 333 (last_state_size, last_batch_size, last_dtype, 334 last_output) = getattr(self, "_last_zero_state") 335 if (last_batch_size == batch_size and last_dtype == dtype and 336 last_state_size == state_size): 337 return last_output 338 with backend.name_scope(type(self).__name__ + "ZeroState"): 339 output = _zero_state_tensors(state_size, batch_size, dtype) 340 if is_eager: 341 self._last_zero_state = (state_size, batch_size, dtype, output) 342 return output 343 344 # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` 345 def get_config(self): # pylint: disable=useless-super-delegation 346 return super(RNNCell, self).get_config() 347 348 @property 349 def _use_input_spec_as_call_signature(self): 350 # We do not store the shape information for the state argument in the call 351 # function for legacy RNN cells, so do not generate an input signature. 352 return False 353 354 355class LayerRNNCell(RNNCell): 356 """Subclass of RNNCells that act like proper `tf.Layer` objects. 357 358 For backwards compatibility purposes, most `RNNCell` instances allow their 359 `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The 360 underlying 361 variable scope thus keeps track of any variables, and returning cached 362 versions. This is atypical of `tf.layer` objects, which separate this 363 part of layer building into a `build` method that is only called once. 364 365 Here we provide a subclass for `RNNCell` objects that act exactly as 366 `Layer` objects do. They must provide a `build` method and their 367 `call` methods do not access Variables `tf.compat.v1.get_variable`. 368 """ 369 370 def __call__(self, inputs, state, scope=None, *args, **kwargs): 371 """Run this RNN cell on inputs, starting from the given state. 372 373 Args: 374 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 375 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 376 with shape `[batch_size, self.state_size]`. Otherwise, if 377 `self.state_size` is a tuple of integers, this should be a tuple with 378 shapes `[batch_size, s] for s in self.state_size`. 379 scope: optional cell scope. 380 *args: Additional positional arguments. 381 **kwargs: Additional keyword arguments. 382 383 Returns: 384 A pair containing: 385 386 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 387 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 388 the arity and shapes of `state`. 389 """ 390 # Bypass RNNCell's variable capturing semantics for LayerRNNCell. 391 # Instead, it is up to subclasses to provide a proper build 392 # method. See the class docstring for more details. 393 return base_layer.Layer.__call__( 394 self, inputs, state, scope=scope, *args, **kwargs) 395 396 397@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"]) 398@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) 399class BasicRNNCell(LayerRNNCell): 400 """The most basic RNN cell. 401 402 Note that this cell is not optimized for performance. Please use 403 `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. 404 405 Args: 406 num_units: int, The number of units in the RNN cell. 407 activation: Nonlinearity to use. Default: `tanh`. It could also be string 408 that is within Keras activation function names. 409 reuse: (optional) Python boolean describing whether to reuse variables in an 410 existing scope. If not `True`, and the existing scope already has the 411 given variables, an error is raised. 412 name: String, the name of the layer. Layers with the same name will share 413 weights, but to avoid mistakes we require reuse=True in such cases. 414 dtype: Default dtype of the layer (default of `None` means use the type of 415 the first input). Required when `build` is called before `call`. 416 **kwargs: Dict, keyword named properties for common layer attributes, like 417 `trainable` etc when constructing the cell from configs of get_config(). 418 """ 419 420 def __init__(self, 421 num_units, 422 activation=None, 423 reuse=None, 424 name=None, 425 dtype=None, 426 **kwargs): 427 warnings.warn("`tf.nn.rnn_cell.BasicRNNCell` is deprecated and will be " 428 "removed in a future version. This class " 429 "is equivalent as `tf.keras.layers.SimpleRNNCell`, " 430 "and will be replaced by that in Tensorflow 2.0.") 431 super(BasicRNNCell, self).__init__( 432 _reuse=reuse, name=name, dtype=dtype, **kwargs) 433 _check_supported_dtypes(self.dtype) 434 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 435 logging.warning( 436 "%s: Note that this cell is not optimized for performance. " 437 "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " 438 "performance on GPU.", self) 439 440 # Inputs must be 2-dimensional. 441 self.input_spec = input_spec.InputSpec(ndim=2) 442 443 self._num_units = num_units 444 if activation: 445 self._activation = activations.get(activation) 446 else: 447 self._activation = math_ops.tanh 448 449 @property 450 def state_size(self): 451 return self._num_units 452 453 @property 454 def output_size(self): 455 return self._num_units 456 457 @tf_utils.shape_type_conversion 458 def build(self, inputs_shape): 459 if inputs_shape[-1] is None: 460 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 461 str(inputs_shape)) 462 _check_supported_dtypes(self.dtype) 463 464 input_depth = inputs_shape[-1] 465 self._kernel = self.add_variable( 466 _WEIGHTS_VARIABLE_NAME, 467 shape=[input_depth + self._num_units, self._num_units]) 468 self._bias = self.add_variable( 469 _BIAS_VARIABLE_NAME, 470 shape=[self._num_units], 471 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 472 473 self.built = True 474 475 def call(self, inputs, state): 476 """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 477 _check_rnn_cell_input_dtypes([inputs, state]) 478 gate_inputs = math_ops.matmul( 479 array_ops.concat([inputs, state], 1), self._kernel) 480 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 481 output = self._activation(gate_inputs) 482 return output, output 483 484 def get_config(self): 485 config = { 486 "num_units": self._num_units, 487 "activation": activations.serialize(self._activation), 488 "reuse": self._reuse, 489 } 490 base_config = super(BasicRNNCell, self).get_config() 491 return dict(list(base_config.items()) + list(config.items())) 492 493 494@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"]) 495@tf_export(v1=["nn.rnn_cell.GRUCell"]) 496class GRUCell(LayerRNNCell): 497 """Gated Recurrent Unit cell. 498 499 Note that this cell is not optimized for performance. Please use 500 `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or 501 `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. 502 503 Args: 504 num_units: int, The number of units in the GRU cell. 505 activation: Nonlinearity to use. Default: `tanh`. 506 reuse: (optional) Python boolean describing whether to reuse variables in an 507 existing scope. If not `True`, and the existing scope already has the 508 given variables, an error is raised. 509 kernel_initializer: (optional) The initializer to use for the weight and 510 projection matrices. 511 bias_initializer: (optional) The initializer to use for the bias. 512 name: String, the name of the layer. Layers with the same name will share 513 weights, but to avoid mistakes we require reuse=True in such cases. 514 dtype: Default dtype of the layer (default of `None` means use the type of 515 the first input). Required when `build` is called before `call`. 516 **kwargs: Dict, keyword named properties for common layer attributes, like 517 `trainable` etc when constructing the cell from configs of get_config(). 518 519 References: 520 Learning Phrase Representations using RNN Encoder Decoder for Statistical 521 Machine Translation: 522 [Cho et al., 2014] 523 (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) 524 ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) 525 """ 526 527 def __init__(self, 528 num_units, 529 activation=None, 530 reuse=None, 531 kernel_initializer=None, 532 bias_initializer=None, 533 name=None, 534 dtype=None, 535 **kwargs): 536 warnings.warn("`tf.nn.rnn_cell.GRUCell` is deprecated and will be removed " 537 "in a future version. This class " 538 "is equivalent as `tf.keras.layers.GRUCell`, " 539 "and will be replaced by that in Tensorflow 2.0.") 540 super(GRUCell, self).__init__( 541 _reuse=reuse, name=name, dtype=dtype, **kwargs) 542 _check_supported_dtypes(self.dtype) 543 544 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 545 logging.warning( 546 "%s: Note that this cell is not optimized for performance. " 547 "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " 548 "performance on GPU.", self) 549 # Inputs must be 2-dimensional. 550 self.input_spec = input_spec.InputSpec(ndim=2) 551 552 self._num_units = num_units 553 if activation: 554 self._activation = activations.get(activation) 555 else: 556 self._activation = math_ops.tanh 557 self._kernel_initializer = initializers.get(kernel_initializer) 558 self._bias_initializer = initializers.get(bias_initializer) 559 560 @property 561 def state_size(self): 562 return self._num_units 563 564 @property 565 def output_size(self): 566 return self._num_units 567 568 @tf_utils.shape_type_conversion 569 def build(self, inputs_shape): 570 if inputs_shape[-1] is None: 571 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 572 str(inputs_shape)) 573 _check_supported_dtypes(self.dtype) 574 input_depth = inputs_shape[-1] 575 self._gate_kernel = self.add_variable( 576 "gates/%s" % _WEIGHTS_VARIABLE_NAME, 577 shape=[input_depth + self._num_units, 2 * self._num_units], 578 initializer=self._kernel_initializer) 579 self._gate_bias = self.add_variable( 580 "gates/%s" % _BIAS_VARIABLE_NAME, 581 shape=[2 * self._num_units], 582 initializer=(self._bias_initializer 583 if self._bias_initializer is not None else 584 init_ops.constant_initializer(1.0, dtype=self.dtype))) 585 self._candidate_kernel = self.add_variable( 586 "candidate/%s" % _WEIGHTS_VARIABLE_NAME, 587 shape=[input_depth + self._num_units, self._num_units], 588 initializer=self._kernel_initializer) 589 self._candidate_bias = self.add_variable( 590 "candidate/%s" % _BIAS_VARIABLE_NAME, 591 shape=[self._num_units], 592 initializer=(self._bias_initializer 593 if self._bias_initializer is not None else 594 init_ops.zeros_initializer(dtype=self.dtype))) 595 596 self.built = True 597 598 def call(self, inputs, state): 599 """Gated recurrent unit (GRU) with nunits cells.""" 600 _check_rnn_cell_input_dtypes([inputs, state]) 601 602 gate_inputs = math_ops.matmul( 603 array_ops.concat([inputs, state], 1), self._gate_kernel) 604 gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) 605 606 value = math_ops.sigmoid(gate_inputs) 607 r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 608 609 r_state = r * state 610 611 candidate = math_ops.matmul( 612 array_ops.concat([inputs, r_state], 1), self._candidate_kernel) 613 candidate = nn_ops.bias_add(candidate, self._candidate_bias) 614 615 c = self._activation(candidate) 616 new_h = u * state + (1 - u) * c 617 return new_h, new_h 618 619 def get_config(self): 620 config = { 621 "num_units": self._num_units, 622 "kernel_initializer": initializers.serialize(self._kernel_initializer), 623 "bias_initializer": initializers.serialize(self._bias_initializer), 624 "activation": activations.serialize(self._activation), 625 "reuse": self._reuse, 626 } 627 base_config = super(GRUCell, self).get_config() 628 return dict(list(base_config.items()) + list(config.items())) 629 630 631_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) 632 633 634@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"]) 635@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) 636class LSTMStateTuple(_LSTMStateTuple): 637 """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. 638 639 Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state 640 and `h` is the output. 641 642 Only used when `state_is_tuple=True`. 643 """ 644 __slots__ = () 645 646 @property 647 def dtype(self): 648 (c, h) = self 649 if c.dtype != h.dtype: 650 raise TypeError("Inconsistent internal state: %s vs %s" % 651 (str(c.dtype), str(h.dtype))) 652 return c.dtype 653 654 655@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"]) 656@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) 657class BasicLSTMCell(LayerRNNCell): 658 """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. 659 660 Basic LSTM recurrent network cell. 661 662 The implementation is based on 663 664 We add forget_bias (default: 1) to the biases of the forget gate in order to 665 reduce the scale of forgetting in the beginning of the training. 666 667 It does not allow cell clipping, a projection layer, and does not 668 use peep-hole connections: it is the basic baseline. 669 670 For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` 671 that follows. 672 673 Note that this cell is not optimized for performance. Please use 674 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 675 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 676 better performance on CPU. 677 """ 678 679 def __init__(self, 680 num_units, 681 forget_bias=1.0, 682 state_is_tuple=True, 683 activation=None, 684 reuse=None, 685 name=None, 686 dtype=None, 687 **kwargs): 688 """Initialize the basic LSTM cell. 689 690 Args: 691 num_units: int, The number of units in the LSTM cell. 692 forget_bias: float, The bias added to forget gates (see above). Must set 693 to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. 694 state_is_tuple: If True, accepted and returned states are 2-tuples of the 695 `c_state` and `m_state`. If False, they are concatenated along the 696 column axis. The latter behavior will soon be deprecated. 697 activation: Activation function of the inner states. Default: `tanh`. It 698 could also be string that is within Keras activation function names. 699 reuse: (optional) Python boolean describing whether to reuse variables in 700 an existing scope. If not `True`, and the existing scope already has 701 the given variables, an error is raised. 702 name: String, the name of the layer. Layers with the same name will share 703 weights, but to avoid mistakes we require reuse=True in such cases. 704 dtype: Default dtype of the layer (default of `None` means use the type of 705 the first input). Required when `build` is called before `call`. 706 **kwargs: Dict, keyword named properties for common layer attributes, like 707 `trainable` etc when constructing the cell from configs of get_config(). 708 When restoring from CudnnLSTM-trained checkpoints, must use 709 `CudnnCompatibleLSTMCell` instead. 710 """ 711 warnings.warn("`tf.nn.rnn_cell.BasicLSTMCell` is deprecated and will be " 712 "removed in a future version. This class " 713 "is equivalent as `tf.keras.layers.LSTMCell`, " 714 "and will be replaced by that in Tensorflow 2.0.") 715 super(BasicLSTMCell, self).__init__( 716 _reuse=reuse, name=name, dtype=dtype, **kwargs) 717 _check_supported_dtypes(self.dtype) 718 if not state_is_tuple: 719 logging.warning( 720 "%s: Using a concatenated state is slower and will soon be " 721 "deprecated. Use state_is_tuple=True.", self) 722 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 723 logging.warning( 724 "%s: Note that this cell is not optimized for performance. " 725 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 726 "performance on GPU.", self) 727 728 # Inputs must be 2-dimensional. 729 self.input_spec = input_spec.InputSpec(ndim=2) 730 731 self._num_units = num_units 732 self._forget_bias = forget_bias 733 self._state_is_tuple = state_is_tuple 734 if activation: 735 self._activation = activations.get(activation) 736 else: 737 self._activation = math_ops.tanh 738 739 @property 740 def state_size(self): 741 return (LSTMStateTuple(self._num_units, self._num_units) 742 if self._state_is_tuple else 2 * self._num_units) 743 744 @property 745 def output_size(self): 746 return self._num_units 747 748 @tf_utils.shape_type_conversion 749 def build(self, inputs_shape): 750 if inputs_shape[-1] is None: 751 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 752 str(inputs_shape)) 753 _check_supported_dtypes(self.dtype) 754 input_depth = inputs_shape[-1] 755 h_depth = self._num_units 756 self._kernel = self.add_variable( 757 _WEIGHTS_VARIABLE_NAME, 758 shape=[input_depth + h_depth, 4 * self._num_units]) 759 self._bias = self.add_variable( 760 _BIAS_VARIABLE_NAME, 761 shape=[4 * self._num_units], 762 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 763 764 self.built = True 765 766 def call(self, inputs, state): 767 """Long short-term memory cell (LSTM). 768 769 Args: 770 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 771 state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, 772 num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a 773 `Tensor` shaped `[batch_size, 2 * num_units]`. 774 775 Returns: 776 A pair containing the new hidden state, and the new state (either a 777 `LSTMStateTuple` or a concatenated state, depending on 778 `state_is_tuple`). 779 """ 780 _check_rnn_cell_input_dtypes([inputs, state]) 781 782 sigmoid = math_ops.sigmoid 783 one = constant_op.constant(1, dtype=dtypes.int32) 784 # Parameters of gates are concatenated into one multiply for efficiency. 785 if self._state_is_tuple: 786 c, h = state 787 else: 788 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) 789 790 gate_inputs = math_ops.matmul( 791 array_ops.concat([inputs, h], 1), self._kernel) 792 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 793 794 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 795 i, j, f, o = array_ops.split( 796 value=gate_inputs, num_or_size_splits=4, axis=one) 797 798 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 799 # Note that using `add` and `multiply` instead of `+` and `*` gives a 800 # performance improvement. So using those at the cost of readability. 801 add = math_ops.add 802 multiply = math_ops.multiply 803 new_c = add( 804 multiply(c, sigmoid(add(f, forget_bias_tensor))), 805 multiply(sigmoid(i), self._activation(j))) 806 new_h = multiply(self._activation(new_c), sigmoid(o)) 807 808 if self._state_is_tuple: 809 new_state = LSTMStateTuple(new_c, new_h) 810 else: 811 new_state = array_ops.concat([new_c, new_h], 1) 812 return new_h, new_state 813 814 def get_config(self): 815 config = { 816 "num_units": self._num_units, 817 "forget_bias": self._forget_bias, 818 "state_is_tuple": self._state_is_tuple, 819 "activation": activations.serialize(self._activation), 820 "reuse": self._reuse, 821 } 822 base_config = super(BasicLSTMCell, self).get_config() 823 return dict(list(base_config.items()) + list(config.items())) 824 825 826@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"]) 827@tf_export(v1=["nn.rnn_cell.LSTMCell"]) 828class LSTMCell(LayerRNNCell): 829 """Long short-term memory unit (LSTM) recurrent network cell. 830 831 The default non-peephole implementation is based on (Gers et al., 1999). 832 The peephole implementation is based on (Sak et al., 2014). 833 834 The class uses optional peep-hole connections, optional cell clipping, and 835 an optional projection layer. 836 837 Note that this cell is not optimized for performance. Please use 838 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 839 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 840 better performance on CPU. 841 References: 842 Long short-term memory recurrent neural network architectures for large 843 scale acoustic modeling: 844 [Sak et al., 2014] 845 (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) 846 ([pdf] 847 (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) 848 Learning to forget: 849 [Gers et al., 1999] 850 (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) 851 ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) 852 Long Short-Term Memory: 853 [Hochreiter et al., 1997] 854 (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) 855 ([pdf](http://ml.jku.at/publications/older/3504.pdf)) 856 """ 857 858 def __init__(self, 859 num_units, 860 use_peepholes=False, 861 cell_clip=None, 862 initializer=None, 863 num_proj=None, 864 proj_clip=None, 865 num_unit_shards=None, 866 num_proj_shards=None, 867 forget_bias=1.0, 868 state_is_tuple=True, 869 activation=None, 870 reuse=None, 871 name=None, 872 dtype=None, 873 **kwargs): 874 """Initialize the parameters for an LSTM cell. 875 876 Args: 877 num_units: int, The number of units in the LSTM cell. 878 use_peepholes: bool, set True to enable diagonal/peephole connections. 879 cell_clip: (optional) A float value, if provided the cell state is clipped 880 by this value prior to the cell output activation. 881 initializer: (optional) The initializer to use for the weight and 882 projection matrices. 883 num_proj: (optional) int, The output dimensionality for the projection 884 matrices. If None, no projection is performed. 885 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 886 provided, then the projected values are clipped elementwise to within 887 `[-proj_clip, proj_clip]`. 888 num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a 889 variable_scope partitioner instead. 890 num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a 891 variable_scope partitioner instead. 892 forget_bias: Biases of the forget gate are initialized by default to 1 in 893 order to reduce the scale of forgetting at the beginning of the 894 training. Must set it manually to `0.0` when restoring from CudnnLSTM 895 trained checkpoints. 896 state_is_tuple: If True, accepted and returned states are 2-tuples of the 897 `c_state` and `m_state`. If False, they are concatenated along the 898 column axis. This latter behavior will soon be deprecated. 899 activation: Activation function of the inner states. Default: `tanh`. It 900 could also be string that is within Keras activation function names. 901 reuse: (optional) Python boolean describing whether to reuse variables in 902 an existing scope. If not `True`, and the existing scope already has 903 the given variables, an error is raised. 904 name: String, the name of the layer. Layers with the same name will share 905 weights, but to avoid mistakes we require reuse=True in such cases. 906 dtype: Default dtype of the layer (default of `None` means use the type of 907 the first input). Required when `build` is called before `call`. 908 **kwargs: Dict, keyword named properties for common layer attributes, like 909 `trainable` etc when constructing the cell from configs of get_config(). 910 When restoring from CudnnLSTM-trained checkpoints, use 911 `CudnnCompatibleLSTMCell` instead. 912 """ 913 warnings.warn("`tf.nn.rnn_cell.LSTMCell` is deprecated and will be " 914 "removed in a future version. This class " 915 "is equivalent as `tf.keras.layers.LSTMCell`, " 916 "and will be replaced by that in Tensorflow 2.0.") 917 super(LSTMCell, self).__init__( 918 _reuse=reuse, name=name, dtype=dtype, **kwargs) 919 _check_supported_dtypes(self.dtype) 920 if not state_is_tuple: 921 logging.warning( 922 "%s: Using a concatenated state is slower and will soon be " 923 "deprecated. Use state_is_tuple=True.", self) 924 if num_unit_shards is not None or num_proj_shards is not None: 925 logging.warning( 926 "%s: The num_unit_shards and proj_unit_shards parameters are " 927 "deprecated and will be removed in Jan 2017. " 928 "Use a variable scope with a partitioner instead.", self) 929 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 930 logging.warning( 931 "%s: Note that this cell is not optimized for performance. " 932 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 933 "performance on GPU.", self) 934 935 # Inputs must be 2-dimensional. 936 self.input_spec = input_spec.InputSpec(ndim=2) 937 938 self._num_units = num_units 939 self._use_peepholes = use_peepholes 940 self._cell_clip = cell_clip 941 self._initializer = initializers.get(initializer) 942 self._num_proj = num_proj 943 self._proj_clip = proj_clip 944 self._num_unit_shards = num_unit_shards 945 self._num_proj_shards = num_proj_shards 946 self._forget_bias = forget_bias 947 self._state_is_tuple = state_is_tuple 948 if activation: 949 self._activation = activations.get(activation) 950 else: 951 self._activation = math_ops.tanh 952 953 if num_proj: 954 self._state_size = ( 955 LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + 956 num_proj) 957 self._output_size = num_proj 958 else: 959 self._state_size = ( 960 LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * 961 num_units) 962 self._output_size = num_units 963 964 @property 965 def state_size(self): 966 return self._state_size 967 968 @property 969 def output_size(self): 970 return self._output_size 971 972 @tf_utils.shape_type_conversion 973 def build(self, inputs_shape): 974 if inputs_shape[-1] is None: 975 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 976 str(inputs_shape)) 977 _check_supported_dtypes(self.dtype) 978 input_depth = inputs_shape[-1] 979 h_depth = self._num_units if self._num_proj is None else self._num_proj 980 maybe_partitioner = ( 981 partitioned_variables.fixed_size_partitioner(self._num_unit_shards) 982 if self._num_unit_shards is not None else None) 983 self._kernel = self.add_variable( 984 _WEIGHTS_VARIABLE_NAME, 985 shape=[input_depth + h_depth, 4 * self._num_units], 986 initializer=self._initializer, 987 partitioner=maybe_partitioner) 988 if self.dtype is None: 989 initializer = init_ops.zeros_initializer 990 else: 991 initializer = init_ops.zeros_initializer(dtype=self.dtype) 992 self._bias = self.add_variable( 993 _BIAS_VARIABLE_NAME, 994 shape=[4 * self._num_units], 995 initializer=initializer) 996 if self._use_peepholes: 997 self._w_f_diag = self.add_variable( 998 "w_f_diag", shape=[self._num_units], initializer=self._initializer) 999 self._w_i_diag = self.add_variable( 1000 "w_i_diag", shape=[self._num_units], initializer=self._initializer) 1001 self._w_o_diag = self.add_variable( 1002 "w_o_diag", shape=[self._num_units], initializer=self._initializer) 1003 1004 if self._num_proj is not None: 1005 maybe_proj_partitioner = ( 1006 partitioned_variables.fixed_size_partitioner(self._num_proj_shards) 1007 if self._num_proj_shards is not None else None) 1008 self._proj_kernel = self.add_variable( 1009 "projection/%s" % _WEIGHTS_VARIABLE_NAME, 1010 shape=[self._num_units, self._num_proj], 1011 initializer=self._initializer, 1012 partitioner=maybe_proj_partitioner) 1013 1014 self.built = True 1015 1016 def call(self, inputs, state): 1017 """Run one step of LSTM. 1018 1019 Args: 1020 inputs: input Tensor, must be 2-D, `[batch, input_size]`. 1021 state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, 1022 [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple 1023 of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. 1024 1025 Returns: 1026 A tuple containing: 1027 1028 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 1029 LSTM after reading `inputs` when previous state was `state`. 1030 Here output_dim is: 1031 num_proj if num_proj was set, 1032 num_units otherwise. 1033 - Tensor(s) representing the new state of LSTM after reading `inputs` when 1034 the previous state was `state`. Same type and shape(s) as `state`. 1035 1036 Raises: 1037 ValueError: If input size cannot be inferred from inputs via 1038 static shape inference. 1039 """ 1040 _check_rnn_cell_input_dtypes([inputs, state]) 1041 1042 num_proj = self._num_units if self._num_proj is None else self._num_proj 1043 sigmoid = math_ops.sigmoid 1044 1045 if self._state_is_tuple: 1046 (c_prev, m_prev) = state 1047 else: 1048 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 1049 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 1050 1051 input_size = inputs.get_shape().with_rank(2).dims[1].value 1052 if input_size is None: 1053 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1054 1055 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 1056 lstm_matrix = math_ops.matmul( 1057 array_ops.concat([inputs, m_prev], 1), self._kernel) 1058 lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) 1059 1060 i, j, f, o = array_ops.split( 1061 value=lstm_matrix, num_or_size_splits=4, axis=1) 1062 # Diagonal connections 1063 if self._use_peepholes: 1064 c = ( 1065 sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + 1066 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 1067 else: 1068 c = ( 1069 sigmoid(f + self._forget_bias) * c_prev + 1070 sigmoid(i) * self._activation(j)) 1071 1072 if self._cell_clip is not None: 1073 # pylint: disable=invalid-unary-operand-type 1074 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 1075 # pylint: enable=invalid-unary-operand-type 1076 if self._use_peepholes: 1077 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 1078 else: 1079 m = sigmoid(o) * self._activation(c) 1080 1081 if self._num_proj is not None: 1082 m = math_ops.matmul(m, self._proj_kernel) 1083 1084 if self._proj_clip is not None: 1085 # pylint: disable=invalid-unary-operand-type 1086 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 1087 # pylint: enable=invalid-unary-operand-type 1088 1089 new_state = ( 1090 LSTMStateTuple(c, m) 1091 if self._state_is_tuple else array_ops.concat([c, m], 1)) 1092 return m, new_state 1093 1094 def get_config(self): 1095 config = { 1096 "num_units": self._num_units, 1097 "use_peepholes": self._use_peepholes, 1098 "cell_clip": self._cell_clip, 1099 "initializer": initializers.serialize(self._initializer), 1100 "num_proj": self._num_proj, 1101 "proj_clip": self._proj_clip, 1102 "num_unit_shards": self._num_unit_shards, 1103 "num_proj_shards": self._num_proj_shards, 1104 "forget_bias": self._forget_bias, 1105 "state_is_tuple": self._state_is_tuple, 1106 "activation": activations.serialize(self._activation), 1107 "reuse": self._reuse, 1108 } 1109 base_config = super(LSTMCell, self).get_config() 1110 return dict(list(base_config.items()) + list(config.items())) 1111 1112 1113class _RNNCellWrapperV1(RNNCell): 1114 """Base class for cells wrappers V1 compatibility. 1115 1116 This class along with `_RNNCellWrapperV2` allows to define cells wrappers that 1117 are compatible with V1 and V2, and defines helper methods for this purpose. 1118 """ 1119 1120 def __init__(self, cell, *args, **kwargs): 1121 super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) 1122 assert_like_rnncell("cell", cell) 1123 self.cell = cell 1124 if isinstance(cell, trackable.Trackable): 1125 self._track_trackable(self.cell, name="cell") 1126 1127 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 1128 """Calls the wrapped cell and performs the wrapping logic. 1129 1130 This method is called from the wrapper's `call` or `__call__` methods. 1131 1132 Args: 1133 inputs: A tensor with wrapped cell's input. 1134 state: A tensor or tuple of tensors with wrapped cell's state. 1135 cell_call_fn: Wrapped cell's method to use for step computation (cell's 1136 `__call__` or 'call' method). 1137 **kwargs: Additional arguments. 1138 1139 Returns: 1140 A pair containing: 1141 - Output: A tensor with cell's output. 1142 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1143 """ 1144 raise NotImplementedError 1145 1146 def __call__(self, inputs, state, scope=None): 1147 """Runs the RNN cell step computation. 1148 1149 We assume that the wrapped RNNCell is being built within its `__call__` 1150 method. We directly use the wrapped cell's `__call__` in the overridden 1151 wrapper `__call__` method. 1152 1153 This allows to use the wrapped cell and the non-wrapped cell equivalently 1154 when using `__call__`. 1155 1156 Args: 1157 inputs: A tensor with wrapped cell's input. 1158 state: A tensor or tuple of tensors with wrapped cell's state. 1159 scope: VariableScope for the subgraph created in the wrapped cells' 1160 `__call__`. 1161 1162 Returns: 1163 A pair containing: 1164 1165 - Output: A tensor with cell's output. 1166 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1167 """ 1168 return self._call_wrapped_cell( 1169 inputs, state, cell_call_fn=self.cell.__call__, scope=scope) 1170 1171 def get_config(self): 1172 config = { 1173 "cell": { 1174 "class_name": self.cell.__class__.__name__, 1175 "config": self.cell.get_config() 1176 }, 1177 } 1178 base_config = super(_RNNCellWrapperV1, self).get_config() 1179 return dict(list(base_config.items()) + list(config.items())) 1180 1181 @classmethod 1182 def from_config(cls, config, custom_objects=None): 1183 config = config.copy() 1184 cell = config.pop("cell") 1185 try: 1186 assert_like_rnncell("cell", cell) 1187 return cls(cell, **config) 1188 except TypeError: 1189 raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " 1190 "Please overwrite the cell in the config with a RNNCell " 1191 "instance.") 1192 1193 1194@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DropoutWrapper"]) 1195@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) 1196class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, 1197 _RNNCellWrapperV1): 1198 """Operator adding dropout to inputs and outputs of the given cell.""" 1199 1200 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 1201 super(DropoutWrapper, self).__init__(*args, **kwargs) 1202 1203 __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ 1204 1205 1206@keras_export(v1=["keras.__internal__.legacy.rnn_cell.ResidualWrapper"]) 1207@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) 1208class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, 1209 _RNNCellWrapperV1): 1210 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 1211 1212 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 1213 super(ResidualWrapper, self).__init__(*args, **kwargs) 1214 1215 __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ 1216 1217 1218@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DeviceWrapper"]) 1219@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) 1220class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, 1221 _RNNCellWrapperV1): 1222 1223 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 1224 super(DeviceWrapper, self).__init__(*args, **kwargs) 1225 1226 __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ 1227 1228 1229@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"]) 1230@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) 1231class MultiRNNCell(RNNCell): 1232 """RNN cell composed sequentially of multiple simple cells. 1233 1234 Example: 1235 1236 ```python 1237 num_units = [128, 64] 1238 cells = [BasicLSTMCell(num_units=n) for n in num_units] 1239 stacked_rnn_cell = MultiRNNCell(cells) 1240 ``` 1241 """ 1242 1243 def __init__(self, cells, state_is_tuple=True): 1244 """Create a RNN cell composed sequentially of a number of RNNCells. 1245 1246 Args: 1247 cells: list of RNNCells that will be composed in this order. 1248 state_is_tuple: If True, accepted and returned states are n-tuples, where 1249 `n = len(cells)`. If False, the states are all concatenated along the 1250 column axis. This latter behavior will soon be deprecated. 1251 1252 Raises: 1253 ValueError: if cells is empty (not allowed), or at least one of the cells 1254 returns a state tuple but the flag `state_is_tuple` is `False`. 1255 """ 1256 logging.warning("`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class " 1257 "is equivalent as `tf.keras.layers.StackedRNNCells`, " 1258 "and will be replaced by that in Tensorflow 2.0.") 1259 super(MultiRNNCell, self).__init__() 1260 if not cells: 1261 raise ValueError("Must specify at least one cell for MultiRNNCell.") 1262 if not nest.is_nested(cells): 1263 raise TypeError("cells must be a list or tuple, but saw: %s." % cells) 1264 1265 if len(set(id(cell) for cell in cells)) < len(cells): 1266 logging.log_first_n( 1267 logging.WARN, "At least two cells provided to MultiRNNCell " 1268 "are the same object and will share weights.", 1) 1269 1270 self._cells = cells 1271 for cell_number, cell in enumerate(self._cells): 1272 # Add Trackable dependencies on these cells so their variables get 1273 # saved with this object when using object-based saving. 1274 if isinstance(cell, trackable.Trackable): 1275 # TODO(allenl): Track down non-Trackable callers. 1276 self._track_trackable(cell, name="cell-%d" % (cell_number,)) 1277 self._state_is_tuple = state_is_tuple 1278 if not state_is_tuple: 1279 if any(nest.is_nested(c.state_size) for c in self._cells): 1280 raise ValueError("Some cells return tuples of states, but the flag " 1281 "state_is_tuple is not set. State sizes are: %s" % 1282 str([c.state_size for c in self._cells])) 1283 1284 @property 1285 def state_size(self): 1286 if self._state_is_tuple: 1287 return tuple(cell.state_size for cell in self._cells) 1288 else: 1289 return sum(cell.state_size for cell in self._cells) 1290 1291 @property 1292 def output_size(self): 1293 return self._cells[-1].output_size 1294 1295 def zero_state(self, batch_size, dtype): 1296 with backend.name_scope(type(self).__name__ + "ZeroState"): 1297 if self._state_is_tuple: 1298 return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) 1299 else: 1300 # We know here that state_size of each cell is not a tuple and 1301 # presumably does not contain TensorArrays or anything else fancy 1302 return super(MultiRNNCell, self).zero_state(batch_size, dtype) 1303 1304 @property 1305 def trainable_weights(self): 1306 if not self.trainable: 1307 return [] 1308 weights = [] 1309 for cell in self._cells: 1310 if isinstance(cell, base_layer.Layer): 1311 weights += cell.trainable_weights 1312 return weights 1313 1314 @property 1315 def non_trainable_weights(self): 1316 weights = [] 1317 for cell in self._cells: 1318 if isinstance(cell, base_layer.Layer): 1319 weights += cell.non_trainable_weights 1320 if not self.trainable: 1321 trainable_weights = [] 1322 for cell in self._cells: 1323 if isinstance(cell, base_layer.Layer): 1324 trainable_weights += cell.trainable_weights 1325 return trainable_weights + weights 1326 return weights 1327 1328 def call(self, inputs, state): 1329 """Run this multi-layer cell on inputs, starting from state.""" 1330 cur_state_pos = 0 1331 cur_inp = inputs 1332 new_states = [] 1333 for i, cell in enumerate(self._cells): 1334 with vs.variable_scope("cell_%d" % i): 1335 if self._state_is_tuple: 1336 if not nest.is_nested(state): 1337 raise ValueError( 1338 "Expected state to be a tuple of length %d, but received: %s" % 1339 (len(self.state_size), state)) 1340 cur_state = state[i] 1341 else: 1342 cur_state = array_ops.slice(state, [0, cur_state_pos], 1343 [-1, cell.state_size]) 1344 cur_state_pos += cell.state_size 1345 cur_inp, new_state = cell(cur_inp, cur_state) 1346 new_states.append(new_state) 1347 1348 new_states = ( 1349 tuple(new_states) if self._state_is_tuple else array_ops.concat( 1350 new_states, 1)) 1351 1352 return cur_inp, new_states 1353 1354 1355def _check_rnn_cell_input_dtypes(inputs): 1356 """Check whether the input tensors are with supported dtypes. 1357 1358 Default RNN cells only support floats and complex as its dtypes since the 1359 activation function (tanh and sigmoid) only allow those types. This function 1360 will throw a proper error message if the inputs is not in a supported type. 1361 1362 Args: 1363 inputs: tensor or nested structure of tensors that are feed to RNN cell as 1364 input or state. 1365 1366 Raises: 1367 ValueError: if any of the input tensor are not having dtypes of float or 1368 complex. 1369 """ 1370 for t in nest.flatten(inputs): 1371 _check_supported_dtypes(t.dtype) 1372 1373 1374def _check_supported_dtypes(dtype): 1375 if dtype is None: 1376 return 1377 dtype = dtypes.as_dtype(dtype) 1378 if not (dtype.is_floating or dtype.is_complex): 1379 raise ValueError("RNN cell only supports floating point inputs, " 1380 "but saw dtype: %s" % dtype) 1381