xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/callbacks.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16# pylint: disable=g-classes-have-attributes
17"""Callbacks: utilities called at certain points during model training."""
18
19import collections
20import copy
21import csv
22import json
23import os
24import re
25import sys
26import time
27
28import numpy as np
29
30from tensorflow.core.framework import summary_pb2
31from tensorflow.python.checkpoint import checkpoint_management
32from tensorflow.python.checkpoint import checkpoint_options as checkpoint_options_lib
33from tensorflow.python.data.ops import iterator_ops
34from tensorflow.python.distribute import collective_all_reduce_strategy
35from tensorflow.python.distribute import distribution_strategy_context as ds_context
36from tensorflow.python.distribute import mirrored_strategy
37from tensorflow.python.distribute import parameter_server_strategy_v2
38from tensorflow.python.distribute import tpu_strategy
39from tensorflow.python.eager import context
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import ops
44from tensorflow.python.keras import backend
45from tensorflow.python.keras.distribute import distributed_file_utils
46from tensorflow.python.keras.distribute import worker_training_state
47from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
48from tensorflow.python.keras.utils import generic_utils
49from tensorflow.python.keras.utils import tf_utils
50from tensorflow.python.keras.utils import version_utils
51from tensorflow.python.keras.utils.data_utils import Sequence
52from tensorflow.python.keras.utils.generic_utils import Progbar
53from tensorflow.python.keras.utils.io_utils import path_to_string
54from tensorflow.python.keras.utils.mode_keys import ModeKeys
55from tensorflow.python.lib.io import file_io
56from tensorflow.python.ops import array_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import summary_ops_v2
59from tensorflow.python.platform import gfile
60from tensorflow.python.platform import tf_logging as logging
61from tensorflow.python.profiler import profiler_v2 as profiler
62from tensorflow.python.saved_model import save_options as save_options_lib
63from tensorflow.python.util import nest
64from tensorflow.python.util.tf_export import keras_export
65from tensorflow.tools.docs import doc_controls
66
67try:
68  import requests
69except ImportError:
70  requests = None
71
72
73# Note: `configure_callbacks` is only used in TF1.
74def configure_callbacks(callbacks,
75                        model,
76                        do_validation=False,
77                        batch_size=None,
78                        epochs=None,
79                        steps_per_epoch=None,
80                        samples=None,
81                        verbose=1,
82                        count_mode='steps',
83                        mode=ModeKeys.TRAIN):
84  """Configures callbacks for use in various training loops.
85
86  Args:
87      callbacks: List of Callbacks.
88      model: Model being trained.
89      do_validation: Whether or not validation loop will be run.
90      batch_size: Number of samples per batch.
91      epochs: Number of epoch to train.
92      steps_per_epoch: Number of batches to run per training epoch.
93      samples: Number of training samples.
94      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
95      count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
96      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
97        Which loop mode to configure callbacks for.
98
99  Returns:
100      Instance of CallbackList used to control all Callbacks.
101  """
102  # Check if callbacks have already been configured.
103  if isinstance(callbacks, CallbackList):
104    return callbacks
105
106  if not callbacks:
107    callbacks = []
108
109  # Add additional callbacks during training.
110  if mode == ModeKeys.TRAIN:
111    model.history = History()
112    callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
113    if verbose:
114      callbacks.append(ProgbarLogger(count_mode))
115  callback_list = CallbackList(callbacks)
116
117  # Set callback model
118  callback_model = model._get_callback_model()  # pylint: disable=protected-access
119  callback_list.set_model(callback_model)
120
121  set_callback_parameters(
122      callback_list,
123      model,
124      do_validation=do_validation,
125      batch_size=batch_size,
126      epochs=epochs,
127      steps_per_epoch=steps_per_epoch,
128      samples=samples,
129      verbose=verbose,
130      mode=mode)
131
132  callback_list.model.stop_training = False
133  return callback_list
134
135
136def set_callback_parameters(callback_list,
137                            model,
138                            do_validation=False,
139                            batch_size=None,
140                            epochs=None,
141                            steps_per_epoch=None,
142                            samples=None,
143                            verbose=1,
144                            mode=ModeKeys.TRAIN):
145  """Sets callback parameters.
146
147  Args:
148      callback_list: CallbackList instance.
149      model: Model being trained.
150      do_validation: Whether or not validation loop will be run.
151      batch_size: Number of samples per batch.
152      epochs: Number of epoch to train.
153      steps_per_epoch: Number of batches to run per training epoch.
154      samples: Number of training samples.
155      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
156      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
157        Which loop mode to configure callbacks for.
158  """
159  metric_names = model.metrics_names
160  for cbk in callback_list:
161    if isinstance(cbk, (BaseLogger, ProgbarLogger)):
162      cbk.stateful_metrics = metric_names[1:]  # Exclude `loss`
163
164  # Set callback parameters
165  callback_metrics = []
166  # When we have deferred build scenario with iterator input, we will compile
167  # when we standardize first batch of data.
168  if mode != ModeKeys.PREDICT:
169    callback_metrics = copy.copy(metric_names)
170    if do_validation:
171      callback_metrics += ['val_' + n for n in metric_names]
172  callback_params = {
173      'batch_size': batch_size,
174      'epochs': epochs,
175      'steps': steps_per_epoch,
176      'samples': samples,
177      'verbose': verbose,
178      'do_validation': do_validation,
179      'metrics': callback_metrics,
180  }
181  callback_list.set_params(callback_params)
182
183
184def _is_generator_like(data):
185  """Checks if data is a generator, Sequence, or Iterator."""
186  return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance(
187      data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase)))
188
189
190def make_logs(model, logs, outputs, mode, prefix=''):
191  """Computes logs for sending to `on_batch_end` methods."""
192  metric_names = model.metrics_names
193  if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names:
194    for label, output in zip(metric_names, outputs):
195      logs[prefix + label] = output
196  else:
197    logs['outputs'] = outputs
198  return logs
199
200
201@keras_export('keras.callbacks.CallbackList')
202class CallbackList:
203  """Container abstracting a list of callbacks."""
204
205  def __init__(self,
206               callbacks=None,
207               add_history=False,
208               add_progbar=False,
209               model=None,
210               **params):
211    """Container for `Callback` instances.
212
213    This object wraps a list of `Callback` instances, making it possible
214    to call them all at once via a single endpoint
215    (e.g. `callback_list.on_epoch_end(...)`).
216
217    Args:
218      callbacks: List of `Callback` instances.
219      add_history: Whether a `History` callback should be added, if one does not
220        already exist in the `callbacks` list.
221      add_progbar: Whether a `ProgbarLogger` callback should be added, if one
222        does not already exist in the `callbacks` list.
223      model: The `Model` these callbacks are used with.
224      **params: If provided, parameters will be passed to each `Callback` via
225        `Callback.set_params`.
226    """
227    self.callbacks = nest.flatten(callbacks) if callbacks else []
228    self._add_default_callbacks(add_history, add_progbar)
229
230    if model:
231      self.set_model(model)
232    if params:
233      self.set_params(params)
234
235    # Performance optimization: determines if batch hooks need to be called.
236    # pylint: disable=protected-access
237    self._supports_tf_logs = all(
238        getattr(cb, '_supports_tf_logs', False) for cb in self.callbacks)
239    self._batch_hooks_support_tf_logs = all(
240        getattr(cb, '_supports_tf_logs', False)
241        for cb in self.callbacks
242        if cb._implements_train_batch_hooks() or cb
243        ._implements_test_batch_hooks() or cb._implements_predict_batch_hooks())
244
245    self._should_call_train_batch_hooks = any(
246        cb._implements_train_batch_hooks() for cb in self.callbacks)
247    self._should_call_test_batch_hooks = any(
248        cb._implements_test_batch_hooks() for cb in self.callbacks)
249    self._should_call_predict_batch_hooks = any(
250        cb._implements_predict_batch_hooks() for cb in self.callbacks)
251    # pylint: enable=protected-access
252
253    self._disallow_batch_hooks_in_ps_strategy()
254
255    # Performance check: Check batch hooks for slowness compared to batch time.
256    # Only run check for custom callbacks (i.e. not present in this file).
257    self._check_timing = any(
258        cbk.__class__.__name__ not in globals() for cbk in self.callbacks)
259    self._num_batches_for_timing_check = 5
260    self._hook_times = {}
261    self._batch_start_time = None
262    self._batch_times = []
263
264  def _add_default_callbacks(self, add_history, add_progbar):
265    """Adds `Callback`s that are always present."""
266    self._progbar = None
267    self._history = None
268
269    for cb in self.callbacks:
270      if isinstance(cb, ProgbarLogger):
271        self._progbar = cb
272      elif isinstance(cb, History):
273        self._history = cb
274
275    if self._progbar is None and add_progbar:
276      self._progbar = ProgbarLogger(count_mode='steps')
277      self.callbacks.insert(0, self._progbar)
278
279    if self._history is None and add_history:
280      self._history = History()
281      self.callbacks.append(self._history)
282
283  def _process_logs(self, logs, is_batch_hook=False):
284    """Turns tensors into numpy arrays or Python scalars if necessary."""
285    if logs is None:
286      return {}
287    if self._supports_tf_logs:
288      return logs
289    if is_batch_hook and self._batch_hooks_support_tf_logs:
290      return logs
291    return tf_utils.sync_to_numpy_or_python_type(logs)
292
293  def append(self, callback):
294    self.callbacks.append(callback)
295
296  def set_params(self, params):
297    self.params = params
298    for callback in self.callbacks:
299      callback.set_params(params)
300
301  def set_model(self, model):
302    self.model = model
303    if self._history:
304      model.history = self._history
305    for callback in self.callbacks:
306      callback.set_model(model)
307
308  def _call_batch_hook(self, mode, hook, batch, logs=None):
309    """Helper function for all batch_{begin | end} methods."""
310    if not self.callbacks:
311      return
312
313    if hook == 'begin':
314      self._call_batch_begin_hook(mode, batch, logs)
315    elif hook == 'end':
316      self._call_batch_end_hook(mode, batch, logs)
317    else:
318      raise ValueError('Unrecognized hook: {}'.format(hook))
319
320  def _call_batch_begin_hook(self, mode, batch, logs):
321    """Helper function for `on_*_batch_begin` methods."""
322    hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
323    self._call_batch_hook_helper(hook_name, batch, logs)
324
325    if self._check_timing:
326      self._batch_start_time = time.time()
327
328  def _call_batch_end_hook(self, mode, batch, logs):
329    """Helper function for `on_*_batch_end` methods."""
330    hook_name = 'on_{mode}_batch_end'.format(mode=mode)
331
332    if self._check_timing and batch >= 1:
333      batch_time = time.time() - self._batch_start_time
334      self._batch_times.append(batch_time)
335
336    self._call_batch_hook_helper(hook_name, batch, logs)
337
338    if len(self._batch_times) >= self._num_batches_for_timing_check:
339      end_hook_name = hook_name
340      begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
341      avg_batch_time = sum(self._batch_times) / len(self._batch_times)
342      avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
343          self._hook_times[end_hook_name])
344      avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
345          self._hook_times[begin_hook_name])
346
347      threshold_time = 1.0 * avg_batch_time
348      warning_msg = ('Callback method `{hook}` is slow compared to '
349                     'the batch time (batch time: {batch_time:.4f}s vs '
350                     '`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
351      if avg_begin_hook_time > threshold_time:
352        logging.warning(warning_msg.format(
353            hook=begin_hook_name,
354            batch_time=avg_batch_time,
355            hook_time=avg_begin_hook_time))
356      if avg_end_hook_time > threshold_time:
357        logging.warning(warning_msg.format(
358            hook=end_hook_name,
359            batch_time=avg_batch_time,
360            hook_time=avg_end_hook_time))
361      self._check_timing = False
362      self._batch_start_time = None
363      self._batch_times = []
364      self._hook_times = {}
365
366  def _call_batch_hook_helper(self, hook_name, batch, logs):
367    """Helper function for `on_*_batch_*` methods."""
368    if self._check_timing:
369      start_time = time.time()
370
371    logs = self._process_logs(logs, is_batch_hook=True)
372    for callback in self.callbacks:
373      hook = getattr(callback, hook_name)
374      hook(batch, logs)
375
376    if self._check_timing:
377      if hook_name not in self._hook_times:
378        self._hook_times[hook_name] = []
379      self._hook_times[hook_name].append(time.time() - start_time)
380
381  def _call_begin_hook(self, mode):
382    """Helper function for on_{train|test|predict}_begin methods."""
383    if mode == ModeKeys.TRAIN:
384      self.on_train_begin()
385    elif mode == ModeKeys.TEST:
386      self.on_test_begin()
387    else:
388      self.on_predict_begin()
389
390  def _call_end_hook(self, mode):
391    """Helper function for on_{train|test|predict}_end methods."""
392    if mode == ModeKeys.TRAIN:
393      self.on_train_end()
394    elif mode == ModeKeys.TEST:
395      self.on_test_end()
396    else:
397      self.on_predict_end()
398
399  def on_batch_begin(self, batch, logs=None):
400    if self._should_call_train_batch_hooks:
401      self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
402
403  def on_batch_end(self, batch, logs=None):
404    if self._should_call_train_batch_hooks:
405      self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
406
407  def on_epoch_begin(self, epoch, logs=None):
408    """Calls the `on_epoch_begin` methods of its callbacks.
409
410    This function should only be called during TRAIN mode.
411
412    Args:
413        epoch: Integer, index of epoch.
414        logs: Dict. Currently no data is passed to this argument for this method
415          but that may change in the future.
416    """
417    logs = self._process_logs(logs)
418    for callback in self.callbacks:
419      callback.on_epoch_begin(epoch, logs)
420
421  def on_epoch_end(self, epoch, logs=None):
422    """Calls the `on_epoch_end` methods of its callbacks.
423
424    This function should only be called during TRAIN mode.
425
426    Args:
427        epoch: Integer, index of epoch.
428        logs: Dict, metric results for this training epoch, and for the
429          validation epoch if validation is performed. Validation result keys
430          are prefixed with `val_`.
431    """
432    logs = self._process_logs(logs)
433    for callback in self.callbacks:
434      callback.on_epoch_end(epoch, logs)
435
436  def on_train_batch_begin(self, batch, logs=None):
437    """Calls the `on_train_batch_begin` methods of its callbacks.
438
439    Args:
440        batch: Integer, index of batch within the current epoch.
441        logs: Dict, contains the return value of `model.train_step`. Typically,
442          the values of the `Model`'s metrics are returned.  Example:
443          `{'loss': 0.2, 'accuracy': 0.7}`.
444    """
445    if self._should_call_train_batch_hooks:
446      self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
447
448  def on_train_batch_end(self, batch, logs=None):
449    """Calls the `on_train_batch_end` methods of its callbacks.
450
451    Args:
452        batch: Integer, index of batch within the current epoch.
453        logs: Dict. Aggregated metric results up until this batch.
454    """
455    if self._should_call_train_batch_hooks:
456      self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
457
458  def on_test_batch_begin(self, batch, logs=None):
459    """Calls the `on_test_batch_begin` methods of its callbacks.
460
461    Args:
462        batch: Integer, index of batch within the current epoch.
463        logs: Dict, contains the return value of `model.test_step`. Typically,
464          the values of the `Model`'s metrics are returned.  Example:
465          `{'loss': 0.2, 'accuracy': 0.7}`.
466    """
467    if self._should_call_test_batch_hooks:
468      self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
469
470  def on_test_batch_end(self, batch, logs=None):
471    """Calls the `on_test_batch_end` methods of its callbacks.
472
473    Args:
474        batch: Integer, index of batch within the current epoch.
475        logs: Dict. Aggregated metric results up until this batch.
476    """
477    if self._should_call_test_batch_hooks:
478      self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
479
480  def on_predict_batch_begin(self, batch, logs=None):
481    """Calls the `on_predict_batch_begin` methods of its callbacks.
482
483    Args:
484        batch: Integer, index of batch within the current epoch.
485        logs: Dict, contains the return value of `model.predict_step`,
486          it typically returns a dict with a key 'outputs' containing
487          the model's outputs.
488    """
489    if self._should_call_predict_batch_hooks:
490      self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
491
492  def on_predict_batch_end(self, batch, logs=None):
493    """Calls the `on_predict_batch_end` methods of its callbacks.
494
495    Args:
496        batch: Integer, index of batch within the current epoch.
497        logs: Dict. Aggregated metric results up until this batch.
498    """
499    if self._should_call_predict_batch_hooks:
500      self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
501
502  def on_train_begin(self, logs=None):
503    """Calls the `on_train_begin` methods of its callbacks.
504
505    Args:
506        logs: Dict. Currently no data is passed to this argument for this method
507          but that may change in the future.
508    """
509    logs = self._process_logs(logs)
510    for callback in self.callbacks:
511      callback.on_train_begin(logs)
512
513  def on_train_end(self, logs=None):
514    """Calls the `on_train_end` methods of its callbacks.
515
516    Args:
517        logs: Dict. Currently no data is passed to this argument for this method
518          but that may change in the future.
519    """
520    logs = self._process_logs(logs)
521    for callback in self.callbacks:
522      callback.on_train_end(logs)
523
524  def on_test_begin(self, logs=None):
525    """Calls the `on_test_begin` methods of its callbacks.
526
527    Args:
528        logs: Dict. Currently no data is passed to this argument for this method
529          but that may change in the future.
530    """
531    logs = self._process_logs(logs)
532    for callback in self.callbacks:
533      callback.on_test_begin(logs)
534
535  def on_test_end(self, logs=None):
536    """Calls the `on_test_end` methods of its callbacks.
537
538    Args:
539        logs: Dict. Currently no data is passed to this argument for this method
540          but that may change in the future.
541    """
542    logs = self._process_logs(logs)
543    for callback in self.callbacks:
544      callback.on_test_end(logs)
545
546  def on_predict_begin(self, logs=None):
547    """Calls the 'on_predict_begin` methods of its callbacks.
548
549    Args:
550        logs: Dict. Currently no data is passed to this argument for this method
551          but that may change in the future.
552    """
553    logs = self._process_logs(logs)
554    for callback in self.callbacks:
555      callback.on_predict_begin(logs)
556
557  def on_predict_end(self, logs=None):
558    """Calls the `on_predict_end` methods of its callbacks.
559
560    Args:
561        logs: Dict. Currently no data is passed to this argument for this method
562          but that may change in the future.
563    """
564    logs = self._process_logs(logs)
565    for callback in self.callbacks:
566      callback.on_predict_end(logs)
567
568  def __iter__(self):
569    return iter(self.callbacks)
570
571  def _disallow_batch_hooks_in_ps_strategy(self):
572    """Error out if batch-level callbacks are passed with PSStrategy."""
573    # pylint: disable=protected-access
574    strategy = ds_context.get_strategy()
575    if strategy._should_use_with_coordinator:
576      unsupported_callbacks = []
577      for cb in self.callbacks:
578        # These Callbacks can accept RemoteValues directly.
579        if getattr(cb, '_supports_tf_logs', False):
580          continue
581        if (cb._implements_train_batch_hooks() or
582            cb._implements_test_batch_hooks() or
583            cb._implements_predict_batch_hooks()):
584          unsupported_callbacks.append(cb)
585      if unsupported_callbacks:
586        raise ValueError('Batch-level `Callback`s are not supported with '
587                         '`ParameterServerStrategy`. Found unsupported '
588                         'callbacks: {}'.format(unsupported_callbacks))
589    # pylint: enable=protected-access
590
591
592@keras_export('keras.callbacks.Callback')
593class Callback:
594  """Abstract base class used to build new callbacks.
595
596  Callbacks can be passed to keras methods such as `fit`, `evaluate`, and
597  `predict` in order to hook into the various stages of the model training and
598  inference lifecycle.
599
600  To create a custom callback, subclass `keras.callbacks.Callback` and override
601  the method associated with the stage of interest. See
602  https://www.tensorflow.org/guide/keras/custom_callback for more information.
603
604  Example:
605
606  >>> training_finished = False
607  >>> class MyCallback(tf.keras.callbacks.Callback):
608  ...   def on_train_end(self, logs=None):
609  ...     global training_finished
610  ...     training_finished = True
611  >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
612  >>> model.compile(loss='mean_squared_error')
613  >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
614  ...           callbacks=[MyCallback()])
615  >>> assert training_finished == True
616
617  If you want to use `Callback` objects in a custom training loop:
618
619  1. You should pack all your callbacks into a single `callbacks.CallbackList`
620     so they can all be called together.
621  2. You will need to manually call all the `on_*` methods at the apropriate
622     locations in your loop. Like this:
623
624     ```
625     callbacks =  tf.keras.callbacks.CallbackList([...])
626     callbacks.append(...)
627
628     callbacks.on_train_begin(...)
629     for epoch in range(EPOCHS):
630       callbacks.on_epoch_begin(epoch)
631       for i, data in dataset.enumerate():
632         callbacks.on_train_batch_begin(i)
633         batch_logs = model.train_step(data)
634         callbacks.on_train_batch_end(i, batch_logs)
635       epoch_logs = ...
636       callbacks.on_epoch_end(epoch, epoch_logs)
637     final_logs=...
638     callbacks.on_train_end(final_logs)
639     ```
640
641  Attributes:
642      params: Dict. Training parameters
643          (eg. verbosity, batch size, number of epochs...).
644      model: Instance of `keras.models.Model`.
645          Reference of the model being trained.
646
647  The `logs` dictionary that callback methods
648  take as argument will contain keys for quantities relevant to
649  the current batch or epoch (see method-specific docstrings).
650  """
651
652  def __init__(self):
653    self.validation_data = None  # pylint: disable=g-missing-from-attributes
654    self.model = None
655    # Whether this Callback should only run on the chief worker in a
656    # Multi-Worker setting.
657    # TODO(omalleyt): Make this attr public once solution is stable.
658    self._chief_worker_only = None
659    self._supports_tf_logs = False
660
661  def set_params(self, params):
662    self.params = params
663
664  def set_model(self, model):
665    self.model = model
666
667  @doc_controls.for_subclass_implementers
668  @generic_utils.default
669  def on_batch_begin(self, batch, logs=None):
670    """A backwards compatibility alias for `on_train_batch_begin`."""
671
672  @doc_controls.for_subclass_implementers
673  @generic_utils.default
674  def on_batch_end(self, batch, logs=None):
675    """A backwards compatibility alias for `on_train_batch_end`."""
676
677  @doc_controls.for_subclass_implementers
678  def on_epoch_begin(self, epoch, logs=None):
679    """Called at the start of an epoch.
680
681    Subclasses should override for any actions to run. This function should only
682    be called during TRAIN mode.
683
684    Args:
685        epoch: Integer, index of epoch.
686        logs: Dict. Currently no data is passed to this argument for this method
687          but that may change in the future.
688    """
689
690  @doc_controls.for_subclass_implementers
691  def on_epoch_end(self, epoch, logs=None):
692    """Called at the end of an epoch.
693
694    Subclasses should override for any actions to run. This function should only
695    be called during TRAIN mode.
696
697    Args:
698        epoch: Integer, index of epoch.
699        logs: Dict, metric results for this training epoch, and for the
700          validation epoch if validation is performed. Validation result keys
701          are prefixed with `val_`. For training epoch, the values of the
702         `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy':
703           0.7}`.
704    """
705
706  @doc_controls.for_subclass_implementers
707  @generic_utils.default
708  def on_train_batch_begin(self, batch, logs=None):
709    """Called at the beginning of a training batch in `fit` methods.
710
711    Subclasses should override for any actions to run.
712
713    Note that if the `steps_per_execution` argument to `compile` in
714    `tf.keras.Model` is set to `N`, this method will only be called every `N`
715    batches.
716
717    Args:
718        batch: Integer, index of batch within the current epoch.
719        logs: Dict, contains the return value of `model.train_step`. Typically,
720          the values of the `Model`'s metrics are returned.  Example:
721          `{'loss': 0.2, 'accuracy': 0.7}`.
722    """
723    # For backwards compatibility.
724    self.on_batch_begin(batch, logs=logs)
725
726  @doc_controls.for_subclass_implementers
727  @generic_utils.default
728  def on_train_batch_end(self, batch, logs=None):
729    """Called at the end of a training batch in `fit` methods.
730
731    Subclasses should override for any actions to run.
732
733    Note that if the `steps_per_execution` argument to `compile` in
734    `tf.keras.Model` is set to `N`, this method will only be called every `N`
735    batches.
736
737    Args:
738        batch: Integer, index of batch within the current epoch.
739        logs: Dict. Aggregated metric results up until this batch.
740    """
741    # For backwards compatibility.
742    self.on_batch_end(batch, logs=logs)
743
744  @doc_controls.for_subclass_implementers
745  @generic_utils.default
746  def on_test_batch_begin(self, batch, logs=None):
747    """Called at the beginning of a batch in `evaluate` methods.
748
749    Also called at the beginning of a validation batch in the `fit`
750    methods, if validation data is provided.
751
752    Subclasses should override for any actions to run.
753
754    Note that if the `steps_per_execution` argument to `compile` in
755    `tf.keras.Model` is set to `N`, this method will only be called every `N`
756    batches.
757
758    Args:
759        batch: Integer, index of batch within the current epoch.
760        logs: Dict, contains the return value of `model.test_step`. Typically,
761          the values of the `Model`'s metrics are returned.  Example:
762          `{'loss': 0.2, 'accuracy': 0.7}`.
763    """
764
765  @doc_controls.for_subclass_implementers
766  @generic_utils.default
767  def on_test_batch_end(self, batch, logs=None):
768    """Called at the end of a batch in `evaluate` methods.
769
770    Also called at the end of a validation batch in the `fit`
771    methods, if validation data is provided.
772
773    Subclasses should override for any actions to run.
774
775    Note that if the `steps_per_execution` argument to `compile` in
776    `tf.keras.Model` is set to `N`, this method will only be called every `N`
777    batches.
778
779    Args:
780        batch: Integer, index of batch within the current epoch.
781        logs: Dict. Aggregated metric results up until this batch.
782    """
783
784  @doc_controls.for_subclass_implementers
785  @generic_utils.default
786  def on_predict_batch_begin(self, batch, logs=None):
787    """Called at the beginning of a batch in `predict` methods.
788
789    Subclasses should override for any actions to run.
790
791    Note that if the `steps_per_execution` argument to `compile` in
792    `tf.keras.Model` is set to `N`, this method will only be called every `N`
793    batches.
794
795    Args:
796        batch: Integer, index of batch within the current epoch.
797        logs: Dict, contains the return value of `model.predict_step`,
798          it typically returns a dict with a key 'outputs' containing
799          the model's outputs.
800    """
801
802  @doc_controls.for_subclass_implementers
803  @generic_utils.default
804  def on_predict_batch_end(self, batch, logs=None):
805    """Called at the end of a batch in `predict` methods.
806
807    Subclasses should override for any actions to run.
808
809    Note that if the `steps_per_execution` argument to `compile` in
810    `tf.keras.Model` is set to `N`, this method will only be called every `N`
811    batches.
812
813    Args:
814        batch: Integer, index of batch within the current epoch.
815        logs: Dict. Aggregated metric results up until this batch.
816    """
817
818  @doc_controls.for_subclass_implementers
819  def on_train_begin(self, logs=None):
820    """Called at the beginning of training.
821
822    Subclasses should override for any actions to run.
823
824    Args:
825        logs: Dict. Currently no data is passed to this argument for this method
826          but that may change in the future.
827    """
828
829  @doc_controls.for_subclass_implementers
830  def on_train_end(self, logs=None):
831    """Called at the end of training.
832
833    Subclasses should override for any actions to run.
834
835    Args:
836        logs: Dict. Currently the output of the last call to `on_epoch_end()`
837          is passed to this argument for this method but that may change in
838          the future.
839    """
840
841  @doc_controls.for_subclass_implementers
842  def on_test_begin(self, logs=None):
843    """Called at the beginning of evaluation or validation.
844
845    Subclasses should override for any actions to run.
846
847    Args:
848        logs: Dict. Currently no data is passed to this argument for this method
849          but that may change in the future.
850    """
851
852  @doc_controls.for_subclass_implementers
853  def on_test_end(self, logs=None):
854    """Called at the end of evaluation or validation.
855
856    Subclasses should override for any actions to run.
857
858    Args:
859        logs: Dict. Currently the output of the last call to
860          `on_test_batch_end()` is passed to this argument for this method
861          but that may change in the future.
862    """
863
864  @doc_controls.for_subclass_implementers
865  def on_predict_begin(self, logs=None):
866    """Called at the beginning of prediction.
867
868    Subclasses should override for any actions to run.
869
870    Args:
871        logs: Dict. Currently no data is passed to this argument for this method
872          but that may change in the future.
873    """
874
875  @doc_controls.for_subclass_implementers
876  def on_predict_end(self, logs=None):
877    """Called at the end of prediction.
878
879    Subclasses should override for any actions to run.
880
881    Args:
882        logs: Dict. Currently no data is passed to this argument for this method
883          but that may change in the future.
884    """
885
886  def _implements_train_batch_hooks(self):
887    """Determines if this Callback should be called for each train batch."""
888    return (not generic_utils.is_default(self.on_batch_begin) or
889            not generic_utils.is_default(self.on_batch_end) or
890            not generic_utils.is_default(self.on_train_batch_begin) or
891            not generic_utils.is_default(self.on_train_batch_end))
892
893  def _implements_test_batch_hooks(self):
894    """Determines if this Callback should be called for each test batch."""
895    return (not generic_utils.is_default(self.on_test_batch_begin) or
896            not generic_utils.is_default(self.on_test_batch_end))
897
898  def _implements_predict_batch_hooks(self):
899    """Determines if this Callback should be called for each predict batch."""
900    return (not generic_utils.is_default(self.on_predict_batch_begin) or
901            not generic_utils.is_default(self.on_predict_batch_end))
902
903
904@keras_export('keras.callbacks.BaseLogger')
905class BaseLogger(Callback):
906  """Callback that accumulates epoch averages of metrics.
907
908  This callback is automatically applied to every Keras model.
909
910  Args:
911      stateful_metrics: Iterable of string names of metrics that
912          should *not* be averaged over an epoch.
913          Metrics in this list will be logged as-is in `on_epoch_end`.
914          All others will be averaged in `on_epoch_end`.
915  """
916
917  def __init__(self, stateful_metrics=None):
918    super(BaseLogger, self).__init__()
919    self.stateful_metrics = set(stateful_metrics or [])
920
921  def on_epoch_begin(self, epoch, logs=None):
922    self.seen = 0
923    self.totals = {}
924
925  def on_batch_end(self, batch, logs=None):
926    logs = logs or {}
927    batch_size = logs.get('size', 0)
928    # In case of distribution strategy we can potentially run multiple steps
929    # at the same time, we should account for that in the `seen` calculation.
930    num_steps = logs.get('num_steps', 1)
931    self.seen += batch_size * num_steps
932
933    for k, v in logs.items():
934      if k in self.stateful_metrics:
935        self.totals[k] = v
936      else:
937        if k in self.totals:
938          self.totals[k] += v * batch_size
939        else:
940          self.totals[k] = v * batch_size
941
942  def on_epoch_end(self, epoch, logs=None):
943    if logs is not None:
944      for k in self.params['metrics']:
945        if k in self.totals:
946          # Make value available to next callbacks.
947          if k in self.stateful_metrics:
948            logs[k] = self.totals[k]
949          else:
950            logs[k] = self.totals[k] / self.seen
951
952
953@keras_export('keras.callbacks.TerminateOnNaN')
954class TerminateOnNaN(Callback):
955  """Callback that terminates training when a NaN loss is encountered.
956  """
957
958  def __init__(self):
959    super(TerminateOnNaN, self).__init__()
960    self._supports_tf_logs = True
961
962  def on_batch_end(self, batch, logs=None):
963    logs = logs or {}
964    loss = logs.get('loss')
965    if loss is not None:
966      loss = tf_utils.sync_to_numpy_or_python_type(loss)
967      if np.isnan(loss) or np.isinf(loss):
968        print('Batch %d: Invalid loss, terminating training' % (batch))
969        self.model.stop_training = True
970
971
972@keras_export('keras.callbacks.ProgbarLogger')
973class ProgbarLogger(Callback):
974  """Callback that prints metrics to stdout.
975
976  Args:
977      count_mode: One of `"steps"` or `"samples"`.
978          Whether the progress bar should
979          count samples seen or steps (batches) seen.
980      stateful_metrics: Iterable of string names of metrics that
981          should *not* be averaged over an epoch.
982          Metrics in this list will be logged as-is.
983          All others will be averaged over time (e.g. loss, etc).
984          If not provided, defaults to the `Model`'s metrics.
985
986  Raises:
987      ValueError: In case of invalid `count_mode`.
988  """
989
990  def __init__(self, count_mode='samples', stateful_metrics=None):
991    super(ProgbarLogger, self).__init__()
992    self._supports_tf_logs = True
993    if count_mode == 'samples':
994      self.use_steps = False
995    elif count_mode == 'steps':
996      self.use_steps = True
997    else:
998      raise ValueError('Unknown `count_mode`: ' + str(count_mode))
999    # Defaults to all Model's metrics except for loss.
1000    self.stateful_metrics = set(stateful_metrics) if stateful_metrics else set()
1001
1002    self.seen = 0
1003    self.progbar = None
1004    self.target = None
1005    self.verbose = 1
1006    self.epochs = 1
1007
1008    self._train_step, self._test_step, self._predict_step = None, None, None
1009    self._call_batch_hooks = True
1010
1011    self._called_in_fit = False
1012
1013  def set_params(self, params):
1014    self.verbose = params['verbose']
1015    self.epochs = params['epochs']
1016    if self.use_steps and 'steps' in params:
1017      self.target = params['steps']
1018    elif not self.use_steps and 'samples' in params:
1019      self.target = params['samples']
1020    else:
1021      self.target = None  # Will be inferred at the end of the first epoch.
1022
1023    self._call_batch_hooks = self.verbose == 1
1024    if self.target is None:
1025      try:
1026        self._train_step = self.model._train_counter  # pylint: disable=protected-access
1027        self._test_step = self.model._test_counter  # pylint: disable=protected-access
1028        self._predict_step = self.model._predict_counter  # pylint: disable=protected-access
1029      except AttributeError:
1030        self._call_batch_hooks = True
1031
1032  def on_train_begin(self, logs=None):
1033    # When this logger is called inside `fit`, validation is silent.
1034    self._called_in_fit = True
1035
1036  def on_test_begin(self, logs=None):
1037    if not self._called_in_fit:
1038      self._reset_progbar()
1039      self._maybe_init_progbar()
1040
1041  def on_predict_begin(self, logs=None):
1042    self._reset_progbar()
1043    self._maybe_init_progbar()
1044
1045  def on_epoch_begin(self, epoch, logs=None):
1046    self._reset_progbar()
1047    self._maybe_init_progbar()
1048    if self.verbose and self.epochs > 1:
1049      print('Epoch %d/%d' % (epoch + 1, self.epochs))
1050
1051  def on_train_batch_end(self, batch, logs=None):
1052    self._batch_update_progbar(batch, logs)
1053
1054  def on_test_batch_end(self, batch, logs=None):
1055    if not self._called_in_fit:
1056      self._batch_update_progbar(batch, logs)
1057
1058  def on_predict_batch_end(self, batch, logs=None):
1059    # Don't pass prediction results.
1060    self._batch_update_progbar(batch, None)
1061
1062  def on_epoch_end(self, epoch, logs=None):
1063    self._finalize_progbar(logs, self._train_step)
1064
1065  def on_test_end(self, logs=None):
1066    if not self._called_in_fit:
1067      self._finalize_progbar(logs, self._test_step)
1068
1069  def on_predict_end(self, logs=None):
1070    self._finalize_progbar(logs, self._predict_step)
1071
1072  def _reset_progbar(self):
1073    self.seen = 0
1074    self.progbar = None
1075
1076  def _maybe_init_progbar(self):
1077    """Instantiate a `Progbar` if not yet, and update the stateful metrics."""
1078    # TODO(rchao): Legacy TF1 code path may use list for
1079    # `self.stateful_metrics`. Remove "cast to set" when TF1 support is dropped.
1080    self.stateful_metrics = set(self.stateful_metrics)
1081
1082    if self.model:
1083      # Update the existing stateful metrics as `self.model.metrics` may contain
1084      # updated metrics after `MetricsContainer` is built in the first train
1085      # step.
1086      self.stateful_metrics = self.stateful_metrics.union(
1087          set(m.name for m in self.model.metrics))
1088
1089    if self.progbar is None:
1090      self.progbar = Progbar(
1091          target=self.target,
1092          verbose=self.verbose,
1093          stateful_metrics=self.stateful_metrics,
1094          unit_name='step' if self.use_steps else 'sample')
1095
1096    self.progbar._update_stateful_metrics(self.stateful_metrics)  # pylint: disable=protected-access
1097
1098  def _implements_train_batch_hooks(self):
1099    return self._call_batch_hooks
1100
1101  def _implements_test_batch_hooks(self):
1102    return self._call_batch_hooks
1103
1104  def _implements_predict_batch_hooks(self):
1105    return self._call_batch_hooks
1106
1107  def _batch_update_progbar(self, batch, logs=None):
1108    """Updates the progbar."""
1109    logs = logs or {}
1110    self._maybe_init_progbar()
1111    if self.use_steps:
1112      self.seen = batch + 1  # One-indexed.
1113    else:
1114      # v1 path only.
1115      logs = copy.copy(logs)
1116      batch_size = logs.pop('size', 0)
1117      num_steps = logs.pop('num_steps', 1)
1118      logs.pop('batch', None)
1119      add_seen = num_steps * batch_size
1120      self.seen += add_seen
1121
1122    if self.verbose == 1:
1123      # Only block async when verbose = 1.
1124      logs = tf_utils.sync_to_numpy_or_python_type(logs)
1125      self.progbar.update(self.seen, list(logs.items()), finalize=False)
1126
1127  def _finalize_progbar(self, logs, counter):
1128    logs = tf_utils.sync_to_numpy_or_python_type(logs or {})
1129    if self.target is None:
1130      if counter is not None:
1131        counter = counter.numpy()
1132        if not self.use_steps:
1133          counter *= logs.get('size', 1)
1134      self.target = counter or self.seen
1135      self.progbar.target = self.target
1136    self.progbar.update(self.target, list(logs.items()), finalize=True)
1137
1138
1139@keras_export('keras.callbacks.History')
1140class History(Callback):
1141  """Callback that records events into a `History` object.
1142
1143  This callback is automatically applied to
1144  every Keras model. The `History` object
1145  gets returned by the `fit` method of models.
1146
1147  Example:
1148
1149  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1150  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1151  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1152  ...                     epochs=10, verbose=1)
1153  >>> print(history.params)
1154  {'verbose': 1, 'epochs': 10, 'steps': 1}
1155  >>> # check the keys of history object
1156  >>> print(history.history.keys())
1157  dict_keys(['loss'])
1158
1159  """
1160
1161  def __init__(self):
1162    super(History, self).__init__()
1163    self.history = {}
1164
1165  def on_train_begin(self, logs=None):
1166    self.epoch = []
1167
1168  def on_epoch_end(self, epoch, logs=None):
1169    logs = logs or {}
1170    self.epoch.append(epoch)
1171    for k, v in logs.items():
1172      self.history.setdefault(k, []).append(v)
1173
1174    # Set the history attribute on the model after the epoch ends. This will
1175    # make sure that the state which is set is the latest one.
1176    self.model.history = self
1177
1178
1179@keras_export('keras.callbacks.ModelCheckpoint')
1180class ModelCheckpoint(Callback):
1181  """Callback to save the Keras model or model weights at some frequency.
1182
1183  `ModelCheckpoint` callback is used in conjunction with training using
1184  `model.fit()` to save a model or weights (in a checkpoint file) at some
1185  interval, so the model or weights can be loaded later to continue the training
1186  from the state saved.
1187
1188  A few options this callback provides include:
1189
1190  - Whether to only keep the model that has achieved the "best performance" so
1191    far, or whether to save the model at the end of every epoch regardless of
1192    performance.
1193  - Definition of 'best'; which quantity to monitor and whether it should be
1194    maximized or minimized.
1195  - The frequency it should save at. Currently, the callback supports saving at
1196    the end of every epoch, or after a fixed number of training batches.
1197  - Whether only weights are saved, or the whole model is saved.
1198
1199  Note: If you get `WARNING:tensorflow:Can save best model only with <name>
1200  available, skipping` see the description of the `monitor` argument for
1201  details on how to get this right.
1202
1203  Example:
1204
1205  ```python
1206  model.compile(loss=..., optimizer=...,
1207                metrics=['accuracy'])
1208
1209  EPOCHS = 10
1210  checkpoint_filepath = '/tmp/checkpoint'
1211  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
1212      filepath=checkpoint_filepath,
1213      save_weights_only=True,
1214      monitor='val_accuracy',
1215      mode='max',
1216      save_best_only=True)
1217
1218  # Model weights are saved at the end of every epoch, if it's the best seen
1219  # so far.
1220  model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
1221
1222  # The model weights (that are considered the best) are loaded into the model.
1223  model.load_weights(checkpoint_filepath)
1224  ```
1225
1226  Args:
1227      filepath: string or `PathLike`, path to save the model file. e.g.
1228        filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath`
1229        can contain named formatting options, which will be filled the value of
1230        `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
1231        `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
1232        checkpoints will be saved with the epoch number and the validation loss
1233        in the filename. The directory of the filepath should not be reused by
1234        any other callbacks to avoid conflicts.
1235      monitor: The metric name to monitor. Typically the metrics are set by the
1236        `Model.compile` method. Note:
1237
1238        * Prefix the name with `"val_`" to monitor validation metrics.
1239        * Use `"loss"` or "`val_loss`" to monitor the model's total loss.
1240        * If you specify metrics as strings, like `"accuracy"`, pass the same
1241          string (with or without the `"val_"` prefix).
1242        * If you pass `metrics.Metric` objects, `monitor` should be set to
1243          `metric.name`
1244        * If you're not sure about the metric names you can check the contents
1245          of the `history.history` dictionary returned by
1246          `history = model.fit()`
1247        * Multi-output models set additional prefixes on the metric names.
1248
1249      verbose: verbosity mode, 0 or 1.
1250      save_best_only: if `save_best_only=True`, it only saves when the model
1251        is considered the "best" and the latest best model according to the
1252        quantity monitored will not be overwritten. If `filepath` doesn't
1253        contain formatting options like `{epoch}` then `filepath` will be
1254        overwritten by each new better model.
1255      mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the
1256        decision to overwrite the current save file is made based on either
1257        the maximization or the minimization of the monitored quantity.
1258        For `val_acc`, this should be `max`, for `val_loss` this should be
1259        `min`, etc. In `auto` mode, the mode is set to `max` if the quantities
1260        monitored are 'acc' or start with 'fmeasure' and are set to `min` for
1261        the rest of the quantities.
1262      save_weights_only: if True, then only the model's weights will be saved
1263        (`model.save_weights(filepath)`), else the full model is saved
1264        (`model.save(filepath)`).
1265      save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
1266        the model after each epoch. When using integer, the callback saves the
1267        model at end of this many batches. If the `Model` is compiled with
1268        `steps_per_execution=N`, then the saving criteria will be
1269        checked every Nth batch. Note that if the saving isn't aligned to
1270        epochs, the monitored metric may potentially be less reliable (it
1271        could reflect as little as 1 batch, since the metrics get reset every
1272        epoch). Defaults to `'epoch'`.
1273      options: Optional `tf.train.CheckpointOptions` object if
1274        `save_weights_only` is true or optional `tf.saved_model.SaveOptions`
1275        object if `save_weights_only` is false.
1276      **kwargs: Additional arguments for backwards compatibility. Possible key
1277        is `period`.
1278  """
1279
1280  def __init__(self,
1281               filepath,
1282               monitor='val_loss',
1283               verbose=0,
1284               save_best_only=False,
1285               save_weights_only=False,
1286               mode='auto',
1287               save_freq='epoch',
1288               options=None,
1289               **kwargs):
1290    super(ModelCheckpoint, self).__init__()
1291    self._supports_tf_logs = True
1292    self.monitor = monitor
1293    self.verbose = verbose
1294    self.filepath = path_to_string(filepath)
1295    self.save_best_only = save_best_only
1296    self.save_weights_only = save_weights_only
1297    self.save_freq = save_freq
1298    self.epochs_since_last_save = 0
1299    self._batches_seen_since_last_saving = 0
1300    self._last_batch_seen = 0
1301
1302    if save_weights_only:
1303      if options is None or isinstance(
1304          options, checkpoint_options_lib.CheckpointOptions):
1305        self._options = options or checkpoint_options_lib.CheckpointOptions()
1306      else:
1307        raise TypeError('If save_weights_only is True, then `options` must be '
1308                        'either None or a tf.train.CheckpointOptions')
1309    else:
1310      if options is None or isinstance(options, save_options_lib.SaveOptions):
1311        self._options = options or save_options_lib.SaveOptions()
1312      else:
1313        raise TypeError('If save_weights_only is False, then `options` must be'
1314                        'either None or a tf.saved_model.SaveOptions')
1315
1316    # Deprecated field `load_weights_on_restart` is for loading the checkpoint
1317    # file from `filepath` at the start of `model.fit()`
1318    # TODO(rchao): Remove the arg during next breaking release.
1319    if 'load_weights_on_restart' in kwargs:
1320      self.load_weights_on_restart = kwargs['load_weights_on_restart']
1321      logging.warning('`load_weights_on_restart` argument is deprecated. '
1322                      'Please use `model.load_weights()` for loading weights '
1323                      'before the start of `model.fit()`.')
1324    else:
1325      self.load_weights_on_restart = False
1326
1327    # Deprecated field `period` is for the number of epochs between which
1328    # the model is saved.
1329    if 'period' in kwargs:
1330      self.period = kwargs['period']
1331      logging.warning('`period` argument is deprecated. Please use `save_freq` '
1332                      'to specify the frequency in number of batches seen.')
1333    else:
1334      self.period = 1
1335
1336    if mode not in ['auto', 'min', 'max']:
1337      logging.warning('ModelCheckpoint mode %s is unknown, '
1338                      'fallback to auto mode.', mode)
1339      mode = 'auto'
1340
1341    if mode == 'min':
1342      self.monitor_op = np.less
1343      self.best = np.Inf
1344    elif mode == 'max':
1345      self.monitor_op = np.greater
1346      self.best = -np.Inf
1347    else:
1348      if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
1349        self.monitor_op = np.greater
1350        self.best = -np.Inf
1351      else:
1352        self.monitor_op = np.less
1353        self.best = np.Inf
1354
1355    if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
1356      raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq))
1357
1358    # Only the chief worker writes model checkpoints, but all workers
1359    # restore checkpoint at on_train_begin().
1360    self._chief_worker_only = False
1361
1362  def on_train_begin(self, logs=None):
1363    if self.load_weights_on_restart:
1364      filepath_to_load = (
1365          self._get_most_recently_modified_file_matching_pattern(self.filepath))
1366      if (filepath_to_load is not None and
1367          self._checkpoint_exists(filepath_to_load)):
1368        try:
1369          # `filepath` may contain placeholders such as `{epoch:02d}`, and
1370          # thus it attempts to load the most recently modified file with file
1371          # name matching the pattern.
1372          self.model.load_weights(filepath_to_load)
1373        except (IOError, ValueError) as e:
1374          raise ValueError('Error loading file from {}. Reason: {}'.format(
1375              filepath_to_load, e))
1376
1377  def _implements_train_batch_hooks(self):
1378    # Only call batch hooks when saving on batch
1379    return self.save_freq != 'epoch'
1380
1381  def on_train_batch_end(self, batch, logs=None):
1382    if self._should_save_on_batch(batch):
1383      self._save_model(epoch=self._current_epoch, logs=logs)
1384
1385  def on_epoch_begin(self, epoch, logs=None):
1386    self._current_epoch = epoch
1387
1388  def on_epoch_end(self, epoch, logs=None):
1389    self.epochs_since_last_save += 1
1390    # pylint: disable=protected-access
1391    if self.save_freq == 'epoch':
1392      self._save_model(epoch=epoch, logs=logs)
1393
1394  def _should_save_on_batch(self, batch):
1395    """Handles batch-level saving logic, supports steps_per_execution."""
1396    if self.save_freq == 'epoch':
1397      return False
1398
1399    if batch <= self._last_batch_seen:  # New epoch.
1400      add_batches = batch + 1  # batches are zero-indexed.
1401    else:
1402      add_batches = batch - self._last_batch_seen
1403    self._batches_seen_since_last_saving += add_batches
1404    self._last_batch_seen = batch
1405
1406    if self._batches_seen_since_last_saving >= self.save_freq:
1407      self._batches_seen_since_last_saving = 0
1408      return True
1409    return False
1410
1411  def _save_model(self, epoch, logs):
1412    """Saves the model.
1413
1414    Args:
1415        epoch: the epoch this iteration is in.
1416        logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
1417    """
1418    logs = logs or {}
1419
1420    if isinstance(self.save_freq,
1421                  int) or self.epochs_since_last_save >= self.period:
1422      # Block only when saving interval is reached.
1423      logs = tf_utils.sync_to_numpy_or_python_type(logs)
1424      self.epochs_since_last_save = 0
1425      filepath = self._get_file_path(epoch, logs)
1426
1427      try:
1428        if self.save_best_only:
1429          current = logs.get(self.monitor)
1430          if current is None:
1431            logging.warning('Can save best model only with %s available, '
1432                            'skipping.', self.monitor)
1433          else:
1434            if self.monitor_op(current, self.best):
1435              if self.verbose > 0:
1436                print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
1437                      ' saving model to %s' % (epoch + 1, self.monitor,
1438                                               self.best, current, filepath))
1439              self.best = current
1440              if self.save_weights_only:
1441                self.model.save_weights(
1442                    filepath, overwrite=True, options=self._options)
1443              else:
1444                self.model.save(filepath, overwrite=True, options=self._options)
1445            else:
1446              if self.verbose > 0:
1447                print('\nEpoch %05d: %s did not improve from %0.5f' %
1448                      (epoch + 1, self.monitor, self.best))
1449        else:
1450          if self.verbose > 0:
1451            print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
1452          if self.save_weights_only:
1453            self.model.save_weights(
1454                filepath, overwrite=True, options=self._options)
1455          else:
1456            self.model.save(filepath, overwrite=True, options=self._options)
1457
1458        self._maybe_remove_file()
1459      except IsADirectoryError as e:  # h5py 3.x
1460        raise IOError('Please specify a non-directory filepath for '
1461                      'ModelCheckpoint. Filepath used is an existing '
1462                      'directory: {}'.format(filepath))
1463      except IOError as e:  # h5py 2.x
1464        # `e.errno` appears to be `None` so checking the content of `e.args[0]`.
1465        if 'is a directory' in str(e.args[0]).lower():
1466          raise IOError('Please specify a non-directory filepath for '
1467                        'ModelCheckpoint. Filepath used is an existing '
1468                        'directory: {}'.format(filepath))
1469        # Re-throw the error for any other causes.
1470        raise e
1471
1472  def _get_file_path(self, epoch, logs):
1473    """Returns the file path for checkpoint."""
1474    # pylint: disable=protected-access
1475    try:
1476      # `filepath` may contain placeholders such as `{epoch:02d}` and
1477      # `{mape:.2f}`. A mismatch between logged metrics and the path's
1478      # placeholders can cause formatting to fail.
1479      file_path = self.filepath.format(epoch=epoch + 1, **logs)
1480    except KeyError as e:
1481      raise KeyError('Failed to format this callback filepath: "{}". '
1482                     'Reason: {}'.format(self.filepath, e))
1483    self._write_filepath = distributed_file_utils.write_filepath(
1484        file_path, self.model.distribute_strategy)
1485    return self._write_filepath
1486
1487  def _maybe_remove_file(self):
1488    # Remove the checkpoint directory in multi-worker training where this worker
1489    # should not checkpoint. It is a dummy directory previously saved for sync
1490    # distributed training.
1491    distributed_file_utils.remove_temp_dir_with_filepath(
1492        self._write_filepath, self.model.distribute_strategy)
1493
1494  def _checkpoint_exists(self, filepath):
1495    """Returns whether the checkpoint `filepath` refers to exists."""
1496    if filepath.endswith('.h5'):
1497      return file_io.file_exists_v2(filepath)
1498    tf_saved_model_exists = file_io.file_exists_v2(filepath)
1499    tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
1500        filepath + '.index')
1501    return tf_saved_model_exists or tf_weights_only_checkpoint_exists
1502
1503  def _get_most_recently_modified_file_matching_pattern(self, pattern):
1504    """Returns the most recently modified filepath matching pattern.
1505
1506    Pattern may contain python formatting placeholder. If
1507    `tf.train.latest_checkpoint()` does not return None, use that; otherwise,
1508    check for most recently modified one that matches the pattern.
1509
1510    In the rare case where there are more than one pattern-matching file having
1511    the same modified time that is most recent among all, return the filepath
1512    that is largest (by `>` operator, lexicographically using the numeric
1513    equivalents). This provides a tie-breaker when multiple files are most
1514    recent. Note that a larger `filepath` can sometimes indicate a later time of
1515    modification (for instance, when epoch/batch is used as formatting option),
1516    but not necessarily (when accuracy or loss is used). The tie-breaker is
1517    put in the logic as best effort to return the most recent, and to avoid
1518    undeterministic result.
1519
1520    Modified time of a file is obtained with `os.path.getmtime()`.
1521
1522    This utility function is best demonstrated via an example:
1523
1524    ```python
1525    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1526    test_dir = self.get_temp_dir()
1527    path_pattern = os.path.join(test_dir, file_pattern)
1528    file_paths = [
1529        os.path.join(test_dir, file_name) for file_name in
1530        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
1531    ]
1532    for file_path in file_paths:
1533      # Write something to each of the files
1534    self.assertEqual(
1535        _get_most_recently_modified_file_matching_pattern(path_pattern),
1536        file_paths[-1])
1537    ```
1538
1539    Args:
1540        pattern: The file pattern that may optionally contain python placeholder
1541            such as `{epoch:02d}`.
1542
1543    Returns:
1544        The most recently modified file's full filepath matching `pattern`. If
1545        `pattern` does not contain any placeholder, this returns the filepath
1546        that
1547        exactly matches `pattern`. Returns `None` if no match is found.
1548    """
1549    dir_name = os.path.dirname(pattern)
1550    base_name = os.path.basename(pattern)
1551    base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$'
1552
1553    # If tf.train.latest_checkpoint tells us there exists a latest checkpoint,
1554    # use that as it is more robust than `os.path.getmtime()`.
1555    latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name)
1556    if latest_tf_checkpoint is not None and re.match(
1557        base_name_regex, os.path.basename(latest_tf_checkpoint)):
1558      return latest_tf_checkpoint
1559
1560    latest_mod_time = 0
1561    file_path_with_latest_mod_time = None
1562    n_file_with_latest_mod_time = 0
1563    file_path_with_largest_file_name = None
1564
1565    if file_io.file_exists_v2(dir_name):
1566      for file_name in os.listdir(dir_name):
1567        # Only consider if `file_name` matches the pattern.
1568        if re.match(base_name_regex, file_name):
1569          file_path = os.path.join(dir_name, file_name)
1570          mod_time = os.path.getmtime(file_path)
1571          if (file_path_with_largest_file_name is None or
1572              file_path > file_path_with_largest_file_name):
1573            file_path_with_largest_file_name = file_path
1574          if mod_time > latest_mod_time:
1575            latest_mod_time = mod_time
1576            file_path_with_latest_mod_time = file_path
1577            # In the case a file with later modified time is found, reset
1578            # the counter for the number of files with latest modified time.
1579            n_file_with_latest_mod_time = 1
1580          elif mod_time == latest_mod_time:
1581            # In the case a file has modified time tied with the most recent,
1582            # increment the counter for the number of files with latest modified
1583            # time by 1.
1584            n_file_with_latest_mod_time += 1
1585
1586    if n_file_with_latest_mod_time == 1:
1587      # Return the sole file that has most recent modified time.
1588      return file_path_with_latest_mod_time
1589    else:
1590      # If there are more than one file having latest modified time, return
1591      # the file path with the largest file name.
1592      return file_path_with_largest_file_name
1593
1594
1595@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[])
1596class BackupAndRestore(Callback):
1597  """Callback to back up and restore the training state.
1598
1599  `BackupAndRestore` callback is intended to recover from interruptions that
1600  happened in the middle of a model.fit execution by backing up the
1601  training states in a temporary checkpoint file (based on TF CheckpointManager)
1602  at the end of each epoch. If training restarted before completion, the
1603  training state and model are restored to the most recently saved state at the
1604  beginning of a new model.fit() run.
1605  Note that user is responsible to bring jobs back up.
1606  This callback is important for the backup and restore mechanism for fault
1607  tolerance purpose. And the model to be restored from an previous checkpoint is
1608  expected to be the same as the one used to back up. If user changes arguments
1609  passed to compile or fit, the checkpoint saved for fault tolerance can become
1610  invalid.
1611
1612  Note:
1613  1. This callback is not compatible with disabling eager execution.
1614  2. A checkpoint is saved at the end of each epoch, when restoring we'll redo
1615  any partial work from an unfinished epoch in which the training got restarted
1616  (so the work done before a interruption doesn't affect the final model state).
1617  3. This works for both single worker and multi-worker mode, only
1618  MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.
1619
1620  Example:
1621
1622  >>> class InterruptingCallback(tf.keras.callbacks.Callback):
1623  ...   def on_epoch_begin(self, epoch, logs=None):
1624  ...     if epoch == 4:
1625  ...       raise RuntimeError('Interrupting!')
1626  >>> callback = tf.keras.callbacks.experimental.BackupAndRestore(
1627  ... backup_dir="/tmp/backup")
1628  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1629  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1630  >>> try:
1631  ...   model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1632  ...             batch_size=1, callbacks=[callback, InterruptingCallback()],
1633  ...             verbose=0)
1634  ... except:
1635  ...   pass
1636  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1637  ...             batch_size=1, callbacks=[callback], verbose=0)
1638  >>> # Only 6 more epochs are run, since first trainning got interrupted at
1639  >>> # zero-indexed epoch 4, second training will continue from 4 to 9.
1640  >>> len(history.history['loss'])
1641  6
1642
1643  Args:
1644      backup_dir: String, path to store the checkpoint.
1645        e.g. backup_dir = os.path.join(working_dir, 'backup')
1646        This is the directory in which the system stores temporary files to
1647        recover the model from jobs terminated unexpectedly. The directory
1648        cannot be reused elsewhere to store other files, e.g. by
1649        BackupAndRestore callback of another training, or by another callback
1650        (ModelCheckpoint) of the same training.
1651  """
1652
1653  def __init__(self, backup_dir):
1654    super(BackupAndRestore, self).__init__()
1655    self.backup_dir = backup_dir
1656    self._supports_tf_logs = True
1657    self._supported_strategies = (
1658        mirrored_strategy.MirroredStrategy,
1659        collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1660        tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2,
1661        parameter_server_strategy_v2.ParameterServerStrategyV2)
1662
1663    if not context.executing_eagerly():
1664      if ops.inside_function():
1665        raise ValueError('This Callback\'s method contains Python state and '
1666                         'should be called outside of `tf.function`s.')
1667      else:  # Legacy graph mode:
1668        raise ValueError(
1669            'BackupAndRestore only supports eager mode. In graph '
1670            'mode, consider using ModelCheckpoint to manually save '
1671            'and restore weights with `model.load_weights()` and by '
1672            'providing `initial_epoch` in `model.fit()` for fault tolerance.')
1673
1674    # Only the chief worker writes model checkpoints, but all workers
1675    # restore checkpoint at on_train_begin().
1676    self._chief_worker_only = False
1677
1678  def on_train_begin(self, logs=None):
1679    # TrainingState is used to manage the training state needed for
1680    # failure-recovery of a worker in training.
1681    # pylint: disable=protected-access
1682
1683    if self.model._distribution_strategy and not isinstance(
1684        self.model.distribute_strategy, self._supported_strategies):
1685      raise NotImplementedError(
1686          '%s is not supported yet. '
1687          'Currently BackupAndRestore callback only supports empty strategy, '
1688          'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' %
1689          type(self.model.distribute_strategy).__name__)
1690    self.model._training_state = (
1691        worker_training_state.WorkerTrainingState(self.model, self.backup_dir))
1692    self._training_state = self.model._training_state
1693    self._training_state.restore()
1694
1695  def on_train_end(self, logs=None):
1696    # pylint: disable=protected-access
1697    # On exit of training, delete the training state backup file that was saved
1698    # for the purpose of worker recovery.
1699    self._training_state.delete_backup()
1700
1701    # Clean up the training state.
1702    del self._training_state
1703    del self.model._training_state
1704
1705  def on_epoch_end(self, epoch, logs=None):
1706    # Back up the model and current epoch for possible future recovery.
1707    self._training_state.back_up(epoch)
1708
1709
1710@keras_export('keras.callbacks.EarlyStopping')
1711class EarlyStopping(Callback):
1712  """Stop training when a monitored metric has stopped improving.
1713
1714  Assuming the goal of a training is to minimize the loss. With this, the
1715  metric to be monitored would be `'loss'`, and mode would be `'min'`. A
1716  `model.fit()` training loop will check at end of every epoch whether
1717  the loss is no longer decreasing, considering the `min_delta` and
1718  `patience` if applicable. Once it's found no longer decreasing,
1719  `model.stop_training` is marked True and the training terminates.
1720
1721  The quantity to be monitored needs to be available in `logs` dict.
1722  To make it so, pass the loss or metrics at `model.compile()`.
1723
1724  Args:
1725    monitor: Quantity to be monitored.
1726    min_delta: Minimum change in the monitored quantity
1727        to qualify as an improvement, i.e. an absolute
1728        change of less than min_delta, will count as no
1729        improvement.
1730    patience: Number of epochs with no improvement
1731        after which training will be stopped.
1732    verbose: verbosity mode.
1733    mode: One of `{"auto", "min", "max"}`. In `min` mode,
1734        training will stop when the quantity
1735        monitored has stopped decreasing; in `"max"`
1736        mode it will stop when the quantity
1737        monitored has stopped increasing; in `"auto"`
1738        mode, the direction is automatically inferred
1739        from the name of the monitored quantity.
1740    baseline: Baseline value for the monitored quantity.
1741        Training will stop if the model doesn't show improvement over the
1742        baseline.
1743    restore_best_weights: Whether to restore model weights from
1744        the epoch with the best value of the monitored quantity.
1745        If False, the model weights obtained at the last step of
1746        training are used. An epoch will be restored regardless
1747        of the performance relative to the `baseline`. If no epoch
1748        improves on `baseline`, training will run for `patience`
1749        epochs and restore weights from the best epoch in that set.
1750
1751  Example:
1752
1753  >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
1754  >>> # This callback will stop the training when there is no improvement in
1755  >>> # the loss for three consecutive epochs.
1756  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1757  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1758  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1759  ...                     epochs=10, batch_size=1, callbacks=[callback],
1760  ...                     verbose=0)
1761  >>> len(history.history['loss'])  # Only 4 epochs are run.
1762  4
1763  """
1764
1765  def __init__(self,
1766               monitor='val_loss',
1767               min_delta=0,
1768               patience=0,
1769               verbose=0,
1770               mode='auto',
1771               baseline=None,
1772               restore_best_weights=False):
1773    super(EarlyStopping, self).__init__()
1774
1775    self.monitor = monitor
1776    self.patience = patience
1777    self.verbose = verbose
1778    self.baseline = baseline
1779    self.min_delta = abs(min_delta)
1780    self.wait = 0
1781    self.stopped_epoch = 0
1782    self.restore_best_weights = restore_best_weights
1783    self.best_weights = None
1784
1785    if mode not in ['auto', 'min', 'max']:
1786      logging.warning('EarlyStopping mode %s is unknown, '
1787                      'fallback to auto mode.', mode)
1788      mode = 'auto'
1789
1790    if mode == 'min':
1791      self.monitor_op = np.less
1792    elif mode == 'max':
1793      self.monitor_op = np.greater
1794    else:
1795      if 'acc' in self.monitor:
1796        self.monitor_op = np.greater
1797      else:
1798        self.monitor_op = np.less
1799
1800    if self.monitor_op == np.greater:
1801      self.min_delta *= 1
1802    else:
1803      self.min_delta *= -1
1804
1805  def on_train_begin(self, logs=None):
1806    # Allow instances to be re-used
1807    self.wait = 0
1808    self.stopped_epoch = 0
1809    self.best = np.Inf if self.monitor_op == np.less else -np.Inf
1810    self.best_weights = None
1811
1812  def on_epoch_end(self, epoch, logs=None):
1813    current = self.get_monitor_value(logs)
1814    if current is None:
1815      return
1816    if self.restore_best_weights and self.best_weights is None:
1817      # Restore the weights after first epoch if no progress is ever made.
1818      self.best_weights = self.model.get_weights()
1819
1820    self.wait += 1
1821    if self._is_improvement(current, self.best):
1822      self.best = current
1823      if self.restore_best_weights:
1824        self.best_weights = self.model.get_weights()
1825      # Only restart wait if we beat both the baseline and our previous best.
1826      if self.baseline is None or self._is_improvement(current, self.baseline):
1827        self.wait = 0
1828
1829    if self.wait >= self.patience:
1830      self.stopped_epoch = epoch
1831      self.model.stop_training = True
1832      if self.restore_best_weights and self.best_weights is not None:
1833        if self.verbose > 0:
1834          print('Restoring model weights from the end of the best epoch.')
1835        self.model.set_weights(self.best_weights)
1836
1837  def on_train_end(self, logs=None):
1838    if self.stopped_epoch > 0 and self.verbose > 0:
1839      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
1840
1841  def get_monitor_value(self, logs):
1842    logs = logs or {}
1843    monitor_value = logs.get(self.monitor)
1844    if monitor_value is None:
1845      logging.warning('Early stopping conditioned on metric `%s` '
1846                      'which is not available. Available metrics are: %s',
1847                      self.monitor, ','.join(list(logs.keys())))
1848    return monitor_value
1849
1850  def _is_improvement(self, monitor_value, reference_value):
1851    return self.monitor_op(monitor_value - self.min_delta, reference_value)
1852
1853
1854@keras_export('keras.callbacks.RemoteMonitor')
1855class RemoteMonitor(Callback):
1856  """Callback used to stream events to a server.
1857
1858  Requires the `requests` library.
1859  Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
1860  HTTP POST, with a `data` argument which is a
1861  JSON-encoded dictionary of event data.
1862  If `send_as_json=True`, the content type of the request will be
1863  `"application/json"`.
1864  Otherwise the serialized JSON will be sent within a form.
1865
1866  Args:
1867    root: String; root url of the target server.
1868    path: String; path relative to `root` to which the events will be sent.
1869    field: String; JSON field under which the data will be stored.
1870        The field is used only if the payload is sent within a form
1871        (i.e. send_as_json is set to False).
1872    headers: Dictionary; optional custom HTTP headers.
1873    send_as_json: Boolean; whether the request should be
1874        sent as `"application/json"`.
1875  """
1876
1877  def __init__(self,
1878               root='http://localhost:9000',
1879               path='/publish/epoch/end/',
1880               field='data',
1881               headers=None,
1882               send_as_json=False):
1883    super(RemoteMonitor, self).__init__()
1884
1885    self.root = root
1886    self.path = path
1887    self.field = field
1888    self.headers = headers
1889    self.send_as_json = send_as_json
1890
1891  def on_epoch_end(self, epoch, logs=None):
1892    if requests is None:
1893      raise ImportError('RemoteMonitor requires the `requests` library.')
1894    logs = logs or {}
1895    send = {}
1896    send['epoch'] = epoch
1897    for k, v in logs.items():
1898      # np.ndarray and np.generic are not scalar types
1899      # therefore we must unwrap their scalar values and
1900      # pass to the json-serializable dict 'send'
1901      if isinstance(v, (np.ndarray, np.generic)):
1902        send[k] = v.item()
1903      else:
1904        send[k] = v
1905    try:
1906      if self.send_as_json:
1907        requests.post(self.root + self.path, json=send, headers=self.headers)
1908      else:
1909        requests.post(
1910            self.root + self.path, {self.field: json.dumps(send)},
1911            headers=self.headers)
1912    except requests.exceptions.RequestException:
1913      logging.warning('Warning: could not reach RemoteMonitor '
1914                      'root server at ' + str(self.root))
1915
1916
1917@keras_export('keras.callbacks.LearningRateScheduler')
1918class LearningRateScheduler(Callback):
1919  """Learning rate scheduler.
1920
1921  At the beginning of every epoch, this callback gets the updated learning rate
1922  value from `schedule` function provided at `__init__`, with the current epoch
1923  and current learning rate, and applies the updated learning rate
1924  on the optimizer.
1925
1926  Args:
1927    schedule: a function that takes an epoch index (integer, indexed from 0)
1928        and current learning rate (float) as inputs and returns a new
1929        learning rate as output (float).
1930    verbose: int. 0: quiet, 1: update messages.
1931
1932  Example:
1933
1934  >>> # This function keeps the initial learning rate for the first ten epochs
1935  >>> # and decreases it exponentially after that.
1936  >>> def scheduler(epoch, lr):
1937  ...   if epoch < 10:
1938  ...     return lr
1939  ...   else:
1940  ...     return lr * tf.math.exp(-0.1)
1941  >>>
1942  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1943  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1944  >>> round(model.optimizer.lr.numpy(), 5)
1945  0.01
1946
1947  >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
1948  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1949  ...                     epochs=15, callbacks=[callback], verbose=0)
1950  >>> round(model.optimizer.lr.numpy(), 5)
1951  0.00607
1952
1953  """
1954
1955  def __init__(self, schedule, verbose=0):
1956    super(LearningRateScheduler, self).__init__()
1957    self.schedule = schedule
1958    self.verbose = verbose
1959
1960  def on_epoch_begin(self, epoch, logs=None):
1961    if not hasattr(self.model.optimizer, 'lr'):
1962      raise ValueError('Optimizer must have a "lr" attribute.')
1963    try:  # new API
1964      lr = float(backend.get_value(self.model.optimizer.lr))
1965      lr = self.schedule(epoch, lr)
1966    except TypeError:  # Support for old API for backward compatibility
1967      lr = self.schedule(epoch)
1968    if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
1969      raise ValueError('The output of the "schedule" function '
1970                       'should be float.')
1971    if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
1972      raise ValueError('The dtype of Tensor should be float')
1973    backend.set_value(self.model.optimizer.lr, backend.get_value(lr))
1974    if self.verbose > 0:
1975      print('\nEpoch %05d: LearningRateScheduler setting learning '
1976            'rate to %s.' % (epoch + 1, lr))
1977
1978  def on_epoch_end(self, epoch, logs=None):
1979    logs = logs or {}
1980    logs['lr'] = backend.get_value(self.model.optimizer.lr)
1981
1982
1983def keras_model_summary(name, data, step=None):
1984  """Writes a Keras model as JSON to as a Summary.
1985
1986  Writing the Keras model configuration allows the TensorBoard graph plugin to
1987  render a conceptual graph, as opposed to graph of ops. In case the model fails
1988  to serialize as JSON, it ignores and returns False.
1989
1990  Args:
1991    name: A name for this summary. The summary tag used for TensorBoard will be
1992      this name prefixed by any active name scopes.
1993    data: A Keras Model to write.
1994    step: Explicit `int64`-castable monotonic step value for this summary. If
1995      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
1996      not be None.
1997
1998  Returns:
1999    True on success, or False if no summary was written because no default
2000    summary writer was available.
2001
2002  Raises:
2003    ValueError: if a default writer exists, but no step was provided and
2004      `tf.summary.experimental.get_step()` is None.
2005  """
2006  summary_metadata = summary_pb2.SummaryMetadata()
2007  # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
2008  # the rationale.
2009  summary_metadata.plugin_data.plugin_name = 'graph_keras_model'
2010  # version number = 1
2011  summary_metadata.plugin_data.content = b'1'
2012
2013  try:
2014    json_string = data.to_json()
2015  except Exception as exc:  # pylint: disable=broad-except
2016    # An exception should not break a model code.
2017    logging.warning('Model failed to serialize as JSON. Ignoring... %s', exc)
2018    return False
2019
2020  with summary_ops_v2.summary_scope(name, 'graph_keras_model',
2021                                    [data, step]) as (tag, _):
2022    with ops.device('cpu:0'):
2023      tensor = constant_op.constant(json_string, dtype=dtypes.string)
2024    return summary_ops_v2.write(
2025        tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
2026
2027
2028@keras_export('keras.callbacks.TensorBoard', v1=[])
2029class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
2030  # pylint: disable=line-too-long
2031  """Enable visualizations for TensorBoard.
2032
2033  TensorBoard is a visualization tool provided with TensorFlow.
2034
2035  This callback logs events for TensorBoard, including:
2036
2037  * Metrics summary plots
2038  * Training graph visualization
2039  * Activation histograms
2040  * Sampled profiling
2041
2042  When used in `Model.evaluate`, in addition to epoch summaries, there will be
2043  a summary that records evaluation metrics vs `Model.optimizer.iterations`
2044  written. The metric names will be prepended with `evaluation`, with
2045  `Model.optimizer.iterations` being the step in the visualized TensorBoard.
2046
2047  If you have installed TensorFlow with pip, you should be able
2048  to launch TensorBoard from the command line:
2049
2050  ```
2051  tensorboard --logdir=path_to_your_logs
2052  ```
2053
2054  You can find more information about TensorBoard
2055  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
2056
2057  Args:
2058      log_dir: the path of the directory where to save the log files to be
2059        parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 'logs')
2060        This directory should not be reused by any other callbacks.
2061      histogram_freq: frequency (in epochs) at which to compute activation and
2062        weight histograms for the layers of the model. If set to 0, histograms
2063        won't be computed. Validation data (or split) must be specified for
2064        histogram visualizations.
2065      write_graph: whether to visualize the graph in TensorBoard. The log file
2066        can become quite large when write_graph is set to True.
2067      write_images: whether to write model weights to visualize as image in
2068        TensorBoard.
2069      write_steps_per_second: whether to log the training steps per second into
2070        Tensorboard. This supports both epoch and batch frequency logging.
2071      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
2072        writes the losses and metrics to TensorBoard after each batch. The same
2073        applies for `'epoch'`. If using an integer, let's say `1000`, the
2074        callback will write the metrics and losses to TensorBoard every 1000
2075        batches. Note that writing too frequently to TensorBoard can slow down
2076        your training.
2077      profile_batch: Profile the batch(es) to sample compute characteristics.
2078        profile_batch must be a non-negative integer or a tuple of integers.
2079        A pair of positive integers signify a range of batches to profile.
2080        By default, it will profile the second batch. Set profile_batch=0
2081        to disable profiling.
2082      embeddings_freq: frequency (in epochs) at which embedding layers will be
2083        visualized. If set to 0, embeddings won't be visualized.
2084      embeddings_metadata: Dictionary which maps embedding layer names to the
2085        filename of a file in which to save metadata for the embedding layer.
2086        In case the same metadata file is to be
2087        used for all embedding layers, a single filename can be passed.
2088
2089  Examples:
2090
2091  Basic usage:
2092
2093  ```python
2094  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
2095  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2096  # Then run the tensorboard command to view the visualizations.
2097  ```
2098
2099  Custom batch-level summaries in a subclassed Model:
2100
2101  ```python
2102  class MyModel(tf.keras.Model):
2103
2104    def build(self, _):
2105      self.dense = tf.keras.layers.Dense(10)
2106
2107    def call(self, x):
2108      outputs = self.dense(x)
2109      tf.summary.histogram('outputs', outputs)
2110      return outputs
2111
2112  model = MyModel()
2113  model.compile('sgd', 'mse')
2114
2115  # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2116  # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2117  # `Model.compile` will be logged every N batches.
2118  tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2119  model.fit(x_train, y_train, callbacks=[tb_callback])
2120  ```
2121
2122  Custom batch-level summaries in a Functional API Model:
2123
2124  ```python
2125  def my_summary(x):
2126    tf.summary.histogram('x', x)
2127    return x
2128
2129  inputs = tf.keras.Input(10)
2130  x = tf.keras.layers.Dense(10)(inputs)
2131  outputs = tf.keras.layers.Lambda(my_summary)(x)
2132  model = tf.keras.Model(inputs, outputs)
2133  model.compile('sgd', 'mse')
2134
2135  # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2136  # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2137  # `Model.compile` will be logged every N batches.
2138  tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2139  model.fit(x_train, y_train, callbacks=[tb_callback])
2140  ```
2141
2142  Profiling:
2143
2144  ```python
2145  # Profile a single batch, e.g. the 5th batch.
2146  tensorboard_callback = tf.keras.callbacks.TensorBoard(
2147      log_dir='./logs', profile_batch=5)
2148  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2149
2150  # Profile a range of batches, e.g. from 10 to 20.
2151  tensorboard_callback = tf.keras.callbacks.TensorBoard(
2152      log_dir='./logs', profile_batch=(10,20))
2153  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2154  ```
2155  """
2156
2157  # pylint: enable=line-too-long
2158
2159  def __init__(self,
2160               log_dir='logs',
2161               histogram_freq=0,
2162               write_graph=True,
2163               write_images=False,
2164               write_steps_per_second=False,
2165               update_freq='epoch',
2166               profile_batch=2,
2167               embeddings_freq=0,
2168               embeddings_metadata=None,
2169               **kwargs):
2170    super(TensorBoard, self).__init__()
2171    self._supports_tf_logs = True
2172    self._validate_kwargs(kwargs)
2173
2174    self.log_dir = path_to_string(log_dir)
2175    self.histogram_freq = histogram_freq
2176    self.write_graph = write_graph
2177    self.write_images = write_images
2178    self.write_steps_per_second = write_steps_per_second
2179    self.update_freq = 1 if update_freq == 'batch' else update_freq
2180    self.embeddings_freq = embeddings_freq
2181    self.embeddings_metadata = embeddings_metadata
2182    self._init_profile_batch(profile_batch)
2183    self._global_train_batch = 0
2184    self._previous_epoch_iterations = 0
2185    self._train_accumulated_time = 0
2186    self._batch_start_time = 0
2187
2188    # Lazily initialized in order to avoid creating event files when
2189    # not needed.
2190    self._writers = {}
2191
2192    # Used to restore any existing `SummaryWriter` after training ends.
2193    self._prev_summary_state = []
2194
2195  def _validate_kwargs(self, kwargs):
2196    """Handle arguments were supported in V1."""
2197    if kwargs.get('write_grads', False):
2198      logging.warning('`write_grads` will be ignored in TensorFlow 2.0 '
2199                      'for the `TensorBoard` Callback.')
2200    if kwargs.get('batch_size', False):
2201      logging.warning('`batch_size` is no longer needed in the '
2202                      '`TensorBoard` Callback and will be ignored '
2203                      'in TensorFlow 2.0.')
2204    if kwargs.get('embeddings_layer_names', False):
2205      logging.warning('`embeddings_layer_names` is not supported in '
2206                      'TensorFlow 2.0. Instead, all `Embedding` layers '
2207                      'will be visualized.')
2208    if kwargs.get('embeddings_data', False):
2209      logging.warning('`embeddings_data` is not supported in TensorFlow '
2210                      '2.0. Instead, all `Embedding` variables will be '
2211                      'visualized.')
2212
2213    unrecognized_kwargs = set(kwargs.keys()) - {
2214        'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size'
2215    }
2216
2217    # Only allow kwargs that were supported in V1.
2218    if unrecognized_kwargs:
2219      raise ValueError('Unrecognized arguments in `TensorBoard` '
2220                       'Callback: ' + str(unrecognized_kwargs))
2221
2222  def set_model(self, model):
2223    """Sets Keras model and writes graph if specified."""
2224    self.model = model
2225    self._log_write_dir = self._get_log_write_dir()
2226
2227    self._train_dir = os.path.join(self._log_write_dir, 'train')
2228    self._train_step = self.model._train_counter  # pylint: disable=protected-access
2229
2230    self._val_dir = os.path.join(self._log_write_dir, 'validation')
2231    self._val_step = self.model._test_counter  # pylint: disable=protected-access
2232
2233    self._writers = {}  # Resets writers.
2234
2235    self._should_write_train_graph = False
2236    if self.write_graph:
2237      self._write_keras_model_summary()
2238      self._should_write_train_graph = True
2239    if self.embeddings_freq:
2240      self._configure_embeddings()
2241
2242  @property
2243  def _train_writer(self):
2244    if 'train' not in self._writers:
2245      self._writers['train'] = summary_ops_v2.create_file_writer_v2(
2246          self._train_dir)
2247    return self._writers['train']
2248
2249  @property
2250  def _val_writer(self):
2251    if 'val' not in self._writers:
2252      self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir)
2253    return self._writers['val']
2254
2255  def _get_log_write_dir(self):
2256    """For multi-worker, only chief should write, others write to '/tmp'."""
2257    return distributed_file_utils.write_dirpath(self.log_dir,
2258                                                self.model.distribute_strategy)
2259
2260  def _delete_tmp_write_dir(self):
2261    """Deletes tmp write directories for multi-worker."""
2262    distributed_file_utils.remove_temp_dirpath(self.log_dir,
2263                                               self.model.distribute_strategy)
2264
2265  def _write_keras_model_train_graph(self):
2266    """Writes Keras model train_function graph to TensorBoard."""
2267    with self._train_writer.as_default():
2268      with summary_ops_v2.record_if(True):
2269        train_fn = self.model.train_tf_function
2270        # If the train_function is a `tf.function`, we can write out a graph
2271        if hasattr(train_fn, 'function_spec'):
2272          summary_ops_v2.graph(train_fn._concrete_stateful_fn.graph)  # pylint: disable=protected-access
2273
2274  def _write_keras_model_summary(self):
2275    """Writes Keras graph network summary to TensorBoard."""
2276    with self._train_writer.as_default():
2277      with summary_ops_v2.record_if(True):
2278        summary_writable = (
2279            self.model._is_graph_network or  # pylint: disable=protected-access
2280            self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
2281        if summary_writable:
2282          keras_model_summary('keras', self.model, step=0)
2283
2284  def _configure_embeddings(self):
2285    """Configure the Projector for embeddings."""
2286    # TODO(omalleyt): Add integration tests.
2287    from google.protobuf import text_format
2288    from tensorflow.python.keras.layers import embeddings
2289    from tensorflow.python.keras.protobuf import projector_config_pb2
2290
2291    config = projector_config_pb2.ProjectorConfig()
2292    for layer in self.model.layers:
2293      if isinstance(layer, embeddings.Embedding):
2294        embedding = config.embeddings.add()
2295        # Embeddings are always the first layer, so this naming should be
2296        # consistent in any keras models checkpoints.
2297        name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
2298        embedding.tensor_name = name
2299
2300        if self.embeddings_metadata is not None:
2301          if isinstance(self.embeddings_metadata, str):
2302            embedding.metadata_path = self.embeddings_metadata
2303          else:
2304            if layer.name in self.embeddings_metadata.keys():
2305              embedding.metadata_path = self.embeddings_metadata.pop(layer.name)
2306
2307    if self.embeddings_metadata and not isinstance(self.embeddings_metadata,
2308                                                   str):
2309      raise ValueError('Unrecognized `Embedding` layer names passed to '
2310                       '`keras.callbacks.TensorBoard` `embeddings_metadata` '
2311                       'argument: ' + str(self.embeddings_metadata.keys()))
2312
2313    config_pbtxt = text_format.MessageToString(config)
2314    path = os.path.join(self._log_write_dir, 'projector_config.pbtxt')
2315    with gfile.Open(path, 'w') as f:
2316      f.write(config_pbtxt)
2317
2318  def _push_writer(self, writer, step):
2319    """Sets the default writer for custom batch-level summaries."""
2320    if self.update_freq == 'epoch':
2321      return
2322
2323    should_record = lambda: math_ops.equal(step % self.update_freq, 0)
2324    # TODO(b/151339474): Fix deadlock when not using .value() here.
2325    summary_context = (writer.as_default(step.value()),
2326                       summary_ops_v2.record_if(should_record))
2327    self._prev_summary_state.append(summary_context)
2328    summary_context[0].__enter__()
2329    summary_context[1].__enter__()
2330
2331  def _pop_writer(self):
2332    """Pops the current writer."""
2333    if self.update_freq == 'epoch':
2334      return
2335
2336    # See _push_writer for the content of the previous_context, which is pair
2337    # of context.
2338    previous_context = self._prev_summary_state.pop()
2339    previous_context[1].__exit__(*sys.exc_info())
2340    previous_context[0].__exit__(*sys.exc_info())
2341
2342  def _close_writers(self):
2343    for writer in self._writers.values():
2344      writer.close()
2345
2346  def _init_profile_batch(self, profile_batch):
2347    """Validate profile_batch value and set the range of batches to profile.
2348    Sets values of _start_batch and _stop_batch attributes,
2349    specifying the start and stop batch to profile.
2350    Setting `profile_batch=0` disables profiling.
2351
2352    Args:
2353      profile_batch: The range of batches to profile. Should be a non-negative
2354        integer or a comma separated string of pair of positive integers. A pair
2355        of positive integers signify a range of batches to profile.
2356
2357    Raises:
2358      ValueError: If profile_batch is not an integer or a comma separated pair
2359                  of positive integers.
2360
2361    """
2362    profile_batch_error_message = (
2363        'profile_batch must be a non-negative integer or 2-tuple of positive '
2364        'integers. A pair of positive integers signifies a range of batches '
2365        'to profile. Found: {}'.format(profile_batch))
2366
2367    # Support legacy way of specifying "start,stop" or "start" as str.
2368    if isinstance(profile_batch, str):
2369      profile_batch = str(profile_batch).split(',')
2370      profile_batch = nest.map_structure(int, profile_batch)
2371
2372    if isinstance(profile_batch, int):
2373      self._start_batch = profile_batch
2374      self._stop_batch = profile_batch
2375    elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2:
2376      self._start_batch, self._stop_batch = profile_batch
2377    else:
2378      raise ValueError(profile_batch_error_message)
2379
2380    if self._start_batch < 0 or self._stop_batch < self._start_batch:
2381      raise ValueError(profile_batch_error_message)
2382
2383    # True when the profiler was successfully started by this callback.
2384    # We track the status here to make sure callbacks do not interfere with
2385    # each other. The callback will only stop the profiler it started.
2386    self._profiler_started = False
2387    if self._start_batch > 0:
2388      # Warm up and improve the profiling accuracy.
2389      self._start_profiler(logdir='')
2390      self._stop_profiler(save=False)
2391    # True when a trace is running.
2392    self._is_tracing = False
2393
2394    # Setting `profile_batch=0` disables profiling.
2395    self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
2396
2397  def on_train_begin(self, logs=None):
2398    self._global_train_batch = 0
2399    self._previous_epoch_iterations = 0
2400    self._train_accumulated_time = 0
2401    self._push_writer(self._train_writer, self._train_step)
2402
2403  def on_train_end(self, logs=None):
2404    self._pop_writer()
2405
2406    if self._is_tracing:
2407      self._stop_trace()
2408
2409    self._close_writers()
2410    self._delete_tmp_write_dir()
2411
2412  def on_test_begin(self, logs=None):
2413    self._push_writer(self._val_writer, self._val_step)
2414
2415  def on_test_end(self, logs=None):
2416    if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
2417      with summary_ops_v2.record_if(True), self._val_writer.as_default():
2418        for name, value in logs.items():
2419          summary_ops_v2.scalar(
2420              'evaluation_' + name + '_vs_iterations',
2421              value,
2422              step=self.model.optimizer.iterations.read_value())
2423    self._pop_writer()
2424
2425  def _implements_train_batch_hooks(self):
2426    # Only call batch hooks when tracing or write_steps_per_second are enabled
2427    return self._should_trace or self.write_steps_per_second
2428
2429  def on_train_batch_begin(self, batch, logs=None):
2430    self._global_train_batch += 1
2431    if self.write_steps_per_second:
2432      self._batch_start_time = time.time()
2433    if not self._should_trace:
2434      return
2435
2436    if self._global_train_batch == self._start_batch:
2437      self._start_trace()
2438
2439  def on_train_batch_end(self, batch, logs=None):
2440    if self._should_write_train_graph:
2441      self._write_keras_model_train_graph()
2442      self._should_write_train_graph = False
2443    if self.write_steps_per_second:
2444      batch_run_time = time.time() - self._batch_start_time
2445      self._train_accumulated_time += batch_run_time
2446      summary_ops_v2.scalar(
2447          'batch_steps_per_second', 1. / batch_run_time, step=self._train_step)
2448    if not self._should_trace:
2449      return
2450
2451    if self._is_tracing and self._global_train_batch >= self._stop_batch:
2452      self._stop_trace()
2453
2454  def on_epoch_begin(self, epoch, logs=None):
2455    # Keeps track of epoch for profiling.
2456    if self.write_steps_per_second:
2457      self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
2458      self._train_accumulated_time = 0
2459
2460  def on_epoch_end(self, epoch, logs=None):
2461    """Runs metrics and histogram summaries at epoch end."""
2462    self._log_epoch_metrics(epoch, logs)
2463
2464    if self.histogram_freq and epoch % self.histogram_freq == 0:
2465      self._log_weights(epoch)
2466
2467    if self.embeddings_freq and epoch % self.embeddings_freq == 0:
2468      self._log_embeddings(epoch)
2469
2470  def _start_trace(self):
2471    summary_ops_v2.trace_on(graph=True, profiler=False)
2472    self._start_profiler(logdir=self._train_dir)
2473    self._is_tracing = True
2474
2475  def _stop_trace(self, batch=None):
2476    """Logs the trace graph to TensorBoard."""
2477    if batch is None:
2478      batch = self._stop_batch
2479    with self._train_writer.as_default():
2480      with summary_ops_v2.record_if(True):
2481        # TODO(b/126388999): Remove step info in the summary name.
2482        summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch)
2483    self._stop_profiler()
2484    self._is_tracing = False
2485
2486  def _collect_learning_rate(self, logs):
2487    lr_schedule = getattr(self.model.optimizer, 'lr', None)
2488    if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
2489      logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
2490    return logs
2491
2492  def _compute_steps_per_second(self):
2493    current_iteration = self.model.optimizer.iterations.numpy()
2494    steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
2495                        (self._train_accumulated_time))
2496    return steps_per_second
2497
2498  def _log_epoch_metrics(self, epoch, logs):
2499    """Writes epoch metrics out as scalar summaries.
2500
2501    Args:
2502        epoch: Int. The global step to use for TensorBoard.
2503        logs: Dict. Keys are scalar summary names, values are scalars.
2504    """
2505    if not logs:
2506      return
2507
2508    train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
2509    val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
2510    train_logs = self._collect_learning_rate(train_logs)
2511    if self.write_steps_per_second:
2512      train_logs['steps_per_second'] = self._compute_steps_per_second()
2513
2514    with summary_ops_v2.record_if(True):
2515      if train_logs:
2516        with self._train_writer.as_default():
2517          for name, value in train_logs.items():
2518            summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2519      if val_logs:
2520        with self._val_writer.as_default():
2521          for name, value in val_logs.items():
2522            name = name[4:]  # Remove 'val_' prefix.
2523            summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2524
2525  def _log_weights(self, epoch):
2526    """Logs the weights of the Model to TensorBoard."""
2527    with self._train_writer.as_default():
2528      with summary_ops_v2.record_if(True):
2529        for layer in self.model.layers:
2530          for weight in layer.weights:
2531            weight_name = weight.name.replace(':', '_')
2532            summary_ops_v2.histogram(weight_name, weight, step=epoch)
2533            if self.write_images:
2534              self._log_weight_as_image(weight, weight_name, epoch)
2535        self._train_writer.flush()
2536
2537  def _log_weight_as_image(self, weight, weight_name, epoch):
2538    """Logs a weight as a TensorBoard image."""
2539    w_img = array_ops.squeeze(weight)
2540    shape = backend.int_shape(w_img)
2541    if len(shape) == 1:  # Bias case
2542      w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
2543    elif len(shape) == 2:  # Dense layer kernel case
2544      if shape[0] > shape[1]:
2545        w_img = array_ops.transpose(w_img)
2546        shape = backend.int_shape(w_img)
2547      w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
2548    elif len(shape) == 3:  # ConvNet case
2549      if backend.image_data_format() == 'channels_last':
2550        # Switch to channels_first to display every kernel as a separate
2551        # image.
2552        w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
2553        shape = backend.int_shape(w_img)
2554      w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
2555
2556    shape = backend.int_shape(w_img)
2557    # Not possible to handle 3D convnets etc.
2558    if len(shape) == 4 and shape[-1] in [1, 3, 4]:
2559      summary_ops_v2.image(weight_name, w_img, step=epoch)
2560
2561  def _log_embeddings(self, epoch):
2562    embeddings_ckpt = os.path.join(self._log_write_dir, 'train',
2563                                   'keras_embedding.ckpt-{}'.format(epoch))
2564    self.model.save_weights(embeddings_ckpt)
2565
2566  def _start_profiler(self, logdir):
2567    """Starts the profiler if currently inactive.
2568
2569    Args:
2570      logdir: Directory where profiler results will be saved.
2571    """
2572    if self._profiler_started:
2573      return
2574    try:
2575      profiler.start(logdir=logdir)
2576      self._profiler_started = True
2577    except errors.AlreadyExistsError as e:
2578      # Profiler errors should not be fatal.
2579      logging.error('Failed to start profiler: %s', e.message)
2580
2581  def _stop_profiler(self, save=True):
2582    """Stops the profiler if currently active.
2583
2584    Args:
2585      save: Whether to save the profiler results to TensorBoard.
2586    """
2587    if not self._profiler_started:
2588      return
2589    try:
2590      profiler.stop(save=save)
2591    except errors.UnavailableError as e:
2592      # Profiler errors should not be fatal.
2593      logging.error('Failed to stop profiler: %s', e.message)
2594    finally:
2595      self._profiler_started = False
2596
2597
2598@keras_export('keras.callbacks.ReduceLROnPlateau')
2599class ReduceLROnPlateau(Callback):
2600  """Reduce learning rate when a metric has stopped improving.
2601
2602  Models often benefit from reducing the learning rate by a factor
2603  of 2-10 once learning stagnates. This callback monitors a
2604  quantity and if no improvement is seen for a 'patience' number
2605  of epochs, the learning rate is reduced.
2606
2607  Example:
2608
2609  ```python
2610  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
2611                                patience=5, min_lr=0.001)
2612  model.fit(X_train, Y_train, callbacks=[reduce_lr])
2613  ```
2614
2615  Args:
2616      monitor: quantity to be monitored.
2617      factor: factor by which the learning rate will be reduced.
2618        `new_lr = lr * factor`.
2619      patience: number of epochs with no improvement after which learning rate
2620        will be reduced.
2621      verbose: int. 0: quiet, 1: update messages.
2622      mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
2623        the learning rate will be reduced when the
2624        quantity monitored has stopped decreasing; in `'max'` mode it will be
2625        reduced when the quantity monitored has stopped increasing; in `'auto'`
2626        mode, the direction is automatically inferred from the name of the
2627        monitored quantity.
2628      min_delta: threshold for measuring the new optimum, to only focus on
2629        significant changes.
2630      cooldown: number of epochs to wait before resuming normal operation after
2631        lr has been reduced.
2632      min_lr: lower bound on the learning rate.
2633  """
2634
2635  def __init__(self,
2636               monitor='val_loss',
2637               factor=0.1,
2638               patience=10,
2639               verbose=0,
2640               mode='auto',
2641               min_delta=1e-4,
2642               cooldown=0,
2643               min_lr=0,
2644               **kwargs):
2645    super(ReduceLROnPlateau, self).__init__()
2646
2647    self.monitor = monitor
2648    if factor >= 1.0:
2649      raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
2650    if 'epsilon' in kwargs:
2651      min_delta = kwargs.pop('epsilon')
2652      logging.warning('`epsilon` argument is deprecated and '
2653                      'will be removed, use `min_delta` instead.')
2654    self.factor = factor
2655    self.min_lr = min_lr
2656    self.min_delta = min_delta
2657    self.patience = patience
2658    self.verbose = verbose
2659    self.cooldown = cooldown
2660    self.cooldown_counter = 0  # Cooldown counter.
2661    self.wait = 0
2662    self.best = 0
2663    self.mode = mode
2664    self.monitor_op = None
2665    self._reset()
2666
2667  def _reset(self):
2668    """Resets wait counter and cooldown counter.
2669    """
2670    if self.mode not in ['auto', 'min', 'max']:
2671      logging.warning('Learning rate reduction mode %s is unknown, '
2672                      'fallback to auto mode.', self.mode)
2673      self.mode = 'auto'
2674    if (self.mode == 'min' or
2675        (self.mode == 'auto' and 'acc' not in self.monitor)):
2676      self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
2677      self.best = np.Inf
2678    else:
2679      self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
2680      self.best = -np.Inf
2681    self.cooldown_counter = 0
2682    self.wait = 0
2683
2684  def on_train_begin(self, logs=None):
2685    self._reset()
2686
2687  def on_epoch_end(self, epoch, logs=None):
2688    logs = logs or {}
2689    logs['lr'] = backend.get_value(self.model.optimizer.lr)
2690    current = logs.get(self.monitor)
2691    if current is None:
2692      logging.warning('Learning rate reduction is conditioned on metric `%s` '
2693                      'which is not available. Available metrics are: %s',
2694                      self.monitor, ','.join(list(logs.keys())))
2695
2696    else:
2697      if self.in_cooldown():
2698        self.cooldown_counter -= 1
2699        self.wait = 0
2700
2701      if self.monitor_op(current, self.best):
2702        self.best = current
2703        self.wait = 0
2704      elif not self.in_cooldown():
2705        self.wait += 1
2706        if self.wait >= self.patience:
2707          old_lr = backend.get_value(self.model.optimizer.lr)
2708          if old_lr > np.float32(self.min_lr):
2709            new_lr = old_lr * self.factor
2710            new_lr = max(new_lr, self.min_lr)
2711            backend.set_value(self.model.optimizer.lr, new_lr)
2712            if self.verbose > 0:
2713              print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
2714                    'rate to %s.' % (epoch + 1, new_lr))
2715            self.cooldown_counter = self.cooldown
2716            self.wait = 0
2717
2718  def in_cooldown(self):
2719    return self.cooldown_counter > 0
2720
2721
2722@keras_export('keras.callbacks.CSVLogger')
2723class CSVLogger(Callback):
2724  """Callback that streams epoch results to a CSV file.
2725
2726  Supports all values that can be represented as a string,
2727  including 1D iterables such as `np.ndarray`.
2728
2729  Example:
2730
2731  ```python
2732  csv_logger = CSVLogger('training.log')
2733  model.fit(X_train, Y_train, callbacks=[csv_logger])
2734  ```
2735
2736  Args:
2737      filename: Filename of the CSV file, e.g. `'run/log.csv'`.
2738      separator: String used to separate elements in the CSV file.
2739      append: Boolean. True: append if file exists (useful for continuing
2740          training). False: overwrite existing file.
2741  """
2742
2743  def __init__(self, filename, separator=',', append=False):
2744    self.sep = separator
2745    self.filename = path_to_string(filename)
2746    self.append = append
2747    self.writer = None
2748    self.keys = None
2749    self.append_header = True
2750    super(CSVLogger, self).__init__()
2751
2752  def on_train_begin(self, logs=None):
2753    if self.append:
2754      if file_io.file_exists_v2(self.filename):
2755        with gfile.GFile(self.filename, 'r') as f:
2756          self.append_header = not bool(len(f.readline()))
2757      mode = 'a'
2758    else:
2759      mode = 'w'
2760    self.csv_file = gfile.GFile(self.filename, mode)
2761
2762  def on_epoch_end(self, epoch, logs=None):
2763    logs = logs or {}
2764
2765    def handle_value(k):
2766      is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
2767      if isinstance(k, str):
2768        return k
2769      elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
2770        return '"[%s]"' % (', '.join(map(str, k)))
2771      else:
2772        return k
2773
2774    if self.keys is None:
2775      self.keys = sorted(logs.keys())
2776
2777    if self.model.stop_training:
2778      # We set NA so that csv parsers do not fail for this last epoch.
2779      logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys)
2780
2781    if not self.writer:
2782
2783      class CustomDialect(csv.excel):
2784        delimiter = self.sep
2785
2786      fieldnames = ['epoch'] + self.keys
2787
2788      self.writer = csv.DictWriter(
2789          self.csv_file,
2790          fieldnames=fieldnames,
2791          dialect=CustomDialect)
2792      if self.append_header:
2793        self.writer.writeheader()
2794
2795    row_dict = collections.OrderedDict({'epoch': epoch})
2796    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
2797    self.writer.writerow(row_dict)
2798    self.csv_file.flush()
2799
2800  def on_train_end(self, logs=None):
2801    self.csv_file.close()
2802    self.writer = None
2803
2804
2805@keras_export('keras.callbacks.LambdaCallback')
2806class LambdaCallback(Callback):
2807  r"""Callback for creating simple, custom callbacks on-the-fly.
2808
2809  This callback is constructed with anonymous functions that will be called
2810  at the appropriate time (during `Model.{fit | evaluate | predict}`).
2811  Note that the callbacks expects positional arguments, as:
2812
2813  - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
2814    `epoch`, `logs`
2815  - `on_batch_begin` and `on_batch_end` expect two positional arguments:
2816    `batch`, `logs`
2817  - `on_train_begin` and `on_train_end` expect one positional argument:
2818    `logs`
2819
2820  Args:
2821      on_epoch_begin: called at the beginning of every epoch.
2822      on_epoch_end: called at the end of every epoch.
2823      on_batch_begin: called at the beginning of every batch.
2824      on_batch_end: called at the end of every batch.
2825      on_train_begin: called at the beginning of model training.
2826      on_train_end: called at the end of model training.
2827
2828  Example:
2829
2830  ```python
2831  # Print the batch number at the beginning of every batch.
2832  batch_print_callback = LambdaCallback(
2833      on_batch_begin=lambda batch,logs: print(batch))
2834
2835  # Stream the epoch loss to a file in JSON format. The file content
2836  # is not well-formed JSON but rather has a JSON object per line.
2837  import json
2838  json_log = open('loss_log.json', mode='wt', buffering=1)
2839  json_logging_callback = LambdaCallback(
2840      on_epoch_end=lambda epoch, logs: json_log.write(
2841          json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
2842      on_train_end=lambda logs: json_log.close()
2843  )
2844
2845  # Terminate some processes after having finished model training.
2846  processes = ...
2847  cleanup_callback = LambdaCallback(
2848      on_train_end=lambda logs: [
2849          p.terminate() for p in processes if p.is_alive()])
2850
2851  model.fit(...,
2852            callbacks=[batch_print_callback,
2853                       json_logging_callback,
2854                       cleanup_callback])
2855  ```
2856  """
2857
2858  def __init__(self,
2859               on_epoch_begin=None,
2860               on_epoch_end=None,
2861               on_batch_begin=None,
2862               on_batch_end=None,
2863               on_train_begin=None,
2864               on_train_end=None,
2865               **kwargs):
2866    super(LambdaCallback, self).__init__()
2867    self.__dict__.update(kwargs)
2868    if on_epoch_begin is not None:
2869      self.on_epoch_begin = on_epoch_begin
2870    else:
2871      self.on_epoch_begin = lambda epoch, logs: None
2872    if on_epoch_end is not None:
2873      self.on_epoch_end = on_epoch_end
2874    else:
2875      self.on_epoch_end = lambda epoch, logs: None
2876    if on_batch_begin is not None:
2877      self.on_batch_begin = on_batch_begin
2878    else:
2879      self.on_batch_begin = lambda batch, logs: None
2880    if on_batch_end is not None:
2881      self.on_batch_end = on_batch_end
2882    else:
2883      self.on_batch_end = lambda batch, logs: None
2884    if on_train_begin is not None:
2885      self.on_train_begin = on_train_begin
2886    else:
2887      self.on_train_begin = lambda logs: None
2888    if on_train_end is not None:
2889      self.on_train_end = on_train_end
2890    else:
2891      self.on_train_end = lambda logs: None
2892