1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module contains the implementation of RNN cell wrappers.""" 16import hashlib 17import numbers 18import sys 19import types as python_types 20import warnings 21 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.keras.utils import generic_utils 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nn_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops import tensor_array_ops 31from tensorflow.python.util import nest 32 33 34class DropoutWrapperBase(object): 35 """Operator adding dropout to inputs and outputs of the given cell.""" 36 37 def __init__(self, 38 cell, 39 input_keep_prob=1.0, 40 output_keep_prob=1.0, 41 state_keep_prob=1.0, 42 variational_recurrent=False, 43 input_size=None, 44 dtype=None, 45 seed=None, 46 dropout_state_filter_visitor=None, 47 **kwargs): 48 """Create a cell with added input, state, and/or output dropout. 49 50 If `variational_recurrent` is set to `True` (**NOT** the default behavior), 51 then the same dropout mask is applied at every step, as described in: 52 [A Theoretically Grounded Application of Dropout in Recurrent 53 Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). 54 55 Otherwise a different dropout mask is applied at every time step. 56 57 Note, by default (unless a custom `dropout_state_filter` is provided), 58 the memory state (`c` component of any `LSTMStateTuple`) passing through 59 a `DropoutWrapper` is never modified. This behavior is described in the 60 above article. 61 62 Args: 63 cell: an RNNCell, a projection to output_size is added to it. 64 input_keep_prob: unit Tensor or float between 0 and 1, input keep 65 probability; if it is constant and 1, no input dropout will be added. 66 output_keep_prob: unit Tensor or float between 0 and 1, output keep 67 probability; if it is constant and 1, no output dropout will be added. 68 state_keep_prob: unit Tensor or float between 0 and 1, output keep 69 probability; if it is constant and 1, no output dropout will be added. 70 State dropout is performed on the outgoing states of the cell. **Note** 71 the state components to which dropout is applied when `state_keep_prob` 72 is in `(0, 1)` are also determined by the argument 73 `dropout_state_filter_visitor` (e.g. by default dropout is never applied 74 to the `c` component of an `LSTMStateTuple`). 75 variational_recurrent: Python bool. If `True`, then the same dropout 76 pattern is applied across all time steps per run call. If this parameter 77 is set, `input_size` **must** be provided. 78 input_size: (optional) (possibly nested tuple of) `TensorShape` objects 79 containing the depth(s) of the input tensors expected to be passed in to 80 the `DropoutWrapper`. Required and used **iff** `variational_recurrent 81 = True` and `input_keep_prob < 1`. 82 dtype: (optional) The `dtype` of the input, state, and output tensors. 83 Required and used **iff** `variational_recurrent = True`. 84 seed: (optional) integer, the randomness seed. 85 dropout_state_filter_visitor: (optional), default: (see below). Function 86 that takes any hierarchical level of the state and returns a scalar or 87 depth=1 structure of Python booleans describing which terms in the state 88 should be dropped out. In addition, if the function returns `True`, 89 dropout is applied across this sublevel. If the function returns 90 `False`, dropout is not applied across this entire sublevel. 91 Default behavior: perform dropout on all terms except the memory (`c`) 92 state of `LSTMCellState` objects, and don't try to apply dropout to 93 `TensorArray` objects: ``` 94 def dropout_state_filter_visitor(s): 95 if isinstance(s, LSTMCellState): # Never perform dropout on the c 96 state. return LSTMCellState(c=False, h=True) 97 elif isinstance(s, TensorArray): return False return True ``` 98 **kwargs: dict of keyword arguments for base layer. 99 100 Raises: 101 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided 102 but not `callable`. 103 ValueError: if any of the keep_probs are not between 0 and 1. 104 """ 105 super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) 106 107 if (dropout_state_filter_visitor is not None and 108 not callable(dropout_state_filter_visitor)): 109 raise TypeError("dropout_state_filter_visitor must be callable") 110 self._dropout_state_filter = ( 111 dropout_state_filter_visitor or _default_dropout_state_filter_visitor) 112 with ops.name_scope_v2("DropoutWrapperInit"): 113 114 def tensor_and_const_value(v): 115 tensor_value = ops.convert_to_tensor_v2_with_dispatch(v) 116 const_value = tensor_util.constant_value(tensor_value) 117 return (tensor_value, const_value) 118 119 for prob, attr in [(input_keep_prob, "input_keep_prob"), 120 (state_keep_prob, "state_keep_prob"), 121 (output_keep_prob, "output_keep_prob")]: 122 tensor_prob, const_prob = tensor_and_const_value(prob) 123 if const_prob is not None: 124 if const_prob < 0 or const_prob > 1: 125 raise ValueError("Parameter %s must be between 0 and 1: %d" % 126 (attr, const_prob)) 127 setattr(self, "_%s" % attr, float(const_prob)) 128 else: 129 setattr(self, "_%s" % attr, tensor_prob) 130 131 # Set variational_recurrent, seed before running the code below 132 self._variational_recurrent = variational_recurrent 133 self._input_size = input_size 134 self._seed = seed 135 136 self._recurrent_input_noise = None 137 self._recurrent_state_noise = None 138 self._recurrent_output_noise = None 139 140 if variational_recurrent: 141 if dtype is None: 142 raise ValueError( 143 "When variational_recurrent=True, dtype must be provided") 144 145 def convert_to_batch_shape(s): 146 # Prepend a 1 for the batch dimension; for recurrent 147 # variational dropout we use the same dropout mask for all 148 # batch elements. 149 return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) 150 151 def batch_noise(s, inner_seed): 152 shape = convert_to_batch_shape(s) 153 return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) 154 155 if (not isinstance(self._input_keep_prob, numbers.Real) or 156 self._input_keep_prob < 1.0): 157 if input_size is None: 158 raise ValueError( 159 "When variational_recurrent=True and input_keep_prob < 1.0 or " 160 "is unknown, input_size must be provided") 161 self._recurrent_input_noise = _enumerated_map_structure_up_to( 162 input_size, 163 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), 164 input_size) 165 self._recurrent_state_noise = _enumerated_map_structure_up_to( 166 cell.state_size, 167 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), 168 cell.state_size) 169 self._recurrent_output_noise = _enumerated_map_structure_up_to( 170 cell.output_size, 171 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), 172 cell.output_size) 173 174 def _gen_seed(self, salt_prefix, index): 175 if self._seed is None: 176 return None 177 salt = "%s_%d" % (salt_prefix, index) 178 string = (str(self._seed) + salt).encode("utf-8") 179 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 180 181 @property 182 def wrapped_cell(self): 183 return self.cell 184 185 @property 186 def state_size(self): 187 return self.cell.state_size 188 189 @property 190 def output_size(self): 191 return self.cell.output_size 192 193 def build(self, inputs_shape): 194 self.cell.build(inputs_shape) 195 self.built = True 196 197 def zero_state(self, batch_size, dtype): 198 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 199 return self.cell.zero_state(batch_size, dtype) 200 201 def _variational_recurrent_dropout_value( 202 self, unused_index, value, noise, keep_prob): 203 """Performs dropout given the pre-calculated noise tensor.""" 204 # uniform [keep_prob, 1.0 + keep_prob) 205 random_tensor = keep_prob + noise 206 207 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) 208 binary_tensor = math_ops.floor(random_tensor) 209 ret = math_ops.divide(value, keep_prob) * binary_tensor 210 ret.set_shape(value.get_shape()) 211 return ret 212 213 def _dropout(self, 214 values, 215 salt_prefix, 216 recurrent_noise, 217 keep_prob, 218 shallow_filtered_substructure=None): 219 """Decides whether to perform standard dropout or recurrent dropout.""" 220 221 if shallow_filtered_substructure is None: 222 # Put something so we traverse the entire structure; inside the 223 # dropout function we check to see if leafs of this are bool or not. 224 shallow_filtered_substructure = values 225 226 if not self._variational_recurrent: 227 228 def dropout(i, do_dropout, v): 229 if not isinstance(do_dropout, bool) or do_dropout: 230 return nn_ops.dropout_v2( 231 v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) 232 else: 233 return v 234 235 return _enumerated_map_structure_up_to( 236 shallow_filtered_substructure, dropout, 237 *[shallow_filtered_substructure, values]) 238 else: 239 240 def dropout(i, do_dropout, v, n): 241 if not isinstance(do_dropout, bool) or do_dropout: 242 return self._variational_recurrent_dropout_value(i, v, n, keep_prob) 243 else: 244 return v 245 246 return _enumerated_map_structure_up_to( 247 shallow_filtered_substructure, dropout, 248 *[shallow_filtered_substructure, values, recurrent_noise]) 249 250 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 251 """Runs the wrapped cell and applies dropout. 252 253 Args: 254 inputs: A tensor with wrapped cell's input. 255 state: A tensor or tuple of tensors with wrapped cell's state. 256 cell_call_fn: Wrapped cell's method to use for step computation (cell's 257 `__call__` or 'call' method). 258 **kwargs: Additional arguments. 259 260 Returns: 261 A pair containing: 262 263 - Output: A tensor with cell's output. 264 - New state: A tensor or tuple of tensors with new wrapped cell's state. 265 """ 266 267 def _should_dropout(p): 268 return (not isinstance(p, float)) or p < 1 269 270 if _should_dropout(self._input_keep_prob): 271 inputs = self._dropout(inputs, "input", self._recurrent_input_noise, 272 self._input_keep_prob) 273 output, new_state = cell_call_fn(inputs, state, **kwargs) 274 if _should_dropout(self._state_keep_prob): 275 # Identify which subsets of the state to perform dropout on and 276 # which ones to keep. 277 shallow_filtered_substructure = nest.get_traverse_shallow_structure( 278 self._dropout_state_filter, new_state) 279 new_state = self._dropout(new_state, "state", self._recurrent_state_noise, 280 self._state_keep_prob, 281 shallow_filtered_substructure) 282 if _should_dropout(self._output_keep_prob): 283 output = self._dropout(output, "output", self._recurrent_output_noise, 284 self._output_keep_prob) 285 return output, new_state 286 287 def get_config(self): 288 """Returns the config of the dropout wrapper.""" 289 config = { 290 "input_keep_prob": self._input_keep_prob, 291 "output_keep_prob": self._output_keep_prob, 292 "state_keep_prob": self._state_keep_prob, 293 "variational_recurrent": self._variational_recurrent, 294 "input_size": self._input_size, 295 "seed": self._seed, 296 } 297 if self._dropout_state_filter != _default_dropout_state_filter_visitor: 298 function, function_type, function_module = _serialize_function_to_config( 299 self._dropout_state_filter) 300 config.update({"dropout_fn": function, 301 "dropout_fn_type": function_type, 302 "dropout_fn_module": function_module}) 303 base_config = super(DropoutWrapperBase, self).get_config() 304 return dict(list(base_config.items()) + list(config.items())) 305 306 @classmethod 307 def from_config(cls, config, custom_objects=None): 308 if "dropout_fn" in config: 309 config = config.copy() 310 dropout_state_filter = _parse_config_to_function( 311 config, custom_objects, "dropout_fn", "dropout_fn_type", 312 "dropout_fn_module") 313 config.pop("dropout_fn") 314 config["dropout_state_filter_visitor"] = dropout_state_filter 315 return super(DropoutWrapperBase, cls).from_config( 316 config, custom_objects=custom_objects) 317 318 319class ResidualWrapperBase(object): 320 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 321 322 def __init__(self, cell, residual_fn=None, **kwargs): 323 """Constructs a `ResidualWrapper` for `cell`. 324 325 Args: 326 cell: An instance of `RNNCell`. 327 residual_fn: (Optional) The function to map raw cell inputs and raw cell 328 outputs to the actual cell outputs of the residual network. 329 Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 330 and outputs. 331 **kwargs: dict of keyword arguments for base layer. 332 """ 333 super(ResidualWrapperBase, self).__init__(cell, **kwargs) 334 self._residual_fn = residual_fn 335 336 @property 337 def state_size(self): 338 return self.cell.state_size 339 340 @property 341 def output_size(self): 342 return self.cell.output_size 343 344 def zero_state(self, batch_size, dtype): 345 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 346 return self.cell.zero_state(batch_size, dtype) 347 348 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 349 """Run the cell and then apply the residual_fn on its inputs to its outputs. 350 351 Args: 352 inputs: cell inputs. 353 state: cell state. 354 cell_call_fn: Wrapped cell's method to use for step computation (cell's 355 `__call__` or 'call' method). 356 **kwargs: Additional arguments passed to the wrapped cell's `call`. 357 358 Returns: 359 Tuple of cell outputs and new state. 360 361 Raises: 362 TypeError: If cell inputs and outputs have different structure (type). 363 ValueError: If cell inputs and outputs have different structure (value). 364 """ 365 outputs, new_state = cell_call_fn(inputs, state, **kwargs) 366 367 # Ensure shapes match 368 def assert_shape_match(inp, out): 369 inp.get_shape().assert_is_compatible_with(out.get_shape()) 370 371 def default_residual_fn(inputs, outputs): 372 nest.assert_same_structure(inputs, outputs) 373 nest.map_structure(assert_shape_match, inputs, outputs) 374 return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) 375 376 res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) 377 return (res_outputs, new_state) 378 379 def get_config(self): 380 """Returns the config of the residual wrapper.""" 381 if self._residual_fn is not None: 382 function, function_type, function_module = _serialize_function_to_config( 383 self._residual_fn) 384 config = { 385 "residual_fn": function, 386 "residual_fn_type": function_type, 387 "residual_fn_module": function_module 388 } 389 else: 390 config = {} 391 base_config = super(ResidualWrapperBase, self).get_config() 392 return dict(list(base_config.items()) + list(config.items())) 393 394 @classmethod 395 def from_config(cls, config, custom_objects=None): 396 if "residual_fn" in config: 397 config = config.copy() 398 residual_function = _parse_config_to_function(config, custom_objects, 399 "residual_fn", 400 "residual_fn_type", 401 "residual_fn_module") 402 config["residual_fn"] = residual_function 403 return super(ResidualWrapperBase, cls).from_config( 404 config, custom_objects=custom_objects) 405 406 407class DeviceWrapperBase(object): 408 """Operator that ensures an RNNCell runs on a particular device.""" 409 410 def __init__(self, cell, device, **kwargs): 411 """Construct a `DeviceWrapper` for `cell` with device `device`. 412 413 Ensures the wrapped `cell` is called with `tf.device(device)`. 414 415 Args: 416 cell: An instance of `RNNCell`. 417 device: A device string or function, for passing to `tf.device`. 418 **kwargs: dict of keyword arguments for base layer. 419 """ 420 super(DeviceWrapperBase, self).__init__(cell, **kwargs) 421 self._device = device 422 423 @property 424 def state_size(self): 425 return self.cell.state_size 426 427 @property 428 def output_size(self): 429 return self.cell.output_size 430 431 def zero_state(self, batch_size, dtype): 432 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 433 with ops.device(self._device): 434 return self.cell.zero_state(batch_size, dtype) 435 436 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 437 """Run the cell on specified device.""" 438 with ops.device(self._device): 439 return cell_call_fn(inputs, state, **kwargs) 440 441 def get_config(self): 442 config = {"device": self._device} 443 base_config = super(DeviceWrapperBase, self).get_config() 444 return dict(list(base_config.items()) + list(config.items())) 445 446 447def _serialize_function_to_config(function): 448 """Serialize the function for get_config().""" 449 if isinstance(function, python_types.LambdaType): 450 output = generic_utils.func_dump(function) 451 output_type = "lambda" 452 module = function.__module__ 453 elif callable(function): 454 output = function.__name__ 455 output_type = "function" 456 module = function.__module__ 457 else: 458 raise ValueError("Unrecognized function type for input: {}".format( 459 type(function))) 460 461 return output, output_type, module 462 463 464def _parse_config_to_function(config, custom_objects, func_attr_name, 465 func_type_attr_name, module_attr_name): 466 """Reconstruct the function from the config.""" 467 globs = globals() 468 module = config.pop(module_attr_name, None) 469 if module in sys.modules: 470 globs.update(sys.modules[module].__dict__) 471 elif module is not None: 472 # Note: we don't know the name of the function if it's a lambda. 473 warnings.warn("{} is not loaded, but a layer uses it. " 474 "It may cause errors.".format(module), UserWarning) 475 if custom_objects: 476 globs.update(custom_objects) 477 function_type = config.pop(func_type_attr_name) 478 if function_type == "function": 479 # Simple lookup in custom objects 480 function = generic_utils.deserialize_keras_object( 481 config[func_attr_name], 482 custom_objects=custom_objects, 483 printable_module_name="function in wrapper") 484 elif function_type == "lambda": 485 # Unsafe deserialization from bytecode 486 function = generic_utils.func_load( 487 config[func_attr_name], globs=globs) 488 else: 489 raise TypeError("Unknown function type:", function_type) 490 return function 491 492 493def _default_dropout_state_filter_visitor(substate): 494 from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top 495 if isinstance(substate, LSTMStateTuple): 496 # Do not perform dropout on the memory state. 497 return LSTMStateTuple(c=False, h=True) 498 elif isinstance(substate, tensor_array_ops.TensorArray): 499 return False 500 return True 501 502 503def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): 504 ix = [0] 505 506 def enumerated_fn(*inner_args, **inner_kwargs): 507 r = map_fn(ix[0], *inner_args, **inner_kwargs) 508 ix[0] += 1 509 return r 510 511 return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, 512 **kwargs) 513