xref: /aosp_15_r20/external/pytorch/torch/optim/lbfgs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional, Union
3
4import torch
5from torch import Tensor
6
7from .optimizer import Optimizer, ParamsT
8
9
10__all__ = ["LBFGS"]
11
12
13def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
14    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
15    # Compute bounds of interpolation area
16    if bounds is not None:
17        xmin_bound, xmax_bound = bounds
18    else:
19        xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
20
21    # Code for most common case: cubic interpolation of 2 points
22    #   w/ function and derivative values for both
23    # Solution in this case (where x2 is the farthest point):
24    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
25    #   d2 = sqrt(d1^2 - g1*g2);
26    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
27    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
28    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
29    d2_square = d1**2 - g1 * g2
30    if d2_square >= 0:
31        d2 = d2_square.sqrt()
32        if x1 <= x2:
33            min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
34        else:
35            min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
36        return min(max(min_pos, xmin_bound), xmax_bound)
37    else:
38        return (xmin_bound + xmax_bound) / 2.0
39
40
41def _strong_wolfe(
42    obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
43):
44    # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
45    d_norm = d.abs().max()
46    g = g.clone(memory_format=torch.contiguous_format)
47    # evaluate objective and gradient using initial step
48    f_new, g_new = obj_func(x, t, d)
49    ls_func_evals = 1
50    gtd_new = g_new.dot(d)
51
52    # bracket an interval containing a point satisfying the Wolfe criteria
53    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
54    done = False
55    ls_iter = 0
56    while ls_iter < max_ls:
57        # check conditions
58        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
59            bracket = [t_prev, t]
60            bracket_f = [f_prev, f_new]
61            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
62            bracket_gtd = [gtd_prev, gtd_new]
63            break
64
65        if abs(gtd_new) <= -c2 * gtd:
66            bracket = [t]
67            bracket_f = [f_new]
68            bracket_g = [g_new]
69            done = True
70            break
71
72        if gtd_new >= 0:
73            bracket = [t_prev, t]
74            bracket_f = [f_prev, f_new]
75            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
76            bracket_gtd = [gtd_prev, gtd_new]
77            break
78
79        # interpolate
80        min_step = t + 0.01 * (t - t_prev)
81        max_step = t * 10
82        tmp = t
83        t = _cubic_interpolate(
84            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
85        )
86
87        # next step
88        t_prev = tmp
89        f_prev = f_new
90        g_prev = g_new.clone(memory_format=torch.contiguous_format)
91        gtd_prev = gtd_new
92        f_new, g_new = obj_func(x, t, d)
93        ls_func_evals += 1
94        gtd_new = g_new.dot(d)
95        ls_iter += 1
96
97    # reached max number of iterations?
98    if ls_iter == max_ls:
99        bracket = [0, t]
100        bracket_f = [f, f_new]
101        bracket_g = [g, g_new]
102
103    # zoom phase: we now have a point satisfying the criteria, or
104    # a bracket around it. We refine the bracket until we find the
105    # exact point satisfying the criteria
106    insuf_progress = False
107    # find high and low points in bracket
108    low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)  # type: ignore[possibly-undefined]
109    while not done and ls_iter < max_ls:
110        # line-search bracket is so small
111        if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:  # type: ignore[possibly-undefined]
112            break
113
114        # compute new trial value
115        t = _cubic_interpolate(
116            bracket[0],
117            bracket_f[0],
118            bracket_gtd[0],  # type: ignore[possibly-undefined]
119            bracket[1],
120            bracket_f[1],
121            bracket_gtd[1],
122        )
123
124        # test that we are making sufficient progress:
125        # in case `t` is so close to boundary, we mark that we are making
126        # insufficient progress, and if
127        #   + we have made insufficient progress in the last step, or
128        #   + `t` is at one of the boundary,
129        # we will move `t` to a position which is `0.1 * len(bracket)`
130        # away from the nearest boundary point.
131        eps = 0.1 * (max(bracket) - min(bracket))
132        if min(max(bracket) - t, t - min(bracket)) < eps:
133            # interpolation close to boundary
134            if insuf_progress or t >= max(bracket) or t <= min(bracket):
135                # evaluate at 0.1 away from boundary
136                if abs(t - max(bracket)) < abs(t - min(bracket)):
137                    t = max(bracket) - eps
138                else:
139                    t = min(bracket) + eps
140                insuf_progress = False
141            else:
142                insuf_progress = True
143        else:
144            insuf_progress = False
145
146        # Evaluate new point
147        f_new, g_new = obj_func(x, t, d)
148        ls_func_evals += 1
149        gtd_new = g_new.dot(d)
150        ls_iter += 1
151
152        if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
153            # Armijo condition not satisfied or not lower than lowest point
154            bracket[high_pos] = t
155            bracket_f[high_pos] = f_new
156            bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
157            bracket_gtd[high_pos] = gtd_new
158            low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
159        else:
160            if abs(gtd_new) <= -c2 * gtd:
161                # Wolfe conditions satisfied
162                done = True
163            elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
164                # old high becomes new low
165                bracket[high_pos] = bracket[low_pos]
166                bracket_f[high_pos] = bracket_f[low_pos]
167                bracket_g[high_pos] = bracket_g[low_pos]  # type: ignore[possibly-undefined]
168                bracket_gtd[high_pos] = bracket_gtd[low_pos]
169
170            # new point becomes new low
171            bracket[low_pos] = t
172            bracket_f[low_pos] = f_new
173            bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
174            bracket_gtd[low_pos] = gtd_new
175
176    # return stuff
177    t = bracket[low_pos]  # type: ignore[possibly-undefined]
178    f_new = bracket_f[low_pos]
179    g_new = bracket_g[low_pos]  # type: ignore[possibly-undefined]
180    return f_new, g_new, t, ls_func_evals
181
182
183class LBFGS(Optimizer):
184    """Implements L-BFGS algorithm.
185
186    Heavily inspired by `minFunc
187    <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
188
189    .. warning::
190        This optimizer doesn't support per-parameter options and parameter
191        groups (there can be only one).
192
193    .. warning::
194        Right now all parameters have to be on a single device. This will be
195        improved in the future.
196
197    .. note::
198        This is a very memory intensive optimizer (it requires additional
199        ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
200        try reducing the history size, or use a different algorithm.
201
202    Args:
203        params (iterable): iterable of parameters to optimize. Parameters must be real.
204        lr (float): learning rate (default: 1)
205        max_iter (int): maximal number of iterations per optimization step
206            (default: 20)
207        max_eval (int): maximal number of function evaluations per optimization
208            step (default: max_iter * 1.25).
209        tolerance_grad (float): termination tolerance on first order optimality
210            (default: 1e-7).
211        tolerance_change (float): termination tolerance on function
212            value/parameter changes (default: 1e-9).
213        history_size (int): update history size (default: 100).
214        line_search_fn (str): either 'strong_wolfe' or None (default: None).
215    """
216
217    def __init__(
218        self,
219        params: ParamsT,
220        lr: Union[float, Tensor] = 1,
221        max_iter: int = 20,
222        max_eval: Optional[int] = None,
223        tolerance_grad: float = 1e-7,
224        tolerance_change: float = 1e-9,
225        history_size: int = 100,
226        line_search_fn: Optional[str] = None,
227    ):
228        if isinstance(lr, Tensor) and lr.numel() != 1:
229            raise ValueError("Tensor lr must be 1-element")
230        if not 0.0 <= lr:
231            raise ValueError(f"Invalid learning rate: {lr}")
232        if max_eval is None:
233            max_eval = max_iter * 5 // 4
234        defaults = dict(
235            lr=lr,
236            max_iter=max_iter,
237            max_eval=max_eval,
238            tolerance_grad=tolerance_grad,
239            tolerance_change=tolerance_change,
240            history_size=history_size,
241            line_search_fn=line_search_fn,
242        )
243        super().__init__(params, defaults)
244
245        if len(self.param_groups) != 1:
246            raise ValueError(
247                "LBFGS doesn't support per-parameter options " "(parameter groups)"
248            )
249
250        self._params = self.param_groups[0]["params"]
251        self._numel_cache = None
252
253    def _numel(self):
254        if self._numel_cache is None:
255            self._numel_cache = sum(
256                2 * p.numel() if torch.is_complex(p) else p.numel()
257                for p in self._params
258            )
259
260        return self._numel_cache
261
262    def _gather_flat_grad(self):
263        views = []
264        for p in self._params:
265            if p.grad is None:
266                view = p.new(p.numel()).zero_()
267            elif p.grad.is_sparse:
268                view = p.grad.to_dense().view(-1)
269            else:
270                view = p.grad.view(-1)
271            if torch.is_complex(view):
272                view = torch.view_as_real(view).view(-1)
273            views.append(view)
274        return torch.cat(views, 0)
275
276    def _add_grad(self, step_size, update):
277        offset = 0
278        for p in self._params:
279            if torch.is_complex(p):
280                p = torch.view_as_real(p)
281            numel = p.numel()
282            # view as to avoid deprecated pointwise semantics
283            p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
284            offset += numel
285        assert offset == self._numel()
286
287    def _clone_param(self):
288        return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
289
290    def _set_param(self, params_data):
291        for p, pdata in zip(self._params, params_data):
292            p.copy_(pdata)
293
294    def _directional_evaluate(self, closure, x, t, d):
295        self._add_grad(t, d)
296        loss = float(closure())
297        flat_grad = self._gather_flat_grad()
298        self._set_param(x)
299        return loss, flat_grad
300
301    @torch.no_grad()
302    def step(self, closure):
303        """Perform a single optimization step.
304
305        Args:
306            closure (Callable): A closure that reevaluates the model
307                and returns the loss.
308        """
309        assert len(self.param_groups) == 1
310
311        # Make sure the closure is always called with grad enabled
312        closure = torch.enable_grad()(closure)
313
314        group = self.param_groups[0]
315        lr = group["lr"]
316        max_iter = group["max_iter"]
317        max_eval = group["max_eval"]
318        tolerance_grad = group["tolerance_grad"]
319        tolerance_change = group["tolerance_change"]
320        line_search_fn = group["line_search_fn"]
321        history_size = group["history_size"]
322
323        # NOTE: LBFGS has only global state, but we register it as state for
324        # the first param, because this helps with casting in load_state_dict
325        state = self.state[self._params[0]]
326        state.setdefault("func_evals", 0)
327        state.setdefault("n_iter", 0)
328
329        # evaluate initial f(x) and df/dx
330        orig_loss = closure()
331        loss = float(orig_loss)
332        current_evals = 1
333        state["func_evals"] += 1
334
335        flat_grad = self._gather_flat_grad()
336        opt_cond = flat_grad.abs().max() <= tolerance_grad
337
338        # optimal condition
339        if opt_cond:
340            return orig_loss
341
342        # tensors cached in state (for tracing)
343        d = state.get("d")
344        t = state.get("t")
345        old_dirs = state.get("old_dirs")
346        old_stps = state.get("old_stps")
347        ro = state.get("ro")
348        H_diag = state.get("H_diag")
349        prev_flat_grad = state.get("prev_flat_grad")
350        prev_loss = state.get("prev_loss")
351
352        n_iter = 0
353        # optimize for a max of max_iter iterations
354        while n_iter < max_iter:
355            # keep track of nb of iterations
356            n_iter += 1
357            state["n_iter"] += 1
358
359            ############################################################
360            # compute gradient descent direction
361            ############################################################
362            if state["n_iter"] == 1:
363                d = flat_grad.neg()
364                old_dirs = []
365                old_stps = []
366                ro = []
367                H_diag = 1
368            else:
369                # do lbfgs update (update memory)
370                y = flat_grad.sub(prev_flat_grad)
371                s = d.mul(t)
372                ys = y.dot(s)  # y*s
373                if ys > 1e-10:
374                    # updating memory
375                    if len(old_dirs) == history_size:
376                        # shift history by one (limited-memory)
377                        old_dirs.pop(0)
378                        old_stps.pop(0)
379                        ro.pop(0)
380
381                    # store new direction/step
382                    old_dirs.append(y)
383                    old_stps.append(s)
384                    ro.append(1.0 / ys)
385
386                    # update scale of initial Hessian approximation
387                    H_diag = ys / y.dot(y)  # (y*y)
388
389                # compute the approximate (L-BFGS) inverse Hessian
390                # multiplied by the gradient
391                num_old = len(old_dirs)
392
393                if "al" not in state:
394                    state["al"] = [None] * history_size
395                al = state["al"]
396
397                # iteration in L-BFGS loop collapsed to use just one buffer
398                q = flat_grad.neg()
399                for i in range(num_old - 1, -1, -1):
400                    al[i] = old_stps[i].dot(q) * ro[i]
401                    q.add_(old_dirs[i], alpha=-al[i])
402
403                # multiply by initial Hessian
404                # r/d is the final direction
405                d = r = torch.mul(q, H_diag)
406                for i in range(num_old):
407                    be_i = old_dirs[i].dot(r) * ro[i]
408                    r.add_(old_stps[i], alpha=al[i] - be_i)
409
410            if prev_flat_grad is None:
411                prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
412            else:
413                prev_flat_grad.copy_(flat_grad)
414            prev_loss = loss
415
416            ############################################################
417            # compute step length
418            ############################################################
419            # reset initial guess for step size
420            if state["n_iter"] == 1:
421                t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
422            else:
423                t = lr
424
425            # directional derivative
426            gtd = flat_grad.dot(d)  # g * d
427
428            # directional derivative is below tolerance
429            if gtd > -tolerance_change:
430                break
431
432            # optional line search: user function
433            ls_func_evals = 0
434            if line_search_fn is not None:
435                # perform line search, using user function
436                if line_search_fn != "strong_wolfe":
437                    raise RuntimeError("only 'strong_wolfe' is supported")
438                else:
439                    x_init = self._clone_param()
440
441                    def obj_func(x, t, d):
442                        return self._directional_evaluate(closure, x, t, d)
443
444                    loss, flat_grad, t, ls_func_evals = _strong_wolfe(
445                        obj_func, x_init, t, d, loss, flat_grad, gtd
446                    )
447                self._add_grad(t, d)
448                opt_cond = flat_grad.abs().max() <= tolerance_grad
449            else:
450                # no line search, simply move with fixed-step
451                self._add_grad(t, d)
452                if n_iter != max_iter:
453                    # re-evaluate function only if not in last iteration
454                    # the reason we do this: in a stochastic setting,
455                    # no use to re-evaluate that function here
456                    with torch.enable_grad():
457                        loss = float(closure())
458                    flat_grad = self._gather_flat_grad()
459                    opt_cond = flat_grad.abs().max() <= tolerance_grad
460                    ls_func_evals = 1
461
462            # update func eval
463            current_evals += ls_func_evals
464            state["func_evals"] += ls_func_evals
465
466            ############################################################
467            # check conditions
468            ############################################################
469            if n_iter == max_iter:
470                break
471
472            if current_evals >= max_eval:
473                break
474
475            # optimal condition
476            if opt_cond:
477                break
478
479            # lack of progress
480            if d.mul(t).abs().max() <= tolerance_change:
481                break
482
483            if abs(loss - prev_loss) < tolerance_change:
484                break
485
486        state["d"] = d
487        state["t"] = t
488        state["old_dirs"] = old_dirs
489        state["old_stps"] = old_stps
490        state["ro"] = ro
491        state["H_diag"] = H_diag
492        state["prev_flat_grad"] = prev_flat_grad
493        state["prev_loss"] = prev_loss
494
495        return orig_loss
496