xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""V1 Training-related part of the Keras engine."""
16
17import collections
18import warnings
19
20import numpy as np
21
22from tensorflow.python import tf2
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import iterator_ops
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.distribute import parameter_server_strategy
27from tensorflow.python.distribute import parameter_server_strategy_v2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.keras import backend
38from tensorflow.python.keras import losses
39from tensorflow.python.keras import metrics as metrics_module
40from tensorflow.python.keras import optimizer_v1
41from tensorflow.python.keras import optimizers
42from tensorflow.python.keras.distribute import distributed_training_utils
43from tensorflow.python.keras.distribute import distributed_training_utils_v1
44from tensorflow.python.keras.engine import base_layer
45from tensorflow.python.keras.engine import training as training_lib
46from tensorflow.python.keras.engine import training_arrays_v1
47from tensorflow.python.keras.engine import training_distributed_v1
48from tensorflow.python.keras.engine import training_eager_v1
49from tensorflow.python.keras.engine import training_generator_v1
50from tensorflow.python.keras.engine import training_utils
51from tensorflow.python.keras.engine import training_utils_v1
52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
53from tensorflow.python.keras.mixed_precision import policy
54from tensorflow.python.keras.optimizer_v2 import optimizer_v2
55from tensorflow.python.keras.saving import saving_utils
56from tensorflow.python.keras.saving.saved_model import model_serialization
57from tensorflow.python.keras.utils import data_utils
58from tensorflow.python.keras.utils import layer_utils
59from tensorflow.python.keras.utils import losses_utils
60from tensorflow.python.keras.utils import tf_inspect
61from tensorflow.python.keras.utils import tf_utils
62from tensorflow.python.keras.utils.mode_keys import ModeKeys
63from tensorflow.python.ops import array_ops
64from tensorflow.python.ops import math_ops
65from tensorflow.python.platform import tf_logging as logging
66from tensorflow.python.trackable import base as trackable
67from tensorflow.python.util import nest
68
69try:
70  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
71except ImportError:
72  issparse = None
73
74
75class Model(training_lib.Model):
76  """`Model` groups layers into an object with training and inference features.
77
78  There are two ways to instantiate a `Model`:
79
80  1 - With the "functional API", where you start from `Input`,
81  you chain layer calls to specify the model's forward pass,
82  and finally you create your model from inputs and outputs:
83
84  ```python
85  import tensorflow as tf
86
87  inputs = tf.keras.Input(shape=(3,))
88  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
89  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
90  model = tf.keras.Model(inputs=inputs, outputs=outputs)
91  ```
92
93  2 - By subclassing the `Model` class: in that case, you should define your
94  layers in `__init__` and you should implement the model's forward pass
95  in `call`.
96
97  ```python
98  import tensorflow as tf
99
100  class MyModel(tf.keras.Model):
101
102    def __init__(self):
103      super(MyModel, self).__init__()
104      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
105      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
106
107    def call(self, inputs):
108      x = self.dense1(inputs)
109      return self.dense2(x)
110
111  model = MyModel()
112  ```
113
114  If you subclass `Model`, you can optionally have
115  a `training` argument (boolean) in `call`, which you can use to specify
116  a different behavior in training and inference:
117
118  ```python
119  import tensorflow as tf
120
121  class MyModel(tf.keras.Model):
122
123    def __init__(self):
124      super(MyModel, self).__init__()
125      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
126      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
127      self.dropout = tf.keras.layers.Dropout(0.5)
128
129    def call(self, inputs, training=False):
130      x = self.dense1(inputs)
131      if training:
132        x = self.dropout(x, training=training)
133      return self.dense2(x)
134
135  model = MyModel()
136  ```
137  """
138
139  def __init__(self, *args, **kwargs):
140    super(Model, self).__init__(*args, **kwargs)
141    # initializing _distribution_strategy here since it is possible to call
142    # predict on a model without compiling it.
143    self._distribution_strategy = None
144    self._compile_time_distribution_strategy = None
145    if (ops.executing_eagerly_outside_functions() and
146        distribution_strategy_context.has_strategy()):
147      self._set_strategy(
148          distribution_strategy_context.get_strategy())
149
150    # This flag is used to track if the user is using the deprecated path of
151    # passing distribution strategy to compile rather than creating the model
152    # under distribution strategy scope.
153    self._compile_distribution = False
154
155    self._run_eagerly = None
156    self._experimental_run_tf_function = (
157        ops.executing_eagerly_outside_functions())
158
159    self._v1_compile_was_called = False
160
161  def _init_batch_counters(self):
162    pass  # Batch counters should not be created in legacy graph mode.
163
164  @trackable.no_automatic_dependency_tracking
165  def _set_strategy(self, strategy):
166    self._compile_time_distribution_strategy = strategy
167
168  def get_weights(self):
169    """Retrieves the weights of the model.
170
171    Returns:
172        A flat list of Numpy arrays.
173    """
174    strategy = (self._distribution_strategy or
175                self._compile_time_distribution_strategy)
176    if strategy:
177      with strategy.scope():
178        return base_layer.Layer.get_weights(self)
179    return base_layer.Layer.get_weights(self)
180
181  def load_weights(self, filepath, by_name=False, skip_mismatch=False):
182    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
183
184    If `by_name` is False weights are loaded based on the network's
185    topology. This means the architecture should be the same as when the weights
186    were saved.  Note that layers that don't have weights are not taken into
187    account in the topological ordering, so adding or removing layers is fine as
188    long as they don't have weights.
189
190    If `by_name` is True, weights are loaded into layers only if they share the
191    same name. This is useful for fine-tuning or transfer-learning models where
192    some of the layers have changed.
193
194    Only topological loading (`by_name=False`) is supported when loading weights
195    from the TensorFlow format. Note that topological loading differs slightly
196    between TensorFlow and HDF5 formats for user-defined classes inheriting from
197    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
198    TensorFlow format loads based on the object-local names of attributes to
199    which layers are assigned in the `Model`'s constructor.
200
201    Args:
202        filepath: String, path to the weights file to load. For weight files in
203            TensorFlow format, this is the file prefix (the same as was passed
204            to `save_weights`).
205        by_name: Boolean, whether to load weights by name or by topological
206            order. Only topological loading is supported for weight files in
207            TensorFlow format.
208        skip_mismatch: Boolean, whether to skip loading of layers where there is
209            a mismatch in the number of weights, or a mismatch in the shape of
210            the weight (only valid when `by_name=True`).
211
212    Returns:
213        When loading a weight file in TensorFlow format, returns the same status
214        object as `tf.train.Checkpoint.restore`. When graph building, restore
215        ops are run automatically as soon as the network is built (on first call
216        for user-defined classes inheriting from `Model`, immediately if it is
217        already built).
218
219        When loading weights in HDF5 format, returns `None`.
220
221    Raises:
222        ImportError: If h5py is not available and the weight file is in HDF5
223            format.
224        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
225          `False`.
226    """
227    if backend.is_tpu_strategy(self._distribution_strategy):
228      if (self._distribution_strategy.extended.steps_per_run > 1 and
229          (not saving_utils.is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
230        raise ValueError('Load weights is not yet supported with TPUStrategy '
231                         'with steps_per_run greater than 1.')
232    return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
233
234  @trackable.no_automatic_dependency_tracking
235  def compile(self,
236              optimizer='rmsprop',
237              loss=None,
238              metrics=None,
239              loss_weights=None,
240              sample_weight_mode=None,
241              weighted_metrics=None,
242              target_tensors=None,
243              distribute=None,
244              **kwargs):
245    """Configures the model for training.
246
247    Args:
248        optimizer: String (name of optimizer) or optimizer instance.
249            See `tf.keras.optimizers`.
250        loss: String (name of objective function), objective function or
251            `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
252            function is any callable with the signature
253            `scalar_loss = fn(y_true, y_pred)`. If the model has multiple
254            outputs, you can use a different loss on each output by passing a
255            dictionary or a list of losses. The loss value that will be
256            minimized by the model will then be the sum of all individual
257            losses.
258        metrics: List of metrics to be evaluated by the model during training
259            and testing. Typically you will use `metrics=['accuracy']`.
260            To specify different metrics for different outputs of a
261            multi-output model, you could also pass a dictionary, such as
262            `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
263            You can also pass a list (len = len(outputs)) of lists of metrics
264            such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
265            `metrics=['accuracy', ['accuracy', 'mse']]`.
266        loss_weights: Optional list or dictionary specifying scalar
267            coefficients (Python floats) to weight the loss contributions
268            of different model outputs.
269            The loss value that will be minimized by the model
270            will then be the *weighted sum* of all individual losses,
271            weighted by the `loss_weights` coefficients.
272            If a list, it is expected to have a 1:1 mapping
273            to the model's outputs. If a tensor, it is expected to map
274            output names (strings) to scalar coefficients.
275        sample_weight_mode: If you need to do timestep-wise
276            sample weighting (2D weights), set this to `"temporal"`.
277            `None` defaults to sample-wise weights (1D).
278            If the model has multiple outputs, you can use a different
279            `sample_weight_mode` on each output by passing a
280            dictionary or a list of modes.
281        weighted_metrics: List of metrics to be evaluated and weighted
282            by sample_weight or class_weight during training and testing.
283        target_tensors: By default, Keras will create placeholders for the
284            model's target, which will be fed with the target data during
285            training. If instead you would like to use your own
286            target tensors (in turn, Keras will not expect external
287            Numpy data for these targets at training time), you
288            can specify them via the `target_tensors` argument. It can be
289            a single tensor (for a single-output model), a list of tensors,
290            or a dict mapping output names to target tensors.
291        distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
292            model under distribution strategy scope instead of passing it to
293            compile.
294        **kwargs: Any additional arguments.
295
296    Raises:
297        ValueError: In case of invalid arguments for
298            `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
299    """
300    self._assert_built_as_v1()
301    self._run_eagerly = kwargs.pop('run_eagerly', None)
302    self._experimental_run_tf_function = kwargs.pop(
303        'experimental_run_tf_function', True)
304    self._v1_compile_was_called = True
305
306    # Prepare Session arguments (legacy).
307    kwargs.pop('cloning', None)  # Legacy DistStrat argument, never used.
308    self._from_serialized = kwargs.pop('from_serialized', False)
309    allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'}
310    unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
311    if unknown_kwargs:
312      raise TypeError(
313          'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,))
314    self._function_kwargs = kwargs
315    if self._function_kwargs:
316      self._experimental_run_tf_function = False
317      if self.run_eagerly:
318        raise ValueError(
319            'Session keyword arguments are not supported '
320            'when `run_eagerly=True`. You passed the following '
321            'Session arguments: %s' % (self._function_kwargs,))
322
323    self._set_optimizer(optimizer)
324    is_any_keras_optimizer_v1 = any(
325        (isinstance(opt, optimizer_v1.Optimizer)
326         and not isinstance(opt, optimizer_v1.TFOptimizer)
327        ) for opt in nest.flatten(self.optimizer))
328
329    if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions():
330      raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
331                       'not supported when eager execution is enabled. Use a '
332                       '`tf.keras` Optimizer instead, or disable eager '
333                       'execution.')
334
335    if ((target_tensors is not None)
336        or not ops.executing_eagerly_outside_functions()):
337      # Fallback out of things that aren't supported with v2 loops
338      self._experimental_run_tf_function = False
339
340    if distribute is not None:
341      if tf2.enabled() or self._experimental_run_tf_function:
342        raise ValueError(
343            'Distribute argument in compile is not available in TF 2.0 please '
344            'create the model under the distribution strategy scope.')
345      logging.warning('Distribute argument in compile is deprecated please '
346                      'create the model under the distribution strategy scope.')
347      self._distribution_strategy = distribute
348      self._compile_distribution = True
349    else:
350      if distribution_strategy_context.has_strategy():
351        # When the user builds the model in the DS scope and cross replica
352        # context we want distribution strategy to be set but when building the
353        # replica copies of the models internally we should not be compiling
354        # with distribution strategy and use the default compilation path.
355        if distribution_strategy_context.in_cross_replica_context():
356          self._distribution_strategy = (
357              distribution_strategy_context.get_strategy())
358
359    if isinstance(self._distribution_strategy,
360                  parameter_server_strategy.ParameterServerStrategyV1):
361      raise NotImplementedError(
362          '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` '
363          'currently only works with the tf.Estimator API')
364
365    if isinstance(self._distribution_strategy,
366                  parameter_server_strategy_v2.ParameterServerStrategyV2):
367      raise NotImplementedError(
368          '`tf.distribute.experimental.ParameterServerStrategy` is only '
369          'supported in TF2.')
370
371    if not self._experimental_run_tf_function:
372      self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
373                                                             sample_weight_mode,
374                                                             target_tensors,
375                                                             weighted_metrics)
376    # We've disabled automatic dependency tracking for this method, but do want
377    # to add a checkpoint dependency on the optimizer if it's trackable.
378    if isinstance(self.optimizer, trackable.Trackable):
379      self._track_trackable(
380          self.optimizer, name='optimizer', overwrite=True)
381    self.loss = loss or {}
382    self.loss_weights = loss_weights
383    self.sample_weight_mode = sample_weight_mode
384    self._compile_metrics = metrics or []
385    self._compile_weighted_metrics = weighted_metrics
386    if self.run_eagerly and target_tensors is not None:
387      raise ValueError(
388          'target_tensors argument is not supported when '
389          'running a model eagerly.')
390
391    # _training_endpoints contains a list of _TrainingEndpoint object, which has
392    # all the model output/target/loss and related metadata.
393    self._training_endpoints = []
394
395    # Used to freeze the behavior of the Model once `compile` has been called.
396    self._compiled_trainable_state = self._get_trainable_state()
397
398    # Set tf.distribute.Strategy specific parameters.
399    self._distributed_model_cache = {}
400    self._distributed_function_cache = {}
401
402    # Clear any `_eager_losses` that was added.
403    self._clear_losses()
404
405    if (not context.executing_eagerly() and
406        self._distribution_strategy is not None):
407      # Ensures a Session is created and configured correctly for Distribution
408      # Strategy.
409      backend.configure_and_create_distributed_session(
410          self._distribution_strategy)
411    # Initialize model metric attributes.
412    self._init_metric_attributes()
413    if not self.built or not self.inputs or not self.outputs:
414      # Model is not compilable because it does not know its number of inputs
415      # and outputs, nor their shapes and names. We will compile after the first
416      # time the model gets called on training data.
417      return
418    self._is_compiled = True
419
420    # Prepare list of loss functions, same size of model outputs.
421    self.loss_functions = training_utils_v1.prepare_loss_functions(
422        self.loss, self.output_names)
423
424    target_tensors = self._process_target_tensor_for_compile(target_tensors)
425
426    for o, n, l, t in zip(self.outputs, self.output_names,
427                          self.loss_functions, target_tensors):
428      endpoint = _TrainingEndpoint(o, n, l)
429      endpoint.create_training_target(t, run_eagerly=self.run_eagerly)
430      self._training_endpoints.append(endpoint)
431
432    # Prepare list loss weights, same size of model outputs.
433    training_utils_v1.prepare_loss_weights(self._training_endpoints,
434                                           loss_weights)
435
436    # Initialization for Eager mode execution.
437    if self.run_eagerly:
438      self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode)
439      return
440
441    with backend.get_graph().as_default():
442      # Save all metric attributes per output of the model.
443      self._cache_output_metric_attributes(metrics, weighted_metrics)
444
445      # Set metric attributes on model.
446      self._set_metric_attributes()
447
448      # Invoke metric functions (unweighted) for all the outputs.
449      self._handle_metrics(
450          self.outputs,
451          targets=self._targets,
452          skip_target_masks=self._prepare_skip_target_masks(),
453          masks=self._prepare_output_masks())
454
455      # Prepare sample weight modes. List with the same length as model outputs.
456      training_utils_v1.prepare_sample_weight_modes(
457          self._training_endpoints, sample_weight_mode)
458
459      # Creates the model loss and weighted metrics sub-graphs.
460      self._compile_weights_loss_and_weighted_metrics()
461
462      # Functions for train, test and predict will
463      # be compiled lazily when required.
464      # This saves time when the user is not using all functions.
465      self.train_function = None
466      self.test_function = None
467      self.predict_function = None
468
469      # Collected trainable weights, sorted in topological order.
470      self._collected_trainable_weights = self.trainable_weights
471
472      # Validate all variables were correctly created in distribution scope.
473      if self._distribution_strategy and not self._compile_distribution:
474        for v in self.variables:
475          strategy = self._distribution_strategy
476          if not strategy.extended.variable_created_in_scope(v):
477            raise ValueError(
478                'Variable (%s) was not created in the distribution strategy '
479                'scope of (%s). It is most likely due to not all layers or '
480                'the model or optimizer being created outside the distribution '
481                'strategy scope. Try to make sure your code looks similar '
482                'to the following.\n'
483                'with strategy.scope():\n'
484                '  model=_create_model()\n'
485                '  model.compile(...)'% (v, strategy))
486
487  @trackable.no_automatic_dependency_tracking
488  def _init_distributed_function_cache_if_not_compiled(self):
489    if not hasattr(self, '_distributed_function_cache'):
490      self._distributed_function_cache = {}
491
492  @property
493  def metrics(self):
494    """Returns the model's metrics added using `compile`, `add_metric` APIs."""
495    metrics = []
496    if self._is_compiled:
497      if not hasattr(self, '_v1_compile_was_called'):
498        # See b/155687393 for more details, the model is created as a v2
499        # instance but converted to v1. Fallback to use base Model to retrieve
500        # the metrics.
501        return super(Model, self).metrics
502      metrics += self._compile_metric_functions
503    metrics.extend(self._metrics)
504    metrics.extend(
505        _get_metrics_from_layers(
506            list(self._flatten_layers(include_self=False, recursive=False))))
507    return metrics
508
509  @property
510  def metrics_names(self):
511    """Returns the model's display labels for all outputs."""
512
513    # This property includes all output names including `loss` and per-output
514    # losses for backward compatibility.
515    metrics_names = ['loss']
516    if self._is_compiled:
517      if not hasattr(self, '_v1_compile_was_called'):
518        # See b/155687393 for more details, the model is created as a v2
519        # instance but converted to v1. Fallback to use base Model to retrieve
520        # the metrics name
521        return super(Model, self).metrics_names
522
523      # Add output loss metric names to the metric names list.
524      if len(self._training_endpoints) > 1:
525        metrics_names.extend([
526            e.loss_name()
527            for e in self._training_endpoints
528            if not e.should_skip_target()
529        ])
530
531    # Add all metric names.
532    metrics_names += [m.name for m in self.metrics]
533    return metrics_names
534
535  @property
536  def run_eagerly(self):
537    """Settable attribute indicating whether the model should run eagerly.
538
539    Running eagerly means that your model will be run step by step,
540    like Python code. Your model might run slower, but it should become easier
541    for you to debug it by stepping into individual layer calls.
542
543    By default, we will attempt to compile your model to a static graph to
544    deliver the best execution performance.
545
546    Returns:
547      Boolean, whether the model should run eagerly.
548    """
549    if self._run_eagerly is True and not context.executing_eagerly():
550      raise ValueError('You can only set `run_eagerly=True` if eager execution '
551                       'is enabled.')
552    if not self.dynamic:
553      if self._run_eagerly is None:
554        # Respect `tf.config.run_functions_eagerly` unless
555        # `run_eagerly` was explicitly passed to `compile`.
556        return def_function.functions_run_eagerly()
557      else:
558        return self._run_eagerly
559    else:
560      if not context.executing_eagerly():
561        raise ValueError('Your model contains layers that can only be '
562                         'successfully run in eager execution (layers '
563                         'constructed with `dynamic=True`). '
564                         'You must enable eager execution with '
565                         '`tf.enable_eager_execution()`.')
566      if self._run_eagerly is False:
567        # TODO(fchollet): consider using py_func to enable this.
568        raise ValueError('Your model contains layers that can only be '
569                         'successfully run in eager execution (layers '
570                         'constructed with `dynamic=True`). '
571                         'You cannot set `run_eagerly=False`.')
572      return context.executing_eagerly()
573
574  @run_eagerly.setter
575  def run_eagerly(self, value):
576    self._run_eagerly = value
577
578  def _select_training_loop(self, inputs):
579    """Select training loop for fit/eval/predict based on the inputs."""
580    # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
581    #  integrated into the data adapters in the v2 loop. We can't do this yet
582    #  because we currently have to fall back for unhandled data types.
583    if isinstance(inputs, (iterator_ops.Iterator,
584                           iterator_ops.IteratorBase)):
585      raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
586                       '`predict` accept tf.data `Datasets` as input but not '
587                       'iterators that have been manually generated from '
588                       'Datasets by users. Please directly pass in the '
589                       'original `Dataset` object instead of passing in '
590                       '`iter(dataset)`.')
591
592    # Case 1: distribution strategy.
593    if self._distribution_strategy:
594      if self._in_multi_worker_mode():
595        return training_distributed_v1.DistributionMultiWorkerTrainingLoop(
596            training_distributed_v1.DistributionSingleWorkerTrainingLoop())
597      else:
598        return training_distributed_v1.DistributionSingleWorkerTrainingLoop()
599
600    # Case 2: generator-like. Input is Python generator, or Sequence object,
601    # or a non-distributed Dataset or iterator in eager execution.
602    if data_utils.is_generator_or_sequence(inputs):
603      return training_generator_v1.GeneratorOrSequenceTrainingLoop()
604    if training_utils_v1.is_eager_dataset_or_iterator(inputs):
605      return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
606
607    # Case 3: Symbolic tensors or Numpy array-like.
608    # This includes Datasets and iterators in graph mode (since they
609    # generate symbolic tensors).
610    if self.run_eagerly:
611      return training_generator_v1.GeneratorLikeTrainingLoop()
612    else:
613      return training_arrays_v1.ArrayLikeTrainingLoop()
614
615  def fit(self,
616          x=None,
617          y=None,
618          batch_size=None,
619          epochs=1,
620          verbose=1,
621          callbacks=None,
622          validation_split=0.,
623          validation_data=None,
624          shuffle=True,
625          class_weight=None,
626          sample_weight=None,
627          initial_epoch=0,
628          steps_per_epoch=None,
629          validation_steps=None,
630          validation_freq=1,
631          max_queue_size=10,
632          workers=1,
633          use_multiprocessing=False,
634          **kwargs):
635    """Trains the model for a fixed number of epochs (iterations on a dataset).
636
637    Args:
638        x: Input data. It could be:
639          - A Numpy array (or array-like), or a list of arrays
640            (in case the model has multiple inputs).
641          - A TensorFlow tensor, or a list of tensors
642            (in case the model has multiple inputs).
643          - A dict mapping input names to the corresponding array/tensors,
644            if the model has named inputs.
645          - A `tf.data` dataset. Should return a tuple
646            of either `(inputs, targets)` or
647            `(inputs, targets, sample_weights)`.
648          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
649            or `(inputs, targets, sample weights)`.
650        y: Target data. Like the input data `x`,
651          it could be either Numpy array(s) or TensorFlow tensor(s).
652          It should be consistent with `x` (you cannot have Numpy inputs and
653          tensor targets, or inversely). If `x` is a dataset, generator,
654          or `keras.utils.Sequence` instance, `y` should
655          not be specified (since targets will be obtained from `x`).
656        batch_size: Integer or `None`.
657            Number of samples per gradient update.
658            If unspecified, `batch_size` will default to 32.
659            Do not specify the `batch_size` if your data is in the
660            form of symbolic tensors, datasets,
661            generators, or `keras.utils.Sequence` instances (since they generate
662            batches).
663        epochs: Integer. Number of epochs to train the model.
664            An epoch is an iteration over the entire `x` and `y`
665            data provided.
666            Note that in conjunction with `initial_epoch`,
667            `epochs` is to be understood as "final epoch".
668            The model is not trained for a number of iterations
669            given by `epochs`, but merely until the epoch
670            of index `epochs` is reached.
671        verbose: 0, 1, or 2. Verbosity mode.
672            0 = silent, 1 = progress bar, 2 = one line per epoch.
673            Note that the progress bar is not particularly useful when
674            logged to a file, so verbose=2 is recommended when not running
675            interactively (eg, in a production environment).
676        callbacks: List of `keras.callbacks.Callback` instances.
677            List of callbacks to apply during training.
678            See `tf.keras.callbacks`.
679        validation_split: Float between 0 and 1.
680            Fraction of the training data to be used as validation data.
681            The model will set apart this fraction of the training data,
682            will not train on it, and will evaluate
683            the loss and any model metrics
684            on this data at the end of each epoch.
685            The validation data is selected from the last samples
686            in the `x` and `y` data provided, before shuffling. This argument is
687            not supported when `x` is a dataset, generator or
688           `keras.utils.Sequence` instance.
689        validation_data: Data on which to evaluate
690            the loss and any model metrics at the end of each epoch.
691            The model will not be trained on this data.
692            `validation_data` will override `validation_split`.
693            `validation_data` could be:
694              - tuple `(x_val, y_val)` of Numpy arrays or tensors
695              - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
696              - dataset
697            For the first two cases, `batch_size` must be provided.
698            For the last case, `validation_steps` could be provided.
699        shuffle: Boolean (whether to shuffle the training data
700            before each epoch) or str (for 'batch').
701            'batch' is a special option for dealing with the
702            limitations of HDF5 data; it shuffles in batch-sized chunks.
703            Has no effect when `steps_per_epoch` is not `None`.
704        class_weight: Optional dictionary mapping class indices (integers)
705            to a weight (float) value, used for weighting the loss function
706            (during training only).
707            This can be useful to tell the model to
708            "pay more attention" to samples from
709            an under-represented class.
710        sample_weight: Optional Numpy array of weights for
711            the training samples, used for weighting the loss function
712            (during training only). You can either pass a flat (1D)
713            Numpy array with the same length as the input samples
714            (1:1 mapping between weights and samples),
715            or in the case of temporal data,
716            you can pass a 2D array with shape
717            `(samples, sequence_length)`,
718            to apply a different weight to every timestep of every sample.
719            In this case you should make sure to specify
720            `sample_weight_mode="temporal"` in `compile()`. This argument is not
721            supported when `x` is a dataset, generator, or
722           `keras.utils.Sequence` instance, instead provide the sample_weights
723            as the third element of `x`.
724        initial_epoch: Integer.
725            Epoch at which to start training
726            (useful for resuming a previous training run).
727        steps_per_epoch: Integer or `None`.
728            Total number of steps (batches of samples)
729            before declaring one epoch finished and starting the
730            next epoch. When training with input tensors such as
731            TensorFlow data tensors, the default `None` is equal to
732            the number of samples in your dataset divided by
733            the batch size, or 1 if that cannot be determined. If x is a
734            `tf.data` dataset, and 'steps_per_epoch'
735            is None, the epoch will run until the input dataset is exhausted.
736            This argument is not supported with array inputs.
737        validation_steps: Only relevant if `validation_data` is provided and
738            is a `tf.data` dataset. Total number of steps (batches of
739            samples) to draw before stopping when performing validation
740            at the end of every epoch. If 'validation_steps' is None, validation
741            will run until the `validation_data` dataset is exhausted. In the
742            case of a infinite dataset, it will run into a infinite loop.
743            If 'validation_steps' is specified and only part of the dataset
744            will be consumed, the evaluation will start from the beginning of
745            the dataset at each epoch. This ensures that the same validation
746            samples are used every time.
747        validation_freq: Only relevant if validation data is provided. Integer
748            or `collections.abc.Container` instance (e.g. list, tuple, etc.).
749            If an integer, specifies how many training epochs to run before a
750            new validation run is performed, e.g. `validation_freq=2` runs
751            validation every 2 epochs. If a Container, specifies the epochs on
752            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
753            validation at the end of the 1st, 2nd, and 10th epochs.
754        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
755            input only. Maximum size for the generator queue.
756            If unspecified, `max_queue_size` will default to 10.
757        workers: Integer. Used for generator or `keras.utils.Sequence` input
758            only. Maximum number of processes to spin up
759            when using process-based threading. If unspecified, `workers`
760            will default to 1. If 0, will execute the generator on the main
761            thread.
762        use_multiprocessing: Boolean. Used for generator or
763            `keras.utils.Sequence` input only. If `True`, use process-based
764            threading. If unspecified, `use_multiprocessing` will default to
765            `False`. Note that because this implementation relies on
766            multiprocessing, you should not pass non-picklable arguments to
767            the generator as they can't be passed easily to children processes.
768        **kwargs: Used for backwards compatibility.
769
770    Returns:
771        A `History` object. Its `History.history` attribute is
772        a record of training loss values and metrics values
773        at successive epochs, as well as validation loss values
774        and validation metrics values (if applicable).
775
776    Raises:
777        RuntimeError: If the model was never compiled.
778        ValueError: In case of mismatch between the provided input data
779            and what the model expects.
780    """
781    self._assert_built_as_v1()
782    # Legacy support
783    if 'nb_epoch' in kwargs:
784      logging.warning(
785          'The `nb_epoch` argument in `fit` has been renamed `epochs`.')
786      epochs = kwargs.pop('nb_epoch')
787    if kwargs:
788      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
789    self._assert_compile_was_called()
790    self._check_call_args('fit')
791
792    func = self._select_training_loop(x)
793    return func.fit(
794        self,
795        x=x,
796        y=y,
797        batch_size=batch_size,
798        epochs=epochs,
799        verbose=verbose,
800        callbacks=callbacks,
801        validation_split=validation_split,
802        validation_data=validation_data,
803        shuffle=shuffle,
804        class_weight=class_weight,
805        sample_weight=sample_weight,
806        initial_epoch=initial_epoch,
807        steps_per_epoch=steps_per_epoch,
808        validation_steps=validation_steps,
809        validation_freq=validation_freq,
810        max_queue_size=max_queue_size,
811        workers=workers,
812        use_multiprocessing=use_multiprocessing)
813
814  def evaluate(self,
815               x=None,
816               y=None,
817               batch_size=None,
818               verbose=1,
819               sample_weight=None,
820               steps=None,
821               callbacks=None,
822               max_queue_size=10,
823               workers=1,
824               use_multiprocessing=False):
825    """Returns the loss value & metrics values for the model in test mode.
826
827    Computation is done in batches (see the `batch_size` arg.)
828
829    Args:
830        x: Input data. It could be:
831          - A Numpy array (or array-like), or a list of arrays
832            (in case the model has multiple inputs).
833          - A TensorFlow tensor, or a list of tensors
834            (in case the model has multiple inputs).
835          - A dict mapping input names to the corresponding array/tensors,
836            if the model has named inputs.
837          - A `tf.data` dataset.
838          - A generator or `keras.utils.Sequence` instance.
839        y: Target data. Like the input data `x`,
840          it could be either Numpy array(s) or TensorFlow tensor(s).
841          It should be consistent with `x` (you cannot have Numpy inputs and
842          tensor targets, or inversely).
843          If `x` is a dataset, generator or
844          `keras.utils.Sequence` instance, `y` should not be specified (since
845          targets will be obtained from the iterator/dataset).
846        batch_size: Integer or `None`.
847            Number of samples per batch of computation.
848            If unspecified, `batch_size` will default to 32.
849            Do not specify the `batch_size` if your data is in the
850            form of symbolic tensors, dataset,
851            generators, or `keras.utils.Sequence` instances (since they generate
852            batches).
853        verbose: 0 or 1. Verbosity mode.
854            0 = silent, 1 = progress bar.
855        sample_weight: Optional Numpy array of weights for
856            the test samples, used for weighting the loss function.
857            You can either pass a flat (1D)
858            Numpy array with the same length as the input samples
859            (1:1 mapping between weights and samples),
860            or in the case of temporal data,
861            you can pass a 2D array with shape
862            `(samples, sequence_length)`,
863            to apply a different weight to every timestep of every sample.
864            In this case you should make sure to specify
865            `sample_weight_mode="temporal"` in `compile()`. This argument is not
866            supported when `x` is a dataset, instead pass
867            sample weights as the third element of `x`.
868        steps: Integer or `None`.
869            Total number of steps (batches of samples)
870            before declaring the evaluation round finished.
871            Ignored with the default value of `None`.
872            If x is a `tf.data` dataset and `steps` is
873            None, 'evaluate' will run until the dataset is exhausted.
874            This argument is not supported with array inputs.
875        callbacks: List of `keras.callbacks.Callback` instances.
876            List of callbacks to apply during evaluation.
877            See [callbacks](/api_docs/python/tf/keras/callbacks).
878        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
879            input only. Maximum size for the generator queue.
880            If unspecified, `max_queue_size` will default to 10.
881        workers: Integer. Used for generator or `keras.utils.Sequence` input
882            only. Maximum number of processes to spin up when using
883            process-based threading. If unspecified, `workers` will default
884            to 1. If 0, will execute the generator on the main thread.
885        use_multiprocessing: Boolean. Used for generator or
886            `keras.utils.Sequence` input only. If `True`, use process-based
887            threading. If unspecified, `use_multiprocessing` will default to
888            `False`. Note that because this implementation relies on
889            multiprocessing, you should not pass non-picklable arguments to
890            the generator as they can't be passed easily to children processes.
891
892    Returns:
893        Scalar test loss (if the model has a single output and no metrics)
894        or list of scalars (if the model has multiple outputs
895        and/or metrics). The attribute `model.metrics_names` will give you
896        the display labels for the scalar outputs.
897
898    Raises:
899        ValueError: in case of invalid arguments.
900    """
901    self._assert_built_as_v1()
902    self._assert_compile_was_called()
903    self._check_call_args('evaluate')
904
905    func = self._select_training_loop(x)
906    return func.evaluate(
907        self,
908        x=x,
909        y=y,
910        batch_size=batch_size,
911        verbose=verbose,
912        sample_weight=sample_weight,
913        steps=steps,
914        callbacks=callbacks,
915        max_queue_size=max_queue_size,
916        workers=workers,
917        use_multiprocessing=use_multiprocessing)
918
919  def predict(self,
920              x,
921              batch_size=None,
922              verbose=0,
923              steps=None,
924              callbacks=None,
925              max_queue_size=10,
926              workers=1,
927              use_multiprocessing=False):
928    """Generates output predictions for the input samples.
929
930    Computation is done in batches (see the `batch_size` arg.)
931
932    Args:
933        x: Input samples. It could be:
934          - A Numpy array (or array-like), or a list of arrays
935            (in case the model has multiple inputs).
936          - A TensorFlow tensor, or a list of tensors
937            (in case the model has multiple inputs).
938          - A `tf.data` dataset.
939          - A generator or `keras.utils.Sequence` instance.
940        batch_size: Integer or `None`.
941            Number of samples per batch of computation.
942            If unspecified, `batch_size` will default to 32.
943            Do not specify the `batch_size` if your data is in the
944            form of symbolic tensors, dataset,
945            generators, or `keras.utils.Sequence` instances (since they generate
946            batches).
947        verbose: Verbosity mode, 0 or 1.
948        steps: Total number of steps (batches of samples)
949            before declaring the prediction round finished.
950            Ignored with the default value of `None`. If x is a `tf.data`
951            dataset and `steps` is None, `predict` will
952            run until the input dataset is exhausted.
953        callbacks: List of `keras.callbacks.Callback` instances.
954            List of callbacks to apply during prediction.
955            See [callbacks](/api_docs/python/tf/keras/callbacks).
956        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
957            input only. Maximum size for the generator queue.
958            If unspecified, `max_queue_size` will default to 10.
959        workers: Integer. Used for generator or `keras.utils.Sequence` input
960            only. Maximum number of processes to spin up when using
961            process-based threading. If unspecified, `workers` will default
962            to 1. If 0, will execute the generator on the main thread.
963        use_multiprocessing: Boolean. Used for generator or
964            `keras.utils.Sequence` input only. If `True`, use process-based
965            threading. If unspecified, `use_multiprocessing` will default to
966            `False`. Note that because this implementation relies on
967            multiprocessing, you should not pass non-picklable arguments to
968            the generator as they can't be passed easily to children processes.
969
970
971    Returns:
972        Numpy array(s) of predictions.
973
974    Raises:
975        ValueError: In case of mismatch between the provided
976            input data and the model's expectations,
977            or in case a stateful model receives a number of samples
978            that is not a multiple of the batch size.
979    """
980    self._assert_built_as_v1()
981    self._check_call_args('predict')
982
983    func = self._select_training_loop(x)
984    return func.predict(
985        self,
986        x=x,
987        batch_size=batch_size,
988        verbose=verbose,
989        steps=steps,
990        callbacks=callbacks,
991        max_queue_size=max_queue_size,
992        workers=workers,
993        use_multiprocessing=use_multiprocessing)
994
995  def reset_metrics(self):
996    """Resets the state of metrics."""
997    metrics = self._get_training_eval_metrics()
998    for m in metrics:
999      m.reset_state()
1000
1001    # Reset metrics on all the distributed (cloned) models.
1002    if self._distribution_strategy:
1003      distributed_training_utils_v1._reset_metrics(self)  # pylint: disable=protected-access
1004
1005  def train_on_batch(self,
1006                     x,
1007                     y=None,
1008                     sample_weight=None,
1009                     class_weight=None,
1010                     reset_metrics=True):
1011    """Runs a single gradient update on a single batch of data.
1012
1013    Args:
1014        x: Input data. It could be:
1015          - A Numpy array (or array-like), or a list of arrays
1016              (in case the model has multiple inputs).
1017          - A TensorFlow tensor, or a list of tensors
1018              (in case the model has multiple inputs).
1019          - A dict mapping input names to the corresponding array/tensors,
1020              if the model has named inputs.
1021          - A `tf.data` dataset.
1022        y: Target data. Like the input data `x`, it could be either Numpy
1023          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1024          (you cannot have Numpy inputs and tensor targets, or inversely). If
1025          `x` is a dataset, `y` should not be specified
1026          (since targets will be obtained from the iterator).
1027        sample_weight: Optional array of the same length as x, containing
1028          weights to apply to the model's loss for each sample. In the case of
1029          temporal data, you can pass a 2D array with shape (samples,
1030          sequence_length), to apply a different weight to every timestep of
1031          every sample. In this case you should make sure to specify
1032          sample_weight_mode="temporal" in compile(). This argument is not
1033          supported when `x` is a dataset.
1034        class_weight: Optional dictionary mapping class indices (integers) to a
1035          weight (float) to apply to the model's loss for the samples from this
1036          class during training. This can be useful to tell the model to "pay
1037          more attention" to samples from an under-represented class.
1038        reset_metrics: If `True`, the metrics returned will be only for this
1039          batch. If `False`, the metrics will be statefully accumulated across
1040          batches.
1041
1042    Returns:
1043        Scalar training loss
1044        (if the model has a single output and no metrics)
1045        or list of scalars (if the model has multiple outputs
1046        and/or metrics). The attribute `model.metrics_names` will give you
1047        the display labels for the scalar outputs.
1048
1049    Raises:
1050      ValueError: In case of invalid user-provided arguments.
1051    """
1052    self._assert_compile_was_called()
1053    self._check_call_args('train_on_batch')
1054
1055    # If at this point we are in the replica context, then it is okay to execute
1056    # the Eager code path.  The expected way to get here is to call `fit` that
1057    # calls `train_on_batch` on each replica.
1058    if (self._distribution_strategy and
1059        distribution_strategy_context.in_cross_replica_context()):
1060      raise NotImplementedError('`train_on_batch` is not supported for models '
1061                                'distributed with tf.distribute.Strategy.')
1062    # Validate and standardize user data.
1063    x, y, sample_weights = self._standardize_user_data(
1064        x, y, sample_weight=sample_weight, class_weight=class_weight,
1065        extract_tensors_from_dataset=True)
1066
1067    # If `self._distribution_strategy` is True, then we are in a replica context
1068    # at this point because of the check above.  `train_on_batch` is being run
1069    # for each replica by `self._distribution_strategy` and the same code path
1070    # as Eager is expected to be taken.
1071    if self.run_eagerly or self._distribution_strategy:
1072      output_dict = training_eager_v1.train_on_batch(
1073          self,
1074          x,
1075          y,
1076          sample_weights=sample_weights,
1077          output_loss_metrics=self._output_loss_metrics)
1078      outputs = (output_dict['total_loss'] + output_dict['output_losses']
1079                 + output_dict['metrics'])
1080      outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
1081    else:
1082      x = training_utils_v1.ModelInputs(x).as_list()
1083      ins = x + list(y or []) + list(sample_weights or [])
1084
1085      if not isinstance(backend.symbolic_learning_phase(), int):
1086        ins += [True]  # Add learning phase value.
1087
1088      self._update_sample_weight_modes(sample_weights=sample_weights)
1089      self._make_train_function()
1090      outputs = self.train_function(ins)  # pylint: disable=not-callable
1091
1092    if reset_metrics:
1093      self.reset_metrics()
1094
1095    if len(outputs) == 1:
1096      return outputs[0]
1097    return outputs
1098
1099  def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
1100    """Test the model on a single batch of samples.
1101
1102    Args:
1103        x: Input data. It could be:
1104          - A Numpy array (or array-like), or a list of arrays
1105            (in case the model has multiple inputs).
1106          - A TensorFlow tensor, or a list of tensors
1107            (in case the model has multiple inputs).
1108          - A dict mapping input names to the corresponding array/tensors,
1109            if the model has named inputs.
1110          - A `tf.data` dataset.
1111        y: Target data. Like the input data `x`,
1112          it could be either Numpy array(s) or TensorFlow tensor(s).
1113          It should be consistent with `x` (you cannot have Numpy inputs and
1114          tensor targets, or inversely). If `x` is a dataset `y` should
1115          not be specified (since targets will be obtained from the iterator).
1116        sample_weight: Optional array of the same length as x, containing
1117            weights to apply to the model's loss for each sample.
1118            In the case of temporal data, you can pass a 2D array
1119            with shape (samples, sequence_length),
1120            to apply a different weight to every timestep of every sample.
1121            In this case you should make sure to specify
1122            sample_weight_mode="temporal" in compile(). This argument is not
1123            supported when `x` is a dataset.
1124        reset_metrics: If `True`, the metrics returned will be only for this
1125          batch. If `False`, the metrics will be statefully accumulated across
1126          batches.
1127
1128    Returns:
1129        Scalar test loss (if the model has a single output and no metrics)
1130        or list of scalars (if the model has multiple outputs
1131        and/or metrics). The attribute `model.metrics_names` will give you
1132        the display labels for the scalar outputs.
1133
1134    Raises:
1135        ValueError: In case of invalid user-provided arguments.
1136    """
1137    self._assert_compile_was_called()
1138    self._check_call_args('test_on_batch')
1139
1140    if (self._distribution_strategy and
1141        distribution_strategy_context.in_cross_replica_context()):
1142      raise NotImplementedError('`test_on_batch` is not supported for models '
1143                                'distributed with tf.distribute.Strategy.')
1144    # Validate and standardize user data.
1145    x, y, sample_weights = self._standardize_user_data(
1146        x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
1147
1148    # If `self._distribution_strategy` is True, then we are in a replica context
1149    # at this point.
1150    if self.run_eagerly or self._distribution_strategy:
1151      output_dict = training_eager_v1.test_on_batch(
1152          self,
1153          x,
1154          y,
1155          sample_weights=sample_weights,
1156          output_loss_metrics=self._output_loss_metrics)
1157      outputs = (output_dict['total_loss'] + output_dict['output_losses']
1158                 + output_dict['metrics'])
1159      outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
1160    else:
1161      x = training_utils_v1.ModelInputs(x).as_list()
1162      inputs = x + list(y or []) + list(sample_weights or [])
1163
1164      self._update_sample_weight_modes(sample_weights=sample_weights)
1165      self._make_test_function()
1166      outputs = self.test_function(inputs)  # pylint: disable=not-callable
1167
1168    if reset_metrics:
1169      self.reset_metrics()
1170
1171    if len(outputs) == 1:
1172      return outputs[0]
1173    return outputs
1174
1175  def predict_on_batch(self, x):
1176    """Returns predictions for a single batch of samples.
1177
1178    Args:
1179        x: Input data. It could be:
1180          - A Numpy array (or array-like), or a list of arrays
1181            (in case the model has multiple inputs).
1182          - A TensorFlow tensor, or a list of tensors
1183            (in case the model has multiple inputs).
1184          - A `tf.data` dataset.
1185
1186    Returns:
1187        Numpy array(s) of predictions.
1188
1189    Raises:
1190        ValueError: In case of mismatch between given number of inputs and
1191          expectations of the model.
1192    """
1193    self._check_call_args('predict_on_batch')
1194
1195    if (self._distribution_strategy and
1196        distribution_strategy_context.in_cross_replica_context()):
1197      raise NotImplementedError(
1198          '`predict_on_batch` is not supported for models distributed with'
1199          ' tf.distribute.Strategy.')
1200    # Validate and standardize user data.
1201    inputs, _, _ = self._standardize_user_data(
1202        x, extract_tensors_from_dataset=True)
1203    # If `self._distribution_strategy` is True, then we are in a replica context
1204    # at this point.
1205    if self.run_eagerly or self._distribution_strategy:
1206      inputs = training_utils_v1.cast_if_floating_dtype(inputs)
1207      if isinstance(inputs, collections.abc.Sequence):
1208        # Unwrap lists with only one input, as we do when training on batch
1209        if len(inputs) == 1:
1210          inputs = inputs[0]
1211
1212      return self(inputs)  # pylint: disable=not-callable
1213
1214    self._make_predict_function()
1215    outputs = self.predict_function(inputs)
1216
1217    if len(outputs) == 1:
1218      return outputs[0]
1219    return outputs
1220
1221  def fit_generator(self,
1222                    generator,
1223                    steps_per_epoch=None,
1224                    epochs=1,
1225                    verbose=1,
1226                    callbacks=None,
1227                    validation_data=None,
1228                    validation_steps=None,
1229                    validation_freq=1,
1230                    class_weight=None,
1231                    max_queue_size=10,
1232                    workers=1,
1233                    use_multiprocessing=False,
1234                    shuffle=True,
1235                    initial_epoch=0):
1236    """Fits the model on data yielded batch-by-batch by a Python generator.
1237
1238    DEPRECATED:
1239      `Model.fit` now supports generators, so there is no longer any need to use
1240      this endpoint.
1241    """
1242    warnings.warn('`model.fit_generator` is deprecated and '
1243                  'will be removed in a future version. '
1244                  'Please use `Model.fit`, which supports generators.')
1245    return self.fit(
1246        generator,
1247        steps_per_epoch=steps_per_epoch,
1248        epochs=epochs,
1249        verbose=verbose,
1250        callbacks=callbacks,
1251        validation_data=validation_data,
1252        validation_steps=validation_steps,
1253        validation_freq=validation_freq,
1254        class_weight=class_weight,
1255        max_queue_size=max_queue_size,
1256        workers=workers,
1257        use_multiprocessing=use_multiprocessing,
1258        shuffle=shuffle,
1259        initial_epoch=initial_epoch)
1260
1261  def evaluate_generator(self,
1262                         generator,
1263                         steps=None,
1264                         callbacks=None,
1265                         max_queue_size=10,
1266                         workers=1,
1267                         use_multiprocessing=False,
1268                         verbose=0):
1269    """Evaluates the model on a data generator.
1270
1271    DEPRECATED:
1272      `Model.evaluate` now supports generators, so there is no longer any need
1273      to use this endpoint.
1274    """
1275    warnings.warn('`Model.evaluate_generator` is deprecated and '
1276                  'will be removed in a future version. '
1277                  'Please use `Model.evaluate`, which supports generators.')
1278    self._check_call_args('evaluate_generator')
1279
1280    return self.evaluate(
1281        generator,
1282        steps=steps,
1283        max_queue_size=max_queue_size,
1284        workers=workers,
1285        use_multiprocessing=use_multiprocessing,
1286        verbose=verbose,
1287        callbacks=callbacks)
1288
1289  def predict_generator(self,
1290                        generator,
1291                        steps=None,
1292                        callbacks=None,
1293                        max_queue_size=10,
1294                        workers=1,
1295                        use_multiprocessing=False,
1296                        verbose=0):
1297    """Generates predictions for the input samples from a data generator.
1298
1299    DEPRECATED:
1300      `Model.predict` now supports generators, so there is no longer any need
1301      to use this endpoint.
1302    """
1303    warnings.warn('`Model.predict_generator` is deprecated and '
1304                  'will be removed in a future version. '
1305                  'Please use `Model.predict`, which supports generators.')
1306    return self.predict(
1307        generator,
1308        steps=steps,
1309        max_queue_size=max_queue_size,
1310        workers=workers,
1311        use_multiprocessing=use_multiprocessing,
1312        verbose=verbose,
1313        callbacks=callbacks)
1314
1315  def _check_call_args(self, method_name):
1316    """Check that `call` has only one positional arg."""
1317    # Always allow first arg, regardless of arg name.
1318    fullargspec = self._call_full_argspec
1319    if fullargspec.defaults:
1320      positional_args = fullargspec.args[:-len(fullargspec.defaults)]
1321    else:
1322      positional_args = fullargspec.args
1323    if 'training' in positional_args:
1324      positional_args.remove('training')
1325
1326    # self and first arg can be positional.
1327    if len(positional_args) > 2:
1328      extra_args = positional_args[2:]
1329      raise ValueError(
1330          'Models passed to `' + method_name + '` can only have `training` '
1331          'and the first argument in `call` as positional arguments, '
1332          'found: ' + str(extra_args) + '.')
1333
1334  def _set_optimizer(self, optimizer):
1335    """Sets self.optimizer.
1336
1337    Sets self.optimizer to `optimizer`, potentially wrapping it with a
1338    LossScaleOptimizer.
1339
1340    Args:
1341      optimizer: The optimizer(s) to assign to self.optimizer.
1342    """
1343    if isinstance(optimizer, (list, tuple)):
1344      self.optimizer = [optimizers.get(opt) for opt in optimizer]
1345    else:
1346      self.optimizer = optimizers.get(optimizer)
1347
1348    if isinstance(self._dtype_policy, policy.PolicyV1):
1349      loss_scale = self._dtype_policy.loss_scale
1350    elif self._dtype_policy.name == 'mixed_float16':
1351      loss_scale = 'dynamic'
1352    else:
1353      loss_scale = None
1354
1355    if (loss_scale is not None and
1356        not isinstance(self.optimizer,
1357                       loss_scale_optimizer.LossScaleOptimizer)):
1358      if isinstance(self.optimizer, list):
1359        raise ValueError('When a dtype policy with a loss scale is used, you '
1360                         'can only pass a single optimizer. Using policy %s '
1361                         'and got optimizers: %s' %
1362                         self._dtype_policy, self.optimizer)
1363      if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
1364        raise ValueError('"optimizer" must be an instance of '
1365                         'tf.keras.optimizers.Optimizer when a dype policy '
1366                         'with a loss scale  used, but got: %s. Using policy: '
1367                         '%s' %
1368                         (self.optimizer, self._dtype_policy))
1369      if loss_scale == 'dynamic':
1370        self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer)
1371      else:
1372        self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1(
1373            self.optimizer, loss_scale)
1374
1375  def _prepare_validation_data(self, validation_data, batch_size,
1376                               validation_steps):
1377    """Unpack and check the validation data."""
1378    val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data(
1379        validation_data)
1380    return self._standardize_user_data(
1381        val_x,
1382        val_y,
1383        sample_weight=val_sample_weights,
1384        batch_size=batch_size,
1385        steps=validation_steps,
1386        steps_name='validation_steps')
1387
1388  def _validate_compile_param_for_distribution_strategy(
1389      self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics):
1390    # Validate that arguments passed by the user to `compile` are supported by
1391    # tf.distribute.Strategy.
1392    if self._distribution_strategy:
1393      if sample_weight_mode:
1394        raise NotImplementedError('sample_weight_mode is not supported with '
1395                                  'tf.distribute.Strategy.')
1396      if weighted_metrics:
1397        raise NotImplementedError('weighted_metrics is not supported with '
1398                                  'tf.distribute.Strategy.')
1399      if target_tensors:
1400        raise ValueError('target_tensors is not supported with '
1401                         'tf.distribute.Strategy.')
1402
1403      if run_eagerly:
1404        raise ValueError(
1405            'We currently do not support enabling `run_eagerly` with '
1406            'distribution strategy.')
1407
1408      if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
1409          (not self.built or not self.inputs or not self.outputs)):
1410        raise ValueError(
1411            'We currently do not support distribution strategy with a '
1412            '`Sequential` model that is created without `input_shape`/'
1413            '`input_dim` set in its first layer or a subclassed model.')
1414
1415  def _process_target_tensor_for_compile(self, target_tensors):
1416    if self.run_eagerly:
1417      # target tensor is not supported with run_eagerly. Create a list with None
1418      # as placeholder for each output.
1419      return [None for _ in self.output_names]
1420
1421    if target_tensors is not None and not (isinstance(target_tensors, list) and
1422                                           target_tensors == []):  # pylint: disable=g-explicit-bool-comparison
1423      if isinstance(target_tensors, list):
1424        if len(target_tensors) != len(self.outputs):
1425          raise ValueError(
1426              'When passing a list as `target_tensors`, '
1427              'it should have one entry per model output. '
1428              'The model has %s outputs, but you passed target_tensors=%s' %
1429              (len(self.outputs), target_tensors))
1430      elif isinstance(target_tensors, dict):
1431        unexpected_target_tensor_names = set(target_tensors.keys()).difference(
1432            self.output_names)
1433        if unexpected_target_tensor_names:
1434          raise ValueError(
1435              'Unknown entry in `target_tensors` dictionary: "{name}". '
1436              'Only expected the following keys: {keys}'.format(
1437                  name=unexpected_target_tensor_names,
1438                  keys=str(self.output_names)))
1439        tmp_target_tensors = []
1440        for name in self.output_names:
1441          tmp_target_tensors.append(target_tensors.get(name, None))
1442        target_tensors = tmp_target_tensors
1443      elif tensor_util.is_tf_type(target_tensors):
1444        target_tensors = [target_tensors]
1445      else:
1446        raise TypeError('Expected `target_tensors` to be a list or tuple or '
1447                        'dict or a single tensor, but got:', target_tensors)
1448    else:
1449      # In case target tensor is empty or None, create a list with Nones
1450      # that has same length as self.output_names. With that, the None check of
1451      # target tensor can be skipped downstream.
1452      target_tensors = [None for _ in self.output_names]
1453    return target_tensors
1454
1455  def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
1456    # Prepare sample weight modes. List with the same length as model outputs.
1457    training_utils_v1.prepare_sample_weight_modes(
1458        self._training_endpoints, sample_weight_mode)
1459    # Prepare sample weights.
1460    self._prepare_sample_weights()
1461    # Save all metric attributes per output of the model.
1462    self._cache_output_metric_attributes(metrics, weighted_metrics)
1463    self.total_loss = None
1464    # Set metric attributes on model.
1465    self._set_metric_attributes()
1466
1467    self._collected_trainable_weights = self.trainable_weights
1468
1469  def _update_sample_weight_modes(self, sample_weights=None):
1470    """Updates sample weight modes based on training/eval inputs.
1471
1472    Sample weight placeholders will be created for all or no outputs
1473    based on whether sample_weight is provided for any output.
1474
1475    If model contains `_sample_weight_modes` we check if the input
1476    `sample_weights` corresponds to the sample weight modes.
1477      1. Set sample weight mode to be 'temporal' for output i, if `compile`
1478        sample_weight_mode was set to `temporal` and sample weight inputs
1479        are given for one or more outputs.
1480      2. Set sample weight mode to be 'samplewise' for output i, if `compile`
1481        sample_weight_mode was not set and sample weight inputs are given for
1482        one or more outputs.
1483      3. Reset sample weight mode to None for output i if sample weight mode
1484        was set but there is no sample weight input.
1485
1486    Args:
1487      sample_weights: List of sample weights of the same length as model outputs
1488        or None.
1489    """
1490    if not self._is_compiled:
1491      return
1492    if sample_weights and any(s is not None for s in sample_weights):
1493      for endpoint in self._training_endpoints:
1494        endpoint.sample_weight_mode = (
1495            endpoint.sample_weight_mode or 'samplewise')
1496    else:
1497      for endpoint in self._training_endpoints:
1498        endpoint.sample_weight_mode = None
1499
1500  def _recompile_weights_loss_and_weighted_metrics(self):
1501    if not self._is_compiled:
1502      return False
1503    recompile = any(
1504        e.sample_weights_mismatch() for e in self._training_endpoints)
1505
1506    if recompile:
1507      self._compile_weights_loss_and_weighted_metrics()
1508    return recompile
1509
1510  @trackable.no_automatic_dependency_tracking
1511  def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None):
1512    """Compiles the model loss and weighted metric sub-graphs.
1513
1514    This may be used to set graph tensors as sample weights (instead of creating
1515    placeholders). This functionality is necessary for
1516    `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1
1517    graph, and creates iterator tensors for inputs, targets, and sample weights.
1518
1519    Args:
1520      sample_weights: List of tensors to use as the sample weights. Must be the
1521        same length as the number of outputs. If left as `None`, placeholders
1522        are used instead.
1523    """
1524    with backend.get_graph().as_default():
1525      if sample_weights is not None:
1526        self._update_sample_weight_modes(sample_weights)
1527      self._prepare_sample_weights(sample_weights)
1528
1529      masks = self._prepare_output_masks()
1530
1531      # Compute weighted metrics.
1532      self._handle_metrics(
1533          self.outputs,
1534          targets=self._targets,
1535          skip_target_masks=self._prepare_skip_target_masks(),
1536          sample_weights=self.sample_weights,
1537          masks=masks,
1538          return_weighted_metrics=True)
1539
1540      # Compute total loss.
1541      # Used to keep track of the total loss value (stateless).
1542      # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
1543      #                   loss_weight_2 * output_2_loss_fn(...) +
1544      #                   layer losses.
1545      self.total_loss = self._prepare_total_loss(masks)
1546
1547  def _prepare_skip_target_masks(self):
1548    """Boolean mask for whether the target in the output list should be skipped.
1549
1550    If the loss function corresponding to a model output is None, then this
1551    output will be skipped during total loss calculation and feed targets
1552    preparation.
1553
1554    Returns:
1555      A boolean list for whether the corresponding target in the output list
1556      should be skipped during loss calculation.
1557    """
1558    return [l is None for l in self.loss_functions]
1559
1560  def _prepare_output_masks(self):
1561    """Returns masks corresponding to model outputs."""
1562    return [getattr(x, '_keras_mask', None) for x in self.outputs]
1563
1564  def _prepare_total_loss(self, masks):
1565    """Computes total loss from loss functions.
1566
1567    Args:
1568        masks: List of mask values corresponding to each model output.
1569
1570    Returns:
1571        A list of loss weights of python floats.
1572
1573    Raises:
1574        TypeError: If model run_eagerly is True.
1575    """
1576    if self.run_eagerly:
1577      raise TypeError('total loss can not be computed when compiled with '
1578                      'run_eagerly = True.')
1579    loss_list = []
1580    with backend.name_scope('loss'):
1581      for endpoint, mask in zip(self._training_endpoints, masks):
1582        if endpoint.should_skip_target():
1583          continue
1584        y_true = endpoint.training_target.target
1585        y_pred = endpoint.output
1586        loss_fn = endpoint.loss_fn
1587        loss_weight = endpoint.loss_weight
1588        loss_name = endpoint.loss_name()
1589        sample_weight = endpoint.sample_weight
1590
1591        with backend.name_scope(loss_name):
1592          if mask is not None:
1593            mask = math_ops.cast(mask, y_pred.dtype)
1594            # Update weights with mask.
1595            if sample_weight is None:
1596              sample_weight = mask
1597            else:
1598              # Update dimensions of weights to match with mask if possible.
1599              mask, _, sample_weight = (
1600                  losses_utils.squeeze_or_expand_dimensions(
1601                      mask, sample_weight=sample_weight))
1602              sample_weight *= mask
1603
1604          if hasattr(loss_fn, 'reduction'):
1605            per_sample_losses = loss_fn.call(y_true, y_pred)
1606            weighted_losses = losses_utils.compute_weighted_loss(
1607                per_sample_losses,
1608                sample_weight=sample_weight,
1609                reduction=losses_utils.ReductionV2.NONE)
1610            loss_reduction = loss_fn.reduction
1611
1612            # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
1613            # compile use cases.
1614            if loss_reduction == losses_utils.ReductionV2.AUTO:
1615              loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1616
1617            # Compute the stateless loss value.
1618            output_loss = losses_utils.reduce_weighted_loss(
1619                weighted_losses, reduction=loss_reduction)
1620          else:
1621            # Compute the stateless loss value for a custom loss class.
1622            # Here we assume that the class takes care of loss reduction
1623            # because if this class returns a vector value we cannot
1624            # differentiate between use case where a custom optimizer
1625            # expects a vector loss value vs unreduced per-sample loss value.
1626            output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
1627            loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1628
1629        if len(self.outputs) > 1:
1630          # Keep track of stateful result tensor for the loss.
1631          endpoint.output_loss_metric(output_loss)
1632
1633        # Scale output loss for distribution. For custom losses we assume
1634        # reduction was mean.
1635        if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
1636          output_loss = losses_utils.scale_loss_for_distribution(output_loss)
1637
1638        loss_list.append(loss_weight * output_loss)
1639      if not loss_list and not self.losses:
1640        raise ValueError('The model cannot be compiled '
1641                         'because it has no loss to optimize.')
1642
1643      # Add regularization penalties and other layer-specific losses.
1644      custom_losses = self.get_losses_for(None) + self.get_losses_for(
1645          self.inputs)
1646      if custom_losses:
1647        total_custom_loss = math_ops.add_n(
1648            losses_utils.cast_losses_to_common_dtype(custom_losses))
1649        loss_list.append(
1650            losses_utils.scale_loss_for_distribution(total_custom_loss))
1651
1652      loss_list = losses_utils.cast_losses_to_common_dtype(loss_list)
1653      if loss_list:
1654        total_loss = math_ops.add_n(loss_list)
1655      else:
1656        total_loss = 0.
1657    return total_loss
1658
1659  def _get_callback_model(self):
1660    """Returns the Callback Model for this Model."""
1661
1662    if hasattr(self, '_replicated_model') and self._replicated_model:
1663      # When using training_distributed, we set the callback model
1664      # to an instance of the `DistributedModel` that we create in
1665      # the `compile` call. The `DistributedModel` is initialized
1666      # with the first replicated model. We need to set the callback
1667      # model to a DistributedModel to allow us to override saving
1668      # and loading weights when we checkpoint the model during training.
1669      return self._replicated_model
1670    if hasattr(self, 'callback_model') and self.callback_model:
1671      return self.callback_model
1672    return self
1673
1674  @trackable.no_automatic_dependency_tracking
1675  def _make_callback_model(self, grouped_model):
1676    first_replicated_model = self._distribution_strategy.unwrap(
1677        grouped_model)[0]
1678    # We initialize the callback model with the first replicated model.
1679    self._replicated_model = DistributedCallbackModel(first_replicated_model)
1680    self._replicated_model.set_original_model(self)
1681
1682  def _validate_or_infer_batch_size(self, batch_size, steps, x):
1683    """Validates that the `batch_size` provided is consistent with InputLayer.
1684
1685    It's possible that the user specified a static batch size in their
1686    InputLayer. If so, this method checks the provided `batch_size` and `x`
1687    arguments are consistent with this static batch size. Also, if
1688    `batch_size` is `None`, this method will attempt to infer the batch size
1689    from the static batch size of the InputLayer. Lastly, ValueError will be
1690    raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
1691    expect users to provide batched datasets.
1692
1693    Args:
1694      batch_size: The batch_size provided as an argument to
1695        fit/evaluate/predict.
1696      steps: The steps provided as an argument to fit/evaluate/predict.
1697      x: The data passed as `x` to fit/evaluate/predict.
1698
1699    Returns:
1700      The validated batch_size, auto-inferred from the first layer if not
1701      provided.
1702    """
1703    if (isinstance(x, (dataset_ops.DatasetV1,
1704                       dataset_ops.DatasetV2,
1705                       data_utils.Sequence)) or
1706        tf_inspect.isgenerator(x)):
1707      if batch_size is not None:
1708        raise ValueError(
1709            'The `batch_size` argument must not be specified for the given '
1710            'input type. Received input: {}, batch_size: {}'.format(
1711                x, batch_size))
1712      return
1713
1714    # Avoids the override in Sequential.layers which filters Input layers.
1715    # (Which are often the very layers that we're after.)
1716    layers = self._flatten_layers(include_self=False, recursive=False)
1717    first_layer = next(layers, None)
1718    if first_layer:
1719      # The per-replica static batch size.
1720      static_batch_size = training_utils.get_static_batch_size(first_layer)
1721      if static_batch_size is not None:
1722
1723        # Determine number of times the user-supplied batch size will be split.
1724        if (self._distribution_strategy and
1725            distributed_training_utils.global_batch_size_supported(
1726                self._distribution_strategy)):
1727          num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync
1728        else:
1729          num_splits_for_ds = 1
1730
1731        # Check `batch_size` argument is consistent with InputLayer.
1732        if batch_size is not None:
1733          if batch_size % num_splits_for_ds != 0:
1734            raise ValueError('The `batch_size` argument ({}) must be divisible '
1735                             'the by number of replicas ({})'.format(
1736                                 batch_size, num_splits_for_ds))
1737          per_replica_batch_size = batch_size // num_splits_for_ds
1738
1739          if per_replica_batch_size != static_batch_size:
1740            raise ValueError('The `batch_size` argument value {} is '
1741                             'incompatible with the specified batch size of '
1742                             'your Input Layer: {}'.format(
1743                                 per_replica_batch_size, static_batch_size))
1744
1745        # Check Dataset/Iterator batch size is consistent with InputLayer.
1746        if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
1747                          iterator_ops.IteratorBase)):
1748          ds_batch_size = tensor_shape.Dimension(
1749              nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
1750          if ds_batch_size is not None:
1751            if ds_batch_size % num_splits_for_ds != 0:
1752              raise ValueError(
1753                  'The batch output shape of your `Dataset` {} '
1754                  'cannot be divisible by number of replicas {}'.format(
1755                      ds_batch_size, num_splits_for_ds))
1756
1757            ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds
1758            if ds_per_replica_batch_size != static_batch_size:
1759              raise ValueError('The batch output shape of your `Dataset` is '
1760                               '{}, which is incompatible with the specified '
1761                               'batch size of your Input Layer: {}'.format(
1762                                   ds_per_replica_batch_size,
1763                                   static_batch_size))
1764
1765        # Set inferred batch size from the InputLayer.
1766        if steps is None:
1767          batch_size = static_batch_size * num_splits_for_ds
1768
1769    if batch_size is None and steps is None:
1770      # Backwards compatibility
1771      batch_size = 32
1772    return batch_size
1773
1774  def _prepare_sample_weights(self, sample_weights=None):
1775    """Sets sample weight attribute on the model."""
1776    # List with the same length as model outputs.
1777    if sample_weights is not None:
1778      if len(sample_weights) != len(self._training_endpoints):
1779        raise ValueError('Provided sample weights must have same length as the '
1780                         'number of outputs. Expected: {}, got: {}.'.format(
1781                             len(self._training_endpoints),
1782                             len(sample_weights)))
1783    else:
1784      sample_weights = [None] * len(self._training_endpoints)
1785    for endpoint, weight in zip(self._training_endpoints, sample_weights):
1786      endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode)
1787
1788  def _cache_output_metric_attributes(self, metrics, weighted_metrics):
1789    """Caches metric name and function attributes for every model output."""
1790    output_shapes = []
1791    for output in self.outputs:
1792      if output is None or output.shape.rank is None:
1793        output_shapes.append(None)
1794      else:
1795        output_shapes.append(output.shape.as_list())
1796    self._per_output_metrics = training_utils_v1.collect_per_output_metric_info(
1797        metrics, self.output_names, output_shapes, self.loss_functions,
1798        from_serialized=self._from_serialized)
1799    self._per_output_weighted_metrics = (
1800        training_utils_v1.collect_per_output_metric_info(
1801            weighted_metrics,
1802            self.output_names,
1803            output_shapes,
1804            self.loss_functions,
1805            from_serialized=self._from_serialized,
1806            is_weighted=True))
1807
1808  def _add_unique_metric_name(self, metric_name, metric_fn, output_index):
1809    """Makes the metric name unique.
1810
1811      If there are multiple outputs for which the metrics are calculated, the
1812      metric names have to be made unique by appending an integer.
1813
1814    Args:
1815      metric_name: Metric name that corresponds to the metric specified by the
1816          user. For example: 'acc'.
1817      metric_fn: The Metric object.
1818      output_index: The index of the model output for which the metric name is
1819        being added.
1820
1821    Returns:
1822      string, name of the model's unique metric name
1823    """
1824    # For multi-output models, prepend the output names to the metric name.
1825    if len(self.output_names) > 1:
1826      # If we're loading from an already-serialized model, we've already
1827      # prepended the output name, and we don't want to do it again.
1828      #
1829      # Alternatively, we may be receiving a stateless metric (e.g. the string
1830      # "accuracy") rather than a `Metric` object, in which case we want to
1831      # prepend the output name even if we are loading a serialized model.
1832      if not getattr(metric_fn, '_from_serialized', False):
1833        metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
1834
1835    j = 1
1836    base_metric_name = metric_name
1837    while metric_name in self.metrics_names:
1838      metric_name = '%s_%d' % (base_metric_name, j)
1839      j += 1
1840
1841    return metric_name
1842
1843  def _init_metric_attributes(self):
1844    """Initialized model metric attributes."""
1845    # List of stateful metric functions. Used for resetting metric state during
1846    # training/eval.
1847    self._compile_metric_functions = []
1848
1849  def _set_per_output_metric_attributes(self, metrics_dict, output_index):
1850    """Sets the metric attributes on the model for the given output.
1851
1852    Args:
1853      metrics_dict: A dict with metric names as keys and metric fns as values.
1854      output_index: The index of the model output for which the metric
1855        attributes are added.
1856
1857    Returns:
1858      Metrics dict updated with unique metric names as keys.
1859    """
1860    updated_metrics_dict = collections.OrderedDict()
1861    for metric_name, metric_fn in metrics_dict.items():
1862      metric_name = self._add_unique_metric_name(
1863          metric_name, metric_fn, output_index)
1864
1865      # Update the name on the metric class to be the unique generated name.
1866      metric_fn._name = metric_name  # pylint: disable=protected-access
1867      updated_metrics_dict[metric_name] = metric_fn
1868      # Keep track of metric name and function.
1869      self._compile_metric_functions.append(metric_fn)
1870    return updated_metrics_dict
1871
1872  def _set_metric_attributes(self):
1873    """Sets the metric attributes on the model for all the model outputs."""
1874    updated_per_output_metrics = []
1875    updated_per_output_weighted_metrics = []
1876    for i, endpoint in enumerate(self._training_endpoints):
1877      if endpoint.should_skip_target():
1878        updated_per_output_metrics.append(self._per_output_metrics[i])
1879        updated_per_output_weighted_metrics.append(
1880            self._per_output_weighted_metrics[i])
1881        continue
1882      updated_per_output_metrics.append(
1883          self._set_per_output_metric_attributes(self._per_output_metrics[i],
1884                                                 i))
1885      updated_per_output_weighted_metrics.append(
1886          self._set_per_output_metric_attributes(
1887              self._per_output_weighted_metrics[i], i))
1888
1889    # Create a metric wrapper for each output loss. This computes mean of an
1890    # output loss across mini-batches (irrespective of how we reduce within a
1891    # batch).
1892    if len(self._training_endpoints) > 1:
1893      for endpoint in self._training_endpoints:
1894        if not endpoint.should_skip_target():
1895          endpoint.output_loss_metric = metrics_module.Mean(
1896              name=endpoint.loss_name())
1897
1898    self._per_output_metrics = updated_per_output_metrics
1899    self._per_output_weighted_metrics = updated_per_output_weighted_metrics
1900
1901  def _handle_per_output_metrics(self,
1902                                 metrics_dict,
1903                                 y_true,
1904                                 y_pred,
1905                                 mask,
1906                                 weights=None):
1907    """Calls metric functions for a single output.
1908
1909    Args:
1910      metrics_dict: A dict with metric names as keys and metric fns as values.
1911      y_true: Target output.
1912      y_pred: Predicted output.
1913      mask: Computed mask value for the current output.
1914      weights: Weights to be applied on the current output.
1915
1916    Returns:
1917      A list of metric result tensors.
1918    """
1919    metric_results = []
1920    for metric_name, metric_fn in metrics_dict.items():
1921      with backend.name_scope(metric_name):
1922        metric_result = training_utils_v1.call_metric_function(
1923            metric_fn, y_true, y_pred, weights=weights, mask=mask)
1924        metric_results.append(metric_result)
1925    return metric_results
1926
1927  def _handle_metrics(self,
1928                      outputs,
1929                      targets=None,
1930                      skip_target_masks=None,
1931                      sample_weights=None,
1932                      masks=None,
1933                      return_weighted_metrics=False,
1934                      return_weighted_and_unweighted_metrics=False):
1935    """Handles calling metric functions.
1936
1937    Args:
1938      outputs: List of outputs (predictions).
1939      targets: List of targets.
1940      skip_target_masks: Optional. List of boolean for whether the corresponding
1941        target should be ignored or not.
1942      sample_weights: Optional list of sample weight arrays.
1943      masks: List of computed output mask values.
1944      return_weighted_metrics: Flag that indicates whether weighted metrics
1945        should be computed instead of unweighted metrics. This flag is ignored
1946        when `return_weighted_and_unweighted_metrics` is enabled.
1947      return_weighted_and_unweighted_metrics: Flag that is used to indicate
1948        whether both weighted and unweighted metrics should be computed. When
1949        this is not enabled, we use `return_weighted_metrics` param to indicate
1950        whether weighted or unweighted metrics should be returned.
1951
1952    Returns:
1953      A list of metric result tensors.
1954    """
1955    # TODO(scottzhu): Update this to use the new training_endpoints. Currently
1956    # the eager and graph logic is bit different.
1957    skip_target_masks = skip_target_masks or [False] * len(outputs)
1958    metric_results = []
1959    with backend.name_scope('metrics'):
1960      # Invoke all metrics added using `compile`.
1961      for i in range(len(outputs)):
1962        if skip_target_masks[i]:
1963          continue
1964        output = outputs[i] if outputs else None
1965        target = targets[i] if targets else None
1966        output_mask = masks[i] if masks else None
1967
1968        if (return_weighted_and_unweighted_metrics or
1969            not return_weighted_metrics):
1970          metric_results.extend(
1971              self._handle_per_output_metrics(self._per_output_metrics[i],
1972                                              target, output, output_mask))
1973        if return_weighted_and_unweighted_metrics or return_weighted_metrics:
1974          metric_results.extend(
1975              self._handle_per_output_metrics(
1976                  self._per_output_weighted_metrics[i],
1977                  target,
1978                  output,
1979                  output_mask,
1980                  weights=sample_weights[i] if sample_weights else None))
1981    return metric_results
1982
1983  def _check_trainable_weights_consistency(self):
1984    """Check trainable weights count consistency.
1985
1986    This will raise a warning if `trainable_weights` and
1987    `_collected_trainable_weights` are inconsistent (i.e. have different
1988    number of parameters).
1989    Inconsistency will typically arise when one modifies `model.trainable`
1990    without calling `model.compile` again.
1991    """
1992    if not hasattr(self, '_collected_trainable_weights'):
1993      return
1994
1995    if len(self.trainable_weights) != len(self._collected_trainable_weights):
1996      logging.log_first_n(
1997          logging.WARN, 'Discrepancy between trainable weights and collected'
1998          ' trainable weights, did you set `model.trainable`'
1999          ' without calling `model.compile` after ?', 1)
2000
2001  def _make_train_function(self):
2002    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2003    self._check_trainable_weights_consistency()
2004    if isinstance(self.optimizer, list):
2005      raise ValueError('The `optimizer` in `compile` should be a single '
2006                       'optimizer.')
2007    # If we have re-compiled the loss/weighted metric sub-graphs then create
2008    # train function even if one exists already. This is because
2009    # `_feed_sample_weights` list has been updated on re-compile.
2010    if getattr(self, 'train_function', None) is None or has_recompiled:
2011      # Restore the compiled trainable state.
2012      current_trainable_state = self._get_trainable_state()
2013      self._set_trainable_state(self._compiled_trainable_state)
2014
2015      inputs = (self._feed_inputs +
2016                self._feed_targets +
2017                self._feed_sample_weights)
2018      if not isinstance(backend.symbolic_learning_phase(), int):
2019        inputs += [backend.symbolic_learning_phase()]
2020
2021      with backend.get_graph().as_default():
2022        with backend.name_scope('training'):
2023          # Training updates
2024          updates = self.optimizer.get_updates(
2025              params=self._collected_trainable_weights, loss=self.total_loss)
2026          # Unconditional updates
2027          updates += self.get_updates_for(None)
2028          # Conditional updates relevant to this model
2029          updates += self.get_updates_for(self.inputs)
2030
2031        metrics = self._get_training_eval_metrics()
2032        metrics_tensors = [
2033            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
2034        ]
2035
2036      with backend.name_scope('training'):
2037        # Gets loss and metrics. Updates weights at each call.
2038        fn = backend.function(
2039            inputs, [self.total_loss] + metrics_tensors,
2040            updates=updates,
2041            name='train_function',
2042            **self._function_kwargs)
2043        setattr(self, 'train_function', fn)
2044
2045      # Restore the current trainable state
2046      self._set_trainable_state(current_trainable_state)
2047
2048  def _make_test_function(self):
2049    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2050    # If we have re-compiled the loss/weighted metric sub-graphs then create
2051    # test function even if one exists already. This is because
2052    # `_feed_sample_weights` list has been updated on re-compile.
2053    if getattr(self, 'test_function', None) is None or has_recompiled:
2054      inputs = (self._feed_inputs +
2055                self._feed_targets +
2056                self._feed_sample_weights)
2057
2058      with backend.get_graph().as_default():
2059        metrics = self._get_training_eval_metrics()
2060        metrics_tensors = [
2061            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
2062        ]
2063
2064      with backend.name_scope('evaluation'):
2065        updates = self.state_updates
2066        # Return loss and metrics, no gradient updates.
2067        # Does update the network states.
2068        fn = backend.function(
2069            inputs, [self.total_loss] + metrics_tensors,
2070            updates=updates,
2071            name='test_function',
2072            **self._function_kwargs)
2073        setattr(self, 'test_function', fn)
2074
2075  def _make_predict_function(self):
2076    if not hasattr(self, 'predict_function'):
2077      self.predict_function = None
2078    if self.predict_function is None:
2079      inputs = self._feed_inputs
2080      # Gets network outputs. Does not update weights.
2081      # Does update the network states.
2082      kwargs = getattr(self, '_function_kwargs', {})
2083      with backend.name_scope(ModeKeys.PREDICT):
2084        self.predict_function = backend.function(
2085            inputs,
2086            self.outputs,
2087            updates=self.state_updates,
2088            name='predict_function',
2089            **kwargs)
2090
2091  def _make_execution_function(self, mode):
2092    if mode == ModeKeys.TRAIN:
2093      self._make_train_function()
2094      return self.train_function
2095    if mode == ModeKeys.TEST:
2096      self._make_test_function()
2097      return self.test_function
2098    if mode == ModeKeys.PREDICT:
2099      self._make_predict_function()
2100      return self.predict_function
2101
2102  def _distribution_standardize_user_data(self,
2103                                          x,
2104                                          y=None,
2105                                          sample_weight=None,
2106                                          class_weight=None,
2107                                          batch_size=None,
2108                                          validation_split=0,
2109                                          shuffle=False,
2110                                          epochs=1,
2111                                          allow_partial_batch=False):
2112    """Runs validation checks on input and target data passed by the user.
2113
2114    This is called when using tf.distribute.Strategy to train, evaluate or serve
2115    the model.
2116
2117    Args:
2118      x: Input data. A numpy array or `tf.data` dataset.
2119      y: Target data. A numpy array or None if x is a `tf.data` dataset.
2120      sample_weight: An optional sample-weight array passed by the user to
2121        weight the importance of each sample in `x`.
2122      class_weight: An optional class-weight array by the user to
2123        weight the importance of samples in `x` based on the class they belong
2124        to, as conveyed by `y`.
2125      batch_size: Integer batch size. If provided, it is used to run additional
2126        validation checks on stateful models.
2127      validation_split: Float between 0 and 1.
2128        Fraction of the training data to be used as validation data.
2129      shuffle: Boolean whether to shuffle the training data before each epoch.
2130      epochs: Integer epochs. If > 1, repeat the numpy training data epochs
2131        times when converting to training dataset.
2132      allow_partial_batch: Boolean whether to enforce that all batches have the
2133        same size.
2134
2135    Returns:
2136      Dataset instance.
2137
2138    Raises:
2139      ValueError: In case of invalid user-provided data.
2140      RuntimeError: If the model was never compiled.
2141    """
2142    if class_weight:
2143      raise NotImplementedError('`class_weight` is currently not supported '
2144                                'when using tf.distribute.Strategy.')
2145
2146    if (sample_weight is not None and sample_weight.all() and
2147        backend.is_tpu_strategy(self._distribution_strategy)):
2148      raise NotImplementedError('`sample_weight` is currently not supported '
2149                                'when using TPUStrategy.')
2150
2151    # Validates `steps` and `shuffle` arguments right at the beginning
2152    # since we use it to construct the dataset object.
2153    # TODO(anjalisridhar): Remove this check once we refactor the
2154    # _standardize_user_data code path. This check is already present elsewhere
2155    # in the codebase.
2156    if isinstance(x, dataset_ops.DatasetV2):
2157      if shuffle:
2158        training_utils_v1.verify_dataset_shuffled(x)
2159
2160    strategy = self._distribution_strategy
2161    with strategy.scope():
2162      # We should be sure to call get_session() inside the strategy.scope()
2163      # so the strategy can affect the session options.
2164      if ops.executing_eagerly_outside_functions():
2165        session = None
2166      else:
2167        session = backend.get_session()
2168
2169      first_x_value = nest.flatten(x)[0]
2170      if isinstance(first_x_value, np.ndarray):
2171        x = training_utils.list_to_tuple(x)
2172        if y is not None:
2173          y = training_utils.list_to_tuple(y)
2174          if sample_weight is not None:
2175            sample_weight = training_utils.list_to_tuple(sample_weight)
2176            in_tuple = (x, y, sample_weight)
2177          else:
2178            in_tuple = (x, y)
2179        else:
2180          in_tuple = x
2181
2182        ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
2183                                                               session=session)
2184        if shuffle:
2185          # We want a buffer size that is larger than the batch size provided by
2186          # the user and provides sufficient randomness. Note that larger
2187          # numbers introduce more memory usage based on the size of each
2188          # sample.
2189          ds = ds.shuffle(max(1024, batch_size * 8))
2190        if epochs > 1:
2191          ds = ds.repeat(epochs)
2192
2193        # We need to use the drop_remainder argument to get a known static
2194        # input shape which is required for TPUs.
2195        drop_remainder = (not allow_partial_batch and
2196                          strategy.extended.experimental_require_static_shapes)
2197
2198        # TODO(b/131720208): We still drop remainder here if number of examples
2199        # is divisible by batch size, as sometimes dynamic padder will time out
2200        # with keras.metrics.CategoricalAccuracy() metric.
2201        if backend.is_tpu_strategy(strategy) and not drop_remainder:
2202          dataset_size = first_x_value.shape[0]
2203          if dataset_size % batch_size == 0:
2204            drop_remainder = True
2205
2206        x = ds.batch(batch_size, drop_remainder=drop_remainder)
2207      else:
2208        assert isinstance(x, dataset_ops.DatasetV2)
2209        training_utils_v1.validate_dataset_input(x, y, sample_weight,
2210                                                 validation_split)
2211    return x
2212
2213  def _standardize_user_data(self,
2214                             x,
2215                             y=None,
2216                             sample_weight=None,
2217                             class_weight=None,
2218                             batch_size=None,
2219                             check_steps=False,
2220                             steps_name='steps',
2221                             steps=None,
2222                             validation_split=0,
2223                             shuffle=False,
2224                             extract_tensors_from_dataset=False):
2225    """Runs validation checks on input and target data passed by the user.
2226
2227    Also standardizes the data to lists of arrays, in order.
2228
2229    Also builds and compiles the model on the fly if it is a subclassed model
2230    that has never been called before (and thus has no inputs/outputs).
2231
2232    This is a purely internal method, subject to refactoring at any time.
2233
2234    Args:
2235      x: Input data. It could be:
2236        - A Numpy array (or array-like), or a list of arrays
2237          (in case the model has multiple inputs).
2238        - A TensorFlow tensor, or a list of tensors
2239          (in case the model has multiple inputs).
2240        - A dict mapping input names to the corresponding array/tensors,
2241          if the model has named inputs.
2242        - A `tf.data` dataset.
2243      y: Target data. Like the input data `x`,
2244        it could be either Numpy array(s) or TensorFlow tensor(s).
2245        It should be consistent with `x` (you cannot have Numpy inputs and
2246        tensor targets, or inversely). If `x` is a dataset, `y` should not be
2247        specified (since targets will be obtained from the iterator).
2248      sample_weight: An optional sample-weight array passed by the user to
2249        weight the importance of each sample in `x`.
2250      class_weight: An optional class-weight array by the user to
2251        weight the importance of samples in `x` based on the class they belong
2252        to, as conveyed by `y`. If both `sample_weight` and `class_weight` are
2253        provided, the weights are multiplied.
2254      batch_size: Integer batch size. If provided, it is used to run additional
2255        validation checks on stateful models.
2256      check_steps: boolean, True if we want to check for validity of `steps` and
2257        False, otherwise. For example, when we are standardizing one batch of
2258        data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
2259        value is not required and we should not check for its validity in these
2260        cases.
2261      steps_name: The public API's parameter name for `steps`.
2262      steps: Integer or `None`. Total number of steps (batches of samples) to
2263        execute.
2264      validation_split: Float between 0 and 1.
2265        Fraction of the training data to be used as validation data.
2266      shuffle: Boolean whether to shuffle the training data before each epoch.
2267      extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
2268        this indicates whether to extract actual tensors from the dataset or
2269        instead output the dataset instance itself.
2270        Set to True when calling from `train_on_batch`/etc.
2271
2272    Returns:
2273      A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
2274      or not), target arrays, sample-weight arrays.
2275      If the model's input and targets are symbolic, these lists are empty
2276      (since the model takes no user-provided data, instead the data comes
2277      from the symbolic inputs/targets).
2278
2279    Raises:
2280      ValueError: In case of invalid user-provided data.
2281      RuntimeError: If the model was never compiled.
2282    """
2283    if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2284      # Graph mode dataset. We'll pass the dataset as-is (unless
2285      # `extract_tensors_from_dataset` is True, in which case we extract
2286      # the tensors from the dataset and we output them.
2287      training_utils_v1.validate_dataset_input(x, y, sample_weight,
2288                                               validation_split)
2289      if shuffle:
2290        training_utils_v1.verify_dataset_shuffled(x)
2291
2292      is_dataset = True
2293      if extract_tensors_from_dataset:
2294        # We do this for `train_on_batch`/etc.
2295        x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x)
2296    elif isinstance(x, iterator_ops.Iterator):
2297      # Graph mode iterator. We extract the symbolic tensors.
2298      training_utils_v1.validate_dataset_input(x, y, sample_weight,
2299                                               validation_split)
2300      iterator = x
2301      x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator)
2302      is_dataset = True
2303    else:
2304      is_dataset = False
2305
2306    # Validates `steps` argument based on x's type.
2307    if check_steps:
2308      training_utils_v1.check_steps_argument(x, steps, steps_name)
2309
2310    # First, we build the model on the fly if necessary.
2311    if not self.inputs:
2312      all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
2313      is_build_called = True
2314    else:
2315      all_inputs = []
2316      # Whether this is a subclassed model that expects dictionary inputs
2317      # rather than list inputs (e.g. FeatureColumn-based models).
2318      dict_inputs = isinstance(self.inputs, dict)
2319      is_build_called = False
2320      y_input = y
2321
2322    # Second, we compile the model on the fly if necessary, mostly for subclass
2323    # models.
2324    is_compile_called = False
2325    if not self._is_compiled and self.optimizer:
2326      self._compile_from_inputs(all_inputs, y_input, x, y)
2327      is_compile_called = True
2328
2329    # In graph mode, if we had just set inputs and targets as symbolic tensors
2330    # by invoking build and compile on the model respectively, we do not have to
2331    # feed anything to the model. Model already has input and target data as
2332    # part of the graph.
2333    # Note: in this case, `any` and `all` are equivalent since we disallow
2334    # mixed symbolic/value inputs.
2335
2336    # self.run_eagerly is not free to compute, so we want to reuse the value.
2337    run_eagerly = self.run_eagerly
2338
2339    if (not run_eagerly and is_build_called and is_compile_called and
2340        not is_dataset  and any(_is_symbolic_tensor(v) for v in all_inputs)):
2341      return [], [], None
2342
2343    return self._standardize_tensors(
2344        x, y, sample_weight,
2345        run_eagerly=run_eagerly,
2346        dict_inputs=dict_inputs,
2347        is_dataset=is_dataset,
2348        class_weight=class_weight,
2349        batch_size=batch_size)
2350
2351  def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
2352                           is_dataset, class_weight=None, batch_size=None):
2353    if run_eagerly:
2354      # In eager mode, do not do shape validation
2355      # since the network has no input nodes (placeholders) to be fed.
2356      feed_input_names = self.input_names
2357      feed_input_shapes = None
2358    elif not self._is_graph_network:
2359      # Case: symbolic-mode subclassed network. Do not do shape validation.
2360      feed_input_names = self._feed_input_names
2361      feed_input_shapes = None
2362    else:
2363      # Case: symbolic-mode graph network.
2364      # In this case, we run extensive shape validation checks.
2365      feed_input_names = self._feed_input_names
2366      feed_input_shapes = self._feed_input_shapes
2367
2368    # Standardize the inputs.
2369    if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2370      # TODO(fchollet): run static checks with dataset output shape(s).
2371      x = training_utils_v1.standardize_input_data(
2372          x,
2373          feed_input_names,
2374          feed_input_shapes,
2375          check_batch_axis=False,  # Don't enforce the batch size.
2376          exception_prefix='input')
2377
2378    # Get typespecs for the input data and sanitize it if necessary.
2379    # TODO(momernick): This should be capable of doing full input validation
2380    # at all times - validate that this is so and refactor the standardization
2381    # code.
2382    if isinstance(x, dataset_ops.DatasetV2):
2383      x_shapes = dataset_ops.get_structure(x)
2384      if isinstance(x_shapes, tuple):
2385        # If the output of a Dataset is a tuple, we assume it's either of the
2386        # form (x_data, y_data) or (x_data, y_data, sample_weights). In either
2387        # case, we only care about x_data here.
2388        x_shapes = x_shapes[0]
2389    else:
2390      flat_inputs = nest.flatten(x, expand_composites=False)
2391      flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2392      converted_x = []
2393      for (a, b) in zip(flat_inputs, flat_expected_inputs):
2394        converted_x.append(_convert_scipy_sparse_tensor(a, b))
2395      x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
2396
2397      def _type_spec_from_value(value):
2398        """Grab type_spec without converting array-likes to tensors."""
2399        if tf_utils.is_extension_type(value):
2400          return value._type_spec  # pylint: disable=protected-access
2401        # Get a TensorSpec for array-like data without
2402        # converting the data to a Tensor
2403        if hasattr(value, 'shape') and hasattr(value, 'dtype'):
2404          return tensor_spec.TensorSpec(value.shape, value.dtype)
2405        else:
2406          return type_spec.type_spec_from_value(value)
2407
2408      x_shapes = nest.map_structure(_type_spec_from_value, x)
2409
2410    flat_inputs = nest.flatten(x_shapes, expand_composites=False)
2411    flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2412    for (a, b) in zip(flat_inputs, flat_expected_inputs):
2413      nest.assert_same_structure(a, b, expand_composites=True)
2414
2415    if y is not None:
2416      # Prepare self._sample_weight_modes. List with the same length as
2417      # model outputs.
2418      training_utils_v1.prepare_sample_weight_modes(self._training_endpoints,
2419                                                    self.sample_weight_mode)
2420      feed_output_names = self._feed_output_names
2421      feed_sample_weight_modes = self._sample_weight_modes
2422      if not self._is_graph_network:
2423        feed_output_shapes = None
2424      else:
2425        feed_output_shapes = self._feed_output_shapes
2426
2427      # Standardize the outputs.
2428      y = training_utils_v1.standardize_input_data(
2429          y,
2430          feed_output_names,
2431          # Don't enforce target shapes to match output shapes.
2432          # Precise checks will be run in `check_loss_and_target_compatibility`.
2433          shapes=None,
2434          check_batch_axis=False,  # Don't enforce the batch size.
2435          exception_prefix='target')
2436
2437      # Generate sample-wise weight values given the `sample_weight` and
2438      # `class_weight` arguments.
2439      sample_weights = training_utils_v1.standardize_sample_weights(
2440          sample_weight, feed_output_names)
2441      class_weights = training_utils_v1.standardize_class_weights(
2442          class_weight, feed_output_names)
2443
2444      sample_weights = [
2445          training_utils_v1.standardize_weights(ref, sw, cw, mode)
2446          for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
2447                                         feed_sample_weight_modes)
2448      ]
2449      # Check that all arrays have the same length.
2450      if not self._distribution_strategy:
2451        training_utils_v1.check_array_lengths(x, y, sample_weights)
2452        if self._is_graph_network and not run_eagerly:
2453          # Additional checks to avoid users mistakenly using improper loss fns.
2454          training_utils_v1.check_loss_and_target_compatibility(
2455              y, self._feed_loss_fns, feed_output_shapes)
2456
2457      sample_weights, _, _ = training_utils.handle_partial_sample_weights(
2458          y, sample_weights, feed_sample_weight_modes, check_all_flat=True)
2459    else:
2460      y = []
2461      sample_weights = None
2462
2463    if self.stateful and batch_size and not is_dataset:
2464      # Check that for stateful networks, number of samples is a multiple
2465      # of the static batch size.
2466      if x[0].shape[0] % batch_size != 0:
2467        raise ValueError('In a stateful network, '
2468                         'you should only pass inputs with '
2469                         'a number of samples that can be '
2470                         'divided by the batch size. Found: ' +
2471                         str(x[0].shape[0]) + ' samples')
2472
2473    # If dictionary inputs were provided, we return a dictionary as well.
2474    if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1,
2475                                          dataset_ops.DatasetV2)):
2476      x = dict(zip(feed_input_names, x))
2477    return x, y, sample_weights
2478
2479  def _build_model_with_inputs(self, inputs, targets):
2480    """Build the model (set model inputs/outputs), mainly for subclass model."""
2481    processed_inputs = []
2482    is_dict_inputs = False
2483    orig_inputs = inputs
2484    # We need to use `inputs` to set the model inputs.
2485    # If input data is a dataset iterator in graph mode or if it is an eager
2486    # iterator and only one batch of samples is required, we fetch the data
2487    # tensors from the iterator and then standardize them.
2488    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2489      inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
2490          inputs)
2491    # We type-check that `inputs` and `targets` are either single arrays
2492    # or lists of arrays, and extract a flat list of inputs from the passed
2493    # structure.
2494    training_utils_v1.validate_input_types(inputs, orig_inputs)
2495
2496    if isinstance(inputs, (list, tuple)):
2497      processed_inputs += list(inputs)
2498    elif isinstance(inputs, dict):
2499      is_dict_inputs = True
2500      keys = sorted(inputs.keys())
2501      processed_inputs = [inputs[k] for k in keys]
2502    else:
2503      processed_inputs.append(inputs)
2504    # Now that we have a flat set of inputs, we make sure that none of them
2505    # are CompositeTensors or CompositeTensorValues of any type (or scipy
2506    # sparse arrays, which we treat as SparseTensor values). We cannot safely
2507    # infer input data from an arbitrary composite tensor, so we don't try -
2508    # users should explicitly add composite tensor inputs to their subclassed
2509    # models.
2510    for input_tensor in processed_inputs:
2511      if training_utils_v1.is_composite_or_composite_value(input_tensor):
2512        # TODO(b/132691975): Document subclass-model CT input handling.
2513        raise ValueError(
2514            'All SparseTensor and RaggedTensor inputs must be explicitly '
2515            'declared using a keras.Input() with sparse=True or ragged=True. '
2516            'We found an undeclared input %s. For Sequential models, please '
2517            'add a keras.Input() as your first Layer. For subclassed models, '
2518            'please call self._set_inputs() on your input set, which you can '
2519            'create using keras.Input() for each input to your model.' %
2520            (input_tensor,))
2521    # Build the model using the retrieved inputs (value or symbolic).
2522    # If values are generated from a dataset, then in symbolic-mode
2523    # placeholders will be created to match the value shapes.
2524    if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
2525                                iterator_ops.Iterator)):
2526      if not self.inputs:
2527        # For subclassed models, a robust input spec is not available so we
2528        # must cast to the model dtype.
2529        inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype)
2530
2531      def create_tensor_spec(t):
2532        return tensor_spec.TensorSpec(t.shape, t.dtype)
2533
2534      cast_inputs = nest.map_structure(create_tensor_spec, inputs)
2535    elif training_utils_v1.has_tensors(inputs):
2536      cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
2537    else:
2538      cast_inputs = inputs
2539    self._set_inputs(cast_inputs)
2540    return processed_inputs, targets, is_dict_inputs
2541
2542  def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
2543    if target is not None:
2544      # We need to use `y` to set the model targets.
2545      if training_utils_v1.has_tensors(target):
2546        target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
2547            target, self.outputs)
2548      training_utils_v1.validate_input_types(
2549          target, orig_target, allow_dict=False, field_name='target')
2550      if isinstance(target, (list, tuple)):
2551        all_inputs += list(target)
2552      else:
2553        all_inputs.append(target)
2554    # Type check that all inputs are *either* value *or* symbolic.
2555    # TODO(fchollet): this check could be removed in Eager mode?
2556    if any(tensor_util.is_tf_type(v) for v in all_inputs):
2557      if not all(tensor_util.is_tf_type(v) for v in all_inputs):
2558        raise ValueError('Do not pass inputs that mix Numpy arrays and '
2559                         'TensorFlow tensors. '
2560                         'You passed: x=' + str(orig_inputs) +
2561                         '; y=' + str(orig_target))
2562    is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1,
2563                                          dataset_ops.DatasetV2,
2564                                          iterator_ops.Iterator))
2565    if is_dataset or context.executing_eagerly():
2566      target_tensors = None
2567    else:
2568      # Handle target tensors if any passed.
2569      if target is not None:
2570        if not isinstance(target, (list, tuple)):
2571          target = [target]
2572        target_tensors = [v for v in target if _is_symbolic_tensor(v)]
2573      else:
2574        target_tensors = None
2575
2576    self.compile(
2577        optimizer=self.optimizer,
2578        loss=self.loss,
2579        metrics=self._compile_metrics,
2580        weighted_metrics=self._compile_weighted_metrics,
2581        loss_weights=self.loss_weights,
2582        target_tensors=target_tensors,
2583        sample_weight_mode=self.sample_weight_mode,
2584        run_eagerly=self.run_eagerly,
2585        experimental_run_tf_function=self._experimental_run_tf_function)
2586
2587  # TODO(omalleyt): Consider changing to a more descriptive function name.
2588  def _set_inputs(self, inputs, outputs=None, training=None):
2589    """Set model's input and output specs based on the input data received.
2590
2591    This is to be used for Model subclasses, which do not know at instantiation
2592    time what their inputs look like.
2593
2594    Args:
2595      inputs: Single array, or list of arrays. The arrays could be placeholders,
2596        Numpy arrays, data tensors, or TensorSpecs.
2597        - if placeholders: the model is built on top of these placeholders,
2598          and we expect Numpy data to be fed for them when calling `fit`/etc.
2599        - if Numpy data or TensorShapes: we create placeholders matching the
2600          TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be
2601          fed for these placeholders when calling `fit`/etc.
2602        - if data tensors: the model is built on top of these tensors.
2603          We do not expect any Numpy data to be provided when calling `fit`/etc.
2604      outputs: None, a data tensor, or a list of tensors. If None, the
2605        outputs will be determined by invoking `self.call()`, otherwise the
2606        provided value will be used.
2607      training: Boolean or None. Only relevant in symbolic mode. Specifies
2608        whether to build the model's graph in inference mode (False), training
2609        mode (True), or using the Keras learning phase (None).
2610    Raises:
2611      ValueError: If dict inputs are passed to a Sequential Model where the
2612        first layer isn't FeatureLayer.
2613    """
2614    self._set_save_spec(inputs)
2615    inputs = self._set_input_attrs(inputs)
2616
2617    if outputs is None:
2618      kwargs = {}
2619      if self._expects_training_arg:
2620        # In V2 mode, feeding `training=None` is not allowed because any value
2621        # explicitly passed by the user is respected, even `None`.`
2622        if training is None and not ops.executing_eagerly_outside_functions():
2623          training = backend.learning_phase()
2624        if training is not None:
2625          kwargs['training'] = training
2626      try:
2627        outputs = self(inputs, **kwargs)
2628      except NotImplementedError:
2629        # This Model or a submodel is dynamic and hasn't overridden
2630        # `compute_output_shape`.
2631        outputs = None
2632
2633    self._set_output_attrs(outputs)
2634
2635  @trackable.no_automatic_dependency_tracking
2636  def _set_input_attrs(self, inputs):
2637    """Sets attributes related to the inputs of the Model."""
2638    if self.inputs:
2639      raise ValueError('Model inputs are already set.')
2640
2641    if self.__class__.__name__ == 'Sequential' and not self.built:
2642      if tensor_util.is_tf_type(inputs):
2643        input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
2644      elif isinstance(inputs, tensor_shape.TensorShape):
2645        input_shape = (None,) + tuple(inputs.as_list()[1:])
2646      elif isinstance(inputs, dict):
2647        # We assert that the first layer is a FeatureLayer.
2648        if not training_utils_v1.is_feature_layer(self.layers[0]):
2649          raise ValueError('Passing a dictionary input to a Sequential Model '
2650                           'which doesn\'t have FeatureLayer as the first layer'
2651                           ' is an error.')
2652        input_shape = (None,)
2653      else:
2654        input_shape = (None,) + tuple(inputs.shape[1:])
2655      self._build_input_shape = input_shape
2656
2657    # Cast inputs to the compute dtype. This is primarily used
2658    # when saving to determine the correct dtype in the input signature.
2659    inputs = self._maybe_cast_inputs(inputs)
2660
2661    # On-the-fly setting of symbolic model inputs (either by using the tensor
2662    # provided, or by creating a placeholder if Numpy data was provided).
2663    model_inputs = training_utils_v1.ModelInputs(inputs)
2664    inputs = model_inputs.get_symbolic_inputs()
2665    self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
2666    self.input_names = model_inputs.get_input_names()
2667
2668    self._feed_inputs = []
2669    self._feed_input_names = []
2670    self._feed_input_shapes = []
2671
2672    for k, v in model_inputs.as_dict():
2673      if backend.is_placeholder(v):
2674        self._feed_input_names.append(k)
2675        self._feed_inputs.append(v)
2676        self._feed_input_shapes.append(backend.int_shape(v))
2677
2678    return inputs
2679
2680  @trackable.no_automatic_dependency_tracking
2681  def _set_output_attrs(self, outputs):
2682    """Sets attributes related to the outputs of the Model."""
2683    # NOTE(taylorrobie): This convention cannot be changed without updating the
2684    #                    data adapter since it assumes nest.flatten ordering.
2685    outputs = nest.flatten(outputs)
2686    self.outputs = outputs
2687    self.output_names = training_utils_v1.generic_output_names(outputs)
2688    # TODO(scottzhu): Should we cleanup the self._training_endpoints here?
2689    self.built = True
2690
2691  @property
2692  def _targets(self):
2693    """The output target tensors for the model."""
2694    return [
2695        e.training_target.target
2696        for e in self._training_endpoints
2697        if e.has_training_target()
2698    ]
2699
2700  @property
2701  def _feed_targets(self):
2702    return [
2703        e.training_target.target
2704        for e in self._training_endpoints
2705        if e.has_feedable_training_target()
2706    ]
2707
2708  @property
2709  def _feed_output_names(self):
2710    return [
2711        e.output_name
2712        for e in self._training_endpoints
2713        if e.has_feedable_training_target()
2714    ]
2715
2716  @property
2717  def _feed_output_shapes(self):
2718    return [
2719        e.feed_output_shape
2720        for e in self._training_endpoints
2721        if e.has_feedable_training_target()
2722    ]
2723
2724  @property
2725  def _feed_loss_fns(self):
2726    return [
2727        e.loss_fn
2728        for e in self._training_endpoints
2729        if e.has_feedable_training_target()
2730    ]
2731
2732  @property
2733  def _loss_weights_list(self):
2734    return [e.loss_weight for e in self._training_endpoints]
2735
2736  @property
2737  def _output_loss_metrics(self):
2738    if hasattr(self, '_training_endpoints'):
2739      return [
2740          e.output_loss_metric
2741          for e in self._training_endpoints
2742          if e.output_loss_metric is not None
2743      ]
2744    return None
2745
2746  @property
2747  def sample_weights(self):
2748    return [e.sample_weight for e in self._training_endpoints]
2749
2750  @property
2751  def _sample_weight_modes(self):
2752    return [e.sample_weight_mode for e in self._training_endpoints]
2753
2754  @property
2755  def _feed_sample_weights(self):
2756    return [e.sample_weight for e in self._training_endpoints
2757            if e.sample_weight is not None]
2758
2759  def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
2760    """Maybe load initial epoch from ckpt considering possible worker recovery.
2761
2762    Refer to tensorflow/python/keras/distribute/worker_training_state.py
2763    for more information.
2764
2765    Args:
2766      initial_epoch: The original initial_epoch user passes in in `fit()`.
2767      mode: The mode for running `model.fit()`.
2768
2769    Returns:
2770      If the training is recovering from previous failure under multi-worker
2771      training setting, return the epoch the training is supposed to continue
2772      at. Otherwise, return the `initial_epoch` the user passes in.
2773    """
2774    if self._training_state is not None:
2775      return self._training_state.maybe_load_initial_epoch_from_ckpt(
2776          initial_epoch, mode)
2777    return initial_epoch
2778
2779  def _get_training_eval_metrics(self):
2780    """Returns all the metrics that are to be reported.
2781
2782    This includes the output loss metrics, compile metrics/weighted metrics,
2783    add_metric metrics.
2784    """
2785    metrics = []
2786    metrics.extend(getattr(self, '_output_loss_metrics', None) or [])
2787    metrics.extend(getattr(self, 'metrics', None) or [])
2788    return metrics
2789
2790  def _assert_compile_was_called(self):
2791    # Checks whether `compile` has been called. If it has been called,
2792    # then the optimizer is set. This is different from whether the
2793    # model is compiled
2794    # (i.e. whether the model is built and its inputs/outputs are set).
2795    if not self._compile_was_called:
2796      raise RuntimeError('You must compile your model before '
2797                         'training/testing. '
2798                         'Use `model.compile(optimizer, loss)`.')
2799
2800  def _in_multi_worker_mode(self):
2801    """Method to infer if this `Model` is working in multi-worker settings.
2802
2803    Multi-worker training refers to the setup where the training is
2804    distributed across multiple workers, as opposed to the case where
2805    only a local process performs the training. This function is
2806    used to infer for example whether or not a distribute coordinator
2807    should be run, and thus TensorFlow servers should be started for
2808    communication with other servers in the cluster, or whether or not
2809    saving/restoring checkpoints is relevant for preemption fault tolerance.
2810
2811    Experimental. Signature and implementation are subject to change.
2812
2813    Returns:
2814      Whether this model indicates it's working in multi-worker settings.
2815    """
2816    strategy = self._distribution_strategy
2817
2818    # Otherwise, use the strategy whose scope this is in.
2819    if not strategy and distribution_strategy_context.has_strategy():
2820      strategy = distribution_strategy_context.get_strategy()
2821    return strategy and strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2822
2823  @property
2824  def _trackable_saved_model_saver(self):
2825    return model_serialization.ModelSavedModelSaver(self)
2826
2827  def _get_compile_args(self, user_metrics=True):
2828    del user_metrics
2829    self._assert_compile_was_called()
2830    kwargs = {
2831        'loss': self.loss,
2832        'metrics': self._compile_metrics,
2833        'loss_weights': self.loss_weights,
2834        'sample_weight_mode': self.sample_weight_mode,
2835        'weighted_metrics': self._compile_weighted_metrics,
2836    }
2837    return kwargs
2838
2839  @property
2840  def _compile_was_called(self):
2841    return self._v1_compile_was_called
2842
2843
2844class DistributedCallbackModel(Model):
2845  """Model that is used for callbacks with tf.distribute.Strategy."""
2846
2847  def __init__(self, model):
2848    super(DistributedCallbackModel, self).__init__()
2849    self.optimizer = model.optimizer
2850
2851  def set_original_model(self, orig_model):
2852    self._original_model = orig_model
2853
2854  def save_weights(self, filepath, overwrite=True, save_format=None):
2855    self._replicated_model.save_weights(filepath, overwrite=overwrite,
2856                                        save_format=save_format)
2857
2858  def save(self, filepath, overwrite=True, include_optimizer=True):
2859    # save weights from the distributed model to the original model
2860    distributed_model_weights = self.get_weights()
2861    self._original_model.set_weights(distributed_model_weights)
2862    # TODO(anjalisridhar): Do we need to save the original model here?
2863    # Saving the first replicated model works as well.
2864    self._original_model.save(filepath, overwrite=True, include_optimizer=False)
2865
2866  def load_weights(self, filepath, by_name=False):
2867    self._original_model.load_weights(filepath, by_name=False)
2868    # Copy the weights from the original model to each of the replicated models.
2869    orig_model_weights = self._original_model.get_weights()
2870    distributed_training_utils_v1.set_weights(
2871        self._original_model._distribution_strategy, self,  # pylint: disable=protected-access
2872        orig_model_weights)
2873
2874  def __getattr__(self, item):
2875    # Allowed attributes of the model that can be accessed by the user
2876    # during a callback.
2877    if item not in ('_setattr_tracking', '_layers'):
2878      logging.warning('You are accessing attribute ' + item + ' of the '
2879                      'DistributedCallbackModel that may not have been set '
2880                      'correctly.')
2881    return super(DistributedCallbackModel, self).__getattr__(item)
2882
2883
2884class _TrainingEndpoint(object):
2885  """A container for the training output/target and related entities.
2886
2887  In the case of model with multiple outputs, there is a one-to-one mapping
2888  between model output (y_pred), model target (y_true), loss, metrics etc.
2889  By unifying these entities into one class, different entity can access
2890  information between each other, rather than currently access different list of
2891  attributes of the model.
2892  """
2893
2894  def __init__(self,
2895               output,
2896               output_name,
2897               loss_fn,
2898               loss_weight=None,
2899               training_target=None,
2900               output_loss_metric=None,
2901               sample_weight=None,
2902               sample_weight_mode=None):
2903    """Initialize the _TrainingEndpoint.
2904
2905    Note that the output and output_name should be stable as long as the model
2906    structure doesn't change. The training_target suppose to be mutable since
2907    the information is provided via `compile()`
2908
2909    Args:
2910      output: the output tensor of the model.
2911      output_name: the unique name of the output tensor.
2912      loss_fn: the loss function for the output tensor.
2913      loss_weight: float, the weights for the loss.
2914      training_target: the _TrainingTarget for the model.
2915      output_loss_metric: the metric object for the loss function.
2916      sample_weight: the weights for how a sample is weighted during metric and
2917        loss calculation. Could be None.
2918      sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for
2919        how the sample_weight is populated.
2920    """
2921    self._output = output
2922    self._output_name = output_name
2923    self._loss_fn = loss_fn
2924    self._loss_weight = loss_weight
2925    self._training_target = training_target
2926    self._output_loss_metric = output_loss_metric
2927    self._sample_weight = sample_weight
2928    self._sample_weight_mode = sample_weight_mode
2929
2930  @property
2931  def output(self):
2932    return self._output
2933
2934  @property
2935  def output_name(self):
2936    return self._output_name
2937
2938  @property
2939  def shape(self):
2940    return backend.int_shape(self.output)
2941
2942  @property
2943  def loss_fn(self):
2944    return self._loss_fn
2945
2946  @property
2947  def loss_weight(self):
2948    return self._loss_weight
2949
2950  @loss_weight.setter
2951  def loss_weight(self, value):
2952    self._loss_weight = value
2953
2954  @property
2955  def training_target(self):
2956    return self._training_target
2957
2958  @training_target.setter
2959  def training_target(self, value):
2960    self._training_target = value
2961
2962  def create_training_target(self, target, run_eagerly=False):
2963    """Create training_target instance and update the self.training_target.
2964
2965    Note that the input target should just be a tensor or None, and
2966    corresponding training target will be created based on the output and
2967    loss_fn.
2968
2969    Args:
2970      target: the target tensor for the current output. Could be None.
2971      run_eagerly: boolean, whether the model is in run_eagerly mode.
2972
2973    Raises:
2974      ValueError if the training_target field for the current instance has
2975      already been populated.
2976    """
2977    if self.has_training_target():
2978      raise ValueError('The training_target field for the _TrainingEndpoint '
2979                       'instance has already been populated')
2980    if run_eagerly:
2981      # When run_eagerly, the target tensor is ignored, and the None placeholder
2982      # is created instead.
2983      self.training_target = _TrainingTarget(
2984          None, feedable=True, skip_target_weights=False)
2985      return
2986
2987    if self.should_skip_target():
2988      self.training_target = _TrainingTarget(None)
2989    else:
2990      if target is not None and not backend.is_placeholder(target):
2991        feedable = False
2992        skip_target_weights = True
2993      else:
2994        feedable = True
2995        skip_target_weights = False
2996
2997      if target is None:
2998        target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
2999            self.loss_fn, backend.dtype(self.output))
3000
3001        target = backend.placeholder(
3002            ndim=len(self.shape),
3003            name=self.output_name + '_target',
3004            sparse=backend.is_sparse(self.output),
3005            dtype=target_dtype)
3006
3007      self.training_target = _TrainingTarget(
3008          target,
3009          feedable=feedable,
3010          skip_target_weights=skip_target_weights)
3011
3012  @property
3013  def output_loss_metric(self):
3014    return self._output_loss_metric
3015
3016  @output_loss_metric.setter
3017  def output_loss_metric(self, value):
3018    self._output_loss_metric = value
3019
3020  @property
3021  def sample_weight(self):
3022    return self._sample_weight
3023
3024  @sample_weight.setter
3025  def sample_weight(self, value):
3026    self._sample_weight = value
3027
3028  @property
3029  def sample_weight_mode(self):
3030    return self._sample_weight_mode
3031
3032  @sample_weight_mode.setter
3033  def sample_weight_mode(self, value):
3034    self._sample_weight_mode = value
3035
3036  def should_skip_target(self):
3037    return self._loss_fn is None
3038
3039  def should_skip_target_weights(self):
3040    return (self.should_skip_target() or self.training_target is None or
3041            self.training_target.skip_target_weights)
3042
3043  def has_training_target(self):
3044    return self.training_target is not None
3045
3046  def has_feedable_training_target(self):
3047    return (not self.should_skip_target() and
3048            self.training_target is not None and self.training_target.feedable)
3049
3050  def loss_name(self):
3051    if self._loss_fn is not None:
3052      return self._output_name + '_loss'
3053    return None
3054
3055  @property
3056  def feed_output_shape(self):
3057    """The output shape for the feedable target."""
3058    if not self.has_feedable_training_target():
3059      return None
3060
3061    if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3062         self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or (
3063             isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)):
3064      if backend.image_data_format() == 'channels_first':
3065        return (self.shape[0], 1) + self.shape[2:]
3066      else:
3067        return self.shape[:-1] + (1,)
3068    elif (not isinstance(self.loss_fn, losses.Loss) or
3069          (isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3070           (getattr(losses, self.loss_fn.fn.__name__, None) is None))):
3071      # If the given loss is not an instance of the `Loss` class (custom
3072      # class) or if the loss function that is wrapped is not in the
3073      # `losses` module, then it is a user-defined loss and we make no
3074      # assumptions about it.
3075      return None
3076    else:
3077      return self.shape
3078
3079  def sample_weights_mismatch(self):
3080    """Check if the sample weight and the mode match or not."""
3081    # If there is a mismatch between sample weight mode and the placeholders
3082    # created, then recompile the sub-graphs that depend on sample weights.
3083    return (
3084        (self.sample_weight_mode is not None and self.sample_weight is None) or
3085        (self.sample_weight_mode is None and self.sample_weight is not None))
3086
3087  def populate_sample_weight(self, sample_weight, sample_weight_mode):
3088    """Populate the sample weight and based on the sample weight mode."""
3089    if (sample_weight is None and
3090        (self.should_skip_target_weights() or sample_weight_mode is None or
3091         context.executing_eagerly())):
3092      self._sample_weight = None
3093      return
3094
3095    assert sample_weight_mode in ['temporal', 'samplewise']
3096    if sample_weight_mode == 'temporal':
3097      default_value = [[1.]]
3098      shape = [None, None]
3099    else:
3100      # sample_weight_mode == 'samplewise'
3101      default_value = [1.]
3102      shape = [None]
3103
3104    if sample_weight is not None:
3105      if not sample_weight.shape.is_compatible_with(shape):
3106        raise ValueError('Received sample weight with shape {}. Expected shape '
3107                         '{}.'.format(sample_weight.shape, shape))
3108      self._sample_weight = sample_weight
3109    else:
3110      self._sample_weight = array_ops.placeholder_with_default(
3111          constant_op.constant(default_value, dtype=backend.floatx()),
3112          shape=shape,
3113          name=self.output_name + '_sample_weights')
3114
3115
3116class _TrainingTarget(object):
3117  """Container for a target tensor (y_true) and its metadata (shape, loss...).
3118
3119  Args:
3120    target: A target tensor for the model. It may be `None` if the
3121      output is excluded from loss computation. It is still kept as None
3122      since each output of the model should have a corresponding target. If
3123      the target is None, the rest of the attributes will be None as well.
3124    feedable: Boolean, whether the target is feedable (requires data to be
3125      passed in `fit` or `train_on_batch`), or not (model compiled with
3126      `target_tensors` argument).
3127    skip_target_weights: Boolean, whether the target should be skipped during
3128      weights calculation.
3129  """
3130
3131  def __init__(self, target, feedable=False, skip_target_weights=True):
3132    self._target = target
3133    self._feedable = feedable
3134    self._skip_target_weights = skip_target_weights
3135
3136  @property
3137  def target(self):
3138    return self._target
3139
3140  @property
3141  def feedable(self):
3142    return self._feedable
3143
3144  @property
3145  def skip_target_weights(self):
3146    return self._skip_target_weights
3147
3148
3149def _is_symbolic_tensor(x):
3150  return tensor_util.is_tf_type(x)
3151
3152
3153def _convert_scipy_sparse_tensor(value, expected_input):
3154  """Handle scipy sparse tensor conversions.
3155
3156  This method takes a value 'value' and returns the proper conversion. If
3157  value is a scipy sparse tensor and the expected input is a dense tensor,
3158  we densify 'value'. If value is a scipy sparse tensor and the expected input
3159  is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is
3160  not a scipy sparse tensor, or scipy is not imported, we pass it through
3161  unchanged.
3162
3163  Args:
3164    value: An object that may be a scipy sparse tensor
3165    expected_input: The expected input placeholder.
3166
3167  Returns:
3168    The possibly-converted 'value'.
3169  """
3170  if issparse is not None and issparse(value):
3171    if backend.is_sparse(expected_input):
3172      sparse_coo = value.tocoo()
3173      row, col = sparse_coo.row, sparse_coo.col
3174      data, shape = sparse_coo.data, sparse_coo.shape
3175      indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)),
3176                               1)
3177      return sparse_tensor.SparseTensor(indices, data, shape)
3178    else:
3179      if ops.executing_eagerly_outside_functions():
3180        # In TF2 we do not silently densify sparse matrices.
3181        raise ValueError('A SciPy sparse matrix was passed to a model '
3182                         'that expects dense inputs. Please densify your '
3183                         'inputs first, such as by calling `x.toarray().')
3184      return value.toarray()
3185  else:
3186    return value
3187
3188
3189def _get_metrics_from_layers(layers):
3190  """Returns list of metrics from the given layers.
3191
3192  This will not include the `compile` metrics of a model layer.
3193
3194  Args:
3195    layers: List of layers.
3196
3197  Returns:
3198    List of metrics.
3199  """
3200  metrics = []
3201  layers = layer_utils.filter_empty_layer_containers(layers)
3202  for layer in layers:
3203    if isinstance(layer, Model):
3204      # We cannot call 'metrics' on the model because we do not want to
3205      # include the metrics that were added in compile API of a nested model.
3206      metrics.extend(layer._metrics)  # pylint: disable=protected-access
3207      metrics.extend(_get_metrics_from_layers(layer.layers))
3208    else:
3209      metrics.extend(layer.metrics)
3210  return metrics
3211
3212
3213def _non_none_constant_value(v):
3214  constant_value = tensor_util.constant_value(v)
3215  return constant_value if constant_value is not None else v
3216