1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras SavedModel serialization. 16 17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should 18go to model_serialization.py. 19""" 20 21import functools 22import threading 23import weakref 24 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.keras import backend as K 29from tensorflow.python.keras.engine import base_layer_utils 30from tensorflow.python.keras.engine import input_spec 31from tensorflow.python.keras.mixed_precision import autocast_variable 32from tensorflow.python.keras.saving import saving_utils 33from tensorflow.python.keras.saving.saved_model import constants 34from tensorflow.python.keras.saving.saved_model import load as keras_load 35from tensorflow.python.keras.saving.saved_model import serialized_attributes 36from tensorflow.python.keras.saving.saved_model import utils 37from tensorflow.python.keras.utils import tf_contextlib 38from tensorflow.python.keras.utils import tf_inspect 39from tensorflow.python.keras.utils import tf_utils 40from tensorflow.python.keras.utils import version_utils 41from tensorflow.python.keras.utils.generic_utils import LazyLoader 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.trackable import data_structures 44from tensorflow.python.util import nest 45from tensorflow.python.util import tf_decorator 46 47 48# To avoid circular dependencies between keras/engine and keras/saving, 49# code in keras/saving must delay imports. 50 51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 52# once the issue with copybara is fixed. 53# pylint:disable=g-inconsistent-quotes 54base_layer = LazyLoader( 55 "base_layer", globals(), 56 "tensorflow.python.keras.engine.base_layer") 57metrics = LazyLoader("metrics", globals(), 58 "tensorflow.python.keras.metrics") 59input_layer = LazyLoader( 60 "input_layer", globals(), 61 "tensorflow.python.keras.engine.input_layer") 62training_lib = LazyLoader( 63 "training_lib", globals(), 64 "tensorflow.python.keras.engine.training") 65sequential_lib = LazyLoader( 66 "sequential_lib", globals(), 67 "tensorflow.python.keras.engine.sequential") 68# pylint:enable=g-inconsistent-quotes 69 70 71def should_skip_serialization(layer): 72 """Skip serializing extra objects and functions if layer inputs aren't set.""" 73 saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and 74 layer._saved_model_inputs_spec is not None) # pylint: disable=protected-access 75 if not layer.built and not saved_model_input_spec_set: 76 logging.warning('Skipping full serialization of Keras layer {}, because ' 77 'it is not built.'.format(layer)) 78 return True 79 return False 80 81 82def wrap_layer_objects(layer, serialization_cache): 83 """Returns extra trackable objects to attach to the serialized layer. 84 85 Args: 86 layer: Keras Layer object. 87 serialization_cache: Dictionary shared between all objects during 88 serialization. 89 90 Returns: 91 A dictionary containing all checkpointable objects from a 92 SerializedAttributes object. See LayerAttributes and ModelAttributes for 93 entire list of objects 94 """ 95 # Wrap all regularization losses as tf.functions. 96 # First, generate list of all regularization losses in this layer and 97 # sublayers. 98 all_losses = layer._callable_losses[:] # pylint: disable=protected-access 99 for child_layer in utils.list_all_layers(layer): 100 all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access 101 # Next, wrap all loss functions as tf.functions. Use the serialization cache 102 # to store already-wrapped functions. 103 keras_loss_cache = serialization_cache.setdefault('keras_losses', {}) 104 wrapped_loss_functions = [] 105 for loss_fn in all_losses: 106 if loss_fn in keras_loss_cache: 107 wrapped_loss_functions.append(keras_loss_cache[loss_fn]) 108 else: 109 wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache)) 110 keras_loss_cache[loss_fn] = wrapped_loss 111 wrapped_loss_functions.append(wrapped_loss) 112 wrapped_layer_losses = [keras_loss_cache[fn] 113 for fn in layer._callable_losses[:]] # pylint: disable=protected-access 114 115 layer_metrics = data_structures.wrap_or_unwrap( 116 {m.name: m for m in layer._metrics}) # pylint: disable=protected-access 117 return dict( 118 variables=data_structures.wrap_or_unwrap(layer.variables), 119 trainable_variables=data_structures.wrap_or_unwrap( 120 layer.trainable_variables), 121 non_trainable_variables=data_structures.wrap_or_unwrap( 122 layer.non_trainable_variables), 123 layers=data_structures.wrap_or_unwrap(utils.list_all_layers(layer)), 124 metrics=data_structures.wrap_or_unwrap(layer.metrics), 125 regularization_losses=data_structures.wrap_or_unwrap( 126 wrapped_loss_functions), 127 layer_regularization_losses=data_structures.wrap_or_unwrap( 128 wrapped_layer_losses), 129 layer_metrics=layer_metrics) 130 # pylint: disable=protected-access 131 132 133def wrap_layer_functions(layer, serialization_cache): 134 """Returns dict of wrapped layer call function and losses in tf.functions. 135 136 Args: 137 layer: Keras Layer object. 138 serialization_cache: Dictionary shared between all objects during 139 serialization. 140 141 Returns: 142 A dictionary containing all keras tf.functions to serialize. See 143 LayerAttributes and ModelAttributes for the list of all attributes. 144 """ 145 # Since Sequential models may be modified in place using model.add() or 146 # model.pop(), don't use saved functions. 147 if (isinstance(layer, keras_load.RevivedLayer) and 148 not isinstance(layer, sequential_lib.Sequential)): 149 return {fn_name: getattr(layer.keras_api, fn_name, None) 150 for fn_name in serialized_attributes.LayerAttributes.all_functions} 151 152 # Reset the losses of the layer and its children. The call function in each 153 # child layer is replaced with tf.functions. 154 original_fns = _replace_child_layer_functions(layer, serialization_cache) 155 original_losses = _reset_layer_losses(layer) 156 157 # Wrap all the layer call and activity regularizer functions. 158 159 # Use LayerCallCollection to ensure that all layer call functions (__call__, 160 # call with losses) are traced with the same inputs. 161 call_collection = LayerCallCollection(layer) 162 call_fn_with_losses = call_collection.add_function( 163 _wrap_call_and_conditional_losses(layer), 164 '{}_layer_call_and_return_conditional_losses'.format(layer.name), 165 # If any of this layer's child layers use the training arg, the traced 166 # call functions of this layer will have a training keyword argument. If 167 # the original layer does not expect the training arg, then it will have 168 # to be removed (by setting `match_layer_training_arg`). 169 match_layer_training_arg=True) 170 call_fn = call_collection.add_function( 171 _extract_outputs_from_fn(layer, call_fn_with_losses), 172 '{}_layer_call_fn'.format(layer.name), 173 # Since `call_fn` wraps call_fn_with_losses and not the original call 174 # function, `match_layer_training_arg` should be set to False. 175 match_layer_training_arg=False) 176 177 fns = {'call_and_return_conditional_losses': call_fn_with_losses, 178 '__call__': call_fn} 179 180 if layer._activity_regularizer is not None: # pylint: disable=protected-access 181 fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) 182 fns['call_and_return_all_conditional_losses'] = ( 183 call_collection.add_function( 184 _append_activity_regularizer_loss( 185 layer, call_fn_with_losses, fns['activity_regularizer_fn']), 186 '{}_layer_call_and_return_all_conditional_losses'.format( 187 layer.name), 188 match_layer_training_arg=False)) 189 else: 190 fns['activity_regularizer_fn'] = None 191 fns['call_and_return_all_conditional_losses'] = call_fn_with_losses 192 193 # Manually trigger traces before restoring the overwritten functions. The 194 # functions are traced within the layer call context to ensure that layer 195 # functions (e.g. add_loss) behave as though running in graph mode. 196 with tracing_scope(): 197 call_collection.trace_with_input_signature() 198 with base_layer_utils.call_context().enter( 199 layer, inputs=None, build_graph=True, training=None, saving=True): 200 for fn in fns.values(): 201 if fn is not None and fn.input_signature is not None: 202 if isinstance(fn, LayerCall): 203 fn = fn.wrapped_call 204 fn.get_concrete_function() 205 206 # Restore overwritten functions and losses 207 _restore_child_layer_functions(original_fns) 208 _restore_layer_losses(original_losses) 209 210 return fns 211 212 213def default_save_signature(layer): 214 original_losses = _reset_layer_losses(layer) 215 fn = saving_utils.trace_model_call(layer) 216 fn.get_concrete_function() 217 _restore_layer_losses(original_losses) 218 return fn 219 220 221def _replace_child_layer_functions(layer, serialization_cache): 222 """Replaces functions in the children layers with wrapped tf.functions. 223 224 This step allows functions from parent layers to reference the wrapped 225 functions from their children layers instead of retracing the ops. 226 227 This function also resets all losses stored in the layer. These are stored in 228 the returned dictionary. Use `_restore_child_layer_functions` to restore 229 the original attributes. 230 231 Args: 232 layer: Keras Layer object. 233 serialization_cache: Dictionary shared between all objects during 234 serialization. 235 236 Returns: 237 Dictionary mapping layer objects -> original functions and losses: 238 { Child layer 1: { 239 'losses': Original losses, 240 'call': Original call function 241 '_activity_regularizer': Original activity regularizer}, 242 Child layer 2: ... 243 } 244 """ 245 # pylint: disable=protected-access 246 original_fns = {} 247 248 def replace_layer_functions(child_layer, serialized_fns): 249 """Replaces layer call and activity regularizer with wrapped functions.""" 250 original_fns[child_layer] = { 251 'call': child_layer.call, 252 '_activity_regularizer': child_layer._activity_regularizer 253 } 254 with utils.no_automatic_dependency_tracking_scope(child_layer): 255 try: 256 child_layer._activity_regularizer = serialized_fns.get( 257 'activity_regularizer_fn') 258 except AttributeError: 259 # Some layers have an unsettable activity regularizer. 260 pass 261 child_layer.call = utils.use_wrapped_call( 262 child_layer, 263 serialized_fns['call_and_return_conditional_losses'], 264 default_training_value=False) 265 266 def replace_metric_functions(child_layer, serialized_fns): 267 """Replaces metric functions with wrapped functions.""" 268 original_fns[child_layer] = { 269 '__call__': child_layer.__call__, 270 'result': child_layer.result, 271 'update_state': child_layer.update_state 272 } 273 with utils.no_automatic_dependency_tracking_scope(child_layer): 274 child_layer.__call__ = serialized_fns['__call__'] 275 child_layer.result = serialized_fns['result'] 276 child_layer.update_state = serialized_fns['update_state'] 277 278 for child_layer in utils.list_all_layers(layer): 279 if isinstance(child_layer, input_layer.InputLayer): 280 continue 281 282 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: 283 serialized_functions = ( 284 child_layer._trackable_saved_model_saver._get_serialized_attributes( 285 serialization_cache).functions) 286 else: 287 serialized_functions = ( 288 serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions) 289 if not serialized_functions: 290 # This indicates either: 291 # - circular dependency, which means the current layer's functions 292 # should be wrapped first. 293 # - Child layer's inputs are not defined, so its functions have not been 294 # wrapped. In this case, no replacement is necessary so move on to the 295 # next child. 296 continue 297 298 if isinstance(child_layer, metrics.Metric): 299 replace_metric_functions(child_layer, serialized_functions) 300 else: 301 replace_layer_functions(child_layer, serialized_functions) 302 303 return original_fns 304 # pylint: enable=protected-access 305 306 307def _restore_child_layer_functions(original_fns): 308 """Restores attributes replaced with `_replace_child_layer_functions`.""" 309 for child_layer, fns in original_fns.items(): 310 with utils.no_automatic_dependency_tracking_scope(child_layer): 311 for fn_name, fn in fns.items(): 312 try: 313 setattr(child_layer, fn_name, fn) # pylint: disable=protected-access 314 except AttributeError: 315 pass # In the case of _activity_regularizer, setting the attribute 316 # may be disallowed. 317 318 319# pylint: disable=protected-access 320def _reset_layer_losses(parent_layer): 321 """Resets losses of layer and its sublayers, and returns original losses.""" 322 losses_dict = {} 323 for layer in utils.list_all_layers_and_sublayers(parent_layer): 324 losses_dict[layer] = {'losses': layer._losses[:], 325 'eager_losses': layer._eager_losses[:]} 326 with utils.no_automatic_dependency_tracking_scope(layer): 327 layer._losses = [] 328 layer._eager_losses = [] 329 return losses_dict 330 331 332def _restore_layer_losses(losses_dict): 333 for layer in losses_dict: 334 with utils.no_automatic_dependency_tracking_scope(layer): 335 layer._losses = losses_dict[layer]['losses'] 336 layer._eager_losses = losses_dict[layer]['eager_losses'] 337# pylint: enable=protected-access 338 339 340class LayerTracingContext(threading.local): 341 342 def __init__(self): 343 super(LayerTracingContext, self).__init__() 344 self.enable_call_tracing = False 345 self.trace_queue = [] 346 347_thread_local_data = LayerTracingContext() 348 349 350@tf_contextlib.contextmanager 351def tracing_scope(): 352 """Enables tracing scope.""" 353 # This enables the LayerCallCollection's tracing mechanism to trace all call 354 # functions in the collection. 355 previous_value = _thread_local_data.enable_call_tracing 356 previous_queue = _thread_local_data.trace_queue 357 try: 358 _thread_local_data.enable_call_tracing = True 359 _thread_local_data.trace_queue = [] 360 yield 361 finally: 362 # Run traces from the queue. 363 while _thread_local_data.trace_queue: 364 fn, args, kwargs, training = _thread_local_data.trace_queue.pop() 365 if training is not None: 366 with K.deprecated_internal_learning_phase_scope(training): 367 fn.get_concrete_function(*args, **kwargs) 368 else: 369 fn.get_concrete_function(*args, **kwargs) 370 _thread_local_data.trace_queue = previous_queue 371 _thread_local_data.enable_call_tracing = previous_value 372 373 374def add_trace_to_queue(fn, args, kwargs, training=None): 375 if tracing_enabled(): 376 _thread_local_data.trace_queue.append( 377 (fn, args[:], kwargs.copy(), training)) 378 379 380def tracing_enabled(): 381 """Whether to add extra traces to the queue.""" 382 return _thread_local_data.enable_call_tracing 383 384 385class LayerCallCollection(object): 386 """Groups wrapped layer call functions. 387 388 This is used to ensure that all layer call functions are traced with the same 389 inputs- 390 - call 391 - call_and_return_conditional_losses 392 - call_and_return_all_conditional_losses 393 """ 394 395 def __init__(self, layer): 396 self.layer = layer 397 398 self.layer_call_method = _get_layer_call_method(layer) 399 self._expects_training_arg = utils.layer_uses_training_bool(layer) 400 self._training_arg_index = utils.get_training_arg_index( 401 self.layer_call_method) 402 403 # If the layer call function has kwargs, then the traced function cannot 404 # have an input signature. 405 arg_spec = tf_inspect.getfullargspec(self.layer_call_method) 406 self._has_kwargs = bool(self._expects_training_arg or 407 arg_spec.defaults or 408 arg_spec.kwonlyargs or 409 arg_spec.varkw) 410 411 self._input_signature = self._generate_input_signature(layer) 412 self._functions = weakref.WeakValueDictionary() 413 414 # Get the input argument name from the args. 415 args = arg_spec.args 416 if tf_inspect.ismethod(self.layer_call_method): 417 args = args[1:] 418 self._input_arg_name = args[0] if args else 'inputs' 419 420 def _generate_input_signature(self, layer): 421 """Inspects layer object and returns the inferred input signature. 422 423 Args: 424 layer: Layer object. 425 426 Returns: 427 List of possibly nested TensorSpecs of the layer call function inputs. 428 The list does not contain the `training` argument. 429 """ 430 if (isinstance(layer.call, def_function.Function) and 431 layer.call.input_signature is not None): 432 return layer.call.input_signature 433 elif isinstance(layer, training_lib.Model): 434 return saving_utils.model_input_signature(layer) 435 elif (layer.input_spec is not None and 436 layer._use_input_spec_as_call_signature): # pylint: disable=protected-access 437 438 def to_tensor_spec_or_none(x): 439 spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access 440 # If the shape is too general (e.g. multiple dimensions are allowed), 441 # return None so that separate functions can be generated for each 442 # inferred input signature. 443 # TODO(b/134962016): currently partial signatures are not supported. 444 if spec.shape == tensor_shape.TensorShape(None): 445 return None 446 return spec 447 input_signature = [nest.map_structure( 448 to_tensor_spec_or_none, layer.input_spec)] 449 450 return input_signature 451 else: 452 return None 453 454 def add_trace(self, *args, **kwargs): 455 """Traces all functions with the same args and kwargs. 456 457 Args: 458 *args: Positional args passed to the original function. 459 **kwargs: Keyword args passed to the original function. 460 """ 461 args = list(args) 462 kwargs = kwargs.copy() 463 464 for fn in self._functions.values(): 465 # TODO(kathywu): Replace arguments with broader shapes defined in the 466 # input signature. 467 if self._expects_training_arg: 468 def trace_with_training(value, fn=fn): 469 utils.set_training_arg(value, self._training_arg_index, args, kwargs) 470 add_trace_to_queue(fn, args, kwargs, value) 471 472 trace_with_training(True) 473 trace_with_training(False) 474 else: 475 add_trace_to_queue(fn, args, kwargs) 476 477 @property 478 def fn_input_signature(self): 479 """Returns input signature for the wrapped layer call function.""" 480 if self._has_kwargs: 481 # Input signatures may only describe tensor arguments and kwargs are not 482 # supported. 483 return None 484 if None in nest.flatten(self._input_signature): 485 # TODO(b/134962016): If input signature cannot be partially defined. 486 return None 487 return self._input_signature 488 489 def training_arg_was_passed(self, args, kwargs): 490 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 491 return (utils.get_training_arg(self._training_arg_index, args, kwargs) 492 is not None) 493 else: 494 return self.layer._call_arg_was_passed( # pylint: disable=protected-access 495 'training', args, kwargs, inputs_in_args=True) 496 497 def get_training_arg_value(self, args, kwargs): 498 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 499 return utils.get_training_arg(self._training_arg_index, args, kwargs) 500 else: 501 return self.layer._get_call_arg_value( # pylint: disable=protected-access 502 'training', args, kwargs, inputs_in_args=True) 503 504 def get_input_arg_value(self, args, kwargs): 505 return self.layer._get_call_arg_value( # pylint: disable=protected-access 506 self._input_arg_name, args, kwargs, inputs_in_args=True) 507 508 def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg): 509 """Wraps call function with added training argument if necessary.""" 510 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 511 # Add training arg to wrapper function. 512 arg_spec = tf_inspect.getfullargspec(call_fn) 513 args = arg_spec.args + ['training'] 514 defaults = list(arg_spec.defaults or []) 515 defaults.append(False) 516 new_arg_spec = tf_inspect.FullArgSpec( 517 args=args, 518 varargs=arg_spec.varargs, 519 varkw=arg_spec.varkw, 520 defaults=defaults, 521 kwonlyargs=arg_spec.kwonlyargs, 522 kwonlydefaults=arg_spec.kwonlydefaults, 523 annotations=arg_spec.annotations) 524 525 # Set new training arg index 526 self._training_arg_index = len(args) - 1 527 if tf_inspect.ismethod(call_fn): 528 self._training_arg_index -= 1 529 530 def wrap_with_training_arg(*args, **kwargs): 531 if match_layer_training_arg: 532 # Remove the training value, since the original call_fn does not 533 # expect a training arg. Instead, the training value will be 534 # propagated using the call context created in LayerCall. 535 args = list(args) 536 kwargs = kwargs.copy() 537 utils.remove_training_arg(self._training_arg_index, args, kwargs) 538 return call_fn(*args, **kwargs) 539 540 return tf_decorator.make_decorator( 541 target=call_fn, 542 decorator_func=wrap_with_training_arg, 543 decorator_argspec=new_arg_spec) 544 545 return call_fn 546 547 def add_function(self, call_fn, name, match_layer_training_arg): 548 """Adds a layer call function to the collection. 549 550 Args: 551 call_fn: a python function 552 name: Name of call function 553 match_layer_training_arg: If True, removes the `training` from the 554 function arguments when calling `call_fn`. 555 556 Returns: 557 LayerCall (tf.function) 558 """ 559 fn = LayerCall( 560 self, 561 self._maybe_wrap_with_training_arg(call_fn, match_layer_training_arg), 562 name, 563 input_signature=self.fn_input_signature) 564 self._functions[name] = fn.wrapped_call 565 return fn 566 567 def trace_with_input_signature(self): 568 """Trace with the layer/models inferred input signature if possible.""" 569 if (None not in nest.flatten(self._input_signature) and self._has_kwargs): 570 # Manually add traces for layers that have keyword arguments and have 571 # a fully defined input signature. 572 self.add_trace(*self._input_signature) 573 574 575def _filtered_inputs(inputs): 576 return list(filter(tf_utils.is_tensor_or_variable, nest.flatten(inputs))) 577 578 579def layer_call_wrapper(call_collection, method, name): 580 """Ensures layer losses are kept the same, and runs method in call context.""" 581 582 # Create wrapper that deals with losses and call context. 583 def wrapper(*args, **kwargs): 584 """Calls method within call context.""" 585 layer = call_collection.layer 586 training = None 587 inputs = _filtered_inputs([args, kwargs]) 588 # pylint: disable=protected-access 589 if (args or kwargs) and call_collection.training_arg_was_passed( 590 args, kwargs): 591 training = call_collection.get_training_arg_value(args, kwargs) 592 # pylint: enable=protected-access 593 original_losses = _reset_layer_losses(layer) 594 with base_layer_utils.call_context().enter( 595 layer, inputs=inputs, build_graph=False, training=training, 596 saving=True): 597 with autocast_variable.enable_auto_cast_variables( 598 layer._compute_dtype_object): # pylint: disable=protected-access 599 ret = method(*args, **kwargs) 600 _restore_layer_losses(original_losses) 601 return ret 602 603 # Rename to `name`, since tf.function doesn't have a name argument. Without 604 # this, all functions returned by this method will be named "call", which 605 # would be a nightmare to debug. 606 fn = tf_decorator.make_decorator(target=method, decorator_func=wrapper) 607 fn.__name__ = name 608 return fn 609 610 611class LayerCall(object): 612 """Function that triggers traces of other functions in the same collection.""" 613 614 def __init__(self, call_collection, call_fn, name, input_signature): 615 """Initializes a LayerCall object. 616 617 Args: 618 call_collection: a LayerCallCollection, which contains the other layer 619 call functions (e.g. call_with_conditional_losses, call). These 620 functions should be traced with the same arguments. 621 call_fn: A call function. 622 name: Name of the call function. 623 input_signature: Input signature of call_fn (can be None). 624 """ 625 self.call_collection = call_collection 626 self.input_signature = input_signature 627 self.wrapped_call = def_function.function( 628 layer_call_wrapper(call_collection, call_fn, name), 629 input_signature=input_signature) 630 self.original_layer_call = call_collection.layer_call_method 631 632 def _maybe_trace(self, args, kwargs): 633 # Trigger traces of other call functions + extra training-arg traces. 634 if tracing_enabled(): 635 self.call_collection.add_trace(*args, **kwargs) 636 637 def __call__(self, *args, **kwargs): 638 self._maybe_trace(args, kwargs) 639 return self.wrapped_call(*args, **kwargs) 640 641 def get_concrete_function(self, *args, **kwargs): 642 self._maybe_trace(args, kwargs) 643 return self.wrapped_call.get_concrete_function(*args, **kwargs) 644 645 646def _wrap_call_and_conditional_losses(layer): 647 """Wraps call function that returns a tuple of (outputs, losses). 648 649 The losses returned are conditional on the inputs passed to the call function. 650 Unconditional losses (e.g. weight regularizeration) are wrapped separately. 651 652 Args: 653 layer: a Keras layer object 654 655 Returns: 656 python call function that returns outputs and conditional losses -- excludes 657 activity regularizer 658 """ 659 # Create function that generates both outputs and losses 660 layer_call = _get_layer_call_method(layer) 661 def call_and_return_conditional_losses(*args, **kwargs): 662 """Returns layer (call_output, conditional losses) tuple.""" 663 call_output = layer_call(*args, **kwargs) 664 if version_utils.is_v1_layer_or_model(layer): 665 conditional_losses = layer.get_losses_for( 666 _filtered_inputs([args, kwargs])) 667 else: 668 conditional_losses = [ 669 l for l in layer.losses if not hasattr(l, '_unconditional_loss') 670 ] 671 return call_output, conditional_losses 672 673 return _create_call_fn_decorator(layer, call_and_return_conditional_losses) 674 675 676def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): 677 """Returns a function that returns only call function outputs.""" 678 if isinstance(layer, keras_load.RevivedLayer): 679 return layer.keras_api.__call__ # pylint: disable=protected-access 680 def call(inputs, *args, **kwargs): 681 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0] 682 return _create_call_fn_decorator(layer, call) 683 684 685def _append_activity_regularizer_loss( 686 layer, call_fn_with_losses, activity_regularizer_fn): 687 """Appends activity regularizer loss to losses returned by the wrapped fn.""" 688 def fn(inputs, *args, **kwargs): 689 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs) 690 losses.append(activity_regularizer_fn(outputs)) 691 return outputs, losses 692 return _create_call_fn_decorator(layer, fn) 693 694 695def _create_call_fn_decorator(layer, wrapped_call): 696 call_fn = _get_layer_call_method(layer) 697 fn, arg_spec = utils.maybe_add_training_arg( 698 call_fn, wrapped_call, layer._expects_training_arg, # pylint: disable=protected-access 699 default_training_value=False) 700 return tf_decorator.make_decorator( 701 target=call_fn, 702 decorator_func=fn, 703 decorator_argspec=arg_spec) 704 705 706def _wrap_unconditional_loss(loss_fn, index): 707 """Wraps callable/unconditional loss, returning a serializable function.""" 708 # Extract original loss function from partial function 709 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn 710 if isinstance(fn, def_function.Function): 711 return fn 712 else: 713 return def_function.Function( 714 fn, 'loss_fn_{}'.format(index), input_signature=[]) 715 716 717def _wrap_activity_regularizer(layer): 718 """Wraps the activity regularizer.""" 719 # pylint: disable=protected-access 720 if isinstance(layer._activity_regularizer, def_function.Function): 721 return layer._activity_regularizer 722 return def_function.Function( 723 layer._activity_regularizer, 724 '{}_activity_regularizer'.format(layer.name), 725 input_signature=[ 726 tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx()) 727 ]) 728 # pylint: enable=protected-access 729 730 731def _get_layer_call_method(layer): 732 if isinstance(layer.call, (def_function.Function)): 733 return layer.call.python_function 734 return layer.call 735