1# mypy: allow-untyped-defs 2from typing import Optional, Union 3 4import torch 5from torch import Tensor 6 7from .optimizer import Optimizer, ParamsT 8 9 10__all__ = ["LBFGS"] 11 12 13def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): 14 # ported from https://github.com/torch/optim/blob/master/polyinterp.lua 15 # Compute bounds of interpolation area 16 if bounds is not None: 17 xmin_bound, xmax_bound = bounds 18 else: 19 xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) 20 21 # Code for most common case: cubic interpolation of 2 points 22 # w/ function and derivative values for both 23 # Solution in this case (where x2 is the farthest point): 24 # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); 25 # d2 = sqrt(d1^2 - g1*g2); 26 # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); 27 # t_new = min(max(min_pos,xmin_bound),xmax_bound); 28 d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) 29 d2_square = d1**2 - g1 * g2 30 if d2_square >= 0: 31 d2 = d2_square.sqrt() 32 if x1 <= x2: 33 min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) 34 else: 35 min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) 36 return min(max(min_pos, xmin_bound), xmax_bound) 37 else: 38 return (xmin_bound + xmax_bound) / 2.0 39 40 41def _strong_wolfe( 42 obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 43): 44 # ported from https://github.com/torch/optim/blob/master/lswolfe.lua 45 d_norm = d.abs().max() 46 g = g.clone(memory_format=torch.contiguous_format) 47 # evaluate objective and gradient using initial step 48 f_new, g_new = obj_func(x, t, d) 49 ls_func_evals = 1 50 gtd_new = g_new.dot(d) 51 52 # bracket an interval containing a point satisfying the Wolfe criteria 53 t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd 54 done = False 55 ls_iter = 0 56 while ls_iter < max_ls: 57 # check conditions 58 if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): 59 bracket = [t_prev, t] 60 bracket_f = [f_prev, f_new] 61 bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 62 bracket_gtd = [gtd_prev, gtd_new] 63 break 64 65 if abs(gtd_new) <= -c2 * gtd: 66 bracket = [t] 67 bracket_f = [f_new] 68 bracket_g = [g_new] 69 done = True 70 break 71 72 if gtd_new >= 0: 73 bracket = [t_prev, t] 74 bracket_f = [f_prev, f_new] 75 bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] 76 bracket_gtd = [gtd_prev, gtd_new] 77 break 78 79 # interpolate 80 min_step = t + 0.01 * (t - t_prev) 81 max_step = t * 10 82 tmp = t 83 t = _cubic_interpolate( 84 t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) 85 ) 86 87 # next step 88 t_prev = tmp 89 f_prev = f_new 90 g_prev = g_new.clone(memory_format=torch.contiguous_format) 91 gtd_prev = gtd_new 92 f_new, g_new = obj_func(x, t, d) 93 ls_func_evals += 1 94 gtd_new = g_new.dot(d) 95 ls_iter += 1 96 97 # reached max number of iterations? 98 if ls_iter == max_ls: 99 bracket = [0, t] 100 bracket_f = [f, f_new] 101 bracket_g = [g, g_new] 102 103 # zoom phase: we now have a point satisfying the criteria, or 104 # a bracket around it. We refine the bracket until we find the 105 # exact point satisfying the criteria 106 insuf_progress = False 107 # find high and low points in bracket 108 low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined] 109 while not done and ls_iter < max_ls: 110 # line-search bracket is so small 111 if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined] 112 break 113 114 # compute new trial value 115 t = _cubic_interpolate( 116 bracket[0], 117 bracket_f[0], 118 bracket_gtd[0], # type: ignore[possibly-undefined] 119 bracket[1], 120 bracket_f[1], 121 bracket_gtd[1], 122 ) 123 124 # test that we are making sufficient progress: 125 # in case `t` is so close to boundary, we mark that we are making 126 # insufficient progress, and if 127 # + we have made insufficient progress in the last step, or 128 # + `t` is at one of the boundary, 129 # we will move `t` to a position which is `0.1 * len(bracket)` 130 # away from the nearest boundary point. 131 eps = 0.1 * (max(bracket) - min(bracket)) 132 if min(max(bracket) - t, t - min(bracket)) < eps: 133 # interpolation close to boundary 134 if insuf_progress or t >= max(bracket) or t <= min(bracket): 135 # evaluate at 0.1 away from boundary 136 if abs(t - max(bracket)) < abs(t - min(bracket)): 137 t = max(bracket) - eps 138 else: 139 t = min(bracket) + eps 140 insuf_progress = False 141 else: 142 insuf_progress = True 143 else: 144 insuf_progress = False 145 146 # Evaluate new point 147 f_new, g_new = obj_func(x, t, d) 148 ls_func_evals += 1 149 gtd_new = g_new.dot(d) 150 ls_iter += 1 151 152 if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: 153 # Armijo condition not satisfied or not lower than lowest point 154 bracket[high_pos] = t 155 bracket_f[high_pos] = f_new 156 bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] 157 bracket_gtd[high_pos] = gtd_new 158 low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) 159 else: 160 if abs(gtd_new) <= -c2 * gtd: 161 # Wolfe conditions satisfied 162 done = True 163 elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: 164 # old high becomes new low 165 bracket[high_pos] = bracket[low_pos] 166 bracket_f[high_pos] = bracket_f[low_pos] 167 bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] 168 bracket_gtd[high_pos] = bracket_gtd[low_pos] 169 170 # new point becomes new low 171 bracket[low_pos] = t 172 bracket_f[low_pos] = f_new 173 bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] 174 bracket_gtd[low_pos] = gtd_new 175 176 # return stuff 177 t = bracket[low_pos] # type: ignore[possibly-undefined] 178 f_new = bracket_f[low_pos] 179 g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] 180 return f_new, g_new, t, ls_func_evals 181 182 183class LBFGS(Optimizer): 184 """Implements L-BFGS algorithm. 185 186 Heavily inspired by `minFunc 187 <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_. 188 189 .. warning:: 190 This optimizer doesn't support per-parameter options and parameter 191 groups (there can be only one). 192 193 .. warning:: 194 Right now all parameters have to be on a single device. This will be 195 improved in the future. 196 197 .. note:: 198 This is a very memory intensive optimizer (it requires additional 199 ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory 200 try reducing the history size, or use a different algorithm. 201 202 Args: 203 params (iterable): iterable of parameters to optimize. Parameters must be real. 204 lr (float): learning rate (default: 1) 205 max_iter (int): maximal number of iterations per optimization step 206 (default: 20) 207 max_eval (int): maximal number of function evaluations per optimization 208 step (default: max_iter * 1.25). 209 tolerance_grad (float): termination tolerance on first order optimality 210 (default: 1e-7). 211 tolerance_change (float): termination tolerance on function 212 value/parameter changes (default: 1e-9). 213 history_size (int): update history size (default: 100). 214 line_search_fn (str): either 'strong_wolfe' or None (default: None). 215 """ 216 217 def __init__( 218 self, 219 params: ParamsT, 220 lr: Union[float, Tensor] = 1, 221 max_iter: int = 20, 222 max_eval: Optional[int] = None, 223 tolerance_grad: float = 1e-7, 224 tolerance_change: float = 1e-9, 225 history_size: int = 100, 226 line_search_fn: Optional[str] = None, 227 ): 228 if isinstance(lr, Tensor) and lr.numel() != 1: 229 raise ValueError("Tensor lr must be 1-element") 230 if not 0.0 <= lr: 231 raise ValueError(f"Invalid learning rate: {lr}") 232 if max_eval is None: 233 max_eval = max_iter * 5 // 4 234 defaults = dict( 235 lr=lr, 236 max_iter=max_iter, 237 max_eval=max_eval, 238 tolerance_grad=tolerance_grad, 239 tolerance_change=tolerance_change, 240 history_size=history_size, 241 line_search_fn=line_search_fn, 242 ) 243 super().__init__(params, defaults) 244 245 if len(self.param_groups) != 1: 246 raise ValueError( 247 "LBFGS doesn't support per-parameter options " "(parameter groups)" 248 ) 249 250 self._params = self.param_groups[0]["params"] 251 self._numel_cache = None 252 253 def _numel(self): 254 if self._numel_cache is None: 255 self._numel_cache = sum( 256 2 * p.numel() if torch.is_complex(p) else p.numel() 257 for p in self._params 258 ) 259 260 return self._numel_cache 261 262 def _gather_flat_grad(self): 263 views = [] 264 for p in self._params: 265 if p.grad is None: 266 view = p.new(p.numel()).zero_() 267 elif p.grad.is_sparse: 268 view = p.grad.to_dense().view(-1) 269 else: 270 view = p.grad.view(-1) 271 if torch.is_complex(view): 272 view = torch.view_as_real(view).view(-1) 273 views.append(view) 274 return torch.cat(views, 0) 275 276 def _add_grad(self, step_size, update): 277 offset = 0 278 for p in self._params: 279 if torch.is_complex(p): 280 p = torch.view_as_real(p) 281 numel = p.numel() 282 # view as to avoid deprecated pointwise semantics 283 p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) 284 offset += numel 285 assert offset == self._numel() 286 287 def _clone_param(self): 288 return [p.clone(memory_format=torch.contiguous_format) for p in self._params] 289 290 def _set_param(self, params_data): 291 for p, pdata in zip(self._params, params_data): 292 p.copy_(pdata) 293 294 def _directional_evaluate(self, closure, x, t, d): 295 self._add_grad(t, d) 296 loss = float(closure()) 297 flat_grad = self._gather_flat_grad() 298 self._set_param(x) 299 return loss, flat_grad 300 301 @torch.no_grad() 302 def step(self, closure): 303 """Perform a single optimization step. 304 305 Args: 306 closure (Callable): A closure that reevaluates the model 307 and returns the loss. 308 """ 309 assert len(self.param_groups) == 1 310 311 # Make sure the closure is always called with grad enabled 312 closure = torch.enable_grad()(closure) 313 314 group = self.param_groups[0] 315 lr = group["lr"] 316 max_iter = group["max_iter"] 317 max_eval = group["max_eval"] 318 tolerance_grad = group["tolerance_grad"] 319 tolerance_change = group["tolerance_change"] 320 line_search_fn = group["line_search_fn"] 321 history_size = group["history_size"] 322 323 # NOTE: LBFGS has only global state, but we register it as state for 324 # the first param, because this helps with casting in load_state_dict 325 state = self.state[self._params[0]] 326 state.setdefault("func_evals", 0) 327 state.setdefault("n_iter", 0) 328 329 # evaluate initial f(x) and df/dx 330 orig_loss = closure() 331 loss = float(orig_loss) 332 current_evals = 1 333 state["func_evals"] += 1 334 335 flat_grad = self._gather_flat_grad() 336 opt_cond = flat_grad.abs().max() <= tolerance_grad 337 338 # optimal condition 339 if opt_cond: 340 return orig_loss 341 342 # tensors cached in state (for tracing) 343 d = state.get("d") 344 t = state.get("t") 345 old_dirs = state.get("old_dirs") 346 old_stps = state.get("old_stps") 347 ro = state.get("ro") 348 H_diag = state.get("H_diag") 349 prev_flat_grad = state.get("prev_flat_grad") 350 prev_loss = state.get("prev_loss") 351 352 n_iter = 0 353 # optimize for a max of max_iter iterations 354 while n_iter < max_iter: 355 # keep track of nb of iterations 356 n_iter += 1 357 state["n_iter"] += 1 358 359 ############################################################ 360 # compute gradient descent direction 361 ############################################################ 362 if state["n_iter"] == 1: 363 d = flat_grad.neg() 364 old_dirs = [] 365 old_stps = [] 366 ro = [] 367 H_diag = 1 368 else: 369 # do lbfgs update (update memory) 370 y = flat_grad.sub(prev_flat_grad) 371 s = d.mul(t) 372 ys = y.dot(s) # y*s 373 if ys > 1e-10: 374 # updating memory 375 if len(old_dirs) == history_size: 376 # shift history by one (limited-memory) 377 old_dirs.pop(0) 378 old_stps.pop(0) 379 ro.pop(0) 380 381 # store new direction/step 382 old_dirs.append(y) 383 old_stps.append(s) 384 ro.append(1.0 / ys) 385 386 # update scale of initial Hessian approximation 387 H_diag = ys / y.dot(y) # (y*y) 388 389 # compute the approximate (L-BFGS) inverse Hessian 390 # multiplied by the gradient 391 num_old = len(old_dirs) 392 393 if "al" not in state: 394 state["al"] = [None] * history_size 395 al = state["al"] 396 397 # iteration in L-BFGS loop collapsed to use just one buffer 398 q = flat_grad.neg() 399 for i in range(num_old - 1, -1, -1): 400 al[i] = old_stps[i].dot(q) * ro[i] 401 q.add_(old_dirs[i], alpha=-al[i]) 402 403 # multiply by initial Hessian 404 # r/d is the final direction 405 d = r = torch.mul(q, H_diag) 406 for i in range(num_old): 407 be_i = old_dirs[i].dot(r) * ro[i] 408 r.add_(old_stps[i], alpha=al[i] - be_i) 409 410 if prev_flat_grad is None: 411 prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 412 else: 413 prev_flat_grad.copy_(flat_grad) 414 prev_loss = loss 415 416 ############################################################ 417 # compute step length 418 ############################################################ 419 # reset initial guess for step size 420 if state["n_iter"] == 1: 421 t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr 422 else: 423 t = lr 424 425 # directional derivative 426 gtd = flat_grad.dot(d) # g * d 427 428 # directional derivative is below tolerance 429 if gtd > -tolerance_change: 430 break 431 432 # optional line search: user function 433 ls_func_evals = 0 434 if line_search_fn is not None: 435 # perform line search, using user function 436 if line_search_fn != "strong_wolfe": 437 raise RuntimeError("only 'strong_wolfe' is supported") 438 else: 439 x_init = self._clone_param() 440 441 def obj_func(x, t, d): 442 return self._directional_evaluate(closure, x, t, d) 443 444 loss, flat_grad, t, ls_func_evals = _strong_wolfe( 445 obj_func, x_init, t, d, loss, flat_grad, gtd 446 ) 447 self._add_grad(t, d) 448 opt_cond = flat_grad.abs().max() <= tolerance_grad 449 else: 450 # no line search, simply move with fixed-step 451 self._add_grad(t, d) 452 if n_iter != max_iter: 453 # re-evaluate function only if not in last iteration 454 # the reason we do this: in a stochastic setting, 455 # no use to re-evaluate that function here 456 with torch.enable_grad(): 457 loss = float(closure()) 458 flat_grad = self._gather_flat_grad() 459 opt_cond = flat_grad.abs().max() <= tolerance_grad 460 ls_func_evals = 1 461 462 # update func eval 463 current_evals += ls_func_evals 464 state["func_evals"] += ls_func_evals 465 466 ############################################################ 467 # check conditions 468 ############################################################ 469 if n_iter == max_iter: 470 break 471 472 if current_evals >= max_eval: 473 break 474 475 # optimal condition 476 if opt_cond: 477 break 478 479 # lack of progress 480 if d.mul(t).abs().max() <= tolerance_change: 481 break 482 483 if abs(loss - prev_loss) < tolerance_change: 484 break 485 486 state["d"] = d 487 state["t"] = t 488 state["old_dirs"] = old_dirs 489 state["old_stps"] = old_stps 490 state["ro"] = ro 491 state["H_diag"] = H_diag 492 state["prev_flat_grad"] = prev_flat_grad 493 state["prev_loss"] = prev_loss 494 495 return orig_loss 496