xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various learning rate decay functions."""
16
17import abc
18import math
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import ops
22from tensorflow.python.keras.utils import generic_utils
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import random_ops
27from tensorflow.python.util import nest
28from tensorflow.python.util.tf_export import keras_export
29
30
31@keras_export("keras.optimizers.schedules.LearningRateSchedule")
32class LearningRateSchedule(object):
33  """The learning rate schedule base class.
34
35  You can use a learning rate schedule to modulate how the learning rate
36  of your optimizer changes over time.
37
38  Several built-in learning rate schedules are available, such as
39  `tf.keras.optimizers.schedules.ExponentialDecay` or
40  `tf.keras.optimizers.schedules.PiecewiseConstantDecay`:
41
42  ```python
43  lr_schedule = keras.optimizers.schedules.ExponentialDecay(
44      initial_learning_rate=1e-2,
45      decay_steps=10000,
46      decay_rate=0.9)
47  optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)
48  ```
49
50  A `LearningRateSchedule` instance can be passed in as the `learning_rate`
51  argument of any optimizer.
52
53  To implement your own schedule object, you should implement the `__call__`
54  method, which takes a `step` argument (scalar integer tensor, the
55  current training step count).
56  Like for any other Keras object, you can also optionally
57  make your object serializable by implementing the `get_config`
58  and `from_config` methods.
59
60  Example:
61
62  ```python
63  class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
64
65    def __init__(self, initial_learning_rate):
66      self.initial_learning_rate = initial_learning_rate
67
68    def __call__(self, step):
69       return self.initial_learning_rate / (step + 1)
70
71  optimizer = tf.keras.optimizers.SGD(learning_rate=MyLRSchedule(0.1))
72  ```
73  """
74
75  @abc.abstractmethod
76  def __call__(self, step):
77    raise NotImplementedError("Learning rate schedule must override __call__")
78
79  @abc.abstractmethod
80  def get_config(self):
81    raise NotImplementedError("Learning rate schedule must override get_config")
82
83  @classmethod
84  def from_config(cls, config):
85    """Instantiates a `LearningRateSchedule` from its config.
86
87    Args:
88        config: Output of `get_config()`.
89
90    Returns:
91        A `LearningRateSchedule` instance.
92    """
93    return cls(**config)
94
95
96@keras_export("keras.optimizers.schedules.ExponentialDecay")
97class ExponentialDecay(LearningRateSchedule):
98  """A LearningRateSchedule that uses an exponential decay schedule.
99
100  When training a model, it is often useful to lower the learning rate as
101  the training progresses. This schedule applies an exponential decay function
102  to an optimizer step, given a provided initial learning rate.
103
104  The schedule a 1-arg callable that produces a decayed learning
105  rate when passed the current optimizer step. This can be useful for changing
106  the learning rate value across different invocations of optimizer functions.
107  It is computed as:
108
109  ```python
110  def decayed_learning_rate(step):
111    return initial_learning_rate * decay_rate ^ (step / decay_steps)
112  ```
113
114  If the argument `staircase` is `True`, then `step / decay_steps` is
115  an integer division and the decayed learning rate follows a
116  staircase function.
117
118  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
119  as the learning rate.
120  Example: When fitting a Keras model, decay every 100000 steps with a base
121  of 0.96:
122
123  ```python
124  initial_learning_rate = 0.1
125  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
126      initial_learning_rate,
127      decay_steps=100000,
128      decay_rate=0.96,
129      staircase=True)
130
131  model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
132                loss='sparse_categorical_crossentropy',
133                metrics=['accuracy'])
134
135  model.fit(data, labels, epochs=5)
136  ```
137
138  The learning rate schedule is also serializable and deserializable using
139  `tf.keras.optimizers.schedules.serialize` and
140  `tf.keras.optimizers.schedules.deserialize`.
141
142  Returns:
143    A 1-arg callable learning rate schedule that takes the current optimizer
144    step and outputs the decayed learning rate, a scalar `Tensor` of the same
145    type as `initial_learning_rate`.
146  """
147
148  def __init__(
149      self,
150      initial_learning_rate,
151      decay_steps,
152      decay_rate,
153      staircase=False,
154      name=None):
155    """Applies exponential decay to the learning rate.
156
157    Args:
158      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
159        Python number.  The initial learning rate.
160      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
161        Must be positive.  See the decay computation above.
162      decay_rate: A scalar `float32` or `float64` `Tensor` or a
163        Python number.  The decay rate.
164      staircase: Boolean.  If `True` decay the learning rate at discrete
165        intervals
166      name: String.  Optional name of the operation.  Defaults to
167        'ExponentialDecay'.
168    """
169    super(ExponentialDecay, self).__init__()
170    self.initial_learning_rate = initial_learning_rate
171    self.decay_steps = decay_steps
172    self.decay_rate = decay_rate
173    self.staircase = staircase
174    self.name = name
175
176  def __call__(self, step):
177    with ops.name_scope_v2(self.name or "ExponentialDecay") as name:
178      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
179          self.initial_learning_rate, name="initial_learning_rate")
180      dtype = initial_learning_rate.dtype
181      decay_steps = math_ops.cast(self.decay_steps, dtype)
182      decay_rate = math_ops.cast(self.decay_rate, dtype)
183
184      global_step_recomp = math_ops.cast(step, dtype)
185      p = global_step_recomp / decay_steps
186      if self.staircase:
187        p = math_ops.floor(p)
188      return math_ops.multiply(
189          initial_learning_rate, math_ops.pow(decay_rate, p), name=name)
190
191  def get_config(self):
192    return {
193        "initial_learning_rate": self.initial_learning_rate,
194        "decay_steps": self.decay_steps,
195        "decay_rate": self.decay_rate,
196        "staircase": self.staircase,
197        "name": self.name
198    }
199
200
201@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
202class PiecewiseConstantDecay(LearningRateSchedule):
203  """A LearningRateSchedule that uses a piecewise constant decay schedule.
204
205  The function returns a 1-arg callable to compute the piecewise constant
206  when passed the current optimizer step. This can be useful for changing the
207  learning rate value across different invocations of optimizer functions.
208
209  Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
210    for the next 10000 steps, and 0.1 for any additional steps.
211
212  ```python
213  step = tf.Variable(0, trainable=False)
214  boundaries = [100000, 110000]
215  values = [1.0, 0.5, 0.1]
216  learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
217      boundaries, values)
218
219  # Later, whenever we perform an optimization step, we pass in the step.
220  learning_rate = learning_rate_fn(step)
221  ```
222
223  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
224  as the learning rate. The learning rate schedule is also serializable and
225  deserializable using `tf.keras.optimizers.schedules.serialize` and
226  `tf.keras.optimizers.schedules.deserialize`.
227
228  Returns:
229    A 1-arg callable learning rate schedule that takes the current optimizer
230    step and outputs the decayed learning rate, a scalar `Tensor` of the same
231    type as the boundary tensors.
232
233    The output of the 1-arg function that takes the `step`
234    is `values[0]` when `step <= boundaries[0]`,
235    `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
236    and values[-1] when `step > boundaries[-1]`.
237  """
238
239  def __init__(
240      self,
241      boundaries,
242      values,
243      name=None):
244    """Piecewise constant from boundaries and interval values.
245
246    Args:
247      boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
248        increasing entries, and with all elements having the same type as the
249        optimizer step.
250      values: A list of `Tensor`s or `float`s or `int`s that specifies the
251        values for the intervals defined by `boundaries`. It should have one
252        more element than `boundaries`, and all elements should have the same
253        type.
254      name: A string. Optional name of the operation. Defaults to
255        'PiecewiseConstant'.
256
257    Raises:
258      ValueError: if the number of elements in the lists do not match.
259    """
260    super(PiecewiseConstantDecay, self).__init__()
261
262    if len(boundaries) != len(values) - 1:
263      raise ValueError(
264          "The length of boundaries should be 1 less than the length of values")
265
266    self.boundaries = boundaries
267    self.values = values
268    self.name = name
269
270  def __call__(self, step):
271    with ops.name_scope_v2(self.name or "PiecewiseConstant"):
272      boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
273                                      nest.flatten(self.boundaries))
274      values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
275                                  nest.flatten(self.values))
276      x_recomp = ops.convert_to_tensor_v2_with_dispatch(step)
277      for i, b in enumerate(boundaries):
278        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
279          # We cast the boundaries to have the same type as the step
280          b = math_ops.cast(b, x_recomp.dtype.base_dtype)
281          boundaries[i] = b
282      pred_fn_pairs = []
283      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
284      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
285      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
286        # Need to bind v here; can do this with lambda v=v: ...
287        pred = (x_recomp > low) & (x_recomp <= high)
288        pred_fn_pairs.append((pred, lambda v=v: v))
289
290      # The default isn't needed here because our conditions are mutually
291      # exclusive and exhaustive, but tf.case requires it.
292      default = lambda: values[0]
293      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
294
295  def get_config(self):
296    return {
297        "boundaries": self.boundaries,
298        "values": self.values,
299        "name": self.name
300    }
301
302
303@keras_export("keras.optimizers.schedules.PolynomialDecay")
304class PolynomialDecay(LearningRateSchedule):
305  """A LearningRateSchedule that uses a polynomial decay schedule.
306
307  It is commonly observed that a monotonically decreasing learning rate, whose
308  degree of change is carefully chosen, results in a better performing model.
309  This schedule applies a polynomial decay function to an optimizer step,
310  given a provided `initial_learning_rate`, to reach an `end_learning_rate`
311  in the given `decay_steps`.
312
313  It requires a `step` value to compute the decayed learning rate. You
314  can just pass a TensorFlow variable that you increment at each training
315  step.
316
317  The schedule is a 1-arg callable that produces a decayed learning rate
318  when passed the current optimizer step. This can be useful for changing the
319  learning rate value across different invocations of optimizer functions.
320  It is computed as:
321
322  ```python
323  def decayed_learning_rate(step):
324    step = min(step, decay_steps)
325    return ((initial_learning_rate - end_learning_rate) *
326            (1 - step / decay_steps) ^ (power)
327           ) + end_learning_rate
328  ```
329
330  If `cycle` is True then a multiple of `decay_steps` is used, the first one
331  that is bigger than `step`.
332
333  ```python
334  def decayed_learning_rate(step):
335    decay_steps = decay_steps * ceil(step / decay_steps)
336    return ((initial_learning_rate - end_learning_rate) *
337            (1 - step / decay_steps) ^ (power)
338           ) + end_learning_rate
339  ```
340
341  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
342  as the learning rate.
343  Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
344  sqrt (i.e. power=0.5):
345
346  ```python
347  ...
348  starter_learning_rate = 0.1
349  end_learning_rate = 0.01
350  decay_steps = 10000
351  learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
352      starter_learning_rate,
353      decay_steps,
354      end_learning_rate,
355      power=0.5)
356
357  model.compile(optimizer=tf.keras.optimizers.SGD(
358                    learning_rate=learning_rate_fn),
359                loss='sparse_categorical_crossentropy',
360                metrics=['accuracy'])
361
362  model.fit(data, labels, epochs=5)
363  ```
364
365  The learning rate schedule is also serializable and deserializable using
366  `tf.keras.optimizers.schedules.serialize` and
367  `tf.keras.optimizers.schedules.deserialize`.
368
369  Returns:
370    A 1-arg callable learning rate schedule that takes the current optimizer
371    step and outputs the decayed learning rate, a scalar `Tensor` of the same
372    type as `initial_learning_rate`.
373  """
374
375  def __init__(
376      self,
377      initial_learning_rate,
378      decay_steps,
379      end_learning_rate=0.0001,
380      power=1.0,
381      cycle=False,
382      name=None):
383    """Applies a polynomial decay to the learning rate.
384
385    Args:
386      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
387        Python number.  The initial learning rate.
388      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
389        Must be positive.  See the decay computation above.
390      end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
391        Python number.  The minimal end learning rate.
392      power: A scalar `float32` or `float64` `Tensor` or a
393        Python number.  The power of the polynomial. Defaults to linear, 1.0.
394      cycle: A boolean, whether or not it should cycle beyond decay_steps.
395      name: String.  Optional name of the operation. Defaults to
396        'PolynomialDecay'.
397    """
398    super(PolynomialDecay, self).__init__()
399
400    self.initial_learning_rate = initial_learning_rate
401    self.decay_steps = decay_steps
402    self.end_learning_rate = end_learning_rate
403    self.power = power
404    self.cycle = cycle
405    self.name = name
406
407  def __call__(self, step):
408    with ops.name_scope_v2(self.name or "PolynomialDecay") as name:
409      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
410          self.initial_learning_rate, name="initial_learning_rate")
411      dtype = initial_learning_rate.dtype
412      end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
413      power = math_ops.cast(self.power, dtype)
414
415      global_step_recomp = math_ops.cast(step, dtype)
416      decay_steps_recomp = math_ops.cast(self.decay_steps, dtype)
417      if self.cycle:
418        # Find the first multiple of decay_steps that is bigger than
419        # global_step. If global_step is zero set the multiplier to 1
420        multiplier = array_ops.where_v2(
421            math_ops.equal(global_step_recomp, 0), 1.0,
422            math_ops.ceil(global_step_recomp / self.decay_steps))
423        decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
424      else:
425        # Make sure that the global_step used is not bigger than decay_steps.
426        global_step_recomp = math_ops.minimum(global_step_recomp,
427                                              decay_steps_recomp)
428
429      p = math_ops.divide(global_step_recomp, decay_steps_recomp)
430      return math_ops.add(
431          math_ops.multiply(initial_learning_rate - end_learning_rate,
432                            math_ops.pow(1 - p, power)),
433          end_learning_rate,
434          name=name)
435
436  def get_config(self):
437    return {
438        "initial_learning_rate": self.initial_learning_rate,
439        "decay_steps": self.decay_steps,
440        "end_learning_rate": self.end_learning_rate,
441        "power": self.power,
442        "cycle": self.cycle,
443        "name": self.name
444    }
445
446
447@keras_export("keras.optimizers.schedules.InverseTimeDecay")
448class InverseTimeDecay(LearningRateSchedule):
449  """A LearningRateSchedule that uses an inverse time decay schedule.
450
451  When training a model, it is often useful to lower the learning rate as
452  the training progresses. This schedule applies the inverse decay function
453  to an optimizer step, given a provided initial learning rate.
454  It requires a `step` value to compute the decayed learning rate. You can
455  just pass a TensorFlow variable that you increment at each training step.
456
457  The schedule a 1-arg callable that produces a decayed learning
458  rate when passed the current optimizer step. This can be useful for changing
459  the learning rate value across different invocations of optimizer functions.
460  It is computed as:
461
462  ```python
463  def decayed_learning_rate(step):
464    return initial_learning_rate / (1 + decay_rate * step / decay_step)
465  ```
466
467  or, if `staircase` is `True`, as:
468
469  ```python
470  def decayed_learning_rate(step):
471    return initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
472  ```
473
474  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
475  as the learning rate.
476  Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
477
478  ```python
479  ...
480  initial_learning_rate = 0.1
481  decay_steps = 1.0
482  decay_rate = 0.5
483  learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
484    initial_learning_rate, decay_steps, decay_rate)
485
486  model.compile(optimizer=tf.keras.optimizers.SGD(
487                    learning_rate=learning_rate_fn),
488                loss='sparse_categorical_crossentropy',
489                metrics=['accuracy'])
490
491  model.fit(data, labels, epochs=5)
492  ```
493
494  Returns:
495    A 1-arg callable learning rate schedule that takes the current optimizer
496    step and outputs the decayed learning rate, a scalar `Tensor` of the same
497    type as `initial_learning_rate`.
498  """
499
500  def __init__(
501      self,
502      initial_learning_rate,
503      decay_steps,
504      decay_rate,
505      staircase=False,
506      name=None):
507    """Applies inverse time decay to the initial learning rate.
508
509    Args:
510      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
511        Python number.  The initial learning rate.
512      decay_steps: How often to apply decay.
513      decay_rate: A Python number.  The decay rate.
514      staircase: Whether to apply decay in a discrete staircase, as opposed to
515        continuous, fashion.
516      name: String.  Optional name of the operation.  Defaults to
517        'InverseTimeDecay'.
518    """
519    super(InverseTimeDecay, self).__init__()
520
521    self.initial_learning_rate = initial_learning_rate
522    self.decay_steps = decay_steps
523    self.decay_rate = decay_rate
524    self.staircase = staircase
525    self.name = name
526
527  def __call__(self, step):
528    with ops.name_scope_v2(self.name or "InverseTimeDecay") as name:
529      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
530          self.initial_learning_rate, name="initial_learning_rate")
531      dtype = initial_learning_rate.dtype
532      decay_steps = math_ops.cast(self.decay_steps, dtype)
533      decay_rate = math_ops.cast(self.decay_rate, dtype)
534
535      global_step_recomp = math_ops.cast(step, dtype)
536      p = global_step_recomp / decay_steps
537      if self.staircase:
538        p = math_ops.floor(p)
539      const = math_ops.cast(constant_op.constant(1), dtype)
540      denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
541      return math_ops.divide(initial_learning_rate, denom, name=name)
542
543  def get_config(self):
544    return {
545        "initial_learning_rate": self.initial_learning_rate,
546        "decay_steps": self.decay_steps,
547        "decay_rate": self.decay_rate,
548        "staircase": self.staircase,
549        "name": self.name
550    }
551
552
553@keras_export("keras.optimizers.schedules.CosineDecay",
554              "keras.experimental.CosineDecay")
555class CosineDecay(LearningRateSchedule):
556  """A LearningRateSchedule that uses a cosine decay schedule.
557
558  See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
559  SGDR: Stochastic Gradient Descent with Warm Restarts.
560
561  When training a model, it is often useful to lower the learning rate as
562  the training progresses. This schedule applies a cosine decay function
563  to an optimizer step, given a provided initial learning rate.
564  It requires a `step` value to compute the decayed learning rate. You can
565  just pass a TensorFlow variable that you increment at each training step.
566
567  The schedule a 1-arg callable that produces a decayed learning
568  rate when passed the current optimizer step. This can be useful for changing
569  the learning rate value across different invocations of optimizer functions.
570  It is computed as:
571
572  ```python
573  def decayed_learning_rate(step):
574    step = min(step, decay_steps)
575    cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
576    decayed = (1 - alpha) * cosine_decay + alpha
577    return initial_learning_rate * decayed
578  ```
579
580  Example usage:
581  ```python
582  decay_steps = 1000
583  lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
584      initial_learning_rate, decay_steps)
585  ```
586
587  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
588  as the learning rate. The learning rate schedule is also serializable and
589  deserializable using `tf.keras.optimizers.schedules.serialize` and
590  `tf.keras.optimizers.schedules.deserialize`.
591
592  Returns:
593    A 1-arg callable learning rate schedule that takes the current optimizer
594    step and outputs the decayed learning rate, a scalar `Tensor` of the same
595    type as `initial_learning_rate`.
596  """
597
598  def __init__(
599      self,
600      initial_learning_rate,
601      decay_steps,
602      alpha=0.0,
603      name=None):
604    """Applies cosine decay to the learning rate.
605
606    Args:
607      initial_learning_rate: A scalar `float32` or `float64` Tensor or a
608        Python number. The initial learning rate.
609      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
610        Number of steps to decay over.
611      alpha: A scalar `float32` or `float64` Tensor or a Python number.
612        Minimum learning rate value as a fraction of initial_learning_rate.
613      name: String. Optional name of the operation.  Defaults to 'CosineDecay'.
614    """
615    super(CosineDecay, self).__init__()
616
617    self.initial_learning_rate = initial_learning_rate
618    self.decay_steps = decay_steps
619    self.alpha = alpha
620    self.name = name
621
622  def __call__(self, step):
623    with ops.name_scope_v2(self.name or "CosineDecay"):
624      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
625          self.initial_learning_rate, name="initial_learning_rate")
626      dtype = initial_learning_rate.dtype
627      decay_steps = math_ops.cast(self.decay_steps, dtype)
628
629      global_step_recomp = math_ops.cast(step, dtype)
630      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
631      completed_fraction = global_step_recomp / decay_steps
632      cosine_decayed = 0.5 * (1.0 + math_ops.cos(
633          constant_op.constant(math.pi) * completed_fraction))
634
635      decayed = (1 - self.alpha) * cosine_decayed + self.alpha
636      return math_ops.multiply(initial_learning_rate, decayed)
637
638  def get_config(self):
639    return {
640        "initial_learning_rate": self.initial_learning_rate,
641        "decay_steps": self.decay_steps,
642        "alpha": self.alpha,
643        "name": self.name
644    }
645
646
647@keras_export("keras.optimizers.schedules.CosineDecayRestarts",
648              "keras.experimental.CosineDecayRestarts")
649class CosineDecayRestarts(LearningRateSchedule):
650  """A LearningRateSchedule that uses a cosine decay schedule with restarts.
651
652  See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
653  SGDR: Stochastic Gradient Descent with Warm Restarts.
654
655  When training a model, it is often useful to lower the learning rate as
656  the training progresses. This schedule applies a cosine decay function with
657  restarts to an optimizer step, given a provided initial learning rate.
658  It requires a `step` value to compute the decayed learning rate. You can
659  just pass a TensorFlow variable that you increment at each training step.
660
661  The schedule a 1-arg callable that produces a decayed learning
662  rate when passed the current optimizer step. This can be useful for changing
663  the learning rate value across different invocations of optimizer functions.
664
665  The learning rate multiplier first decays
666  from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
667  restart is performed. Each new warm restart runs for `t_mul` times more
668  steps and with `m_mul` times smaller initial learning rate.
669
670  Example usage:
671  ```python
672  first_decay_steps = 1000
673  lr_decayed_fn = (
674    tf.keras.optimizers.schedules.CosineDecayRestarts(
675        initial_learning_rate,
676        first_decay_steps))
677  ```
678
679  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
680  as the learning rate. The learning rate schedule is also serializable and
681  deserializable using `tf.keras.optimizers.schedules.serialize` and
682  `tf.keras.optimizers.schedules.deserialize`.
683
684  Returns:
685    A 1-arg callable learning rate schedule that takes the current optimizer
686    step and outputs the decayed learning rate, a scalar `Tensor` of the same
687    type as `initial_learning_rate`.
688  """
689
690  def __init__(
691      self,
692      initial_learning_rate,
693      first_decay_steps,
694      t_mul=2.0,
695      m_mul=1.0,
696      alpha=0.0,
697      name=None):
698    """Applies cosine decay with restarts to the learning rate.
699
700    Args:
701      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
702        number. The initial learning rate.
703      first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python
704        number. Number of steps to decay over.
705      t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
706        Used to derive the number of iterations in the i-th period
707      m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
708        Used to derive the initial learning rate of the i-th period:
709      alpha: A scalar `float32` or `float64` Tensor or a Python number.
710        Minimum learning rate value as a fraction of the initial_learning_rate.
711      name: String. Optional name of the operation.  Defaults to 'SGDRDecay'.
712    """
713    super(CosineDecayRestarts, self).__init__()
714
715    self.initial_learning_rate = initial_learning_rate
716    self.first_decay_steps = first_decay_steps
717    self._t_mul = t_mul
718    self._m_mul = m_mul
719    self.alpha = alpha
720    self.name = name
721
722  def __call__(self, step):
723    with ops.name_scope_v2(self.name or "SGDRDecay") as name:
724      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
725          self.initial_learning_rate, name="initial_learning_rate")
726      dtype = initial_learning_rate.dtype
727      first_decay_steps = math_ops.cast(self.first_decay_steps, dtype)
728      alpha = math_ops.cast(self.alpha, dtype)
729      t_mul = math_ops.cast(self._t_mul, dtype)
730      m_mul = math_ops.cast(self._m_mul, dtype)
731
732      global_step_recomp = math_ops.cast(step, dtype)
733      completed_fraction = global_step_recomp / first_decay_steps
734
735      def compute_step(completed_fraction, geometric=False):
736        """Helper for `cond` operation."""
737        if geometric:
738          i_restart = math_ops.floor(
739              math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
740              math_ops.log(t_mul))
741
742          sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
743          completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
744
745        else:
746          i_restart = math_ops.floor(completed_fraction)
747          completed_fraction -= i_restart
748
749        return i_restart, completed_fraction
750
751      i_restart, completed_fraction = control_flow_ops.cond(
752          math_ops.equal(t_mul, 1.0),
753          lambda: compute_step(completed_fraction, geometric=False),
754          lambda: compute_step(completed_fraction, geometric=True))
755
756      m_fac = m_mul**i_restart
757      cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
758          constant_op.constant(math.pi) * completed_fraction))
759      decayed = (1 - alpha) * cosine_decayed + alpha
760
761      return math_ops.multiply(initial_learning_rate, decayed, name=name)
762
763  def get_config(self):
764    return {
765        "initial_learning_rate": self.initial_learning_rate,
766        "first_decay_steps": self.first_decay_steps,
767        "t_mul": self._t_mul,
768        "m_mul": self._m_mul,
769        "alpha": self.alpha,
770        "name": self.name
771    }
772
773
774# Note: this code is still used by V1 APIs.
775class LinearCosineDecay(LearningRateSchedule):
776  """A LearningRateSchedule that uses a linear cosine decay schedule.
777
778  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
779  https://arxiv.org/abs/1709.07417
780
781  For the idea of warm starts here controlled by `num_periods`,
782  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
783  with Warm Restarts. https://arxiv.org/abs/1608.03983
784
785  Note that linear cosine decay is more aggressive than cosine decay and
786  larger initial learning rates can typically be used.
787
788  When training a model, it is often recommended to lower the learning rate as
789  the training progresses. This schedule applies a linear cosine decay
790  function to an optimizer step, given a provided initial learning rate.
791  It requires a `step` value to compute the decayed learning rate. You can
792  just pass a TensorFlow variable that you increment at each training step.
793
794  The schedule a 1-arg callable that produces a decayed learning
795  rate when passed the current optimizer step. This can be useful for changing
796  the learning rate value across different invocations of optimizer functions.
797  It is computed as:
798
799  ```python
800  def decayed_learning_rate(step):
801    step = min(step, decay_steps)
802    linear_decay = (decay_steps - step) / decay_steps
803    cosine_decay = 0.5 * (
804        1 + cos(pi * 2 * num_periods * step / decay_steps))
805    decayed = (alpha + linear_decay) * cosine_decay + beta
806    return initial_learning_rate * decayed
807  ```
808
809  Example usage:
810  ```python
811  decay_steps = 1000
812  lr_decayed_fn = (
813    tf.keras.experimental.LinearCosineDecay(
814      initial_learning_rate, decay_steps))
815  ```
816
817  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
818  as the learning rate. The learning rate schedule is also serializable and
819  deserializable using `tf.keras.optimizers.schedules.serialize` and
820  `tf.keras.optimizers.schedules.deserialize`.
821
822  Returns:
823    A 1-arg callable learning rate schedule that takes the current optimizer
824    step and outputs the decayed learning rate, a scalar `Tensor` of the same
825    type as `initial_learning_rate`.
826  """
827
828  def __init__(
829      self,
830      initial_learning_rate,
831      decay_steps,
832      num_periods=0.5,
833      alpha=0.0,
834      beta=0.001,
835      name=None):
836    """Applies linear cosine decay to the learning rate.
837
838    Args:
839      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
840        number. The initial learning rate.
841      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
842        Number of steps to decay over.
843      num_periods: Number of periods in the cosine part of the decay.
844        See computation above.
845      alpha: See computation above.
846      beta: See computation above.
847      name: String.  Optional name of the operation.  Defaults to
848        'LinearCosineDecay'.
849    """
850    super(LinearCosineDecay, self).__init__()
851
852    self.initial_learning_rate = initial_learning_rate
853    self.decay_steps = decay_steps
854    self.num_periods = num_periods
855    self.alpha = alpha
856    self.beta = beta
857    self.name = name
858
859  def __call__(self, step):
860    with ops.name_scope_v2(self.name or "LinearCosineDecay") as name:
861      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
862          self.initial_learning_rate, name="initial_learning_rate")
863      dtype = initial_learning_rate.dtype
864      decay_steps = math_ops.cast(self.decay_steps, dtype)
865      num_periods = math_ops.cast(self.num_periods, dtype)
866      alpha = math_ops.cast(self.alpha, dtype)
867      beta = math_ops.cast(self.beta, dtype)
868
869      global_step_recomp = math_ops.cast(step, dtype)
870      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
871      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
872      completed_fraction = global_step_recomp / decay_steps
873      fraction = 2.0 * num_periods * completed_fraction
874      cosine_decayed = 0.5 * (
875          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
876
877      linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
878      return math_ops.multiply(initial_learning_rate, linear_cosine_decayed,
879                               name=name)
880
881  def get_config(self):
882    return {
883        "initial_learning_rate": self.initial_learning_rate,
884        "decay_steps": self.decay_steps,
885        "num_periods": self.num_periods,
886        "alpha": self.alpha,
887        "beta": self.beta,
888        "name": self.name
889    }
890
891
892# Note: this code is still used by V1 APIs.
893class NoisyLinearCosineDecay(LearningRateSchedule):
894  """A LearningRateSchedule that uses a noisy linear cosine decay schedule.
895
896  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
897  https://arxiv.org/abs/1709.07417
898
899  For the idea of warm starts here controlled by `num_periods`,
900  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
901  with Warm Restarts. https://arxiv.org/abs/1608.03983
902
903  Note that linear cosine decay is more aggressive than cosine decay and
904  larger initial learning rates can typically be used.
905
906  When training a model, it is often recommended to lower the learning rate as
907  the training progresses. This schedule applies a noisy linear cosine decay
908  function to an optimizer step, given a provided initial learning rate.
909  It requires a `step` value to compute the decayed learning rate. You can
910  just pass a TensorFlow variable that you increment at each training step.
911
912  The schedule a 1-arg callable that produces a decayed learning
913  rate when passed the current optimizer step. This can be useful for changing
914  the learning rate value across different invocations of optimizer functions.
915  It is computed as:
916
917  ```python
918  def decayed_learning_rate(step):
919    step = min(step, decay_steps)
920    linear_decay = (decay_steps - step) / decay_steps)
921    cosine_decay = 0.5 * (
922        1 + cos(pi * 2 * num_periods * step / decay_steps))
923    decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
924    return initial_learning_rate * decayed
925  ```
926  where eps_t is 0-centered gaussian noise with variance
927  initial_variance / (1 + global_step) ** variance_decay
928
929  Example usage:
930  ```python
931  decay_steps = 1000
932  lr_decayed_fn = (
933    tf.keras.experimental.NoisyLinearCosineDecay(
934      initial_learning_rate, decay_steps))
935  ```
936
937  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
938  as the learning rate. The learning rate schedule is also serializable and
939  deserializable using `tf.keras.optimizers.schedules.serialize` and
940  `tf.keras.optimizers.schedules.deserialize`.
941
942  Returns:
943    A 1-arg callable learning rate schedule that takes the current optimizer
944    step and outputs the decayed learning rate, a scalar `Tensor` of the same
945    type as `initial_learning_rate`.
946  """
947
948  def __init__(
949      self,
950      initial_learning_rate,
951      decay_steps,
952      initial_variance=1.0,
953      variance_decay=0.55,
954      num_periods=0.5,
955      alpha=0.0,
956      beta=0.001,
957      name=None):
958    """Applies noisy linear cosine decay to the learning rate.
959
960    Args:
961      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
962        number. The initial learning rate.
963      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
964        Number of steps to decay over.
965      initial_variance: initial variance for the noise. See computation above.
966      variance_decay: decay for the noise's variance. See computation above.
967      num_periods: Number of periods in the cosine part of the decay.
968        See computation above.
969      alpha: See computation above.
970      beta: See computation above.
971      name: String.  Optional name of the operation.  Defaults to
972        'NoisyLinearCosineDecay'.
973    """
974    super(NoisyLinearCosineDecay, self).__init__()
975
976    self.initial_learning_rate = initial_learning_rate
977    self.decay_steps = decay_steps
978    self.initial_variance = initial_variance
979    self.variance_decay = variance_decay
980    self.num_periods = num_periods
981    self.alpha = alpha
982    self.beta = beta
983    self.name = name
984
985  def __call__(self, step):
986    with ops.name_scope_v2(self.name or "NoisyLinearCosineDecay") as name:
987      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
988          self.initial_learning_rate, name="initial_learning_rate")
989      dtype = initial_learning_rate.dtype
990      decay_steps = math_ops.cast(self.decay_steps, dtype)
991      initial_variance = math_ops.cast(self.initial_variance, dtype)
992      variance_decay = math_ops.cast(self.variance_decay, dtype)
993      num_periods = math_ops.cast(self.num_periods, dtype)
994      alpha = math_ops.cast(self.alpha, dtype)
995      beta = math_ops.cast(self.beta, dtype)
996
997      global_step_recomp = math_ops.cast(step, dtype)
998      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
999      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
1000      variance = initial_variance / (
1001          math_ops.pow(1.0 + global_step_recomp, variance_decay))
1002      std = math_ops.sqrt(variance)
1003      noisy_linear_decayed = (
1004          linear_decayed + random_ops.random_normal(
1005              linear_decayed.shape, stddev=std))
1006
1007      completed_fraction = global_step_recomp / decay_steps
1008      fraction = 2.0 * num_periods * completed_fraction
1009      cosine_decayed = 0.5 * (
1010          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
1011      noisy_linear_cosine_decayed = (
1012          (alpha + noisy_linear_decayed) * cosine_decayed + beta)
1013
1014      return math_ops.multiply(
1015          initial_learning_rate, noisy_linear_cosine_decayed, name=name)
1016
1017  def get_config(self):
1018    return {
1019        "initial_learning_rate": self.initial_learning_rate,
1020        "decay_steps": self.decay_steps,
1021        "initial_variance": self.initial_variance,
1022        "variance_decay": self.variance_decay,
1023        "num_periods": self.num_periods,
1024        "alpha": self.alpha,
1025        "beta": self.beta,
1026        "name": self.name
1027    }
1028
1029
1030@keras_export("keras.optimizers.schedules.serialize")
1031def serialize(learning_rate_schedule):
1032  return generic_utils.serialize_keras_object(learning_rate_schedule)
1033
1034
1035@keras_export("keras.optimizers.schedules.deserialize")
1036def deserialize(config, custom_objects=None):
1037  return generic_utils.deserialize_keras_object(
1038      config,
1039      module_objects=globals(),
1040      custom_objects=custom_objects,
1041      printable_module_name="decay")
1042