1# mypy: allow-untyped-defs 2from typing import List, Tuple 3 4import torch 5from torch._vmap_internals import _vmap 6 7from . import forward_ad as fwAD 8 9 10__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] 11 12# Utility functions 13 14 15def _as_tuple_nocheck(x): 16 if isinstance(x, tuple): 17 return x 18 elif isinstance(x, list): 19 return tuple(x) 20 else: 21 return (x,) 22 23 24def _as_tuple(inp, arg_name=None, fn_name=None): 25 # Ensures that inp is a tuple of Tensors 26 # Returns whether or not the original inp was a tuple and the tupled version of the input 27 if arg_name is None and fn_name is None: 28 return _as_tuple_nocheck(inp) 29 30 is_inp_tuple = True 31 if not isinstance(inp, tuple): 32 inp = (inp,) 33 is_inp_tuple = False 34 35 for i, el in enumerate(inp): 36 if not isinstance(el, torch.Tensor): 37 if is_inp_tuple: 38 raise TypeError( 39 f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" 40 f" value at index {i} has type {type(el)}." 41 ) 42 else: 43 raise TypeError( 44 f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" 45 f" given {arg_name} has type {type(el)}." 46 ) 47 48 return is_inp_tuple, inp 49 50 51def _tuple_postprocess(res, to_unpack): 52 # Unpacks a potentially nested tuple of Tensors 53 # to_unpack should be a single boolean or a tuple of two booleans. 54 # It is used to: 55 # - invert _as_tuple when res should match the inp given to _as_tuple 56 # - optionally remove nesting of two tuples created by multiple calls to _as_tuple 57 if isinstance(to_unpack, tuple): 58 assert len(to_unpack) == 2 59 if not to_unpack[1]: 60 res = tuple(el[0] for el in res) 61 if not to_unpack[0]: 62 res = res[0] 63 else: 64 if not to_unpack: 65 res = res[0] 66 return res 67 68 69def _grad_preprocess(inputs, create_graph, need_graph): 70 # Preprocess the inputs to make sure they require gradient 71 # inputs is a tuple of Tensors to preprocess 72 # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs 73 # need_graph specifies if we internally want gradients to flow back to the Tensors in res 74 # Note that we *always* create a new Tensor object to be able to see the difference between 75 # inputs given as arguments and the same Tensors automatically captured by the user function. 76 # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 77 res = [] 78 for inp in inputs: 79 if create_graph and inp.requires_grad: 80 # Create at least a new Tensor object in a differentiable way 81 if not inp.is_sparse: 82 # Use .view_as() to get a shallow copy 83 res.append(inp.view_as(inp)) 84 else: 85 # We cannot use view for sparse Tensors so we clone 86 res.append(inp.clone()) 87 else: 88 res.append(inp.detach().requires_grad_(need_graph)) 89 return tuple(res) 90 91 92def _grad_postprocess(inputs, create_graph): 93 # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not 94 # request it. 95 if isinstance(inputs[0], torch.Tensor): 96 if not create_graph: 97 return tuple(inp.detach() for inp in inputs) 98 else: 99 return inputs 100 else: 101 return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) 102 103 104def _validate_v(v, other, is_other_tuple): 105 # This assumes that other is the correct shape, and v should match 106 # Both are assumed to be tuples of Tensors 107 if len(other) != len(v): 108 if is_other_tuple: 109 raise RuntimeError( 110 f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." 111 ) 112 else: 113 raise RuntimeError("The given v should contain a single Tensor.") 114 115 for idx, (el_v, el_other) in enumerate(zip(v, other)): 116 if el_v.size() != el_other.size(): 117 prepend = "" 118 if is_other_tuple: 119 prepend = f"Entry {idx} in " 120 raise RuntimeError( 121 f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." 122 ) 123 124 125def _check_requires_grad(inputs, input_type, strict): 126 # Used to make all the necessary checks to raise nice errors in strict mode. 127 if not strict: 128 return 129 130 if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: 131 raise RuntimeError("Invalid input_type to _check_requires_grad") 132 for i, inp in enumerate(inputs): 133 if inp is None: 134 # This can only be reached for grad_inputs. 135 raise RuntimeError( 136 f"The output of the user-provided function is independent of input {i}." 137 " This is not allowed in strict mode." 138 ) 139 if not inp.requires_grad: 140 if input_type == "hessian": 141 raise RuntimeError( 142 f"The hessian of the user-provided function with respect to input {i}" 143 " is independent of the input. This is not allowed in strict mode." 144 " You should ensure that your function is thrice differentiable and that" 145 " the hessian depends on the inputs." 146 ) 147 elif input_type == "jacobian": 148 raise RuntimeError( 149 "While computing the hessian, found that the jacobian of the user-provided" 150 f" function with respect to input {i} is independent of the input. This is not" 151 " allowed in strict mode. You should ensure that your function is twice" 152 " differentiable and that the jacobian depends on the inputs (this would be" 153 " violated by a linear function for example)." 154 ) 155 elif input_type == "grad_inputs": 156 raise RuntimeError( 157 f"The gradient with respect to input {i} is independent of the inputs of the" 158 " user-provided function. This is not allowed in strict mode." 159 ) 160 else: 161 raise RuntimeError( 162 f"Output {i} of the user-provided function does not require gradients." 163 " The outputs must be computed in a differentiable manner from the input" 164 " when running in strict mode." 165 ) 166 167 168def _autograd_grad( 169 outputs, 170 inputs, 171 grad_outputs=None, 172 create_graph=False, 173 retain_graph=None, 174 is_grads_batched=False, 175): 176 # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. 177 # This has the extra constraint that inputs has to be a tuple 178 assert isinstance(outputs, tuple) 179 if grad_outputs is None: 180 grad_outputs = (None,) * len(outputs) 181 assert isinstance(grad_outputs, tuple) 182 assert len(outputs) == len(grad_outputs) 183 184 new_outputs: Tuple[torch.Tensor, ...] = () 185 new_grad_outputs: Tuple[torch.Tensor, ...] = () 186 for out, grad_out in zip(outputs, grad_outputs): 187 if out is not None and out.requires_grad: 188 new_outputs += (out,) 189 new_grad_outputs += (grad_out,) 190 191 if len(new_outputs) == 0: 192 # No differentiable output, we don't need to call the autograd engine 193 return (None,) * len(inputs) 194 else: 195 return torch.autograd.grad( 196 new_outputs, 197 inputs, 198 new_grad_outputs, 199 allow_unused=True, 200 create_graph=create_graph, 201 retain_graph=retain_graph, 202 is_grads_batched=is_grads_batched, 203 ) 204 205 206def _fill_in_zeros(grads, refs, strict, create_graph, stage): 207 # Used to detect None in the grads and depending on the flags, either replace them 208 # with Tensors full of 0s of the appropriate size based on the refs or raise an error. 209 # strict and create graph allow us to detect when it is appropriate to raise an error 210 # stage gives us information of which backward call we consider to give good error message 211 if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: 212 raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") 213 214 res: Tuple[torch.Tensor, ...] = () 215 for i, grads_i in enumerate(grads): 216 if grads_i is None: 217 if strict: 218 if stage == "back": 219 raise RuntimeError( 220 "The output of the user-provided function is independent of " 221 f"input {i}. This is not allowed in strict mode." 222 ) 223 elif stage == "back_trick": 224 raise RuntimeError( 225 f"The gradient with respect to the input is independent of entry {i}" 226 " in the grad_outputs when using the double backward trick to compute" 227 " forward mode gradients. This is not allowed in strict mode." 228 ) 229 elif stage == "double_back": 230 raise RuntimeError( 231 "The jacobian of the user-provided function is independent of " 232 f"input {i}. This is not allowed in strict mode." 233 ) 234 else: 235 raise RuntimeError( 236 "The hessian of the user-provided function is independent of " 237 f"entry {i} in the grad_jacobian. This is not allowed in strict " 238 "mode as it prevents from using the double backward trick to " 239 "replace forward mode AD." 240 ) 241 242 grads_i = torch.zeros_like(refs[i]) 243 else: 244 if strict and create_graph and not grads_i.requires_grad: 245 if "double" not in stage: 246 raise RuntimeError( 247 "The jacobian of the user-provided function is independent of " 248 f"input {i}. This is not allowed in strict mode when create_graph=True." 249 ) 250 else: 251 raise RuntimeError( 252 "The hessian of the user-provided function is independent of " 253 f"input {i}. This is not allowed in strict mode when create_graph=True." 254 ) 255 256 res += (grads_i,) 257 258 return res 259 260 261# Public API 262 263 264def vjp(func, inputs, v=None, create_graph=False, strict=False): 265 r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. 266 267 Args: 268 func (function): a Python function that takes Tensor inputs and returns 269 a tuple of Tensors or a Tensor. 270 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 271 v (tuple of Tensors or Tensor): The vector for which the vector 272 Jacobian product is computed. Must be the same size as the output 273 of ``func``. This argument is optional when the output of ``func`` 274 contains a single element and (if it is not provided) will be set 275 as a Tensor containing a single ``1``. 276 create_graph (bool, optional): If ``True``, both the output and result 277 will be computed in a differentiable way. Note that when ``strict`` 278 is ``False``, the result can not require gradients or be 279 disconnected from the inputs. Defaults to ``False``. 280 strict (bool, optional): If ``True``, an error will be raised when we 281 detect that there exists an input such that all the outputs are 282 independent of it. If ``False``, we return a Tensor of zeros as the 283 vjp for said inputs, which is the expected mathematical value. 284 Defaults to ``False``. 285 286 Returns: 287 output (tuple): tuple with: 288 func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 289 290 vjp (tuple of Tensors or Tensor): result of the dot product with 291 the same shape as the inputs. 292 293 Example: 294 295 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 296 >>> def exp_reducer(x): 297 ... return x.exp().sum(dim=1) 298 >>> inputs = torch.rand(4, 4) 299 >>> v = torch.ones(4) 300 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 301 >>> vjp(exp_reducer, inputs, v) 302 (tensor([5.7817, 7.2458, 5.7830, 6.7782]), 303 tensor([[1.4458, 1.3962, 1.3042, 1.6354], 304 [2.1288, 1.0652, 1.5483, 2.5035], 305 [2.2046, 1.1292, 1.1432, 1.3059], 306 [1.3225, 1.6652, 1.7753, 2.0152]])) 307 308 >>> vjp(exp_reducer, inputs, v, create_graph=True) 309 (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>), 310 tensor([[1.4458, 1.3962, 1.3042, 1.6354], 311 [2.1288, 1.0652, 1.5483, 2.5035], 312 [2.2046, 1.1292, 1.1432, 1.3059], 313 [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>)) 314 315 >>> def adder(x, y): 316 ... return 2 * x + 3 * y 317 >>> inputs = (torch.rand(2), torch.rand(2)) 318 >>> v = torch.ones(2) 319 >>> vjp(adder, inputs, v) 320 (tensor([2.4225, 2.3340]), 321 (tensor([2., 2.]), tensor([3., 3.]))) 322 """ 323 with torch.enable_grad(): 324 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") 325 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 326 327 outputs = func(*inputs) 328 is_outputs_tuple, outputs = _as_tuple( 329 outputs, "outputs of the user-provided function", "vjp" 330 ) 331 _check_requires_grad(outputs, "outputs", strict=strict) 332 333 if v is not None: 334 _, v = _as_tuple(v, "v", "vjp") 335 v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 336 _validate_v(v, outputs, is_outputs_tuple) 337 else: 338 if len(outputs) != 1 or outputs[0].nelement() != 1: 339 raise RuntimeError( 340 "The vector v can only be None if the " 341 "user-provided function returns " 342 "a single Tensor with a single element." 343 ) 344 345 enable_grad = True if create_graph else torch.is_grad_enabled() 346 with torch.set_grad_enabled(enable_grad): 347 grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) 348 vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") 349 350 # Cleanup objects and return them to the user 351 outputs = _grad_postprocess(outputs, create_graph) 352 vjp = _grad_postprocess(vjp, create_graph) 353 354 return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 355 vjp, is_inputs_tuple 356 ) 357 358 359def jvp(func, inputs, v=None, create_graph=False, strict=False): 360 r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. 361 362 Args: 363 func (function): a Python function that takes Tensor inputs and returns 364 a tuple of Tensors or a Tensor. 365 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 366 v (tuple of Tensors or Tensor): The vector for which the Jacobian 367 vector product is computed. Must be the same size as the input of 368 ``func``. This argument is optional when the input to ``func`` 369 contains a single element and (if it is not provided) will be set 370 as a Tensor containing a single ``1``. 371 create_graph (bool, optional): If ``True``, both the output and result 372 will be computed in a differentiable way. Note that when ``strict`` 373 is ``False``, the result can not require gradients or be 374 disconnected from the inputs. Defaults to ``False``. 375 strict (bool, optional): If ``True``, an error will be raised when we 376 detect that there exists an input such that all the outputs are 377 independent of it. If ``False``, we return a Tensor of zeros as the 378 jvp for said inputs, which is the expected mathematical value. 379 Defaults to ``False``. 380 381 Returns: 382 output (tuple): tuple with: 383 func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 384 385 jvp (tuple of Tensors or Tensor): result of the dot product with 386 the same shape as the output. 387 388 Note: 389 ``autograd.functional.jvp`` computes the jvp by using the backward of 390 the backward (sometimes called the double backwards trick). This is not 391 the most performant way of computing the jvp. Please consider using 392 :func:`torch.func.jvp` or the 393 :ref:`low-level forward-mode AD API <forward-mode-ad>` instead. 394 395 Example: 396 397 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 398 >>> def exp_reducer(x): 399 ... return x.exp().sum(dim=1) 400 >>> inputs = torch.rand(4, 4) 401 >>> v = torch.ones(4, 4) 402 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 403 >>> jvp(exp_reducer, inputs, v) 404 (tensor([6.3090, 4.6742, 7.9114, 8.2106]), 405 tensor([6.3090, 4.6742, 7.9114, 8.2106])) 406 407 >>> jvp(exp_reducer, inputs, v, create_graph=True) 408 (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>), 409 tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>)) 410 411 >>> def adder(x, y): 412 ... return 2 * x + 3 * y 413 >>> inputs = (torch.rand(2), torch.rand(2)) 414 >>> v = (torch.ones(2), torch.ones(2)) 415 >>> jvp(adder, inputs, v) 416 (tensor([2.2399, 2.5005]), 417 tensor([5., 5.])) 418 419 """ 420 with torch.enable_grad(): 421 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") 422 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 423 424 if v is not None: 425 _, v = _as_tuple(v, "v", "jvp") 426 v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 427 _validate_v(v, inputs, is_inputs_tuple) 428 else: 429 if len(inputs) != 1 or inputs[0].nelement() != 1: 430 raise RuntimeError( 431 "The vector v can only be None if the input to " 432 "the user-provided function is a single Tensor " 433 "with a single element." 434 ) 435 436 outputs = func(*inputs) 437 is_outputs_tuple, outputs = _as_tuple( 438 outputs, "outputs of the user-provided function", "jvp" 439 ) 440 _check_requires_grad(outputs, "outputs", strict=strict) 441 # The backward is linear so the value of grad_outputs is not important as 442 # it won't appear in the double backward graph. We only need to ensure that 443 # it does not contain inf or nan. 444 grad_outputs = tuple( 445 torch.zeros_like(out, requires_grad=True) for out in outputs 446 ) 447 448 grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) 449 _check_requires_grad(grad_inputs, "grad_inputs", strict=strict) 450 451 if create_graph: 452 with torch.enable_grad(): 453 grad_res = _autograd_grad( 454 grad_inputs, grad_outputs, v, create_graph=create_graph 455 ) 456 jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") 457 else: 458 grad_res = _autograd_grad( 459 grad_inputs, grad_outputs, v, create_graph=create_graph 460 ) 461 jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") 462 463 # Cleanup objects and return them to the user 464 outputs = _grad_postprocess(outputs, create_graph) 465 jvp = _grad_postprocess(jvp, create_graph) 466 467 return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 468 jvp, is_outputs_tuple 469 ) 470 471 472def _construct_standard_basis_for( 473 tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...] 474) -> Tuple[torch.Tensor, ...]: 475 # This function: 476 # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. 477 # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. 478 # - Each chunk corresponds to one tensor. The chunk has the same dtype and 479 # device as the tensor 480 # 481 # For example, with tensor_numels = [1, 2, 1], this function returns: 482 # ( tensor([[1], tensor([[0, 0], tensor([[0], 483 # [0], [1, 0], [0], 484 # [0], [0, 1], [0], 485 # [0]]) , [0, 0]]) , [1]]) ) 486 # 487 # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) 488 # Precondition: tensors always has at least one element. 489 # 490 # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] 491 # for context behind this function. All the pre-conditions are guarded for 492 # in torch.autograd.functional.jacobian. 493 assert len(tensors) == len(tensor_numels) 494 assert len(tensors) > 0 495 total_numel = sum(tensor_numels) 496 chunks = tuple( 497 tensor.new_zeros(total_numel, tensor_numel) 498 for tensor, tensor_numel in zip(tensors, tensor_numels) 499 ) 500 diag_start_idx = 0 501 for chunk, numel in zip(chunks, tensor_numels): 502 chunk.diagonal(diag_start_idx).fill_(1) 503 diag_start_idx -= numel 504 return chunks 505 506 507def _jacfwd(func, inputs, strict=False, vectorize=False): 508 if strict: 509 raise RuntimeError( 510 "torch.autograd.functional.jacobian: `strict=True` " 511 'and `strategy="forward-mode"` are not supported together (yet). ' 512 "Please either set `strict=False` or " 513 '`strategy="reverse-mode"`.' 514 ) 515 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") 516 output_info = [] 517 518 if vectorize: 519 # See NOTE: [Computing jacobian with vmap and grad for multiple outputs] 520 input_numels = tuple(input.numel() for input in inputs) 521 522 # Step 1: Prepare tangents 523 tangents = _construct_standard_basis_for(inputs, input_numels) 524 525 # Step 2: Compute vmap over computation with dual tensors 526 def jvp(tangents): 527 with fwAD.dual_level(): 528 dual_inputs = tuple( 529 fwAD.make_dual(input, tangent.view_as(input)) 530 for input, tangent in zip(inputs, tangents) 531 ) 532 _is_outputs_tuple, dual_outputs = _as_tuple( 533 func(*dual_inputs), "outputs" 534 ) 535 output_info.append(_is_outputs_tuple) 536 jv = [] 537 primal_outs = [] 538 for dual_out in dual_outputs: 539 primal, tangent = fwAD.unpack_dual(dual_out) 540 primal_outs.append(primal) 541 if tangent is not None: 542 jv.append(tangent) 543 else: 544 jv.append(torch.zeros_like(primal)) 545 output_info.append(primal_outs) 546 return tuple(jv) 547 548 outputs_before_split = _vmap(jvp)(tangents) 549 is_outputs_tuple, outputs = output_info 550 # Step 3: for each of the output tangents, split along dim 0 551 jacobian_input_output = [] 552 for jac_output_i, output_i in zip(outputs_before_split, outputs): 553 jacobian_output_i_output = [] 554 for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs): 555 # We need to transpose the Jacobian because in forward AD, the 556 # batch dimension represents that of the inputs 557 jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( 558 (*output_i.shape, *input_j.shape) 559 ) # noqa: C409 560 561 jacobian_output_i_output.append(jacobian_input_i_output_j) 562 jacobian_input_output.append(jacobian_output_i_output) 563 564 # Omit [Step 4] because everything is already transposed w/ forward AD 565 return _tuple_postprocess( 566 jacobian_input_output, (is_outputs_tuple, is_inputs_tuple) 567 ) 568 else: 569 raise NotImplementedError( 570 "Computing Jacobian using forward-AD or forward-over-reverse Hessian is" 571 "only implemented for `vectorize=True`." 572 ) 573 574 575def jacobian( 576 func, 577 inputs, 578 create_graph=False, 579 strict=False, 580 vectorize=False, 581 strategy="reverse-mode", 582): 583 r"""Compute the Jacobian of a given function. 584 585 Args: 586 func (function): a Python function that takes Tensor inputs and returns 587 a tuple of Tensors or a Tensor. 588 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 589 create_graph (bool, optional): If ``True``, the Jacobian will be 590 computed in a differentiable manner. Note that when ``strict`` is 591 ``False``, the result can not require gradients or be disconnected 592 from the inputs. Defaults to ``False``. 593 strict (bool, optional): If ``True``, an error will be raised when we 594 detect that there exists an input such that all the outputs are 595 independent of it. If ``False``, we return a Tensor of zeros as the 596 jacobian for said inputs, which is the expected mathematical value. 597 Defaults to ``False``. 598 vectorize (bool, optional): This feature is experimental. 599 Please consider using :func:`torch.func.jacrev` or 600 :func:`torch.func.jacfwd` instead if you are looking for something 601 less experimental and more performant. 602 When computing the jacobian, usually we invoke 603 ``autograd.grad`` once per row of the jacobian. If this flag is 604 ``True``, we perform only a single ``autograd.grad`` call with 605 ``batched_grad=True`` which uses the vmap prototype feature. 606 Though this should lead to performance improvements in many cases, 607 because this feature is still experimental, there may be performance 608 cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for 609 more information. 610 strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to 611 determine whether the Jacobian will be computed with forward or reverse 612 mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. 613 Defaults to ``"reverse-mode"``. If ``func`` has more outputs than 614 inputs, ``"forward-mode"`` tends to be more performant. Otherwise, 615 prefer to use ``"reverse-mode"``. 616 617 Returns: 618 Jacobian (Tensor or nested tuple of Tensors): if there is a single 619 input and output, this will be a single Tensor containing the 620 Jacobian for the linearized inputs and output. If one of the two is 621 a tuple, then the Jacobian will be a tuple of Tensors. If both of 622 them are tuples, then the Jacobian will be a tuple of tuple of 623 Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the 624 ``i``\th output and ``j``\th input and will have as size the 625 concatenation of the sizes of the corresponding output and the 626 corresponding input and will have same dtype and device as the 627 corresponding input. If strategy is ``forward-mode``, the dtype will be 628 that of the output; otherwise, the input. 629 630 Example: 631 632 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 633 >>> def exp_reducer(x): 634 ... return x.exp().sum(dim=1) 635 >>> inputs = torch.rand(2, 2) 636 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 637 >>> jacobian(exp_reducer, inputs) 638 tensor([[[1.4917, 2.4352], 639 [0.0000, 0.0000]], 640 [[0.0000, 0.0000], 641 [2.4369, 2.3799]]]) 642 643 >>> jacobian(exp_reducer, inputs, create_graph=True) 644 tensor([[[1.4917, 2.4352], 645 [0.0000, 0.0000]], 646 [[0.0000, 0.0000], 647 [2.4369, 2.3799]]], grad_fn=<ViewBackward>) 648 649 >>> def exp_adder(x, y): 650 ... return 2 * x.exp() + 3 * y 651 >>> inputs = (torch.rand(2), torch.rand(2)) 652 >>> jacobian(exp_adder, inputs) 653 (tensor([[2.8052, 0.0000], 654 [0.0000, 3.3963]]), 655 tensor([[3., 0.], 656 [0., 3.]])) 657 """ 658 assert strategy in ("forward-mode", "reverse-mode"), ( 659 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' 660 'function has more outputs than inputs, "forward-mode" tends to be more performant. ' 661 'Otherwise, prefer to use "reverse-mode".' 662 ) 663 if strategy == "forward-mode": 664 if create_graph: 665 raise NotImplementedError( 666 "torch.autograd.functional.jacobian: `create_graph=True` " 667 'and `strategy="forward-mode"` are not supported together (yet). ' 668 "Please either set `create_graph=False` or " 669 '`strategy="reverse-mode"`.' 670 ) 671 return _jacfwd(func, inputs, strict, vectorize) 672 673 with torch.enable_grad(): 674 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") 675 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 676 677 outputs = func(*inputs) 678 is_outputs_tuple, outputs = _as_tuple( 679 outputs, "outputs of the user-provided function", "jacobian" 680 ) 681 _check_requires_grad(outputs, "outputs", strict=strict) 682 683 if vectorize: 684 if strict: 685 raise RuntimeError( 686 "torch.autograd.functional.jacobian: `strict=True` " 687 "and `vectorized=True` are not supported together. " 688 "Please either set `strict=False` or " 689 "`vectorize=False`." 690 ) 691 # NOTE: [Computing jacobian with vmap and grad for multiple outputs] 692 # 693 # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). 694 # It turns out we can compute the jacobian of this function with a single 695 # call to autograd.grad by using vmap over the correct grad_outputs. 696 # 697 # Firstly, one way to compute the jacobian is to stack x**2 and x.sum() 698 # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) 699 # 700 # To get the first row of the jacobian, we call 701 # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) 702 # To get the 2nd row of the jacobian, we call 703 # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) 704 # and so on. 705 # 706 # Using vmap, we can vectorize all 4 of these computations into one by 707 # passing the standard basis for R^4 as the grad_output. 708 # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). 709 # 710 # Now, how do we compute the jacobian *without stacking the output*? 711 # We can just split the standard basis across the outputs. So to 712 # compute the jacobian of f(x), we'd use 713 # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) 714 # The grad_outputs looks like the following: 715 # ( torch.tensor([[1, 0, 0], 716 # [0, 1, 0], 717 # [0, 0, 1], 718 # [0, 0, 0]]), 719 # torch.tensor([[0], 720 # [0], 721 # [0], 722 # [1]]) ) 723 # 724 # But we're not done yet! 725 # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) 726 # returns a Tensor of shape [4, 3]. We have to remember to split the 727 # jacobian of shape [4, 3] into two: 728 # - one of shape [3, 3] for the first output 729 # - one of shape [ 3] for the second output 730 731 # Step 1: Construct grad_outputs by splitting the standard basis 732 output_numels = tuple(output.numel() for output in outputs) 733 grad_outputs = _construct_standard_basis_for(outputs, output_numels) 734 flat_outputs = tuple(output.reshape(-1) for output in outputs) 735 736 # Step 2: Call vmap + autograd.grad 737 def vjp(grad_output): 738 vj = list( 739 _autograd_grad( 740 flat_outputs, 741 inputs, 742 grad_output, 743 create_graph=create_graph, 744 is_grads_batched=True, 745 ) 746 ) 747 for el_idx, vj_el in enumerate(vj): 748 if vj_el is not None: 749 continue 750 vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand( 751 (sum(output_numels),) + inputs[el_idx].shape 752 ) 753 return tuple(vj) 754 755 jacobians_of_flat_output = vjp(grad_outputs) 756 757 # Step 3: The returned jacobian is one big tensor per input. In this step, 758 # we split each Tensor by output. 759 jacobian_input_output = [] 760 for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs): 761 jacobian_input_i_output = [] 762 for jac, output_j in zip( 763 jac_input_i.split(output_numels, dim=0), outputs 764 ): 765 jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) 766 jacobian_input_i_output.append(jacobian_input_i_output_j) 767 jacobian_input_output.append(jacobian_input_i_output) 768 769 # Step 4: Right now, `jacobian` is a List[List[Tensor]]. 770 # The outer List corresponds to the number of inputs, 771 # the inner List corresponds to the number of outputs. 772 # We need to exchange the order of these and convert to tuples 773 # before returning. 774 jacobian_output_input = tuple(zip(*jacobian_input_output)) 775 776 jacobian_output_input = _grad_postprocess( 777 jacobian_output_input, create_graph 778 ) 779 return _tuple_postprocess( 780 jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) 781 ) 782 783 jacobian: Tuple[torch.Tensor, ...] = () 784 785 for i, out in enumerate(outputs): 786 # mypy complains that expression and variable have different types due to the empty list 787 jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] 788 for j in range(out.nelement()): 789 vj = _autograd_grad( 790 (out.reshape(-1)[j],), 791 inputs, 792 retain_graph=True, 793 create_graph=create_graph, 794 ) 795 796 for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( 797 zip(jac_i, vj, inputs) 798 ): 799 if vj_el is not None: 800 if strict and create_graph and not vj_el.requires_grad: 801 msg = ( 802 "The jacobian of the user-provided function is " 803 f"independent of input {i}. This is not allowed in " 804 "strict mode when create_graph=True." 805 ) 806 raise RuntimeError(msg) 807 jac_i_el.append(vj_el) 808 else: 809 if strict: 810 msg = ( 811 f"Output {i} of the user-provided function is " 812 f"independent of input {el_idx}. This is not allowed in " 813 "strict mode." 814 ) 815 raise RuntimeError(msg) 816 jac_i_el.append(torch.zeros_like(inp_el)) 817 818 jacobian += ( 819 tuple( 820 torch.stack(jac_i_el, dim=0).view( 821 out.size() + inputs[el_idx].size() # type: ignore[operator] 822 ) 823 for (el_idx, jac_i_el) in enumerate(jac_i) 824 ), 825 ) 826 827 jacobian = _grad_postprocess(jacobian, create_graph) 828 829 return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) 830 831 832def hessian( 833 func, 834 inputs, 835 create_graph=False, 836 strict=False, 837 vectorize=False, 838 outer_jacobian_strategy="reverse-mode", 839): 840 r"""Compute the Hessian of a given scalar function. 841 842 Args: 843 func (function): a Python function that takes Tensor inputs and returns 844 a Tensor with a single element. 845 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 846 create_graph (bool, optional): If ``True``, the Hessian will be computed in 847 a differentiable manner. Note that when ``strict`` is ``False``, the result can not 848 require gradients or be disconnected from the inputs. 849 Defaults to ``False``. 850 strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input 851 such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the 852 hessian for said inputs, which is the expected mathematical value. 853 Defaults to ``False``. 854 vectorize (bool, optional): This feature is experimental. 855 Please consider using :func:`torch.func.hessian` 856 instead if you are looking for something less experimental and more performant. 857 When computing the hessian, usually we invoke 858 ``autograd.grad`` once per row of the hessian. If this flag is 859 ``True``, we use the vmap prototype feature as the backend to 860 vectorize calls to ``autograd.grad`` so we only invoke it once 861 instead of once per row. This should lead to performance 862 improvements in many use cases, however, due to this feature 863 being incomplete, there may be performance cliffs. Please 864 use `torch._C._debug_only_display_vmap_fallback_warnings(True)` 865 to show any performance warnings and file us issues if 866 warnings exist for your use case. Defaults to ``False``. 867 outer_jacobian_strategy (str, optional): The Hessian is computed by 868 computing the Jacobian of a Jacobian. The inner Jacobian is always 869 computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` 870 or ``"reverse-mode"`` determines whether the outer Jacobian will be 871 computed with forward or reverse mode AD. Currently, computing the outer 872 Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults 873 to ``"reverse-mode"``. 874 875 Returns: 876 Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, 877 this will be a single Tensor containing the Hessian for the input. 878 If it is a tuple, then the Hessian will be a tuple of tuples where 879 ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input 880 and ``j``\th input with size the sum of the size of the ``i``\th input plus 881 the size of the ``j``\th input. ``Hessian[i][j]`` will have the same 882 dtype and device as the corresponding ``i``\th input. 883 884 Example: 885 886 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 887 >>> def pow_reducer(x): 888 ... return x.pow(3).sum() 889 >>> inputs = torch.rand(2, 2) 890 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 891 >>> hessian(pow_reducer, inputs) 892 tensor([[[[5.2265, 0.0000], 893 [0.0000, 0.0000]], 894 [[0.0000, 4.8221], 895 [0.0000, 0.0000]]], 896 [[[0.0000, 0.0000], 897 [1.9456, 0.0000]], 898 [[0.0000, 0.0000], 899 [0.0000, 3.2550]]]]) 900 901 >>> hessian(pow_reducer, inputs, create_graph=True) 902 tensor([[[[5.2265, 0.0000], 903 [0.0000, 0.0000]], 904 [[0.0000, 4.8221], 905 [0.0000, 0.0000]]], 906 [[[0.0000, 0.0000], 907 [1.9456, 0.0000]], 908 [[0.0000, 0.0000], 909 [0.0000, 3.2550]]]], grad_fn=<ViewBackward>) 910 911 912 >>> def pow_adder_reducer(x, y): 913 ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 914 >>> inputs = (torch.rand(2), torch.rand(2)) 915 >>> hessian(pow_adder_reducer, inputs) 916 ((tensor([[4., 0.], 917 [0., 4.]]), 918 tensor([[0., 0.], 919 [0., 0.]])), 920 (tensor([[0., 0.], 921 [0., 0.]]), 922 tensor([[6., 0.], 923 [0., 6.]]))) 924 """ 925 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") 926 assert outer_jacobian_strategy in ( 927 "forward-mode", 928 "reverse-mode", 929 ), 'Expected strategy to be either "forward-mode" or "reverse-mode".' 930 931 def ensure_single_output_function(*inp): 932 out = func(*inp) 933 is_out_tuple, t_out = _as_tuple( 934 out, "outputs of the user-provided function", "hessian" 935 ) 936 _check_requires_grad(t_out, "outputs", strict=strict) 937 938 if is_out_tuple or not isinstance(out, torch.Tensor): 939 raise RuntimeError( 940 "The function given to hessian should return a single Tensor" 941 ) 942 943 if out.nelement() != 1: 944 raise RuntimeError( 945 "The Tensor returned by the function given to hessian should contain a single element" 946 ) 947 948 return out.squeeze() 949 950 def jac_func(*inp): 951 if outer_jacobian_strategy == "forward-mode": 952 # _grad_preprocess requires create_graph=True and input to require_grad 953 # or else the input will be detached 954 inp = tuple(t.requires_grad_(True) for t in inp) 955 jac = jacobian(ensure_single_output_function, inp, create_graph=True) 956 _check_requires_grad(jac, "jacobian", strict=strict) 957 return jac 958 959 res = jacobian( 960 jac_func, 961 inputs, 962 create_graph=create_graph, 963 strict=strict, 964 vectorize=vectorize, 965 strategy=outer_jacobian_strategy, 966 ) 967 return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) 968 969 970def vhp(func, inputs, v=None, create_graph=False, strict=False): 971 r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point. 972 973 Args: 974 func (function): a Python function that takes Tensor inputs and returns 975 a Tensor with a single element. 976 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 977 v (tuple of Tensors or Tensor): The vector for which the vector Hessian 978 product is computed. Must be the same size as the input of 979 ``func``. This argument is optional when ``func``'s input contains 980 a single element and (if it is not provided) will be set as a 981 Tensor containing a single ``1``. 982 create_graph (bool, optional): If ``True``, both the output and result 983 will be computed in a differentiable way. Note that when ``strict`` 984 is ``False``, the result can not require gradients or be 985 disconnected from the inputs. 986 Defaults to ``False``. 987 strict (bool, optional): If ``True``, an error will be raised when we 988 detect that there exists an input such that all the outputs are 989 independent of it. If ``False``, we return a Tensor of zeros as the 990 vhp for said inputs, which is the expected mathematical value. 991 Defaults to ``False``. 992 993 Returns: 994 output (tuple): tuple with: 995 func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 996 997 vhp (tuple of Tensors or Tensor): result of the dot product with the 998 same shape as the inputs. 999 1000 Example: 1001 1002 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 1003 >>> def pow_reducer(x): 1004 ... return x.pow(3).sum() 1005 >>> inputs = torch.rand(2, 2) 1006 >>> v = torch.ones(2, 2) 1007 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 1008 >>> vhp(pow_reducer, inputs, v) 1009 (tensor(0.5591), 1010 tensor([[1.0689, 1.2431], 1011 [3.0989, 4.4456]])) 1012 >>> vhp(pow_reducer, inputs, v, create_graph=True) 1013 (tensor(0.5591, grad_fn=<SumBackward0>), 1014 tensor([[1.0689, 1.2431], 1015 [3.0989, 4.4456]], grad_fn=<MulBackward0>)) 1016 >>> def pow_adder_reducer(x, y): 1017 ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 1018 >>> inputs = (torch.rand(2), torch.rand(2)) 1019 >>> v = (torch.zeros(2), torch.ones(2)) 1020 >>> vhp(pow_adder_reducer, inputs, v) 1021 (tensor(4.8053), 1022 (tensor([0., 0.]), 1023 tensor([6., 6.]))) 1024 """ 1025 with torch.enable_grad(): 1026 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") 1027 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 1028 1029 if v is not None: 1030 _, v = _as_tuple(v, "v", "vhp") 1031 v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 1032 _validate_v(v, inputs, is_inputs_tuple) 1033 else: 1034 if len(inputs) != 1 or inputs[0].nelement() != 1: 1035 raise RuntimeError( 1036 "The vector v can only be None if the input to the user-provided function " 1037 "is a single Tensor with a single element." 1038 ) 1039 outputs = func(*inputs) 1040 is_outputs_tuple, outputs = _as_tuple( 1041 outputs, "outputs of the user-provided function", "vhp" 1042 ) 1043 _check_requires_grad(outputs, "outputs", strict=strict) 1044 1045 if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): 1046 raise RuntimeError( 1047 "The function given to vhp should return a single Tensor" 1048 ) 1049 1050 if outputs[0].nelement() != 1: 1051 raise RuntimeError( 1052 "The Tensor returned by the function given to vhp should contain a single element" 1053 ) 1054 1055 jac = _autograd_grad(outputs, inputs, create_graph=True) 1056 _check_requires_grad(jac, "jacobian", strict=strict) 1057 1058 enable_grad = True if create_graph else torch.is_grad_enabled() 1059 with torch.set_grad_enabled(enable_grad): 1060 grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) 1061 vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") 1062 1063 outputs = _grad_postprocess(outputs, create_graph) 1064 vhp = _grad_postprocess(vhp, create_graph) 1065 1066 return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 1067 vhp, is_inputs_tuple 1068 ) 1069 1070 1071def hvp(func, inputs, v=None, create_graph=False, strict=False): 1072 r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point. 1073 1074 Args: 1075 func (function): a Python function that takes Tensor inputs and returns 1076 a Tensor with a single element. 1077 inputs (tuple of Tensors or Tensor): inputs to the function ``func``. 1078 v (tuple of Tensors or Tensor): The vector for which the Hessian vector 1079 product is computed. Must be the same size as the input of 1080 ``func``. This argument is optional when ``func``'s input contains 1081 a single element and (if it is not provided) will be set as a 1082 Tensor containing a single ``1``. 1083 create_graph (bool, optional): If ``True``, both the output and result will be 1084 computed in a differentiable way. Note that when ``strict`` is 1085 ``False``, the result can not require gradients or be disconnected 1086 from the inputs. Defaults to ``False``. 1087 strict (bool, optional): If ``True``, an error will be raised when we 1088 detect that there exists an input such that all the outputs are 1089 independent of it. If ``False``, we return a Tensor of zeros as the 1090 hvp for said inputs, which is the expected mathematical value. 1091 Defaults to ``False``. 1092 Returns: 1093 output (tuple): tuple with: 1094 func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` 1095 1096 hvp (tuple of Tensors or Tensor): result of the dot product with 1097 the same shape as the inputs. 1098 1099 Example: 1100 1101 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 1102 >>> def pow_reducer(x): 1103 ... return x.pow(3).sum() 1104 >>> inputs = torch.rand(2, 2) 1105 >>> v = torch.ones(2, 2) 1106 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 1107 >>> hvp(pow_reducer, inputs, v) 1108 (tensor(0.1448), 1109 tensor([[2.0239, 1.6456], 1110 [2.4988, 1.4310]])) 1111 1112 >>> hvp(pow_reducer, inputs, v, create_graph=True) 1113 (tensor(0.1448, grad_fn=<SumBackward0>), 1114 tensor([[2.0239, 1.6456], 1115 [2.4988, 1.4310]], grad_fn=<MulBackward0>)) 1116 1117 1118 >>> def pow_adder_reducer(x, y): 1119 ... return (2 * x.pow(2) + 3 * y.pow(2)).sum() 1120 >>> inputs = (torch.rand(2), torch.rand(2)) 1121 >>> v = (torch.zeros(2), torch.ones(2)) 1122 >>> hvp(pow_adder_reducer, inputs, v) 1123 (tensor(2.3030), 1124 (tensor([0., 0.]), 1125 tensor([6., 6.]))) 1126 1127 Note: 1128 1129 This function is significantly slower than `vhp` due to backward mode AD constraints. 1130 If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you 1131 know that your function satisfies this condition, you should use vhp instead that is 1132 much faster with the current implementation. 1133 1134 """ 1135 with torch.enable_grad(): 1136 is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") 1137 inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) 1138 1139 if v is not None: 1140 _, v = _as_tuple(v, "v", "hvp") 1141 v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) 1142 _validate_v(v, inputs, is_inputs_tuple) 1143 else: 1144 if len(inputs) != 1 or inputs[0].nelement() != 1: 1145 raise RuntimeError( 1146 "The vector v can only be None if the input to the user-provided function " 1147 "is a single Tensor with a single element." 1148 ) 1149 outputs = func(*inputs) 1150 is_outputs_tuple, outputs = _as_tuple( 1151 outputs, "outputs of the user-provided function", "hvp" 1152 ) 1153 _check_requires_grad(outputs, "outputs", strict=strict) 1154 1155 if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): 1156 raise RuntimeError( 1157 "The function given to hvp should return a single Tensor" 1158 ) 1159 1160 if outputs[0].nelement() != 1: 1161 raise RuntimeError( 1162 "The Tensor returned by the function given to hvp should contain a single element" 1163 ) 1164 1165 jac = _autograd_grad(outputs, inputs, create_graph=True) 1166 _check_requires_grad(jac, "jacobian", strict=strict) 1167 1168 grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs) 1169 1170 double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) 1171 _check_requires_grad(jac, "hessian", strict=strict) 1172 1173 enable_grad = True if create_graph else torch.is_grad_enabled() 1174 with torch.set_grad_enabled(enable_grad): 1175 grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) 1176 hvp = _fill_in_zeros( 1177 grad_res, inputs, strict, create_graph, "double_back_trick" 1178 ) 1179 1180 outputs = _grad_postprocess(outputs, create_graph) 1181 hvp = _grad_postprocess(hvp, create_graph) 1182 1183 return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( 1184 hvp, is_inputs_tuple 1185 ) 1186