xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saved_model/save_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel serialization.
16
17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should
18go to model_serialization.py.
19"""
20
21import functools
22import threading
23import weakref
24
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras.engine import base_layer_utils
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.mixed_precision import autocast_variable
32from tensorflow.python.keras.saving import saving_utils
33from tensorflow.python.keras.saving.saved_model import constants
34from tensorflow.python.keras.saving.saved_model import load as keras_load
35from tensorflow.python.keras.saving.saved_model import serialized_attributes
36from tensorflow.python.keras.saving.saved_model import utils
37from tensorflow.python.keras.utils import tf_contextlib
38from tensorflow.python.keras.utils import tf_inspect
39from tensorflow.python.keras.utils import tf_utils
40from tensorflow.python.keras.utils import version_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.trackable import data_structures
44from tensorflow.python.util import nest
45from tensorflow.python.util import tf_decorator
46
47
48# To avoid circular dependencies between keras/engine and keras/saving,
49# code in keras/saving must delay imports.
50
51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
52# once the issue with copybara is fixed.
53# pylint:disable=g-inconsistent-quotes
54base_layer = LazyLoader(
55    "base_layer", globals(),
56    "tensorflow.python.keras.engine.base_layer")
57metrics = LazyLoader("metrics", globals(),
58                     "tensorflow.python.keras.metrics")
59input_layer = LazyLoader(
60    "input_layer", globals(),
61    "tensorflow.python.keras.engine.input_layer")
62training_lib = LazyLoader(
63    "training_lib", globals(),
64    "tensorflow.python.keras.engine.training")
65sequential_lib = LazyLoader(
66    "sequential_lib", globals(),
67    "tensorflow.python.keras.engine.sequential")
68# pylint:enable=g-inconsistent-quotes
69
70
71def should_skip_serialization(layer):
72  """Skip serializing extra objects and functions if layer inputs aren't set."""
73  saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and
74                                layer._saved_model_inputs_spec is not None)  # pylint: disable=protected-access
75  if not layer.built and not saved_model_input_spec_set:
76    logging.warning('Skipping full serialization of Keras layer {}, because '
77                    'it is not built.'.format(layer))
78    return True
79  return False
80
81
82def wrap_layer_objects(layer, serialization_cache):
83  """Returns extra trackable objects to attach to the serialized layer.
84
85  Args:
86    layer: Keras Layer object.
87    serialization_cache: Dictionary shared between all objects during
88      serialization.
89
90  Returns:
91    A dictionary containing all checkpointable objects from a
92    SerializedAttributes object. See LayerAttributes and ModelAttributes for
93    entire list of objects
94  """
95  # Wrap all regularization losses as tf.functions.
96  # First, generate list of all regularization losses in this layer and
97  # sublayers.
98  all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
99  for child_layer in utils.list_all_layers(layer):
100    all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
101  # Next, wrap all loss functions as tf.functions. Use the serialization cache
102  # to store already-wrapped functions.
103  keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
104  wrapped_loss_functions = []
105  for loss_fn in all_losses:
106    if loss_fn in keras_loss_cache:
107      wrapped_loss_functions.append(keras_loss_cache[loss_fn])
108    else:
109      wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
110      keras_loss_cache[loss_fn] = wrapped_loss
111      wrapped_loss_functions.append(wrapped_loss)
112  wrapped_layer_losses = [keras_loss_cache[fn]
113                          for fn in layer._callable_losses[:]]  # pylint: disable=protected-access
114
115  layer_metrics = data_structures.wrap_or_unwrap(
116      {m.name: m for m in layer._metrics})  # pylint: disable=protected-access
117  return dict(
118      variables=data_structures.wrap_or_unwrap(layer.variables),
119      trainable_variables=data_structures.wrap_or_unwrap(
120          layer.trainable_variables),
121      non_trainable_variables=data_structures.wrap_or_unwrap(
122          layer.non_trainable_variables),
123      layers=data_structures.wrap_or_unwrap(utils.list_all_layers(layer)),
124      metrics=data_structures.wrap_or_unwrap(layer.metrics),
125      regularization_losses=data_structures.wrap_or_unwrap(
126          wrapped_loss_functions),
127      layer_regularization_losses=data_structures.wrap_or_unwrap(
128          wrapped_layer_losses),
129      layer_metrics=layer_metrics)
130  # pylint: disable=protected-access
131
132
133def wrap_layer_functions(layer, serialization_cache):
134  """Returns dict of wrapped layer call function and losses in tf.functions.
135
136  Args:
137    layer: Keras Layer object.
138    serialization_cache: Dictionary shared between all objects during
139      serialization.
140
141  Returns:
142    A dictionary containing all keras tf.functions to serialize. See
143    LayerAttributes and ModelAttributes for the list of all attributes.
144  """
145  # Since Sequential models may be modified in place using model.add() or
146  # model.pop(), don't use saved functions.
147  if (isinstance(layer, keras_load.RevivedLayer) and
148      not isinstance(layer, sequential_lib.Sequential)):
149    return {fn_name: getattr(layer.keras_api, fn_name, None)
150            for fn_name in serialized_attributes.LayerAttributes.all_functions}
151
152  # Reset the losses of the layer and its children. The call function in each
153  # child layer is replaced with tf.functions.
154  original_fns = _replace_child_layer_functions(layer, serialization_cache)
155  original_losses = _reset_layer_losses(layer)
156
157  # Wrap all the layer call and activity regularizer functions.
158
159  # Use LayerCallCollection to ensure that all layer call functions (__call__,
160  # call with losses) are traced with the same inputs.
161  call_collection = LayerCallCollection(layer)
162  call_fn_with_losses = call_collection.add_function(
163      _wrap_call_and_conditional_losses(layer),
164      '{}_layer_call_and_return_conditional_losses'.format(layer.name),
165      # If any of this layer's child layers use the training arg, the traced
166      # call functions of this layer will have a training keyword argument. If
167      # the original layer does not expect the training arg, then it will have
168      # to be removed (by setting `match_layer_training_arg`).
169      match_layer_training_arg=True)
170  call_fn = call_collection.add_function(
171      _extract_outputs_from_fn(layer, call_fn_with_losses),
172      '{}_layer_call_fn'.format(layer.name),
173      # Since `call_fn` wraps call_fn_with_losses and not the original call
174      # function, `match_layer_training_arg` should be set to False.
175      match_layer_training_arg=False)
176
177  fns = {'call_and_return_conditional_losses': call_fn_with_losses,
178         '__call__': call_fn}
179
180  if layer._activity_regularizer is not None:  # pylint: disable=protected-access
181    fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
182    fns['call_and_return_all_conditional_losses'] = (
183        call_collection.add_function(
184            _append_activity_regularizer_loss(
185                layer, call_fn_with_losses, fns['activity_regularizer_fn']),
186            '{}_layer_call_and_return_all_conditional_losses'.format(
187                layer.name),
188            match_layer_training_arg=False))
189  else:
190    fns['activity_regularizer_fn'] = None
191    fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
192
193  # Manually trigger traces before restoring the overwritten functions. The
194  # functions are traced within the layer call context to ensure that layer
195  # functions (e.g. add_loss) behave as though running in graph mode.
196  with tracing_scope():
197    call_collection.trace_with_input_signature()
198    with base_layer_utils.call_context().enter(
199        layer, inputs=None, build_graph=True, training=None, saving=True):
200      for fn in fns.values():
201        if fn is not None and fn.input_signature is not None:
202          if isinstance(fn, LayerCall):
203            fn = fn.wrapped_call
204          fn.get_concrete_function()
205
206  # Restore overwritten functions and losses
207  _restore_child_layer_functions(original_fns)
208  _restore_layer_losses(original_losses)
209
210  return fns
211
212
213def default_save_signature(layer):
214  original_losses = _reset_layer_losses(layer)
215  fn = saving_utils.trace_model_call(layer)
216  fn.get_concrete_function()
217  _restore_layer_losses(original_losses)
218  return fn
219
220
221def _replace_child_layer_functions(layer, serialization_cache):
222  """Replaces functions in the children layers with wrapped tf.functions.
223
224  This step allows functions from parent layers to reference the wrapped
225  functions from their children layers instead of retracing the ops.
226
227  This function also resets all losses stored in the layer. These are stored in
228  the returned dictionary. Use `_restore_child_layer_functions` to restore
229  the original attributes.
230
231  Args:
232    layer: Keras Layer object.
233    serialization_cache: Dictionary shared between all objects during
234      serialization.
235
236  Returns:
237    Dictionary mapping layer objects -> original functions and losses:
238      { Child layer 1: {
239          'losses': Original losses,
240          'call': Original call function
241          '_activity_regularizer': Original activity regularizer},
242        Child layer 2: ...
243      }
244  """
245  # pylint: disable=protected-access
246  original_fns = {}
247
248  def replace_layer_functions(child_layer, serialized_fns):
249    """Replaces layer call and activity regularizer with wrapped functions."""
250    original_fns[child_layer] = {
251        'call': child_layer.call,
252        '_activity_regularizer': child_layer._activity_regularizer
253    }
254    with utils.no_automatic_dependency_tracking_scope(child_layer):
255      try:
256        child_layer._activity_regularizer = serialized_fns.get(
257            'activity_regularizer_fn')
258      except AttributeError:
259        # Some layers have an unsettable activity regularizer.
260        pass
261      child_layer.call = utils.use_wrapped_call(
262          child_layer,
263          serialized_fns['call_and_return_conditional_losses'],
264          default_training_value=False)
265
266  def replace_metric_functions(child_layer, serialized_fns):
267    """Replaces metric functions with wrapped functions."""
268    original_fns[child_layer] = {
269        '__call__': child_layer.__call__,
270        'result': child_layer.result,
271        'update_state': child_layer.update_state
272    }
273    with utils.no_automatic_dependency_tracking_scope(child_layer):
274      child_layer.__call__ = serialized_fns['__call__']
275      child_layer.result = serialized_fns['result']
276      child_layer.update_state = serialized_fns['update_state']
277
278  for child_layer in utils.list_all_layers(layer):
279    if isinstance(child_layer, input_layer.InputLayer):
280      continue
281
282    if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
283      serialized_functions = (
284          child_layer._trackable_saved_model_saver._get_serialized_attributes(
285              serialization_cache).functions)
286    else:
287      serialized_functions = (
288          serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
289    if not serialized_functions:
290      # This indicates either:
291      #   - circular dependency, which means the current layer's functions
292      #     should be wrapped first.
293      #   - Child layer's inputs are not defined, so its functions have not been
294      #     wrapped. In this case, no replacement is necessary so move on to the
295      #     next child.
296      continue
297
298    if isinstance(child_layer, metrics.Metric):
299      replace_metric_functions(child_layer, serialized_functions)
300    else:
301      replace_layer_functions(child_layer, serialized_functions)
302
303  return original_fns
304  # pylint: enable=protected-access
305
306
307def _restore_child_layer_functions(original_fns):
308  """Restores attributes replaced with `_replace_child_layer_functions`."""
309  for child_layer, fns in original_fns.items():
310    with utils.no_automatic_dependency_tracking_scope(child_layer):
311      for fn_name, fn in fns.items():
312        try:
313          setattr(child_layer, fn_name, fn)  # pylint: disable=protected-access
314        except AttributeError:
315          pass  # In the case of _activity_regularizer, setting the attribute
316          # may be disallowed.
317
318
319# pylint: disable=protected-access
320def _reset_layer_losses(parent_layer):
321  """Resets losses of layer and its sublayers, and returns original losses."""
322  losses_dict = {}
323  for layer in utils.list_all_layers_and_sublayers(parent_layer):
324    losses_dict[layer] = {'losses': layer._losses[:],
325                          'eager_losses': layer._eager_losses[:]}
326    with utils.no_automatic_dependency_tracking_scope(layer):
327      layer._losses = []
328      layer._eager_losses = []
329  return losses_dict
330
331
332def _restore_layer_losses(losses_dict):
333  for layer in losses_dict:
334    with utils.no_automatic_dependency_tracking_scope(layer):
335      layer._losses = losses_dict[layer]['losses']
336      layer._eager_losses = losses_dict[layer]['eager_losses']
337# pylint: enable=protected-access
338
339
340class LayerTracingContext(threading.local):
341
342  def __init__(self):
343    super(LayerTracingContext, self).__init__()
344    self.enable_call_tracing = False
345    self.trace_queue = []
346
347_thread_local_data = LayerTracingContext()
348
349
350@tf_contextlib.contextmanager
351def tracing_scope():
352  """Enables tracing scope."""
353  # This enables the LayerCallCollection's tracing mechanism to trace all call
354  # functions in the collection.
355  previous_value = _thread_local_data.enable_call_tracing
356  previous_queue = _thread_local_data.trace_queue
357  try:
358    _thread_local_data.enable_call_tracing = True
359    _thread_local_data.trace_queue = []
360    yield
361  finally:
362    # Run traces from the queue.
363    while _thread_local_data.trace_queue:
364      fn, args, kwargs, training = _thread_local_data.trace_queue.pop()
365      if training is not None:
366        with K.deprecated_internal_learning_phase_scope(training):
367          fn.get_concrete_function(*args, **kwargs)
368      else:
369        fn.get_concrete_function(*args, **kwargs)
370    _thread_local_data.trace_queue = previous_queue
371    _thread_local_data.enable_call_tracing = previous_value
372
373
374def add_trace_to_queue(fn, args, kwargs, training=None):
375  if tracing_enabled():
376    _thread_local_data.trace_queue.append(
377        (fn, args[:], kwargs.copy(), training))
378
379
380def tracing_enabled():
381  """Whether to add extra traces to the queue."""
382  return _thread_local_data.enable_call_tracing
383
384
385class LayerCallCollection(object):
386  """Groups wrapped layer call functions.
387
388  This is used to ensure that all layer call functions are traced with the same
389  inputs-
390    - call
391    - call_and_return_conditional_losses
392    - call_and_return_all_conditional_losses
393  """
394
395  def __init__(self, layer):
396    self.layer = layer
397
398    self.layer_call_method = _get_layer_call_method(layer)
399    self._expects_training_arg = utils.layer_uses_training_bool(layer)
400    self._training_arg_index = utils.get_training_arg_index(
401        self.layer_call_method)
402
403    # If the layer call function has kwargs, then the traced function cannot
404    # have an input signature.
405    arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
406    self._has_kwargs = bool(self._expects_training_arg or
407                            arg_spec.defaults or
408                            arg_spec.kwonlyargs or
409                            arg_spec.varkw)
410
411    self._input_signature = self._generate_input_signature(layer)
412    self._functions = weakref.WeakValueDictionary()
413
414    # Get the input argument name from the args.
415    args = arg_spec.args
416    if tf_inspect.ismethod(self.layer_call_method):
417      args = args[1:]
418    self._input_arg_name = args[0] if args else 'inputs'
419
420  def _generate_input_signature(self, layer):
421    """Inspects layer object and returns the inferred input signature.
422
423    Args:
424      layer: Layer object.
425
426    Returns:
427      List of possibly nested TensorSpecs of the layer call function inputs.
428      The list does not contain the `training` argument.
429    """
430    if (isinstance(layer.call, def_function.Function) and
431        layer.call.input_signature is not None):
432      return layer.call.input_signature
433    elif isinstance(layer, training_lib.Model):
434      return saving_utils.model_input_signature(layer)
435    elif (layer.input_spec is not None and
436          layer._use_input_spec_as_call_signature):  # pylint: disable=protected-access
437
438      def to_tensor_spec_or_none(x):
439        spec = input_spec.to_tensor_spec(x, layer._compute_dtype)  # pylint: disable=protected-access
440        # If the shape is too general (e.g. multiple dimensions are allowed),
441        # return None so that separate functions can be generated for each
442        # inferred input signature.
443        # TODO(b/134962016): currently partial signatures are not supported.
444        if spec.shape == tensor_shape.TensorShape(None):
445          return None
446        return spec
447      input_signature = [nest.map_structure(
448          to_tensor_spec_or_none, layer.input_spec)]
449
450      return input_signature
451    else:
452      return None
453
454  def add_trace(self, *args, **kwargs):
455    """Traces all functions with the same args and kwargs.
456
457    Args:
458      *args: Positional args passed to the original function.
459      **kwargs: Keyword args passed to the original function.
460    """
461    args = list(args)
462    kwargs = kwargs.copy()
463
464    for fn in self._functions.values():
465      # TODO(kathywu): Replace arguments with broader shapes defined in the
466      # input signature.
467      if self._expects_training_arg:
468        def trace_with_training(value, fn=fn):
469          utils.set_training_arg(value, self._training_arg_index, args, kwargs)
470          add_trace_to_queue(fn, args, kwargs, value)
471
472        trace_with_training(True)
473        trace_with_training(False)
474      else:
475        add_trace_to_queue(fn, args, kwargs)
476
477  @property
478  def fn_input_signature(self):
479    """Returns input signature for the wrapped layer call function."""
480    if self._has_kwargs:
481      # Input signatures may only describe tensor arguments and kwargs are not
482      # supported.
483      return None
484    if None in nest.flatten(self._input_signature):
485      # TODO(b/134962016): If input signature cannot be partially defined.
486      return None
487    return self._input_signature
488
489  def training_arg_was_passed(self, args, kwargs):
490    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
491      return (utils.get_training_arg(self._training_arg_index, args, kwargs)
492              is not None)
493    else:
494      return self.layer._call_arg_was_passed(  # pylint: disable=protected-access
495          'training', args, kwargs, inputs_in_args=True)
496
497  def get_training_arg_value(self, args, kwargs):
498    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
499      return utils.get_training_arg(self._training_arg_index, args, kwargs)
500    else:
501      return self.layer._get_call_arg_value(  # pylint: disable=protected-access
502          'training', args, kwargs, inputs_in_args=True)
503
504  def get_input_arg_value(self, args, kwargs):
505    return self.layer._get_call_arg_value(  # pylint: disable=protected-access
506        self._input_arg_name, args, kwargs, inputs_in_args=True)
507
508  def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
509    """Wraps call function with added training argument if necessary."""
510    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
511      # Add training arg to wrapper function.
512      arg_spec = tf_inspect.getfullargspec(call_fn)
513      args = arg_spec.args + ['training']
514      defaults = list(arg_spec.defaults or [])
515      defaults.append(False)
516      new_arg_spec = tf_inspect.FullArgSpec(
517          args=args,
518          varargs=arg_spec.varargs,
519          varkw=arg_spec.varkw,
520          defaults=defaults,
521          kwonlyargs=arg_spec.kwonlyargs,
522          kwonlydefaults=arg_spec.kwonlydefaults,
523          annotations=arg_spec.annotations)
524
525      # Set new training arg index
526      self._training_arg_index = len(args) - 1
527      if tf_inspect.ismethod(call_fn):
528        self._training_arg_index -= 1
529
530      def wrap_with_training_arg(*args, **kwargs):
531        if match_layer_training_arg:
532          # Remove the training value, since the original call_fn does not
533          # expect a training arg. Instead, the training value will be
534          # propagated using the call context created in LayerCall.
535          args = list(args)
536          kwargs = kwargs.copy()
537          utils.remove_training_arg(self._training_arg_index, args, kwargs)
538        return call_fn(*args, **kwargs)
539
540      return tf_decorator.make_decorator(
541          target=call_fn,
542          decorator_func=wrap_with_training_arg,
543          decorator_argspec=new_arg_spec)
544
545    return call_fn
546
547  def add_function(self, call_fn, name, match_layer_training_arg):
548    """Adds a layer call function to the collection.
549
550    Args:
551      call_fn: a python function
552      name: Name of call function
553      match_layer_training_arg: If True, removes the `training` from the
554        function arguments when calling `call_fn`.
555
556    Returns:
557      LayerCall (tf.function)
558    """
559    fn = LayerCall(
560        self,
561        self._maybe_wrap_with_training_arg(call_fn, match_layer_training_arg),
562        name,
563        input_signature=self.fn_input_signature)
564    self._functions[name] = fn.wrapped_call
565    return fn
566
567  def trace_with_input_signature(self):
568    """Trace with the layer/models inferred input signature if possible."""
569    if (None not in nest.flatten(self._input_signature) and self._has_kwargs):
570      # Manually add traces for layers that have keyword arguments and have
571      # a fully defined input signature.
572      self.add_trace(*self._input_signature)
573
574
575def _filtered_inputs(inputs):
576  return list(filter(tf_utils.is_tensor_or_variable, nest.flatten(inputs)))
577
578
579def layer_call_wrapper(call_collection, method, name):
580  """Ensures layer losses are kept the same, and runs method in call context."""
581
582  # Create wrapper that deals with losses and call context.
583  def wrapper(*args, **kwargs):
584    """Calls method within call context."""
585    layer = call_collection.layer
586    training = None
587    inputs = _filtered_inputs([args, kwargs])
588    # pylint: disable=protected-access
589    if (args or kwargs) and call_collection.training_arg_was_passed(
590        args, kwargs):
591      training = call_collection.get_training_arg_value(args, kwargs)
592    # pylint: enable=protected-access
593    original_losses = _reset_layer_losses(layer)
594    with base_layer_utils.call_context().enter(
595        layer, inputs=inputs, build_graph=False, training=training,
596        saving=True):
597      with autocast_variable.enable_auto_cast_variables(
598          layer._compute_dtype_object):  # pylint: disable=protected-access
599        ret = method(*args, **kwargs)
600    _restore_layer_losses(original_losses)
601    return ret
602
603  # Rename to `name`, since tf.function doesn't have a name argument. Without
604  # this, all functions returned by this method will be named "call", which
605  # would be a nightmare to debug.
606  fn = tf_decorator.make_decorator(target=method, decorator_func=wrapper)
607  fn.__name__ = name
608  return fn
609
610
611class LayerCall(object):
612  """Function that triggers traces of other functions in the same collection."""
613
614  def __init__(self, call_collection, call_fn, name, input_signature):
615    """Initializes a LayerCall object.
616
617    Args:
618      call_collection: a LayerCallCollection, which contains the other layer
619        call functions (e.g. call_with_conditional_losses, call). These
620        functions should be traced with the same arguments.
621      call_fn: A call function.
622      name: Name of the call function.
623      input_signature: Input signature of call_fn (can be None).
624    """
625    self.call_collection = call_collection
626    self.input_signature = input_signature
627    self.wrapped_call = def_function.function(
628        layer_call_wrapper(call_collection, call_fn, name),
629        input_signature=input_signature)
630    self.original_layer_call = call_collection.layer_call_method
631
632  def _maybe_trace(self, args, kwargs):
633    # Trigger traces of other call functions + extra training-arg traces.
634    if tracing_enabled():
635      self.call_collection.add_trace(*args, **kwargs)
636
637  def __call__(self, *args, **kwargs):
638    self._maybe_trace(args, kwargs)
639    return self.wrapped_call(*args, **kwargs)
640
641  def get_concrete_function(self, *args, **kwargs):
642    self._maybe_trace(args, kwargs)
643    return self.wrapped_call.get_concrete_function(*args, **kwargs)
644
645
646def _wrap_call_and_conditional_losses(layer):
647  """Wraps call function that returns a tuple of (outputs, losses).
648
649  The losses returned are conditional on the inputs passed to the call function.
650  Unconditional losses (e.g. weight regularizeration) are wrapped separately.
651
652  Args:
653    layer: a Keras layer object
654
655  Returns:
656    python call function that returns outputs and conditional losses -- excludes
657    activity regularizer
658  """
659  # Create function that generates both outputs and losses
660  layer_call = _get_layer_call_method(layer)
661  def call_and_return_conditional_losses(*args, **kwargs):
662    """Returns layer (call_output, conditional losses) tuple."""
663    call_output = layer_call(*args, **kwargs)
664    if version_utils.is_v1_layer_or_model(layer):
665      conditional_losses = layer.get_losses_for(
666          _filtered_inputs([args, kwargs]))
667    else:
668      conditional_losses = [
669          l for l in layer.losses if not hasattr(l, '_unconditional_loss')
670      ]
671    return call_output, conditional_losses
672
673  return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
674
675
676def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
677  """Returns a function that returns only call function outputs."""
678  if isinstance(layer, keras_load.RevivedLayer):
679    return layer.keras_api.__call__  # pylint: disable=protected-access
680  def call(inputs, *args, **kwargs):
681    return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
682  return _create_call_fn_decorator(layer, call)
683
684
685def _append_activity_regularizer_loss(
686    layer, call_fn_with_losses, activity_regularizer_fn):
687  """Appends activity regularizer loss to losses returned by the wrapped fn."""
688  def fn(inputs, *args, **kwargs):
689    outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
690    losses.append(activity_regularizer_fn(outputs))
691    return outputs, losses
692  return _create_call_fn_decorator(layer, fn)
693
694
695def _create_call_fn_decorator(layer, wrapped_call):
696  call_fn = _get_layer_call_method(layer)
697  fn, arg_spec = utils.maybe_add_training_arg(
698      call_fn, wrapped_call, layer._expects_training_arg,  # pylint: disable=protected-access
699      default_training_value=False)
700  return tf_decorator.make_decorator(
701      target=call_fn,
702      decorator_func=fn,
703      decorator_argspec=arg_spec)
704
705
706def _wrap_unconditional_loss(loss_fn, index):
707  """Wraps callable/unconditional loss, returning a serializable function."""
708  # Extract original loss function from partial function
709  fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
710  if isinstance(fn, def_function.Function):
711    return fn
712  else:
713    return def_function.Function(
714        fn, 'loss_fn_{}'.format(index), input_signature=[])
715
716
717def _wrap_activity_regularizer(layer):
718  """Wraps the activity regularizer."""
719  # pylint: disable=protected-access
720  if isinstance(layer._activity_regularizer, def_function.Function):
721    return layer._activity_regularizer
722  return def_function.Function(
723      layer._activity_regularizer,
724      '{}_activity_regularizer'.format(layer.name),
725      input_signature=[
726          tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx())
727      ])
728  # pylint: enable=protected-access
729
730
731def _get_layer_call_method(layer):
732  if isinstance(layer.call, (def_function.Function)):
733    return layer.call.python_function
734  return layer.call
735