xref: /aosp_15_r20/external/pytorch/torch/autograd/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Tuple
3
4import torch
5from torch._vmap_internals import _vmap
6
7from . import forward_ad as fwAD
8
9
10__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
11
12# Utility functions
13
14
15def _as_tuple_nocheck(x):
16    if isinstance(x, tuple):
17        return x
18    elif isinstance(x, list):
19        return tuple(x)
20    else:
21        return (x,)
22
23
24def _as_tuple(inp, arg_name=None, fn_name=None):
25    # Ensures that inp is a tuple of Tensors
26    # Returns whether or not the original inp was a tuple and the tupled version of the input
27    if arg_name is None and fn_name is None:
28        return _as_tuple_nocheck(inp)
29
30    is_inp_tuple = True
31    if not isinstance(inp, tuple):
32        inp = (inp,)
33        is_inp_tuple = False
34
35    for i, el in enumerate(inp):
36        if not isinstance(el, torch.Tensor):
37            if is_inp_tuple:
38                raise TypeError(
39                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
40                    f" value at index {i} has type {type(el)}."
41                )
42            else:
43                raise TypeError(
44                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
45                    f" given {arg_name} has type {type(el)}."
46                )
47
48    return is_inp_tuple, inp
49
50
51def _tuple_postprocess(res, to_unpack):
52    # Unpacks a potentially nested tuple of Tensors
53    # to_unpack should be a single boolean or a tuple of two booleans.
54    # It is used to:
55    # - invert _as_tuple when res should match the inp given to _as_tuple
56    # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
57    if isinstance(to_unpack, tuple):
58        assert len(to_unpack) == 2
59        if not to_unpack[1]:
60            res = tuple(el[0] for el in res)
61        if not to_unpack[0]:
62            res = res[0]
63    else:
64        if not to_unpack:
65            res = res[0]
66    return res
67
68
69def _grad_preprocess(inputs, create_graph, need_graph):
70    # Preprocess the inputs to make sure they require gradient
71    # inputs is a tuple of Tensors to preprocess
72    # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
73    # need_graph specifies if we internally want gradients to flow back to the Tensors in res
74    # Note that we *always* create a new Tensor object to be able to see the difference between
75    # inputs given as arguments and the same Tensors automatically captured by the user function.
76    # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
77    res = []
78    for inp in inputs:
79        if create_graph and inp.requires_grad:
80            # Create at least a new Tensor object in a differentiable way
81            if not inp.is_sparse:
82                # Use .view_as() to get a shallow copy
83                res.append(inp.view_as(inp))
84            else:
85                # We cannot use view for sparse Tensors so we clone
86                res.append(inp.clone())
87        else:
88            res.append(inp.detach().requires_grad_(need_graph))
89    return tuple(res)
90
91
92def _grad_postprocess(inputs, create_graph):
93    # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
94    # request it.
95    if isinstance(inputs[0], torch.Tensor):
96        if not create_graph:
97            return tuple(inp.detach() for inp in inputs)
98        else:
99            return inputs
100    else:
101        return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
102
103
104def _validate_v(v, other, is_other_tuple):
105    # This assumes that other is the correct shape, and v should match
106    # Both are assumed to be tuples of Tensors
107    if len(other) != len(v):
108        if is_other_tuple:
109            raise RuntimeError(
110                f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
111            )
112        else:
113            raise RuntimeError("The given v should contain a single Tensor.")
114
115    for idx, (el_v, el_other) in enumerate(zip(v, other)):
116        if el_v.size() != el_other.size():
117            prepend = ""
118            if is_other_tuple:
119                prepend = f"Entry {idx} in "
120            raise RuntimeError(
121                f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
122            )
123
124
125def _check_requires_grad(inputs, input_type, strict):
126    # Used to make all the necessary checks to raise nice errors in strict mode.
127    if not strict:
128        return
129
130    if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
131        raise RuntimeError("Invalid input_type to _check_requires_grad")
132    for i, inp in enumerate(inputs):
133        if inp is None:
134            # This can only be reached for grad_inputs.
135            raise RuntimeError(
136                f"The output of the user-provided function is independent of input {i}."
137                " This is not allowed in strict mode."
138            )
139        if not inp.requires_grad:
140            if input_type == "hessian":
141                raise RuntimeError(
142                    f"The hessian of the user-provided function with respect to input {i}"
143                    " is independent of the input. This is not allowed in strict mode."
144                    " You should ensure that your function is thrice differentiable and that"
145                    " the hessian depends on the inputs."
146                )
147            elif input_type == "jacobian":
148                raise RuntimeError(
149                    "While computing the hessian, found that the jacobian of the user-provided"
150                    f" function with respect to input {i} is independent of the input. This is not"
151                    " allowed in strict mode. You should ensure that your function is twice"
152                    " differentiable and that the jacobian depends on the inputs (this would be"
153                    " violated by a linear function for example)."
154                )
155            elif input_type == "grad_inputs":
156                raise RuntimeError(
157                    f"The gradient with respect to input {i} is independent of the inputs of the"
158                    " user-provided function. This is not allowed in strict mode."
159                )
160            else:
161                raise RuntimeError(
162                    f"Output {i} of the user-provided function does not require gradients."
163                    " The outputs must be computed in a differentiable manner from the input"
164                    " when running in strict mode."
165                )
166
167
168def _autograd_grad(
169    outputs,
170    inputs,
171    grad_outputs=None,
172    create_graph=False,
173    retain_graph=None,
174    is_grads_batched=False,
175):
176    # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
177    # This has the extra constraint that inputs has to be a tuple
178    assert isinstance(outputs, tuple)
179    if grad_outputs is None:
180        grad_outputs = (None,) * len(outputs)
181    assert isinstance(grad_outputs, tuple)
182    assert len(outputs) == len(grad_outputs)
183
184    new_outputs: Tuple[torch.Tensor, ...] = ()
185    new_grad_outputs: Tuple[torch.Tensor, ...] = ()
186    for out, grad_out in zip(outputs, grad_outputs):
187        if out is not None and out.requires_grad:
188            new_outputs += (out,)
189            new_grad_outputs += (grad_out,)
190
191    if len(new_outputs) == 0:
192        # No differentiable output, we don't need to call the autograd engine
193        return (None,) * len(inputs)
194    else:
195        return torch.autograd.grad(
196            new_outputs,
197            inputs,
198            new_grad_outputs,
199            allow_unused=True,
200            create_graph=create_graph,
201            retain_graph=retain_graph,
202            is_grads_batched=is_grads_batched,
203        )
204
205
206def _fill_in_zeros(grads, refs, strict, create_graph, stage):
207    # Used to detect None in the grads and depending on the flags, either replace them
208    # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
209    # strict and create graph allow us to detect when it is appropriate to raise an error
210    # stage gives us information of which backward call we consider to give good error message
211    if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
212        raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
213
214    res: Tuple[torch.Tensor, ...] = ()
215    for i, grads_i in enumerate(grads):
216        if grads_i is None:
217            if strict:
218                if stage == "back":
219                    raise RuntimeError(
220                        "The output of the user-provided function is independent of "
221                        f"input {i}. This is not allowed in strict mode."
222                    )
223                elif stage == "back_trick":
224                    raise RuntimeError(
225                        f"The gradient with respect to the input is independent of entry {i}"
226                        " in the grad_outputs when using the double backward trick to compute"
227                        " forward mode gradients. This is not allowed in strict mode."
228                    )
229                elif stage == "double_back":
230                    raise RuntimeError(
231                        "The jacobian of the user-provided function is independent of "
232                        f"input {i}. This is not allowed in strict mode."
233                    )
234                else:
235                    raise RuntimeError(
236                        "The hessian of the user-provided function is independent of "
237                        f"entry {i} in the grad_jacobian. This is not allowed in strict "
238                        "mode as it prevents from using the double backward trick to "
239                        "replace forward mode AD."
240                    )
241
242            grads_i = torch.zeros_like(refs[i])
243        else:
244            if strict and create_graph and not grads_i.requires_grad:
245                if "double" not in stage:
246                    raise RuntimeError(
247                        "The jacobian of the user-provided function is independent of "
248                        f"input {i}. This is not allowed in strict mode when create_graph=True."
249                    )
250                else:
251                    raise RuntimeError(
252                        "The hessian of the user-provided function is independent of "
253                        f"input {i}. This is not allowed in strict mode when create_graph=True."
254                    )
255
256        res += (grads_i,)
257
258    return res
259
260
261# Public API
262
263
264def vjp(func, inputs, v=None, create_graph=False, strict=False):
265    r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
266
267    Args:
268        func (function): a Python function that takes Tensor inputs and returns
269            a tuple of Tensors or a Tensor.
270        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
271        v (tuple of Tensors or Tensor): The vector for which the vector
272            Jacobian product is computed.  Must be the same size as the output
273            of ``func``. This argument is optional when the output of ``func``
274            contains a single element and (if it is not provided) will be set
275            as a Tensor containing a single ``1``.
276        create_graph (bool, optional): If ``True``, both the output and result
277            will be computed in a differentiable way. Note that when ``strict``
278            is ``False``, the result can not require gradients or be
279            disconnected from the inputs.  Defaults to ``False``.
280        strict (bool, optional): If ``True``, an error will be raised when we
281            detect that there exists an input such that all the outputs are
282            independent of it. If ``False``, we return a Tensor of zeros as the
283            vjp for said inputs, which is the expected mathematical value.
284            Defaults to ``False``.
285
286    Returns:
287        output (tuple): tuple with:
288            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
289
290            vjp (tuple of Tensors or Tensor): result of the dot product with
291            the same shape as the inputs.
292
293    Example:
294
295        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
296        >>> def exp_reducer(x):
297        ...     return x.exp().sum(dim=1)
298        >>> inputs = torch.rand(4, 4)
299        >>> v = torch.ones(4)
300        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
301        >>> vjp(exp_reducer, inputs, v)
302        (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
303         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
304                [2.1288, 1.0652, 1.5483, 2.5035],
305                [2.2046, 1.1292, 1.1432, 1.3059],
306                [1.3225, 1.6652, 1.7753, 2.0152]]))
307
308        >>> vjp(exp_reducer, inputs, v, create_graph=True)
309        (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
310         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
311                [2.1288, 1.0652, 1.5483, 2.5035],
312                [2.2046, 1.1292, 1.1432, 1.3059],
313                [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
314
315        >>> def adder(x, y):
316        ...     return 2 * x + 3 * y
317        >>> inputs = (torch.rand(2), torch.rand(2))
318        >>> v = torch.ones(2)
319        >>> vjp(adder, inputs, v)
320        (tensor([2.4225, 2.3340]),
321         (tensor([2., 2.]), tensor([3., 3.])))
322    """
323    with torch.enable_grad():
324        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
325        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
326
327        outputs = func(*inputs)
328        is_outputs_tuple, outputs = _as_tuple(
329            outputs, "outputs of the user-provided function", "vjp"
330        )
331        _check_requires_grad(outputs, "outputs", strict=strict)
332
333        if v is not None:
334            _, v = _as_tuple(v, "v", "vjp")
335            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
336            _validate_v(v, outputs, is_outputs_tuple)
337        else:
338            if len(outputs) != 1 or outputs[0].nelement() != 1:
339                raise RuntimeError(
340                    "The vector v can only be None if the "
341                    "user-provided function returns "
342                    "a single Tensor with a single element."
343                )
344
345    enable_grad = True if create_graph else torch.is_grad_enabled()
346    with torch.set_grad_enabled(enable_grad):
347        grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
348        vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
349
350    # Cleanup objects and return them to the user
351    outputs = _grad_postprocess(outputs, create_graph)
352    vjp = _grad_postprocess(vjp, create_graph)
353
354    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
355        vjp, is_inputs_tuple
356    )
357
358
359def jvp(func, inputs, v=None, create_graph=False, strict=False):
360    r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
361
362    Args:
363        func (function): a Python function that takes Tensor inputs and returns
364            a tuple of Tensors or a Tensor.
365        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
366        v (tuple of Tensors or Tensor): The vector for which the Jacobian
367            vector product is computed. Must be the same size as the input of
368            ``func``. This argument is optional when the input to ``func``
369            contains a single element and (if it is not provided) will be set
370            as a Tensor containing a single ``1``.
371        create_graph (bool, optional): If ``True``, both the output and result
372            will be computed in a differentiable way. Note that when ``strict``
373            is ``False``, the result can not require gradients or be
374            disconnected from the inputs.  Defaults to ``False``.
375        strict (bool, optional): If ``True``, an error will be raised when we
376            detect that there exists an input such that all the outputs are
377            independent of it. If ``False``, we return a Tensor of zeros as the
378            jvp for said inputs, which is the expected mathematical value.
379            Defaults to ``False``.
380
381    Returns:
382        output (tuple): tuple with:
383            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
384
385            jvp (tuple of Tensors or Tensor): result of the dot product with
386            the same shape as the output.
387
388    Note:
389        ``autograd.functional.jvp`` computes the jvp by using the backward of
390        the backward (sometimes called the double backwards trick). This is not
391        the most performant way of computing the jvp. Please consider using
392        :func:`torch.func.jvp` or the
393        :ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
394
395    Example:
396
397        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
398        >>> def exp_reducer(x):
399        ...     return x.exp().sum(dim=1)
400        >>> inputs = torch.rand(4, 4)
401        >>> v = torch.ones(4, 4)
402        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
403        >>> jvp(exp_reducer, inputs, v)
404        (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
405         tensor([6.3090, 4.6742, 7.9114, 8.2106]))
406
407        >>> jvp(exp_reducer, inputs, v, create_graph=True)
408        (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
409         tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
410
411        >>> def adder(x, y):
412        ...     return 2 * x + 3 * y
413        >>> inputs = (torch.rand(2), torch.rand(2))
414        >>> v = (torch.ones(2), torch.ones(2))
415        >>> jvp(adder, inputs, v)
416        (tensor([2.2399, 2.5005]),
417         tensor([5., 5.]))
418
419    """
420    with torch.enable_grad():
421        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
422        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
423
424        if v is not None:
425            _, v = _as_tuple(v, "v", "jvp")
426            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
427            _validate_v(v, inputs, is_inputs_tuple)
428        else:
429            if len(inputs) != 1 or inputs[0].nelement() != 1:
430                raise RuntimeError(
431                    "The vector v can only be None if the input to "
432                    "the user-provided function is a single Tensor "
433                    "with a single element."
434                )
435
436        outputs = func(*inputs)
437        is_outputs_tuple, outputs = _as_tuple(
438            outputs, "outputs of the user-provided function", "jvp"
439        )
440        _check_requires_grad(outputs, "outputs", strict=strict)
441        # The backward is linear so the value of grad_outputs is not important as
442        # it won't appear in the double backward graph. We only need to ensure that
443        # it does not contain inf or nan.
444        grad_outputs = tuple(
445            torch.zeros_like(out, requires_grad=True) for out in outputs
446        )
447
448        grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
449        _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
450
451    if create_graph:
452        with torch.enable_grad():
453            grad_res = _autograd_grad(
454                grad_inputs, grad_outputs, v, create_graph=create_graph
455            )
456            jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
457    else:
458        grad_res = _autograd_grad(
459            grad_inputs, grad_outputs, v, create_graph=create_graph
460        )
461        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
462
463    # Cleanup objects and return them to the user
464    outputs = _grad_postprocess(outputs, create_graph)
465    jvp = _grad_postprocess(jvp, create_graph)
466
467    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
468        jvp, is_outputs_tuple
469    )
470
471
472def _construct_standard_basis_for(
473    tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
474) -> Tuple[torch.Tensor, ...]:
475    # This function:
476    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
477    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
478    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
479    #   device as the tensor
480    #
481    # For example, with tensor_numels = [1, 2, 1], this function returns:
482    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
483    #           [0],             [1, 0],              [0],
484    #           [0],             [0, 1],              [0],
485    #           [0]])  ,         [0, 0]])  ,          [1]])  )
486    #
487    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
488    # Precondition: tensors always has at least one element.
489    #
490    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
491    # for context behind this function. All the pre-conditions are guarded for
492    # in torch.autograd.functional.jacobian.
493    assert len(tensors) == len(tensor_numels)
494    assert len(tensors) > 0
495    total_numel = sum(tensor_numels)
496    chunks = tuple(
497        tensor.new_zeros(total_numel, tensor_numel)
498        for tensor, tensor_numel in zip(tensors, tensor_numels)
499    )
500    diag_start_idx = 0
501    for chunk, numel in zip(chunks, tensor_numels):
502        chunk.diagonal(diag_start_idx).fill_(1)
503        diag_start_idx -= numel
504    return chunks
505
506
507def _jacfwd(func, inputs, strict=False, vectorize=False):
508    if strict:
509        raise RuntimeError(
510            "torch.autograd.functional.jacobian: `strict=True` "
511            'and `strategy="forward-mode"` are not supported together (yet). '
512            "Please either set `strict=False` or "
513            '`strategy="reverse-mode"`.'
514        )
515    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
516    output_info = []
517
518    if vectorize:
519        # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
520        input_numels = tuple(input.numel() for input in inputs)
521
522        # Step 1: Prepare tangents
523        tangents = _construct_standard_basis_for(inputs, input_numels)
524
525        # Step 2: Compute vmap over computation with dual tensors
526        def jvp(tangents):
527            with fwAD.dual_level():
528                dual_inputs = tuple(
529                    fwAD.make_dual(input, tangent.view_as(input))
530                    for input, tangent in zip(inputs, tangents)
531                )
532                _is_outputs_tuple, dual_outputs = _as_tuple(
533                    func(*dual_inputs), "outputs"
534                )
535                output_info.append(_is_outputs_tuple)
536                jv = []
537                primal_outs = []
538                for dual_out in dual_outputs:
539                    primal, tangent = fwAD.unpack_dual(dual_out)
540                    primal_outs.append(primal)
541                    if tangent is not None:
542                        jv.append(tangent)
543                    else:
544                        jv.append(torch.zeros_like(primal))
545                output_info.append(primal_outs)
546                return tuple(jv)
547
548        outputs_before_split = _vmap(jvp)(tangents)
549        is_outputs_tuple, outputs = output_info
550        # Step 3: for each of the output tangents, split along dim 0
551        jacobian_input_output = []
552        for jac_output_i, output_i in zip(outputs_before_split, outputs):
553            jacobian_output_i_output = []
554            for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
555                # We need to transpose the Jacobian because in forward AD, the
556                # batch dimension represents that of the inputs
557                jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
558                    (*output_i.shape, *input_j.shape)
559                )  # noqa: C409
560
561                jacobian_output_i_output.append(jacobian_input_i_output_j)
562            jacobian_input_output.append(jacobian_output_i_output)
563
564        # Omit [Step 4] because everything is already transposed w/ forward AD
565        return _tuple_postprocess(
566            jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
567        )
568    else:
569        raise NotImplementedError(
570            "Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
571            "only implemented for `vectorize=True`."
572        )
573
574
575def jacobian(
576    func,
577    inputs,
578    create_graph=False,
579    strict=False,
580    vectorize=False,
581    strategy="reverse-mode",
582):
583    r"""Compute the Jacobian of a given function.
584
585    Args:
586        func (function): a Python function that takes Tensor inputs and returns
587            a tuple of Tensors or a Tensor.
588        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
589        create_graph (bool, optional): If ``True``, the Jacobian will be
590            computed in a differentiable manner. Note that when ``strict`` is
591            ``False``, the result can not require gradients or be disconnected
592            from the inputs.  Defaults to ``False``.
593        strict (bool, optional): If ``True``, an error will be raised when we
594            detect that there exists an input such that all the outputs are
595            independent of it. If ``False``, we return a Tensor of zeros as the
596            jacobian for said inputs, which is the expected mathematical value.
597            Defaults to ``False``.
598        vectorize (bool, optional): This feature is experimental.
599            Please consider using :func:`torch.func.jacrev` or
600            :func:`torch.func.jacfwd` instead if you are looking for something
601            less experimental and more performant.
602            When computing the jacobian, usually we invoke
603            ``autograd.grad`` once per row of the jacobian. If this flag is
604            ``True``, we perform only a single ``autograd.grad`` call with
605            ``batched_grad=True`` which uses the vmap prototype feature.
606            Though this should lead to performance improvements in many cases,
607            because this feature is still experimental, there may be performance
608            cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
609            more information.
610        strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
611            determine whether the Jacobian will be computed with forward or reverse
612            mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
613            Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
614            inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
615            prefer to use ``"reverse-mode"``.
616
617    Returns:
618        Jacobian (Tensor or nested tuple of Tensors): if there is a single
619        input and output, this will be a single Tensor containing the
620        Jacobian for the linearized inputs and output. If one of the two is
621        a tuple, then the Jacobian will be a tuple of Tensors. If both of
622        them are tuples, then the Jacobian will be a tuple of tuple of
623        Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
624        ``i``\th output and ``j``\th input and will have as size the
625        concatenation of the sizes of the corresponding output and the
626        corresponding input and will have same dtype and device as the
627        corresponding input. If strategy is ``forward-mode``, the dtype will be
628        that of the output; otherwise, the input.
629
630    Example:
631
632        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
633        >>> def exp_reducer(x):
634        ...     return x.exp().sum(dim=1)
635        >>> inputs = torch.rand(2, 2)
636        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
637        >>> jacobian(exp_reducer, inputs)
638        tensor([[[1.4917, 2.4352],
639                 [0.0000, 0.0000]],
640                [[0.0000, 0.0000],
641                 [2.4369, 2.3799]]])
642
643        >>> jacobian(exp_reducer, inputs, create_graph=True)
644        tensor([[[1.4917, 2.4352],
645                 [0.0000, 0.0000]],
646                [[0.0000, 0.0000],
647                 [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
648
649        >>> def exp_adder(x, y):
650        ...     return 2 * x.exp() + 3 * y
651        >>> inputs = (torch.rand(2), torch.rand(2))
652        >>> jacobian(exp_adder, inputs)
653        (tensor([[2.8052, 0.0000],
654                [0.0000, 3.3963]]),
655         tensor([[3., 0.],
656                 [0., 3.]]))
657    """
658    assert strategy in ("forward-mode", "reverse-mode"), (
659        'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
660        'function has more outputs than inputs, "forward-mode" tends to be more performant. '
661        'Otherwise, prefer to use "reverse-mode".'
662    )
663    if strategy == "forward-mode":
664        if create_graph:
665            raise NotImplementedError(
666                "torch.autograd.functional.jacobian: `create_graph=True` "
667                'and `strategy="forward-mode"` are not supported together (yet). '
668                "Please either set `create_graph=False` or "
669                '`strategy="reverse-mode"`.'
670            )
671        return _jacfwd(func, inputs, strict, vectorize)
672
673    with torch.enable_grad():
674        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
675        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
676
677        outputs = func(*inputs)
678        is_outputs_tuple, outputs = _as_tuple(
679            outputs, "outputs of the user-provided function", "jacobian"
680        )
681        _check_requires_grad(outputs, "outputs", strict=strict)
682
683        if vectorize:
684            if strict:
685                raise RuntimeError(
686                    "torch.autograd.functional.jacobian: `strict=True` "
687                    "and `vectorized=True` are not supported together. "
688                    "Please either set `strict=False` or "
689                    "`vectorize=False`."
690                )
691            # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
692            #
693            # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
694            # It turns out we can compute the jacobian of this function with a single
695            # call to autograd.grad by using vmap over the correct grad_outputs.
696            #
697            # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
698            # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
699            #
700            # To get the first row of the jacobian, we call
701            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
702            # To get the 2nd row of the jacobian, we call
703            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
704            # and so on.
705            #
706            # Using vmap, we can vectorize all 4 of these computations into one by
707            # passing the standard basis for R^4 as the grad_output.
708            # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
709            #
710            # Now, how do we compute the jacobian *without stacking the output*?
711            # We can just split the standard basis across the outputs. So to
712            # compute the jacobian of f(x), we'd use
713            # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
714            # The grad_outputs looks like the following:
715            # ( torch.tensor([[1, 0, 0],
716            #                 [0, 1, 0],
717            #                 [0, 0, 1],
718            #                 [0, 0, 0]]),
719            #   torch.tensor([[0],
720            #                 [0],
721            #                 [0],
722            #                 [1]]) )
723            #
724            # But we're not done yet!
725            # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
726            # returns a Tensor of shape [4, 3]. We have to remember to split the
727            # jacobian of shape [4, 3] into two:
728            # - one of shape [3, 3] for the first output
729            # - one of shape [   3] for the second output
730
731            # Step 1: Construct grad_outputs by splitting the standard basis
732            output_numels = tuple(output.numel() for output in outputs)
733            grad_outputs = _construct_standard_basis_for(outputs, output_numels)
734            flat_outputs = tuple(output.reshape(-1) for output in outputs)
735
736            # Step 2: Call vmap + autograd.grad
737            def vjp(grad_output):
738                vj = list(
739                    _autograd_grad(
740                        flat_outputs,
741                        inputs,
742                        grad_output,
743                        create_graph=create_graph,
744                        is_grads_batched=True,
745                    )
746                )
747                for el_idx, vj_el in enumerate(vj):
748                    if vj_el is not None:
749                        continue
750                    vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
751                        (sum(output_numels),) + inputs[el_idx].shape
752                    )
753                return tuple(vj)
754
755            jacobians_of_flat_output = vjp(grad_outputs)
756
757            # Step 3: The returned jacobian is one big tensor per input. In this step,
758            # we split each Tensor by output.
759            jacobian_input_output = []
760            for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
761                jacobian_input_i_output = []
762                for jac, output_j in zip(
763                    jac_input_i.split(output_numels, dim=0), outputs
764                ):
765                    jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
766                    jacobian_input_i_output.append(jacobian_input_i_output_j)
767                jacobian_input_output.append(jacobian_input_i_output)
768
769            # Step 4: Right now, `jacobian` is a List[List[Tensor]].
770            # The outer List corresponds to the number of inputs,
771            # the inner List corresponds to the number of outputs.
772            # We need to exchange the order of these and convert to tuples
773            # before returning.
774            jacobian_output_input = tuple(zip(*jacobian_input_output))
775
776            jacobian_output_input = _grad_postprocess(
777                jacobian_output_input, create_graph
778            )
779            return _tuple_postprocess(
780                jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
781            )
782
783        jacobian: Tuple[torch.Tensor, ...] = ()
784
785        for i, out in enumerate(outputs):
786            # mypy complains that expression and variable have different types due to the empty list
787            jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs)))  # type: ignore[assignment]
788            for j in range(out.nelement()):
789                vj = _autograd_grad(
790                    (out.reshape(-1)[j],),
791                    inputs,
792                    retain_graph=True,
793                    create_graph=create_graph,
794                )
795
796                for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
797                    zip(jac_i, vj, inputs)
798                ):
799                    if vj_el is not None:
800                        if strict and create_graph and not vj_el.requires_grad:
801                            msg = (
802                                "The jacobian of the user-provided function is "
803                                f"independent of input {i}. This is not allowed in "
804                                "strict mode when create_graph=True."
805                            )
806                            raise RuntimeError(msg)
807                        jac_i_el.append(vj_el)
808                    else:
809                        if strict:
810                            msg = (
811                                f"Output {i} of the user-provided function is "
812                                f"independent of input {el_idx}. This is not allowed in "
813                                "strict mode."
814                            )
815                            raise RuntimeError(msg)
816                        jac_i_el.append(torch.zeros_like(inp_el))
817
818            jacobian += (
819                tuple(
820                    torch.stack(jac_i_el, dim=0).view(
821                        out.size() + inputs[el_idx].size()  # type: ignore[operator]
822                    )
823                    for (el_idx, jac_i_el) in enumerate(jac_i)
824                ),
825            )
826
827        jacobian = _grad_postprocess(jacobian, create_graph)
828
829        return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
830
831
832def hessian(
833    func,
834    inputs,
835    create_graph=False,
836    strict=False,
837    vectorize=False,
838    outer_jacobian_strategy="reverse-mode",
839):
840    r"""Compute the Hessian of a given scalar function.
841
842    Args:
843        func (function): a Python function that takes Tensor inputs and returns
844            a Tensor with a single element.
845        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
846        create_graph (bool, optional): If ``True``, the Hessian will be computed in
847            a differentiable manner. Note that when ``strict`` is ``False``, the result can not
848            require gradients or be disconnected from the inputs.
849            Defaults to ``False``.
850        strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
851            such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
852            hessian for said inputs, which is the expected mathematical value.
853            Defaults to ``False``.
854        vectorize (bool, optional): This feature is experimental.
855            Please consider using :func:`torch.func.hessian`
856            instead if you are looking for something less experimental and more performant.
857            When computing the hessian, usually we invoke
858            ``autograd.grad`` once per row of the hessian. If this flag is
859            ``True``, we use the vmap prototype feature as the backend to
860            vectorize calls to ``autograd.grad`` so we only invoke it once
861            instead of once per row. This should lead to performance
862            improvements in many use cases, however, due to this feature
863            being incomplete, there may be performance cliffs. Please
864            use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
865            to show any performance warnings and file us issues if
866            warnings exist for your use case. Defaults to ``False``.
867        outer_jacobian_strategy (str, optional): The Hessian is computed by
868            computing the Jacobian of a Jacobian. The inner Jacobian is always
869            computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
870            or ``"reverse-mode"`` determines whether the outer Jacobian will be
871            computed with forward or reverse mode AD. Currently, computing the outer
872            Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
873            to ``"reverse-mode"``.
874
875    Returns:
876        Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
877        this will be a single Tensor containing the Hessian for the input.
878        If it is a tuple, then the Hessian will be a tuple of tuples where
879        ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
880        and ``j``\th input with size the sum of the size of the ``i``\th input plus
881        the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
882        dtype and device as the corresponding ``i``\th input.
883
884    Example:
885
886        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
887        >>> def pow_reducer(x):
888        ...     return x.pow(3).sum()
889        >>> inputs = torch.rand(2, 2)
890        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
891        >>> hessian(pow_reducer, inputs)
892        tensor([[[[5.2265, 0.0000],
893                  [0.0000, 0.0000]],
894                 [[0.0000, 4.8221],
895                  [0.0000, 0.0000]]],
896                [[[0.0000, 0.0000],
897                  [1.9456, 0.0000]],
898                 [[0.0000, 0.0000],
899                  [0.0000, 3.2550]]]])
900
901        >>> hessian(pow_reducer, inputs, create_graph=True)
902        tensor([[[[5.2265, 0.0000],
903                  [0.0000, 0.0000]],
904                 [[0.0000, 4.8221],
905                  [0.0000, 0.0000]]],
906                [[[0.0000, 0.0000],
907                  [1.9456, 0.0000]],
908                 [[0.0000, 0.0000],
909                  [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
910
911
912        >>> def pow_adder_reducer(x, y):
913        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
914        >>> inputs = (torch.rand(2), torch.rand(2))
915        >>> hessian(pow_adder_reducer, inputs)
916        ((tensor([[4., 0.],
917                  [0., 4.]]),
918          tensor([[0., 0.],
919                  [0., 0.]])),
920         (tensor([[0., 0.],
921                  [0., 0.]]),
922          tensor([[6., 0.],
923                  [0., 6.]])))
924    """
925    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
926    assert outer_jacobian_strategy in (
927        "forward-mode",
928        "reverse-mode",
929    ), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
930
931    def ensure_single_output_function(*inp):
932        out = func(*inp)
933        is_out_tuple, t_out = _as_tuple(
934            out, "outputs of the user-provided function", "hessian"
935        )
936        _check_requires_grad(t_out, "outputs", strict=strict)
937
938        if is_out_tuple or not isinstance(out, torch.Tensor):
939            raise RuntimeError(
940                "The function given to hessian should return a single Tensor"
941            )
942
943        if out.nelement() != 1:
944            raise RuntimeError(
945                "The Tensor returned by the function given to hessian should contain a single element"
946            )
947
948        return out.squeeze()
949
950    def jac_func(*inp):
951        if outer_jacobian_strategy == "forward-mode":
952            # _grad_preprocess requires create_graph=True and input to require_grad
953            # or else the input will be detached
954            inp = tuple(t.requires_grad_(True) for t in inp)
955        jac = jacobian(ensure_single_output_function, inp, create_graph=True)
956        _check_requires_grad(jac, "jacobian", strict=strict)
957        return jac
958
959    res = jacobian(
960        jac_func,
961        inputs,
962        create_graph=create_graph,
963        strict=strict,
964        vectorize=vectorize,
965        strategy=outer_jacobian_strategy,
966    )
967    return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
968
969
970def vhp(func, inputs, v=None, create_graph=False, strict=False):
971    r"""Compute the dot product between vector ``v`` and Hessian of a  given scalar function at a specified point.
972
973    Args:
974        func (function): a Python function that takes Tensor inputs and returns
975            a Tensor with a single element.
976        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
977        v (tuple of Tensors or Tensor): The vector for which the vector Hessian
978            product is computed. Must be the same size as the input of
979            ``func``. This argument is optional when ``func``'s input contains
980            a single element and (if it is not provided) will be set as a
981            Tensor containing a single ``1``.
982        create_graph (bool, optional): If ``True``, both the output and result
983            will be computed in a differentiable way. Note that when ``strict``
984            is ``False``, the result can not require gradients or be
985            disconnected from the inputs.
986            Defaults to ``False``.
987        strict (bool, optional): If ``True``, an error will be raised when we
988            detect that there exists an input such that all the outputs are
989            independent of it. If ``False``, we return a Tensor of zeros as the
990            vhp for said inputs, which is the expected mathematical value.
991            Defaults to ``False``.
992
993    Returns:
994        output (tuple): tuple with:
995            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
996
997            vhp (tuple of Tensors or Tensor): result of the dot product with the
998            same shape as the inputs.
999
1000    Example:
1001
1002        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1003        >>> def pow_reducer(x):
1004        ...     return x.pow(3).sum()
1005        >>> inputs = torch.rand(2, 2)
1006        >>> v = torch.ones(2, 2)
1007        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1008        >>> vhp(pow_reducer, inputs, v)
1009        (tensor(0.5591),
1010         tensor([[1.0689, 1.2431],
1011                 [3.0989, 4.4456]]))
1012        >>> vhp(pow_reducer, inputs, v, create_graph=True)
1013        (tensor(0.5591, grad_fn=<SumBackward0>),
1014         tensor([[1.0689, 1.2431],
1015                 [3.0989, 4.4456]], grad_fn=<MulBackward0>))
1016        >>> def pow_adder_reducer(x, y):
1017        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1018        >>> inputs = (torch.rand(2), torch.rand(2))
1019        >>> v = (torch.zeros(2), torch.ones(2))
1020        >>> vhp(pow_adder_reducer, inputs, v)
1021        (tensor(4.8053),
1022         (tensor([0., 0.]),
1023          tensor([6., 6.])))
1024    """
1025    with torch.enable_grad():
1026        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
1027        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1028
1029        if v is not None:
1030            _, v = _as_tuple(v, "v", "vhp")
1031            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1032            _validate_v(v, inputs, is_inputs_tuple)
1033        else:
1034            if len(inputs) != 1 or inputs[0].nelement() != 1:
1035                raise RuntimeError(
1036                    "The vector v can only be None if the input to the user-provided function "
1037                    "is a single Tensor with a single element."
1038                )
1039        outputs = func(*inputs)
1040        is_outputs_tuple, outputs = _as_tuple(
1041            outputs, "outputs of the user-provided function", "vhp"
1042        )
1043        _check_requires_grad(outputs, "outputs", strict=strict)
1044
1045        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1046            raise RuntimeError(
1047                "The function given to vhp should return a single Tensor"
1048            )
1049
1050        if outputs[0].nelement() != 1:
1051            raise RuntimeError(
1052                "The Tensor returned by the function given to vhp should contain a single element"
1053            )
1054
1055        jac = _autograd_grad(outputs, inputs, create_graph=True)
1056        _check_requires_grad(jac, "jacobian", strict=strict)
1057
1058    enable_grad = True if create_graph else torch.is_grad_enabled()
1059    with torch.set_grad_enabled(enable_grad):
1060        grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
1061        vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
1062
1063    outputs = _grad_postprocess(outputs, create_graph)
1064    vhp = _grad_postprocess(vhp, create_graph)
1065
1066    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1067        vhp, is_inputs_tuple
1068    )
1069
1070
1071def hvp(func, inputs, v=None, create_graph=False, strict=False):
1072    r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
1073
1074    Args:
1075        func (function): a Python function that takes Tensor inputs and returns
1076            a Tensor with a single element.
1077        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
1078        v (tuple of Tensors or Tensor): The vector for which the Hessian vector
1079            product is computed. Must be the same size as the input of
1080            ``func``. This argument is optional when ``func``'s input contains
1081            a single element and (if it is not provided) will be set as a
1082            Tensor containing a single ``1``.
1083        create_graph (bool, optional): If ``True``, both the output and result will be
1084            computed in a differentiable way. Note that when ``strict`` is
1085            ``False``, the result can not require gradients or be disconnected
1086            from the inputs.  Defaults to ``False``.
1087        strict (bool, optional): If ``True``, an error will be raised when we
1088            detect that there exists an input such that all the outputs are
1089            independent of it. If ``False``, we return a Tensor of zeros as the
1090            hvp for said inputs, which is the expected mathematical value.
1091            Defaults to ``False``.
1092    Returns:
1093        output (tuple): tuple with:
1094            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
1095
1096            hvp (tuple of Tensors or Tensor): result of the dot product with
1097            the same shape as the inputs.
1098
1099    Example:
1100
1101        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1102        >>> def pow_reducer(x):
1103        ...     return x.pow(3).sum()
1104        >>> inputs = torch.rand(2, 2)
1105        >>> v = torch.ones(2, 2)
1106        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1107        >>> hvp(pow_reducer, inputs, v)
1108        (tensor(0.1448),
1109         tensor([[2.0239, 1.6456],
1110                 [2.4988, 1.4310]]))
1111
1112        >>> hvp(pow_reducer, inputs, v, create_graph=True)
1113        (tensor(0.1448, grad_fn=<SumBackward0>),
1114         tensor([[2.0239, 1.6456],
1115                 [2.4988, 1.4310]], grad_fn=<MulBackward0>))
1116
1117
1118        >>> def pow_adder_reducer(x, y):
1119        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1120        >>> inputs = (torch.rand(2), torch.rand(2))
1121        >>> v = (torch.zeros(2), torch.ones(2))
1122        >>> hvp(pow_adder_reducer, inputs, v)
1123        (tensor(2.3030),
1124         (tensor([0., 0.]),
1125          tensor([6., 6.])))
1126
1127    Note:
1128
1129        This function is significantly slower than `vhp` due to backward mode AD constraints.
1130        If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
1131        know that your function satisfies this condition, you should use vhp instead that is
1132        much faster with the current implementation.
1133
1134    """
1135    with torch.enable_grad():
1136        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
1137        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1138
1139        if v is not None:
1140            _, v = _as_tuple(v, "v", "hvp")
1141            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1142            _validate_v(v, inputs, is_inputs_tuple)
1143        else:
1144            if len(inputs) != 1 or inputs[0].nelement() != 1:
1145                raise RuntimeError(
1146                    "The vector v can only be None if the input to the user-provided function "
1147                    "is a single Tensor with a single element."
1148                )
1149        outputs = func(*inputs)
1150        is_outputs_tuple, outputs = _as_tuple(
1151            outputs, "outputs of the user-provided function", "hvp"
1152        )
1153        _check_requires_grad(outputs, "outputs", strict=strict)
1154
1155        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1156            raise RuntimeError(
1157                "The function given to hvp should return a single Tensor"
1158            )
1159
1160        if outputs[0].nelement() != 1:
1161            raise RuntimeError(
1162                "The Tensor returned by the function given to hvp should contain a single element"
1163            )
1164
1165        jac = _autograd_grad(outputs, inputs, create_graph=True)
1166        _check_requires_grad(jac, "jacobian", strict=strict)
1167
1168        grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
1169
1170        double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
1171        _check_requires_grad(jac, "hessian", strict=strict)
1172
1173    enable_grad = True if create_graph else torch.is_grad_enabled()
1174    with torch.set_grad_enabled(enable_grad):
1175        grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
1176        hvp = _fill_in_zeros(
1177            grad_res, inputs, strict, create_graph, "double_back_trick"
1178        )
1179
1180    outputs = _grad_postprocess(outputs, create_graph)
1181    hvp = _grad_postprocess(hvp, create_graph)
1182
1183    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1184        hvp, is_inputs_tuple
1185    )
1186